树状数组
主要操作:
一维树状数组:
int tr[N];
int lowbit(int x){
return x & -x;
}
void add(int x, int k){ // 第 x 个数加上 k
for(int i = x; i <= n; i += lowbit(i)) tr[i] += k;
}
int query(int x){ // 查询 [1, x] 的总和
int res = 0;
for(int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
二维树状数组:
int tr[N][N];
int lowbit(int x){
return x & -x;
}
void add(int x, int y, int k){
for(int i = x; i < N; i += lowbit(i)){
for(int j = y; j < N; j += lowbit(j)){
tr[i][j] += k;
}
}
}
int query(int x, int y){
int res = 0;
for(int i = x; i; i -= lowbit(i)){
for(int j = y; j; j -= lowbit(j)){
res += tr[i][j];
}
}
return res;
}
模板 1:(前缀和)
对已知数列 ${A_i}$:
- 1 :将某一个数加上 $x$
- 2 :求指定区间内的数的和
int mian(){
cin >> n >> m;
for(int i = 1; i <= n; i ++){
cin >> a[i];
add(i, a[i]);
}
while(m --){
int op, x, k;
cin >> op >> x >> k;
if(op == 1) add(x, k); // 操作 1
else{ // 操作 2
int res = query(k) - query(x - 1);
cout << res << '\n';
}
}
return 0;
}
模板 2:(差分)
对已知数列 ${A_i}$:
- 1 :将某区间内的每一个数加上 $x$
- 2 :求出某一个数的值
int main(){
cin >> n >> m;
for(int i = 1; i <= n; i ++){
cin >> a[i];
add(i, a[i] - a[i - 1]);
}
while(m --){
int op;
cin >> op;
if(op == 1){ // 操作 1
int x, y, k;
cin >> x >> y >> k;
add(x, k);
add(y + 1, -k);
}
else{ // 操作 2
int x;
cin >> x;
int res = query(x);
cout << res << '\n';
}
}
return 0;
}
模板3
对已知数列 ${A_i}$:
- 1 :将某区间内的每一个数加上 $k$
- 2 :求指定区间
[l, r]
内的数的和
方式:差分
设求 a[i]
在 [1, x]
区间上的和,公式为 :
$\sum_{i = 1}^{x} a[i] = \sum_{i = 1}^{x}\sum_{j = 1}^{i}b[j] = = (1 + x)\sum_{i = 1}^{x}b[i] - \sum_{i = 1}^{x}b[i] * i$
$\sum_{i = 1}^{x}\sum_{j = 1}^{i}b[j] = (1 + x)\sum_{i = 1}^{x}b[i] - \sum_{i = 1}^{x}b[i] * i$
int n, m;
int a[N];
int tr1[N]; // 维护b[i]的前缀和
int tr2[N]; // 维护b[i] * i的前缀和
int lowbit(int x){
return x & -x;
}
void add(int tr[], int x, int k){
for(int i = x; i <= n; i += lowbit(i)) tr[i] += k;
}
int sum(int tr[], int x){
int res = 0;
for(int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
int pre_sum(int x){
int a = sum(tr1, x) * (x + 1);
int b = sum(tr2, x);
return a - b;
}
int main(){
cin >> n >> m;
for(int i = 1; i <= n; i ++){
cin >> a[i];
int b = a[i] - a[i - 1];
add(tr1, i, b);
add(tr2, i, b * i);
}
while(m --){
int op;
int l, r, d;
cin >> op;
if(op == 1){
cin >> l >> r >> d;
add(tr1, l, d), add(tr1, r + 1, -d);
add(tr2, l, l * d), add(tr2, r + 1, -(r + 1) * d);
}
else{
cin >> l >> r;
cout << pre_sum(r) - pre_sum(l - 1) << '\n';
}
}
return 0;
}
模板 4:
已知三个数列 {$a_i$}、{$b_i$}、{$c_i$},且满足 $b_i$=$\sum_{j = 1}^{i}a_j$,$c_i$=$\sum_{j = 1}^{i}b_j$,
- 1:将数列 {$a_i$} 的某一个数加上 $k$;
- 2:求数列 {$c_i$} 的指定区间内的数的和。
与 模板3 类似,设求 c[i]
在 [1, x]
区间上的和,公式为 :
$\sum_{i = 1}^{x}c_i = c_1 + c_2 + c_3… + c_x = b_1 + (b_1 + b_2) + (b_1 + b_2 + b_3) + … + (b_1 + b_2 + … + b_x)$
$ = x \cdot b1 + (x - 1) \cdot b_2 + (x - 2) \cdot b_3 + … + 1 \cdot b_x$
$ = x \cdot a_1 + (x - 1) \cdot (a_1 + a_2)+ … + 1 \cdot (a_1 + a_2 + … + a_x)$
$ = (x + … + 1) \cdot a_1 + ((x - 1) + … + 1) \cdot a_2 + … + 1 \cdot a_x$
$ = \sum_{i = 1}^{x} \frac{(x - i + 1)(x - i + 2)}{2}a_i$
$ = \frac{(x + 1)(x + 2)}{2}\sum_{i = 1}^{x} a_i - \frac{2x + 3}{2}\sum_{i = 1}^{x} a_i \cdot i + \frac{1}{2}\sum_{i = 1}^{x} a_i \cdot i^2$
显然,需要维护 a[i]
,a[i] * i
,a[i] * i * i
三个前缀和树状数组
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e5 + 5;
int n, m;
int a[N];
int tr1[N]; // 维护a[i]的前缀和
int tr2[N]; // 维护a[i] * i的前缀和
int tr3[N]; // 维护a[i] * i * i的前缀和
int lowbit(int x){
return x & -x;
}
void add(int tr[], int x, int k){
for(int i = x; i <= n; i += lowbit(i)) tr[i] += k;
}
int sum(int tr[], int x){
int res = 0;
for(int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
int pre_sum(int k){
int a = (k + 1) * (k + 2) * sum(tr1, k);
int b = (2 * k + 3) * sum(tr2, k);
int c = sum(tr3, k);
return (a - b + c) / 2;
}
signed main(){
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
cin >> n >> m;
for(int i = 1; i <= n; i ++){
cin >> a[i];
add(tr1, i, a[i]);
add(tr2, i, a[i] * i);
add(tr3, i, a[i] * i * i);
}
while(m --){
int op;
int x, y, k;
cin >> op;
if(op == 1){
cin >> x >> k;
add(tr1, x, k);
add(tr2, x, x * k);
add(tr3, x, x * x * k);
}
else{
cin >> x >> y;
cout << pre_sum(y) - pre_sum(x - 1) << '\n';
}
}
return 0;
}