大概跑4.5s左右,有一定概率会更长然后被判TLE
参考 Efficient and easy segment trees, By Al.Cash
注意网页上懒标记range increment and sum queries的 calc 和 apply 是有问题的,参考P___的回复
更新精简后的代码
def get_h(n):
i = 0
while n >> i: i += 1
return i
class Node:
sum_v = add = 0
def pushup(u, val):
t[u].sum_v += val
if u < n: t[u].add += val
def build(l, r): # init build, bottom up
l += n; r += n - 1
while l > 1:
l >>= 1; r >>= 1
for u in range(l, r + 1):
t[u].sum_v = t[u << 1].sum_v + t[u << 1 | 1].sum_v
def build_u(u):
u += n
while u > 1:
u >>= 1
t[u].sum_v = t[u << 1].sum_v + t[u << 1 | 1].sum_v + t[u].add
def pushdown(u):
u += n
for s in range(h - 1, 0, -1):
i = u >> s
if t[i].add:
pushup(i << 1, t[i].add >> 1)
pushup(i << 1 | 1, t[i].add >> 1)
t[i].add = 0
def modify_range(l, r, val):
if val == 0: return
# pushdown(l); pushdown(r - 1)
l0, r0 = l, r
l += n; r += n
while l < r:
if l & 1:
pushup(l, val)
l += 1
if r & 1:
r -= 1
pushup(r, val)
l >>= 1; r >>= 1; val <<= 1
build_u(l0); build_u(r0-1)
# pushdown(l0); pushdown(r0-1)
def query(l, r):
pushdown(l)
pushdown(r - 1)
res = 0
l += n; r += n
while l < r:
if l & 1:
res += t[l].sum_v
l += 1
if r & 1:
r -= 1
res += t[r].sum_v
l >>= 1; r >>= 1
return res
n, m = map(int, input().split())
w = list(map(int, input().split()))
h = get_h(n)
t = [Node() for _ in range(n << 1)]
for i in range(n):
t[n + i].sum_v = w[i]
build(0, n)
res = []
while m:
op, opt = input().split(maxsplit=1)
if op == 'Q':
l, r = map(int, opt.split())
res.append(str(query(l - 1, r)))
# print(query(l - 1, r))
else:
l, r, val = map(int, opt.split())
modify_range(l - 1, r, val)
m -= 1
print("\n".join(res))
def get_h(n):
if n == 0: return 1
i = 0
while i <= 128:
if n < (1 << i): return i
i += 1
class Node:
sum_v = 0
add = 0
def __repr__(self):
return str(self.sum_v) + ":" + str(self.add)
def calc(u, k):
# if t[u].add == 0: t[u].sum_v = t[u << 1].sum_v + t[u << 1 | 1].sum_v
# else: t[u].sum_v = t[u].add * k
t[u].sum_v = t[u << 1].sum_v + t[u << 1 | 1].sum_v + t[u].add * k
def add_value(u, val, k):
# print("add_value:", u, val, k)
# t[u].sum_v = val * k
t[u].sum_v += val * k
if u < n: t[u].add += val
def build(l, r):
k = 2
l += n; r += n - 1
while l > 1:
l >>= 1; r >>= 1
for i in range(r, l - 1, -1):
calc(i, k)
k <<= 1
def pushdown(l, r):
s = h
k = 1 << (h - 1)
# print("pushdown:", l, r, h, k)
l += n; r += n - 1
while s > 0:
# print("s:", s, l >> s, r >> s, k)
for i in range(l >> s, (r >> s) + 1):
if t[i].add:
add_value(i << 1, t[i].add, k)
add_value(i << 1 | 1, t[i].add, k)
t[i].add = 0
s -= 1
k >>= 1
def modify_range(l, r, val):
if val == 0: return
pushdown(l, l + 1)
pushdown(r - 1, r)
l0, r0 = l, r
k = 1
l += n; r += n
while l < r:
# print("l-r-k:", l, r, k)
if l & 1:
add_value(l, val, k)
l += 1
if r & 1:
r -= 1
add_value(r, val, k)
l >>= 1; r >>= 1; k <<= 1
build(l0, l0 + 1); build(r0 - 1, r0)
# print("modify:", l0, r0, t)
def query(l, r):
ll, rr = l, r
pushdown(l, l + 1)
# print("pushdown-1:", t)
pushdown(r - 1, r)
# print("pushdown-2:", t)
res = 0
l += n; r += n
while l < r:
if l & 1:
res += t[l].sum_v
l += 1
if r & 1:
r -= 1
res += t[r].sum_v
l >>= 1; r >>= 1
# print("query:", ll, rr)
return res
n, m = map(int, input().split())
w = list(map(int, input().split()))
h = get_h(n)
# print("h:", h, n)
t = [Node() for _ in range((n + 1) << 1)]
for i in range(n):
t[n + i].sum_v = w[i]
# print("org t:", t)
build(0, n - 1)
# print("build:", t)
res = []
while m:
op, opt = input().split(maxsplit=1)
if op == 'Q':
l, r = map(int, opt.split())
res.append(str(query(l - 1, r)))
# print(query(l - 1, r))
else:
l, r, val = map(int, opt.split())
modify_range(l - 1, r, val)
m -= 1
print("\n".join(res))