普通NTT + 任意模数 NTT
class Convolution {
/**
* Find a primitive root.
*
* @param m A prime number.
* @return Primitive root.
*/
private static int primitiveRoot(int m) {
if (m == 2) return 1;
if (m == 167772161) return 3;
if (m == 469762049) return 3;
if (m == 754974721) return 11;
if (m == 998244353) return 3;
int[] divs = new int[20];
divs[0] = 2;
int cnt = 1;
int x = (m - 1) / 2;
while (x % 2 == 0) x /= 2;
for (int i = 3; (long) (i) * i <= x; i += 2) {
if (x % i == 0) {
divs[cnt++] = i;
while (x % i == 0) {
x /= i;
}
}
}
if (x > 1) {
divs[cnt++] = x;
}
for (int g = 2; ; g++) {
boolean ok = true;
for (int i = 0; i < cnt; i++) {
if (pow(g, (m - 1) / divs[i], m) == 1) {
ok = false;
break;
}
}
if (ok) return g;
}
}
/**
* Power.
*
* @param x Parameter x.
* @param n Parameter n.
* @param m Mod.
* @return n-th power of x mod m.
*/
private static long pow(long x, long n, int m) {
if (m == 1) return 0;
long r = 1;
long y = x % m;
while (n > 0) {
if ((n & 1) != 0) r = (r * y) % m;
y = (y * y) % m;
n >>= 1;
}
return r;
}
/**
* Ceil of power 2.
*
* @param n Value.
* @return Ceil of power 2.
*/
private static int ceilPow2(int n) {
int x = 0;
while ((1L << x) < n) x++;
return x;
}
private static class FftInfo {
private static int bsfConstexpr(int n) {
int x = 0;
while ((n & (1 << x)) == 0) x++;
return x;
}
private static long inv(long a, long mod) {
long b = mod;
long p = 1, q = 0;
while (b > 0) {
long c = a / b;
long d;
d = a;
a = b;
b = d % b;
d = p;
p = q;
q = d - c * q;
}
return p < 0 ? p + mod : p;
}
private final int rank2;
public final long[] root;
public final long[] iroot;
public final long[] rate2;
public final long[] irate2;
public final long[] rate3;
public final long[] irate3;
public FftInfo(int g, int mod) {
rank2 = bsfConstexpr(mod - 1);
root = new long[rank2 + 1];
iroot = new long[rank2 + 1];
rate2 = new long[Math.max(0, rank2 - 2 + 1)];
irate2 = new long[Math.max(0, rank2 - 2 + 1)];
rate3 = new long[Math.max(0, rank2 - 3 + 1)];
irate3 = new long[Math.max(0, rank2 - 3 + 1)];
root[rank2] = pow(g, (mod - 1) >> rank2, mod);
iroot[rank2] = inv(root[rank2], mod);
for (int i = rank2 - 1; i >= 0; i--) {
root[i] = root[i + 1] * root[i + 1] % mod;
iroot[i] = iroot[i + 1] * iroot[i + 1] % mod;
}
{
long prod = 1, iprod = 1;
for (int i = 0; i <= rank2 - 2; i++) {
rate2[i] = root[i + 2] * prod % mod;
irate2[i] = iroot[i + 2] * iprod % mod;
prod = prod * iroot[i + 2] % mod;
iprod = iprod * root[i + 2] % mod;
}
}
{
long prod = 1, iprod = 1;
for (int i = 0; i <= rank2 - 3; i++) {
rate3[i] = root[i + 3] * prod % mod;
irate3[i] = iroot[i + 3] * iprod % mod;
prod = prod * iroot[i + 3] % mod;
iprod = iprod * root[i + 3] % mod;
}
}
}
};
/**
* Garner's algorithm.
*
* @param c Mod convolution results.
* @param mods Mods.
* @return Result.
*/
private static long garner(long[] c, int[] mods) {
int n = c.length + 1;
long[] cnst = new long[n];
long[] coef = new long[n];
java.util.Arrays.fill(coef, 1);
for (int i = 0; i < n - 1; i++) {
int m1 = mods[i];
long v = (c[i] - cnst[i] + m1) % m1;
v = v * pow(coef[i], m1 - 2, m1) % m1;
for (int j = i + 1; j < n; j++) {
long m2 = mods[j];
cnst[j] = (cnst[j] + coef[j] * v) % m2;
coef[j] = (coef[j] * m1) % m2;
}
}
return cnst[n - 1];
}
/**
* Inverse NTT.
*
* @param a Target array.
* @param g Primitive root of mod.
* @param mod NTT Prime.
*/
private static void butterflyInv(long[] a, int g, int mod) {
int n = a.length;
int h = ceilPow2(n);
FftInfo info = new FftInfo(g, mod);
int len = h; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
while (len > 0) {
if (len == 1) {
int p = 1 << (h - len);
long irot = 1;
for (int s = 0; s < (1 << (len - 1)); s++) {
int offset = s << (h - len + 1);
for (int i = 0; i < p; i++) {
long l = a[i + offset];
long r = a[i + offset + p];
a[i + offset] = (l + r) % mod;
a[i + offset + p] = (mod + l - r) % mod * irot % mod;
}
if (s + 1 != (1 << (len - 1))) {
irot *= info.irate2[Integer.numberOfTrailingZeros(~s)];
irot %= mod;
}
}
len--;
} else {
// 4-base
int p = 1 << (h - len);
long irot = 1, iimag = info.iroot[2];
for (int s = 0; s < (1 << (len - 2)); s++) {
long irot2 = irot * irot % mod;
long irot3 = irot2 * irot % mod;
int offset = s << (h - len + 2);
for (int i = 0; i < p; i++) {
long a0 = 1L * a[i + offset + 0 * p];
long a1 = 1L * a[i + offset + 1 * p];
long a2 = 1L * a[i + offset + 2 * p];
long a3 = 1L * a[i + offset + 3 * p];
long a2na3iimag = 1L * (mod + a2 - a3) % mod * iimag % mod;
a[i + offset] = (a0 + a1 + a2 + a3) % mod;
a[i + offset + 1 * p] = (a0 + (mod - a1) + a2na3iimag) % mod * irot % mod;
a[i + offset + 2 * p] = (a0 + a1 + (mod - a2) + (mod - a3)) % mod * irot2 % mod;
a[i + offset + 3 * p] = (a0 + (mod - a1) + (mod - a2na3iimag)) % mod * irot3 % mod;
}
if (s + 1 != (1 << (len - 2))) {
irot *= info.irate3[Integer.numberOfTrailingZeros(~s)];
irot %= mod;
}
}
len -= 2;
}
}
}
/**
* Inverse NTT.
*
* @param a Target array.
* @param g Primitive root of mod.
* @param mod NTT Prime.
*/
private static void butterfly(long[] a, int g, int mod) {
int n = a.length;
int h = ceilPow2(n);
FftInfo info = new FftInfo(g, mod);
int len = 0; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
while (len < h) {
if (h - len == 1) {
int p = 1 << (h - len - 1);
long rot = 1;
for (int s = 0; s < (1 << len); s++) {
int offset = s << (h - len);
for (int i = 0; i < p; i++) {
long l = a[i + offset];
long r = a[i + offset + p] * rot % mod;
a[i + offset] = (l + r) % mod;
a[i + offset + p] = (l + mod - r) % mod;
}
if (s + 1 != (1 << len)) {
rot *= info.rate2[Integer.numberOfTrailingZeros(~s)];
rot %= mod;
}
}
len++;
} else {
// 4-base
int p = 1 << (h - len - 2);
long rot = 1, imag = info.root[2];
for (int s = 0; s < (1 << len); s++) {
long rot2 = rot * rot % mod;
long rot3 = rot2 * rot % mod;
int offset = s << (h - len);
for (int i = 0; i < p; i++) {
long mod2 = 1L * mod * mod;
long a0 = 1L * a[i + offset];
long a1 = 1L * a[i + offset + p] * rot % mod;
long a2 = 1L * a[i + offset + 2 * p] * rot2 % mod;
long a3 = 1L * a[i + offset + 3 * p] * rot3 % mod;
long a1na3imag = 1L * (a1 + mod2 - a3) % mod * imag % mod;
long na2 = mod2 - a2;
a[i + offset] = (a0 + a2 + a1 + a3) % mod;
a[i + offset + 1 * p] = (a0 + a2 + (2 * mod2 - (a1 + a3))) % mod;
a[i + offset + 2 * p] = (a0 + na2 + a1na3imag) % mod;
a[i + offset + 3 * p] = (a0 + na2 + (mod2 - a1na3imag)) % mod;
}
if (s + 1 != (1 << len)) {
rot *= info.rate3[Integer.numberOfTrailingZeros(~s)];
rot %= mod;
}
}
len += 2;
}
}
}
/**
* Convolution.
*
* @param a Target array 1.
* @param b Target array 2.
* @param mod NTT Prime.
* @return Answer.
*/
public static long[] convolution(long[] a, long[] b, int mod) {
int n = a.length;
int m = b.length;
if (n == 0 || m == 0) return new long[0];
int z = 1 << ceilPow2(n + m - 1);
{
long[] na = new long[z];
long[] nb = new long[z];
System.arraycopy(a, 0, na, 0, n);
System.arraycopy(b, 0, nb, 0, m);
a = na;
b = nb;
}
int g = primitiveRoot(mod);
butterfly(a, g, mod);
butterfly(b, g, mod);
for (int i = 0; i < z; i++) {
a[i] = a[i] * b[i] % mod;
}
butterflyInv(a, g, mod);
a = java.util.Arrays.copyOf(a, n + m - 1);
long iz = pow(z, mod - 2, mod);
for (int i = 0; i < n + m - 1; i++) a[i] = a[i] * iz % mod;
return a;
}
/**
* Convolution.
*
* @param a Target array 1.
* @param b Target array 2.
* @param mod Any mod.
* @return Answer.
*/
public static long[] convolutionLL(long[] a, long[] b, int mod) {
int n = a.length;
int m = b.length;
if (n == 0 || m == 0) return new long[0];
int mod1 = 754974721;
int mod2 = 167772161;
int mod3 = 469762049;
long[] c1 = convolution(a, b, mod1);
long[] c2 = convolution(a, b, mod2);
long[] c3 = convolution(a, b, mod3);
int retSize = c1.length;
long[] ret = new long[retSize];
int[] mods = {mod1, mod2, mod3, mod};
for (int i = 0; i < retSize; ++i) {
ret[i] = garner(new long[]{c1[i], c2[i], c3[i]}, mods);
}
return ret;
}
/**
* Convolution by ModInt.
*
* @param a Target array 1.
* @param b Target array 2.
* @return Answer.
*/
public static java.util.List<ModIntFactory.ModInt> convolution(
java.util.List<ModIntFactory.ModInt> a,
java.util.List<ModIntFactory.ModInt> b
) {
int mod = a.get(0).mod();
long[] va = a.stream().mapToLong(ModIntFactory.ModInt::value).toArray();
long[] vb = b.stream().mapToLong(ModIntFactory.ModInt::value).toArray();
long[] c = convolutionLL(va, vb, mod);
ModIntFactory factory = new ModIntFactory(mod);
return java.util.Arrays.stream(c).mapToObj(factory::create).collect(java.util.stream.Collectors.toList());
}
/**
* Naive convolution. (Complexity is O(N^2)!!)
*
* @param a Target array 1.
* @param b Target array 2.
* @param mod Mod.
* @return Answer.
*/
public static long[] convolutionNaive(long[] a, long[] b, int mod) {
int n = a.length;
int m = b.length;
int k = n + m - 1;
long[] ret = new long[k];
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
ret[i + j] += a[i] * b[j] % mod;
ret[i + j] %= mod;
}
}
return ret;
}
}
字符串算法
import java.util.Arrays;
class StringAlgorithm {
private static int[] saNaive(int[] s) {
int n = s.length;
int[] sa = new int[n];
for(int i = 0;i < n;i++){
sa[i] = i;
}
insertionsortUsingComparator(sa, (l, r) -> {
while (l < n && r < n) {
if (s[l] != s[r]) return s[l] - s[r];
l++;
r++;
}
return -(l - r);
});
return sa;
}
private static int[] saDoubling(int[] s) {
int n = s.length;
int[] sa = new int[n];
for(int i = 0;i < n;i++){
sa[i] = i;
}
int[] rnk = java.util.Arrays.copyOf(s, n);
int[] tmp = new int[n];
for (int k = 1; k < n; k *= 2) {
final int _k = k;
final int[] _rnk = rnk;
java.util.function.IntBinaryOperator cmp = (x, y) -> {
if (_rnk[x] != _rnk[y]) return _rnk[x] - _rnk[y];
int rx = x + _k < n ? _rnk[x + _k] : -1;
int ry = y + _k < n ? _rnk[y + _k] : -1;
return rx - ry;
};
mergesortUsingComparator(sa, cmp);
tmp[sa[0]] = 0;
for (int i = 1; i < n; i++) {
tmp[sa[i]] = tmp[sa[i - 1]] + (cmp.applyAsInt(sa[i - 1], sa[i]) < 0 ? 1 : 0);
}
int[] buf = tmp;
tmp = rnk;
rnk = buf;
}
return sa;
}
private static void insertionsortUsingComparator(int[] a, java.util.function.IntBinaryOperator comparator) {
final int n = a.length;
for (int i = 1; i < n; i++) {
final int tmp = a[i];
if (comparator.applyAsInt(a[i - 1], tmp) > 0) {
int j = i;
do {
a[j] = a[j - 1];
j--;
} while (j > 0 && comparator.applyAsInt(a[j - 1], tmp) > 0);
a[j] = tmp;
}
}
}
private static void mergesortUsingComparator(int[] a, java.util.function.IntBinaryOperator comparator) {
final int n = a.length;
final int[] work = new int[n];
for (int block = 1; block <= n; block <<= 1) {
final int block2 = block << 1;
for (int l = 0, max = n - block; l < max; l += block2) {
int m = l + block;
int r = Math.min(l + block2, n);
System.arraycopy(a, l, work, 0, block);
for (int i = l, wi = 0, ti = m;; i++) {
if (ti == r) {
System.arraycopy(work, wi, a, i, block - wi);
break;
}
if (comparator.applyAsInt(work[wi], a[ti]) > 0) {
a[i] = a[ti++];
} else {
a[i] = work[wi++];
if (wi == block) break;
}
}
}
}
}
private static final int THRESHOLD_NAIVE = 50;
private static final int THRESHOLD_DOUBLING = 0;
private static int[] sais(int[] s, int upper) {
int n = s.length;
if (n == 0) return new int[0];
if (n == 1) return new int[]{0};
if (n == 2) {
if (s[0] < s[1]) {
return new int[]{0, 1};
} else {
return new int[]{1, 0};
}
}
if (n < THRESHOLD_NAIVE) {
return saNaive(s);
}
// if (n < THRESHOLD_DOUBLING) {
// return saDoubling(s);
// }
int[] sa = new int[n];
boolean[] ls = new boolean[n];
for (int i = n - 2; i >= 0; i--) {
ls[i] = s[i] == s[i + 1] ? ls[i + 1] : s[i] < s[i + 1];
}
int[] sumL = new int[upper + 1];
int[] sumS = new int[upper + 1];
for (int i = 0; i < n; i++) {
if (ls[i]) {
sumL[s[i] + 1]++;
} else {
sumS[s[i]]++;
}
}
for (int i = 0; i <= upper; i++) {
sumS[i] += sumL[i];
if (i < upper) sumL[i + 1] += sumS[i];
}
java.util.function.Consumer<int[]> induce = lms -> {
java.util.Arrays.fill(sa, -1);
int[] buf = new int[upper + 1];
System.arraycopy(sumS, 0, buf, 0, upper + 1);
for (int d : lms) {
if (d == n) continue;
sa[buf[s[d]]++] = d;
}
System.arraycopy(sumL, 0, buf, 0, upper + 1);
sa[buf[s[n - 1]]++] = n - 1;
for (int i = 0; i < n; i++) {
int v = sa[i];
if (v >= 1 && !ls[v - 1]) {
sa[buf[s[v - 1]]++] = v - 1;
}
}
System.arraycopy(sumL, 0, buf, 0, upper + 1);
for (int i = n - 1; i >= 0; i--) {
int v = sa[i];
if (v >= 1 && ls[v - 1]) {
sa[--buf[s[v - 1] + 1]] = v - 1;
}
}
};
int[] lmsMap = new int[n + 1];
java.util.Arrays.fill(lmsMap, -1);
int m = 0;
for (int i = 1; i < n; i++) {
if (!ls[i - 1] && ls[i]) {
lmsMap[i] = m++;
}
}
int[] lms = new int[m];
{
int p = 0;
for (int i = 1; i < n; i++) {
if (!ls[i - 1] && ls[i]) {
lms[p++] = i;
}
}
}
induce.accept(lms);
if (m > 0) {
int[] sortedLms = new int[m];
{
int p = 0;
for (int v : sa) {
if (lmsMap[v] != -1) {
sortedLms[p++] = v;
}
}
}
int[] recS = new int[m];
int recUpper = 0;
recS[lmsMap[sortedLms[0]]] = 0;
for (int i = 1; i < m; i++) {
int l = sortedLms[i - 1], r = sortedLms[i];
int endL = (lmsMap[l] + 1 < m) ? lms[lmsMap[l] + 1] : n;
int endR = (lmsMap[r] + 1 < m) ? lms[lmsMap[r] + 1] : n;
boolean same = true;
if (endL - l != endR - r) {
same = false;
} else {
while (l < endL && s[l] == s[r]) {
l++;
r++;
}
if (l == n || s[l] != s[r]) same = false;
}
if (!same) {
recUpper++;
}
recS[lmsMap[sortedLms[i]]] = recUpper;
}
int[] recSA = sais(recS, recUpper);
for (int i = 0; i < m; i++) {
sortedLms[i] = lms[recSA[i]];
}
induce.accept(sortedLms);
}
return sa;
}
public static int[] suffixArray(int[] s, int upper) {
assert (0 <= upper);
for (int d : s) {
assert (0 <= d && d <= upper);
}
return sais(s, upper);
}
public static int[] suffixArray(int[] s)
{
int n = s.length;
int[] vals = Arrays.copyOf(s, n);
java.util.Arrays.sort(vals);
int p = 1;
for(int i = 1;i < n;i++){
if(vals[i] != vals[i-1]){
vals[p++] = vals[i];
}
}
int[] s2 = new int[n];
for(int i = 0;i < n;i++){
s2[i] = java.util.Arrays.binarySearch(vals, 0, p, s[i]);
}
return sais(s2, p);
}
public static int[] suffixArray(char[] s) {
int n = s.length;
int[] s2 = new int[n];
for (int i = 0; i < n; i++) {
s2[i] = s[i];
}
return sais(s2, 255);
}
public static int[] suffixArray(java.lang.String s) {
return suffixArray(s.toCharArray());
}
public static int[] lcpArray(int[] s, int[] sa) {
int n = s.length;
assert (n >= 1);
int[] rnk = new int[n];
for (int i = 0; i < n; i++) {
rnk[sa[i]] = i;
}
int[] lcp = new int[n - 1];
int h = 0;
for (int i = 0; i < n; i++) {
if (h > 0) h--;
if (rnk[i] == 0) {
continue;
}
int j = sa[rnk[i] - 1];
for (; j + h < n && i + h < n; h++) {
if (s[j + h] != s[i + h]) break;
}
lcp[rnk[i] - 1] = h;
}
return lcp;
}
public static int[] lcpArray(char[] s, int[] sa) {
int n = s.length;
int[] s2 = new int[n];
for (int i = 0; i < n; i++) {
s2[i] = s[i];
}
return lcpArray(s2, sa);
}
public static int[] lcpArray(java.lang.String s, int[] sa) {
return lcpArray(s.toCharArray(), sa);
}
public static int[] zAlgorithm(int[] s) {
int n = s.length;
if (n == 0) return new int[0];
int[] z = new int[n];
for (int i = 1, j = 0; i < n; i++) {
int k = j + z[j] <= i ? 0 : Math.min(j + z[j] - i, z[i - j]);
while (i + k < n && s[k] == s[i + k]) k++;
z[i] = k;
if (j + z[j] < i + z[i]) j = i;
}
z[0] = n;
return z;
}
public static int[] zAlgorithm(char[] s) {
int n = s.length;
if (n == 0) return new int[0];
int[] z = new int[n];
for (int i = 1, j = 0; i < n; i++) {
int k = j + z[j] <= i ? 0 : Math.min(j + z[j] - i, z[i - j]);
while (i + k < n && s[k] == s[i + k]) k++;
z[i] = k;
if (j + z[j] < i + z[i]) j = i;
}
z[0] = n;
return z;
}
public static int[] zAlgorithm(String s) {
return zAlgorithm(s.toCharArray());
}
}
数学
class MathLib{
private static long safe_mod(long x, long m){
x %= m;
if(x<0) x += m;
return x;
}
private static long[] inv_gcd(long a, long b){
a = safe_mod(a, b);
if(a==0) return new long[]{b,0};
long s=b, t=a;
long m0=0, m1=1;
while(t>0){
long u = s/t;
s -= t*u;
m0 -= m1*u;
long tmp = s; s = t; t = tmp;
tmp = m0; m0 = m1; m1 = tmp;
}
if(m0<0) m0 += b/s;
return new long[]{s,m0};
}
public static long gcd(long... a){
if(a.length == 0) return 0;
long r = java.lang.Math.abs(a[0]);
for(int i=1; i<a.length; i++){
if(a[i]!=0) {
if(r==0) r = java.lang.Math.abs(a[i]);
else r = inv_gcd(r, java.lang.Math.abs(a[i]))[0];
}
}
return r;
}
public static long lcm(long... a){
if(a.length == 0) return 0;
long r = java.lang.Math.abs(a[0]);
for(int i=1; i<a.length; i++){
r = r / gcd(r,java.lang.Math.abs(a[i])) * java.lang.Math.abs(a[i]);
}
return r;
}
public static long pow_mod(long x, long n, int m){
assert n >= 0;
assert m >= 1;
if(m == 1)return 0L;
x = safe_mod(x, m);
long ans = 1L;
while(n > 0){
if((n&1) == 1) ans = (ans * x) % m;
x = (x*x) % m;
n >>>= 1;
}
return ans;
}
public static long[] crt(long[] r, long[] m){
assert(r.length == m.length);
int n = r.length;
long r0=0, m0=1;
for(int i=0; i<n; i++){
assert(1 <= m[i]);
long r1 = safe_mod(r[i], m[i]), m1 = m[i];
if(m0 < m1){
long tmp = r0; r0 = r1; r1 = tmp;
tmp = m0; m0 = m1; m1 = tmp;
}
if(m0%m1 == 0){
if(r0%m1 != r1) return new long[]{0,0};
continue;
}
long[] ig = inv_gcd(m0, m1);
long g = ig[0], im = ig[1];
long u1 = m1/g;
if((r1-r0)%g != 0) return new long[]{0,0};
long x = (r1-r0) / g % u1 * im % u1;
r0 += x * m0;
m0 *= u1;
if(r0<0) r0 += m0;
//System.err.printf("%d %d\n", r0, m0);
}
return new long[]{r0, m0};
}
public static long floor_sum(long n, long m, long a, long b){
long ans = 0;
if(a >= m){
ans += (n-1) * n * (a/m) / 2;
a %= m;
}
if(b >= m){
ans += n * (b/m);
b %= m;
}
long y_max = (a*n+b) / m;
long x_max = y_max * m - b;
if(y_max == 0) return ans;
ans += (n - (x_max+a-1)/a) * y_max;
ans += floor_sum(y_max, a, m, (a-x_max%a)%a);
return ans;
}
public static java.util.ArrayList<Long> divisors(long n){
java.util.ArrayList<Long> divisors = new ArrayList<>();
java.util.ArrayList<Long> large = new ArrayList<>();
for(long i=1; i*i<=n; i++) if(n%i==0){
divisors.add(i);
if(i*i<n) large.add(n/i);
}
for(int p=large.size()-1; p>=0; p--){
divisors.add(large.get(p));
}
return divisors;
}
}
并查集
class DSU {
private int n;
private int[] parentOrSize;
public DSU(int n) {
this.n = n;
this.parentOrSize = new int[n];
java.util.Arrays.fill(parentOrSize, -1);
}
int merge(int a, int b) {
if (!(0 <= a && a < n))
throw new IndexOutOfBoundsException("a=" + a);
if (!(0 <= b && b < n))
throw new IndexOutOfBoundsException("b=" + b);
int x = leader(a);
int y = leader(b);
if (x == y) return x;
if (-parentOrSize[x] < -parentOrSize[y]) {
int tmp = x;
x = y;
y = tmp;
}
parentOrSize[x] += parentOrSize[y];
parentOrSize[y] = x;
return x;
}
boolean same(int a, int b) {
if (!(0 <= a && a < n))
throw new IndexOutOfBoundsException("a=" + a);
if (!(0 <= b && b < n))
throw new IndexOutOfBoundsException("b=" + b);
return leader(a) == leader(b);
}
int leader(int a) {
if (parentOrSize[a] < 0) {
return a;
} else {
parentOrSize[a] = leader(parentOrSize[a]);
return parentOrSize[a];
}
}
int size(int a) {
if (!(0 <= a && a < n))
throw new IndexOutOfBoundsException("" + a);
return -parentOrSize[leader(a)];
}
java.util.ArrayList<java.util.ArrayList<Integer>> groups() {
int[] leaderBuf = new int[n];
int[] groupSize = new int[n];
for (int i = 0; i < n; i++) {
leaderBuf[i] = leader(i);
groupSize[leaderBuf[i]]++;
}
java.util.ArrayList<java.util.ArrayList<Integer>> result = new java.util.ArrayList<>(n);
for (int i = 0; i < n; i++) {
result.add(new java.util.ArrayList<>(groupSize[i]));
}
for (int i = 0; i < n; i++) {
result.get(leaderBuf[i]).add(i);
}
result.removeIf(java.util.ArrayList::isEmpty);
return result;
}
}
树状数组
class FenwickTree{
private int _n;
private long[] data;
public FenwickTree(int n){
this._n = n;
data = new long[n];
}
public FenwickTree(long[] data) {
this(data.length);
build(data);
}
public void set(int p, long x){
add(p, x - get(p));
}
public void add(int p, long x){
assert(0<=p && p<_n);
p++;
while(p<=_n){
data[p-1] += x;
p += p&-p;
}
}
public long sum(int l, int r){
assert(0<=l && l<=r && r<=_n);
return sum(r)-sum(l);
}
public long get(int p){
return sum(p, p+1);
}
private long sum(int r){
long s = 0;
while(r>0){
s += data[r-1];
r -= r&-r;
}
return s;
}
private void build(long[] dat) {
System.arraycopy(dat, 0, data, 0, _n);
for (int i=1; i<=_n; i++) {
int p = i+(i&-i);
if(p<=_n){
data[p-1] += data[i-1];
}
}
}
}
最大流
class MaxFlow {
private static final class InternalCapEdge {
final int to;
final int rev;
long cap;
InternalCapEdge(int to, int rev, long cap) { this.to = to; this.rev = rev; this.cap = cap; }
}
public static final class CapEdge {
public final int from, to;
public final long cap, flow;
CapEdge(int from, int to, long cap, long flow) { this.from = from; this.to = to; this.cap = cap; this.flow = flow; }
@Override
public boolean equals(Object o) {
if (o instanceof CapEdge) {
CapEdge e = (CapEdge) o;
return from == e.from && to == e.to && cap == e.cap && flow == e.flow;
}
return false;
}
}
private static final class IntPair {
final int first, second;
IntPair(int first, int second) { this.first = first; this.second = second; }
}
static final long INF = Long.MAX_VALUE;
private final int n;
private final java.util.ArrayList<IntPair> pos;
private final java.util.ArrayList<InternalCapEdge>[] g;
@SuppressWarnings("unchecked")
public MaxFlow(int n) {
this.n = n;
this.pos = new java.util.ArrayList<>();
this.g = new java.util.ArrayList[n];
for (int i = 0; i < n; i++) {
this.g[i] = new java.util.ArrayList<>();
}
}
public int addEdge(int from, int to, long cap) {
rangeCheck(from, 0, n);
rangeCheck(to, 0, n);
nonNegativeCheck(cap, "Capacity");
int m = pos.size();
pos.add(new IntPair(from, g[from].size()));
int fromId = g[from].size();
int toId = g[to].size();
if (from == to) toId++;
g[from].add(new InternalCapEdge(to, toId, cap));
g[to].add(new InternalCapEdge(from, fromId, 0L));
return m;
}
private InternalCapEdge getInternalEdge(int i) {
return g[pos.get(i).first].get(pos.get(i).second);
}
private InternalCapEdge getInternalEdgeReversed(InternalCapEdge e) {
return g[e.to].get(e.rev);
}
public CapEdge getEdge(int i) {
int m = pos.size();
rangeCheck(i, 0, m);
InternalCapEdge e = getInternalEdge(i);
InternalCapEdge re = getInternalEdgeReversed(e);
return new CapEdge(re.to, e.to, e.cap + re.cap, re.cap);
}
public CapEdge[] getEdges() {
CapEdge[] res = new CapEdge[pos.size()];
java.util.Arrays.setAll(res, this::getEdge);
return res;
}
public void changeEdge(int i, long newCap, long newFlow) {
int m = pos.size();
rangeCheck(i, 0, m);
nonNegativeCheck(newCap, "Capacity");
if (newFlow > newCap) {
throw new IllegalArgumentException(
String.format("Flow %d is greater than the capacity %d.", newCap, newFlow)
);
}
InternalCapEdge e = getInternalEdge(i);
InternalCapEdge re = getInternalEdgeReversed(e);
e.cap = newCap - newFlow;
re.cap = newFlow;
}
public long maxFlow(int s, int t) {
return flow(s, t, INF);
}
public long flow(int s, int t, long flowLimit) {
rangeCheck(s, 0, n);
rangeCheck(t, 0, n);
long flow = 0L;
int[] level = new int[n];
int[] que = new int[n];
int[] iter = new int[n];
while (flow < flowLimit) {
bfs(s, t, level, que);
if (level[t] < 0) break;
java.util.Arrays.fill(iter, 0);
while (flow < flowLimit) {
long d = dfs(t, s, flowLimit - flow, iter, level);
if (d == 0) break;
flow += d;
}
}
return flow;
}
private void bfs(int s, int t, int[] level, int[] que) {
java.util.Arrays.fill(level, -1);
int hd = 0, tl = 0;
que[tl++] = s;
level[s] = 0;
while (hd < tl) {
int u = que[hd++];
for (InternalCapEdge e : g[u]) {
int v = e.to;
if (e.cap == 0 || level[v] >= 0) continue;
level[v] = level[u] + 1;
if (v == t) return;
que[tl++] = v;
}
}
}
private long dfs(int cur, int s, long flowLimit, int[] iter, int[] level) {
if (cur == s) return flowLimit;
long res = 0;
int curLevel = level[cur];
for (int itMax = g[cur].size(); iter[cur] < itMax; iter[cur]++) {
int i = iter[cur];
InternalCapEdge e = g[cur].get(i);
InternalCapEdge re = getInternalEdgeReversed(e);
if (curLevel <= level[e.to] || re.cap == 0) continue;
long d = dfs(e.to, s, Math.min(flowLimit - res, re.cap), iter, level);
if (d <= 0) continue;
e.cap += d;
re.cap -= d;
res += d;
if (res == flowLimit) break;
}
return res;
}
public boolean[] minCut(int s) {
rangeCheck(s, 0, n);
boolean[] visited = new boolean[n];
int[] stack = new int[n];
int ptr = 0;
stack[ptr++] = s;
visited[s] = true;
while (ptr > 0) {
int u = stack[--ptr];
for (InternalCapEdge e : g[u]) {
int v = e.to;
if (e.cap > 0 && !visited[v]) {
visited[v] = true;
stack[ptr++] = v;
}
}
}
return visited;
}
private void rangeCheck(int i, int minInclusive, int maxExclusive) {
if (i < 0 || i >= maxExclusive) {
throw new IndexOutOfBoundsException(
String.format("Index %d out of bounds for length %d", i, maxExclusive)
);
}
}
private void nonNegativeCheck(long cap, String attribute) {
if (cap < 0) {
throw new IllegalArgumentException(
String.format("%s %d is negative.", attribute, cap)
);
}
}
}
最小费用流
class MinCostFlow {
private static final class InternalWeightedCapEdge {
final int to, rev;
long cap;
final long cost;
InternalWeightedCapEdge(int to, int rev, long cap, long cost) { this.to = to; this.rev = rev; this.cap = cap; this.cost = cost; }
}
public static final class WeightedCapEdge {
public final int from, to;
public final long cap, flow, cost;
WeightedCapEdge(int from, int to, long cap, long flow, long cost) { this.from = from; this.to = to; this.cap = cap; this.flow = flow; this.cost = cost; }
@Override
public boolean equals(Object o) {
if (o instanceof WeightedCapEdge) {
WeightedCapEdge e = (WeightedCapEdge) o;
return from == e.from && to == e.to && cap == e.cap && flow == e.flow && cost == e.cost;
}
return false;
}
}
private static final class IntPair {
final int first, second;
IntPair(int first, int second) { this.first = first; this.second = second; }
}
public static final class FlowAndCost {
public final long flow, cost;
FlowAndCost(long flow, long cost) { this.flow = flow; this.cost = cost; }
@Override
public boolean equals(Object o) {
if (o instanceof FlowAndCost) {
FlowAndCost c = (FlowAndCost) o;
return flow == c.flow && cost == c.cost;
}
return false;
}
}
static final long INF = Long.MAX_VALUE;
private final int n;
private final java.util.ArrayList<IntPair> pos;
private final java.util.ArrayList<InternalWeightedCapEdge>[] g;
@SuppressWarnings("unchecked")
public MinCostFlow(int n) {
this.n = n;
this.pos = new java.util.ArrayList<>();
this.g = new java.util.ArrayList[n];
for (int i = 0; i < n; i++) {
this.g[i] = new java.util.ArrayList<>();
}
}
public int addEdge(int from, int to, long cap, long cost) {
rangeCheck(from, 0, n);
rangeCheck(to, 0, n);
nonNegativeCheck(cap, "Capacity");
nonNegativeCheck(cost, "Cost");
int m = pos.size();
pos.add(new IntPair(from, g[from].size()));
int fromId = g[from].size();
int toId = g[to].size();
if (from == to) toId++;
g[from].add(new InternalWeightedCapEdge(to, toId, cap, cost));
g[to].add(new InternalWeightedCapEdge(from, fromId, 0L, -cost));
return m;
}
private InternalWeightedCapEdge getInternalEdge(int i) {
return g[pos.get(i).first].get(pos.get(i).second);
}
private InternalWeightedCapEdge getInternalEdgeReversed(InternalWeightedCapEdge e) {
return g[e.to].get(e.rev);
}
public WeightedCapEdge getEdge(int i) {
int m = pos.size();
rangeCheck(i, 0, m);
InternalWeightedCapEdge e = getInternalEdge(i);
InternalWeightedCapEdge re = getInternalEdgeReversed(e);
return new WeightedCapEdge(re.to, e.to, e.cap + re.cap, re.cap, e.cost);
}
public WeightedCapEdge[] getEdges() {
WeightedCapEdge[] res = new WeightedCapEdge[pos.size()];
java.util.Arrays.setAll(res, this::getEdge);
return res;
}
public FlowAndCost minCostMaxFlow(int s, int t) {
return minCostFlow(s, t, INF);
}
public FlowAndCost minCostFlow(int s, int t, long flowLimit) {
return minCostSlope(s, t, flowLimit).getLast();
}
java.util.LinkedList<FlowAndCost> minCostSlope(int s, int t) {
return minCostSlope(s, t, INF);
}
public java.util.LinkedList<FlowAndCost> minCostSlope(int s, int t, long flowLimit) {
rangeCheck(s, 0, n);
rangeCheck(t, 0, n);
if (s == t) {
throw new IllegalArgumentException(
String.format("%d and %d is the same vertex.", s, t)
);
}
long[] dual = new long[n];
long[] dist = new long[n];
int[] pv = new int[n];
int[] pe = new int[n];
boolean[] vis = new boolean[n];
long flow = 0;
long cost = 0, prev_cost = -1;
java.util.LinkedList<FlowAndCost> result = new java.util.LinkedList<>();
result.addLast(new FlowAndCost(flow, cost));
while (flow < flowLimit) {
if (!dualRef(s, t, dual, dist, pv, pe, vis)) break;
long c = flowLimit - flow;
for (int v = t; v != s; v = pv[v]) {
c = Math.min(c, g[pv[v]].get(pe[v]).cap);
}
for (int v = t; v != s; v = pv[v]) {
InternalWeightedCapEdge e = g[pv[v]].get(pe[v]);
e.cap -= c;
g[v].get(e.rev).cap += c;
}
long d = -dual[s];
flow += c;
cost += c * d;
if (prev_cost == d) {
result.removeLast();
}
result.addLast(new FlowAndCost(flow, cost));
prev_cost = cost;
}
return result;
}
private boolean dualRef(int s, int t, long[] dual, long[] dist, int[] pv, int[] pe, boolean[] vis) {
java.util.Arrays.fill(dist, INF);
java.util.Arrays.fill(pv, -1);
java.util.Arrays.fill(pe, -1);
java.util.Arrays.fill(vis, false);
class State implements Comparable<State> {
final long key;
final int to;
State(long key, int to) { this.key = key; this.to = to; }
public int compareTo(State q) {
return key > q.key ? 1 : -1;
}
};
java.util.PriorityQueue<State> pq = new java.util.PriorityQueue<>();
dist[s] = 0;
pq.add(new State(0L, s));
while (pq.size() > 0) {
int v = pq.poll().to;
if (vis[v]) continue;
vis[v] = true;
if (v == t) break;
for (int i = 0, deg = g[v].size(); i < deg; i++) {
InternalWeightedCapEdge e = g[v].get(i);
if (vis[e.to] || e.cap == 0) continue;
long cost = e.cost - dual[e.to] + dual[v];
if (dist[e.to] - dist[v] > cost) {
dist[e.to] = dist[v] + cost;
pv[e.to] = v;
pe[e.to] = i;
pq.add(new State(dist[e.to], e.to));
}
}
}
if (!vis[t]) {
return false;
}
for (int v = 0; v < n; v++) {
if (!vis[v]) continue;
dual[v] -= dist[t] - dist[v];
}
return true;
}
private void rangeCheck(int i, int minInlusive, int maxExclusive) {
if (i < 0 || i >= maxExclusive) {
throw new IndexOutOfBoundsException(
String.format("Index %d out of bounds for length %d", i, maxExclusive)
);
}
}
private void nonNegativeCheck(long cap, java.lang.String attribute) {
if (cap < 0) {
throw new IllegalArgumentException(
String.format("%s %d is negative.", attribute, cap)
);
}
}
}
LCT
public class LinkCutTree {
static class Node {
int p, count;
int[] c;
long val, sum, min, max, assign, delta;
boolean flip;
public Node(long val){
c = new int[2];
this.count = 1; this.assign = -1;
this.val = this.sum = this.min = this.max = val;
}
}
static class LCT {
Node[] T;
public LCT(int N, long[] A){
T = new Node[N+1];
for(int i = 0; i <= N; i++)
T[i] = new Node(A[i]);
}
private int count(int x){
return x == 0 ? 0 : T[x].count;
}
private long sum(int x){
return x == 0 ? 0 : T[x].sum;
}
private long min(int x){
return x == 0 ? Long.MAX_VALUE : T[x].min;
}
private long max(int x){
return x == 0 ? Long.MIN_VALUE : T[x].max;
}
private void increment(int x, long val){
if(x == 0) return;
T[x].delta += val;
T[x].val += val;
T[x].min += val;
T[x].max += val;
T[x].sum += val * T[x].count;
}
private void assign(int x, long val){
if(x == 0) return;
T[x].delta = 0;
T[x].val = val;
T[x].min = val;
T[x].max = val;
T[x].assign = val;
T[x].sum = val * T[x].count;
}
private void flip(int x){
if(x == 0) return;
else T[x].flip = !T[x].flip;
}
private boolean notRoot(int x){
return T[x].p != 0 && (T[T[x].p].c[0] == x || T[T[x].p].c[1] == x);
}
private void pullUp(int x){
if(x == 0) return;
T[x].count = 1 + count(T[x].c[0]) + count(T[x].c[1]);
T[x].sum = T[x].val + sum(T[x].c[0]) + sum(T[x].c[1]);
T[x].min = Math.min(T[x].val, Math.min(min(T[x].c[0]), min(T[x].c[1])));
T[x].max = Math.max(T[x].val, Math.max(max(T[x].c[0]), max(T[x].c[1])));
}
private void pushDown(int x){
if(x != 0 && T[x].flip){
int temp = T[x].c[0];
T[x].c[0] = T[x].c[1];
T[x].c[1] = temp;
flip(T[x].c[0]);
flip(T[x].c[1]);
T[x].flip = false;
}
if(x != 0 && T[x].assign != -1){
assign(T[x].c[0], T[x].assign);
assign(T[x].c[1], T[x].assign);
T[x].assign = -1;
}
if(x != 0){
increment(T[x].c[0], T[x].delta);
increment(T[x].c[1], T[x].delta);
T[x].delta = 0;
}
}
private void rotate(int x) {
int p = T[x].p;
int d = (p != 0 && T[p].c[0] == x ? 1 : 0);
T[x].p = T[p].p;
if(notRoot(p)){
if(T[T[x].p].c[0] == p) T[T[x].p].c[0] = x;
else T[T[x].p].c[1] = x;
}
T[p].c[(d^1)] = T[x].c[d];
if(T[p].c[(d^1)] != 0)
T[T[p].c[(d^1)]].p = p;
T[x].c[d] = p;
T[p].p = x;
pullUp(p);
}
private void splay(int x){
while(notRoot(x)){
int p = T[x].p;
if(notRoot(p))
pushDown(T[p].p);
pushDown(p);
pushDown(x);
if(notRoot(p)){
boolean d1 = (p != 0 && T[p].p != 0 && T[T[p].p].c[0] == p);
boolean d2 = (p != 0 && T[p].c[0] == x);
if(d1^d2) rotate(x);
else rotate(p);
}
rotate(x);
}
pushDown(x);
pullUp(x);
}
public int access(int _u){
int v = 0;
for(int u = _u; u != 0; u = T[u].p){
splay(u);
T[u].c[1] = v;
v = u;
}
splay(_u);
return v;
}
public void makeRoot(int u){
access(u);
T[u].flip = true;
}
public void link(int u, int v){
makeRoot(u);
T[u].p = v;
}
public void cut(int u, int v){
makeRoot(u);
access(v);
if(T[v].c[0] != 0)
T[T[v].c[0]].p = 0;
T[v].c[0] = 0;
pullUp(v);
}
private int getPath(int u, int v){
makeRoot(u); access(v); return v;
}
public void pathModify(int x, int y, long z){
int u = getPath(x, y);
assign(u, z);
}
public void pathIncrement(int x, int y, long z){
int u = getPath(x, y);
increment(u, z);
}
public long pathMin(int x, int y){
int u = getPath(x, y);
return T[u].min;
}
public long pathMax(int x, int y){
int u = getPath(x, y);
return T[u].max;
}
public long pathSum(int x, int y){
int u = getPath(x, y);
return T[u].sum;
}
public void changeParent(int root, int x, int y){
if(x == LCA(root, x, y)) return;
cut(root, x);
link(x, y);
}
public int LCA(int root, int x, int y){
makeRoot(root);
access(x);
return access(y);
}
}
}