题目描述
给定长度为 N 的数列 A,以及 M 条指令,每条指令可能是以下两种之一:
1 x y,查询区间 [x,y] 中的最大连续子段和,即 maxx≤l≤r≤y{∑i=lrA[i]}。
2 x y,把 A[x] 改成 y。
对于每个查询指令,输出一个整数表示答案。
输入格式
第一行两个整数 N,M。
第二行 N 个整数 A[i]。
接下来 M 行每行 3 个整数 k,x,y,k=1 表示查询(此时如果 x>y,请交换 x,y),k=2 表示修改。
输出格式
对于每个查询指令输出一个整数表示答案。
每个答案占一行。
数据范围
N≤500000,M≤100000,
−1000≤A[i]≤1000
样例
输入样例:
5 3
1 2 -3 4 5
1 2 3
2 2 -1
1 3 2
输出样例:
2
-1
算法1
(线段树(分情况讨论)) $O()$
1.lmax是左端点在内的情形
2.rmax是右端点在内的情形
3.线段树本身就是dfs的灵活运用,无论是边界还是分情况讨论
4.python3 TLE
时间复杂度
参考文献
python3 代码
class SegTreeNode:
def __init__(self, l=0, r=0, sum=0, max=0, lmax=0, rmax=0):
self.l = l
self.r = r
self.sum = sum
self.max = max
self.lmax = lmax
self.rmax = rmax
global a
global tree
a = []
tree = []
def push_up(root: int, ll: int, rr: int) -> None:
tree[root].sum = tree[ll].sum + tree[rr].sum
tree[root].lmax = max(tree[ll].lmax, tree[ll].sum + tree[rr].lmax)
tree[root].rmax = max(tree[rr].rmax, tree[ll].rmax + tree[rr].sum)
tree[root].max = max(tree[ll].max, tree[rr].max, tree[ll].rmax + tree[rr].lmax)
def build(root: int, l: int, r: int) -> None:
tree[root].l = l
tree[root].r = r
if l == r:
tree[root].sum = a[l]
tree[root].max = a[l]
tree[root].lmax = a[l]
tree[root].rmax = a[l]
return
ll = root << 1
rr = root << 1 | 1
mid = l + r >> 1
# push_down()
build(ll, l, mid)
build(rr, mid + 1, r)
push_up(root, ll, rr)
def update(root: int, idx: int, val: int) -> None:
if tree[root].l == tree[root].r:
tree[root].sum = val
tree[root].max = val
tree[root].lmax = val
tree[root].rmax = val
return
ll = root << 1
rr = root << 1 | 1
mid = tree[root].l + tree[root].r >> 1
# push_down()
if idx <= mid:
update(ll, idx, val)
else:
update(rr, idx, val)
push_up(root, ll, rr)
def query (root: int, ql: int, qr: int) -> SegTreeNode:
if ql <= tree[root].l and tree[root].r <= qr:
return tree[root]
ll = root << 1
rr = root << 1 | 1
mid = tree[root].l + tree[root].r >> 1
L = SegTreeNode()
R = SegTreeNode()
res = SegTreeNode()
L.sum = L.max = L.lmax = L.rmax = -(1 << 30)
R.sum = R.max = R.lmax = R.rmax = -(1 << 30)
res.sum = 0
if ql <= mid:
L = query(ll, ql, qr)
res.sum += L.sum
if mid + 1 <= qr:
R = query(rr, ql, qr)
res.sum += R.sum
res.max = max(L.max, R.max, L.rmax + R.lmax)
res.lmax = max(L.lmax, L.sum + R.lmax)
#----只有右子
if mid + 1 <= ql:
res.lmax = max(res.lmax, R.lmax)
res.rmax = max(R.rmax, L.rmax + R.sum)
#----只有左子
if qr <= mid:
res.rmax = max(res.rmax, L.rmax)
return res
def main():
global a
global tree
n, m = map(int, input().split())
a = [0] + list(map(int, input().split()))
tree = [SegTreeNode() for _ in range(4 * n)]
build(1, 1, n)
for _ in range(m):
op, x, y = map(int, input().split())
if op == 1:
if x > y:
x, y = y, x
cur = query(1, x, y).max
print(cur)
else:
update(1, x, y)
if __name__ == "__main__":
main()
C++ 代码
#include <iostream>
#include <string.h>
#include <algorithm>
using namespace std;
class SegTreeNode
{
public:
int l;
int r;
int sum;
int max;
int lmax;
int rmax;
SegTreeNode() {}
SegTreeNode(int l_, int r_, int sum_, int max_, int lmax_, int rmax_)
{
l = l_;
r = r_;
sum = sum_;
max = max_;
lmax = lmax_;
rmax = rmax_;
}
};
int * a;
SegTreeNode * tree;
void push_up(int root, int ll, int rr)
{
tree[root].sum = tree[ll].sum + tree[rr].sum;
tree[root].lmax = max(tree[ll].lmax, tree[ll].sum + tree[rr].lmax);
tree[root].rmax = max(tree[rr].rmax, tree[ll].rmax + tree[rr].sum);
tree[root].max = max(max(tree[ll].max, tree[rr].max), tree[ll].rmax + tree[rr].lmax);
}
void build(int root, int l, int r)
{
tree[root].l = l;
tree[root].r = r;
if (l == r)
{
tree[root].sum = a[l];
tree[root].max = a[l];
tree[root].lmax = a[l];
tree[root].rmax = a[l];
return ;
}
int ll = root << 1;
int rr = root << 1 | 1;
int mid = l + r >> 1;
// push_down();
build(ll, l, mid);
build(rr, mid + 1, r);
push_up(root, ll, rr);
}
void update(int root, int idx, int val)
{
if (tree[root].l == tree[root].r)
{
tree[root].sum = val;
tree[root].max = val;
tree[root].lmax = val;
tree[root].rmax = val;
return ;
}
int ll = root << 1;
int rr = root << 1 | 1;
int mid = tree[root].l + tree[root].r >> 1;
// push_down();
if (idx <= mid)
{
update(ll, idx, val);
}
else
{
update(rr, idx, val);
}
push_up(root, ll, rr);
}
SegTreeNode query(int root, int ql, int qr)
{
if (ql <= tree[root].l && tree[root].r <= qr)
{
return tree[root];
}
int ll = root << 1;
int rr = root << 1 | 1;
int mid = tree[root].l + tree[root].r >> 1;
SegTreeNode L = SegTreeNode();
SegTreeNode R = SegTreeNode();
SegTreeNode res;
L.sum = L.max = L.lmax = L.rmax = -(1 << 30);
R.sum = R.max = R.lmax = R.rmax = -(1 << 30);
res.sum = 0;
if (ql <= mid)
{
L = query(ll, ql, qr);
res.sum += L.sum;
}
if (mid + 1 <= qr)
{
R = query(rr, ql, qr);
res.sum += R.sum;
}
res.max = max(max(L.max, R.max), L.rmax + R.lmax);
res.lmax = max(L.lmax, L.sum + R.lmax);
if (mid + 1 <= ql)
{
res.lmax = max(res.lmax, R.lmax);
}
res.rmax = max(R.rmax, L.rmax + R.sum);
if (qr <= mid)
{
res.rmax = max(res.rmax, L.rmax);
}
return res;
}
int main()
{
std::ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int n; cin >> n;
int m; cin >> m;
a = new int[n + 1];
for (int i = 1; i < n + 1; i ++)
{
cin >> a[i];
}
tree = new SegTreeNode[4 * n];
build(1, 1, n);
for (int _ = 0; _ < m; _ ++)
{
int op; cin >> op;
if (op == 1)
{
int l; cin >> l;
int r; cin >> r;
if (l > r)
{
swap(l, r);
}
int cur = query(1, l, r).max;
cout << cur << endl;
}
else
{
int idx; cin >> idx;
int val; cin >> val;
update(1, idx, val);
}
}
return 0;
}
java 代码
import java.util.Scanner;
class SegTreeNode
{
int l;
int r;
int sum;
int max;
int lmax;
int rmax;
SegTreeNode() {}
SegTreeNode(int l_, int r_, int sum_, int max_, int lmax_, int rmax_)
{
l = l_;
r = r_;
sum = sum_;
max = max_;
lmax = lmax_;
rmax = rmax_;
}
}
public class Main
{
static int [] a;
static SegTreeNode [] tree;
static void push_up(int root, int ll, int rr)
{
tree[root].sum = tree[ll].sum + tree[rr].sum;
tree[root].lmax = Math.max(tree[ll].lmax, tree[ll].sum + tree[rr].lmax);
tree[root].rmax = Math.max(tree[rr].rmax, tree[ll].rmax + tree[rr].sum);
tree[root].max = Math.max(Math.max(tree[ll].max, tree[rr].max), tree[ll].rmax + tree[rr].lmax);
}
static void build(int root, int l, int r)
{
tree[root] = new SegTreeNode();
tree[root].l = l;
tree[root].r = r;
if (l == r)
{
tree[root].sum = a[l];
tree[root].max = a[l];
tree[root].lmax = a[l];
tree[root].rmax = a[l];
return ;
}
int ll = root << 1;
int rr = root << 1 | 1;
int mid = (l + r) >> 1;
// push_down();
build(ll, l, mid);
build(rr, mid + 1, r);
push_up(root, ll, rr);
}
static void update(int root, int idx, int val)
{
if (tree[root].l == tree[root].r)
{
tree[root].sum = val;
tree[root].max = val;
tree[root].lmax = val;
tree[root].rmax = val;
return ;
}
int ll = root << 1;
int rr = root << 1 | 1;
int mid = tree[root].l + tree[root].r >> 1;
// push_down();
if(idx <= mid)
{
update(ll, idx, val);
}
else
{
update(rr, idx, val);
}
push_up(root, ll, rr);
}
static SegTreeNode query(int root, int ql, int qr)
{
if (ql <= tree[root].l && tree[root].r <= qr)
{
return tree[root];
}
int ll = root << 1;
int rr = root << 1 | 1;
int mid = tree[root].l + tree[root].r >> 1;
SegTreeNode L = new SegTreeNode();
SegTreeNode R = new SegTreeNode();
SegTreeNode res = new SegTreeNode();
L.sum = L.max = L.lmax = L.rmax = -(1 << 30);
R.sum = R.max = R.lmax = R.rmax = -(1 << 30);
res.sum = 0;
if (ql <= mid)
{
L = query(ll, ql, qr);
res.sum += L.sum;
}
if (mid + 1 <= qr)
{
R = query(rr, ql, qr);
res.sum += R.sum;
}
res.max = Math.max(Math.max(L.max, R.max), L.rmax + R.lmax);
res.lmax = Math.max(L.lmax, L.sum + R.lmax);
//----只有右子
if (mid + 1 <= ql)
{
res.lmax = Math.max(res.lmax, R.lmax);
}
res.rmax = Math.max(R.rmax, L.rmax + R.sum);
//----只有左子
if (qr <= mid)
{
res.rmax = Math.max(res.rmax, L.rmax);
}
return res;
}
public static void main(String [] args)
{
Scanner scan = new Scanner(System.in);
int n = scan.nextInt();
int m = scan.nextInt();
a = new int[n + 1];
for (int i = 1; i < n + 1; i ++)
{
a[i] = scan.nextInt();
}
tree = new SegTreeNode[4 * n];
build(1, 1, n);
for (int i = 0; i < m; i ++)
{
int op = scan.nextInt();
if (op == 1)
{
int l = scan.nextInt();
int r = scan.nextInt();
if (l > r)
{
int tmp = l;
l = r;
r = tmp;
}
int cur = query(1, l, r).max;
System.out.println(cur);
}
else
{
int idx = scan.nextInt();
int val = scan.nextInt();
update(1, idx, val);
}
}
}
}
算法2
(暴力枚举) $O(n^2)$
blablabla
时间复杂度
参考文献
C++ 代码
blablabla