线段树
一道调了整个下午的模板题
。。。。
题目描述
如题,已知一个数列,你需要进行下面三种操作:
将某区间每一个数乘上 x
将某区间每一个数加上 x
求出某区间每一个数的和
输入格式
第一行包含三个整数 n,m,p分别表示该数列数字的个数、操作的总个数和模数。
第二行包含 n个用空格分隔的整数,其中第 i 个数字表示数列第 i 项的初始值。
接下来 m 行每行包含若干个整数,表示一个操作,具体如下:
操作 1: 格式:1 x y k 含义:将区间 [x,y] 内每个数乘上 k
操作 2: 格式:2 x y k 含义:将区间 [x,y] 内每个数加上 k
操作 3: 格式:3 x y 含义:输出区间 [x,y] 内每个数的和对 p 取模所得的结果
输入样例
5 5 38
1 5 4 2 3
2 1 4 1
3 2 5
1 2 4 2
2 3 5 5
3 1 4
输出样例
17
2
代码
#include <cstring>
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 200010;
struct Node{
int l, r;
ll sum, add1, add2;//add1为加的懒标记 add2为乘的懒标记
}tr[N*4];
int n, m, p;
ll w[N];
void pushdown(int u) //pushdown 操作 对当前结点的所有懒标记向下传递,先乘后加
{
Node &x = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
left.sum = (left.sum * x.add2 + (left.r - left.l + 1) * x.add1) % p;
right.sum = (right.sum * x.add2 + (right.r - right.l+1) * x.add1) % p;
left.add2 = (left.add2 * x.add2 ) % p;
right.add2 = (right.add2 % p * x.add2 % p) % p;
left.add1 = (left.add1 * x.add2 + x.add1) % p;
right.add1 = (right.add1 * x.add2 + x.add1) % p;
x.add1=0;
x.add2=1;
}
void pushup(int u)
{
tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % p;
}
void build(int u, int l, int r)
{
tr[u] = {l, r};
tr[u].add2 = 1;
if(l == r){
tr[u].sum = w[l] % p;
}else{
int mid = l + r >> 1 ;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1,r);
pushup(u);
}
}
void modify1(int u,int l,int r,int v)//区间加
{
if(l <= tr[u].l && tr[u].r <= r){
tr[u].sum = (tr[u].sum + (tr[u].r - tr[u].l + 1) * v) % p;
tr[u].add1 = (tr[u].add1 + v) % p;
}else{
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
if(l <= mid) modify1(u << 1, l, r, v);
if(r > mid) modify1(u << 1 | 1, l, r, v);
pushup(u);
}
}
void modify2(int u, int l, int r, int v)//区间乘
{
if(l <= tr[u].l && tr[u].r <= r){
tr[u].sum = (tr[u].sum * v ) % p;
tr[u].add2 = (tr[u].add2 * v) % p;
tr[u].add1 = (tr[u].add1 * v) % p;
}else{
pushdown(u);
int mid=tr[u].l+tr[u].r >> 1;
if(l <= mid) modify2(u << 1, l, r, v);
if(r > mid) modify2(u << 1 | 1, l, r, v);
pushup(u);
}
}
ll query(int u, int l, int r)
{
if(l <= tr[u].l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
ll res = 0;
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) res = (res + query(u << 1, l, r))% p;
if(r > mid) res = (res + query(u << 1 | 1, l, r)) % p;
pushup(u);
return res;
}
int main()
{
scanf("%d%d%d", &n, &m, &p);
for(int i = 1;i <= n; i ++ ) scanf("%d",&w[i]);
build(1, 1, n);
int tmp;
int x, y, k;
while(m -- ){
scanf("%d%d%d", &tmp, &x, &y);
if(tmp != 3){
scanf("%d", &k);
if(tmp == 2)
modify1(1, x, y, k);
else modify2(1, x, y, k);
}else{
printf("%lld\n", query(1, x, y) % p);
}
}
return 0;
}
建议修改下题面,有点变形了~
好哒~
这样看好多了~多谢~