题目描述
给定长度为 $n$ 的序列 $a$,支持以下操作:
1 l r x
将区间 $[l,r]$ 中的 $x$ 得倍数都除以 $x$,2 l r
查询区间 $[l,r]$ 内元素的和。$1\leq n,m \leq 10^5, 1\leq l,r \leq 10^5, 0\leq a_i \leq 5 \times 10^5, 1\leq x\leq 5\times 10^5$。
思路
$d$ 表示 $[1,n]$ 中因数个数最多的数得因数个数,大约为 $200$,$M$ 表示值域。
对与每一个数 $x\in[2,5\times 10^5]$,建议棵平衡树,动态维护能被 $x$ 整除的数。
初始化如果一个一个加入复杂度 $O(nd \log n+n\sqrt{M})$,我们先存下来,然后整体加入,复杂度 $O(\sqrt{M}+nd)$
对于每一次修改,我们在 $x$ 的平衡树里找下标在 $[l,r]$ 中的数,一个一个修改,用树状数组或线段树维护区间和,由于所有数只会减小,所以总共只会执行 $O(n\log n)$ 次,由于修改时需要在平衡树上修改,所以复杂度 $O(nd\log^2 n+n\log n \sqrt{M})$。
我们考虑每次修改不动态维护平衡树信息,等到下次一个一个查找时再判断删除,由于平衡树少总共只有 $O(nd)$ 个点,平衡树删除 $O(\log n)$,所以我们得到了一个 $O(nd\log n)$ 的复杂度。
如果我们先将 $[l,r]$ 范围内的数整体删除,再将所有可以整除的数加入,加入的点只有 $O(n\log n)$ 个,在树状数组或线段树上修改复杂度 $O(\log n)$,总复杂度 $O(n\sqrt{M}+n\log^2 n)$。
实现
实现用 FHQ 和树状数组,由于有区间删除操作,所以不能用 set
,并且用了空间回收。
#include <bits/stdc++.h>
using namespace std;
const int N = 100010, M = 500010, K = 2e7 + 10;
int a[N], rt[M], cur[N], tot;
long long tr[M];
vector<int> ru, v, d[M];
struct node {
int l, r, rnd, key;
} T[K];
int newnode(int x) {
if (true) {
T[++ tot] = {0, 0, rand(), x};
return tot;
} else {
int id = ru.back();
ru.pop_back();
T[id] = {0, 0, rand(), x};
return id;
}
}
int build(int l, int r) {
if (l > r) return 0;
int mid = l + r >> 1;
int id = newnode(cur[mid]);
if (l == r) return id;
T[id].l = build(l, mid - 1);
T[id].r = build(mid + 1, r);
T[id].rnd = max(T[id].rnd, max(T[T[id].l].rnd, T[T[id].r].rnd) + 1);
return id;
}
int merge(int x, int y) {
if (!x || !y) return x | y;
if (T[x].rnd < T[y].rnd) {
T[y].l = merge(x, T[y].l);
return y;
} else {
T[x].r = merge(T[x].r, y);
return x;
}
}
void split(int x, int y, int &l, int &r) {
if (!x) return l = 0, r = 0, void();
if (T[x].key <= y) {
l = x;
split(T[x].r, y, T[x].r, r);
} else {
r = x;
split(T[x].l, y, l, T[x].l);
}
}
void dfs(int u) {
if (!u) return;
v.push_back(T[u].key);
ru.push_back(u);
dfs(T[u].l), dfs(T[u].r);
}
void add(int x, int k) {
for (int i = x; i < M; i += i & -i)
tr[i] += k;
}
long long qry(int x) {
long long res = 0;
for (int i = x; i; i -= i & -i) res += tr[i];
return res;
}
int main() {
srand(time(0));
int n, m;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i ++ ) scanf("%d", &a[i]), add(i, a[i]);
for (int i = 1; i <= n; i ++ )
for (int j = 1; j * j <= a[i]; j ++ ) {
if (a[i] % j == 0) {
if (j > 1) d[j].push_back(i);
if (j * j < a[i] && a[i] / j > 1) d[a[i] / j].push_back(i);
}
}
for (int i = 2; i < M; i ++ ) {
if (d[i].empty()) continue;
int len = d[i].size();
for (int j = 1; j <= len; j ++ ) cur[j] = d[i][j - 1];
rt[i] = build(1, len);
}
while (m -- ) {
int op;
scanf("%d", &op);
if (op == 1) {
int l, r, x;
scanf("%d%d%d", &l, &r, &x);
if (x == 1 || !rt[x]) continue;
int h, y, z;
split(rt[x], r, y, z);
split(y, l - 1, h, y);
rt[x] = merge(h, z);
v.clear();
dfs(y);
for (int j : v) {
if (a[j] % x == 0) {
add(j, a[j] / x - a[j]);
a[j] /= x;
if (a[j] % x == 0) {
int h1 = 0, h2 = 0;
split(rt[x], j - 1, h1, h2);
h1 = merge(h1, newnode(j));
rt[x] = merge(h1, h2);
}
}
}
} else {
int l, r;
scanf("%d%d", &l, &r);
printf("%lld\n", qry(r) - qry(l - 1));
}
}
return 0;
}