博客食用更佳~ https://www.cnblogs.com/czyty114/p/14443656.html
引入
$\;$
假设现在我们得到了一个$n$次多项式$f(x)=a_0+a_1x+a_2x^2+\cdots+a_nx^n$,$n+1$个条件形如$(x_i,y_i)$,表示当$x=x_i$时,多项式的值为$y_i$
求:$a_0,a_1,\cdots,a_n$
对于这类问题,我们当然可以根据$(x_i,y_i)$写出$n+1$个$n+1$次方程组成的线性方程组,然后再$O(n^3)$的时间内用高斯消元求解
但对于$n\leq 2000$这样的范围,高斯消元就不适用了。
接下来介绍的拉格朗日插值法,时间复杂度是$O(n^2)$的
$\;$
拉格朗日插值
$\;$
我们考虑构造出$n+1$个多项式,满足第$i$个多项式只有当自变量取值为$x_i$时,其值为1,否则为0
那么对于第$i$个多项式,其形式即为:
$$fi(k)=\prod_{i\neq j} \frac{k-x_j}{x_i-x_j}$$
显然,上式是满足条件的。
那么对于原多项式,显然:
$$f(k)=\sum_{i=0}^n y_i fi(k)$$
如果题目只是要求这个多项式在给定$k$下的函数值,显然可以$O(n^2)$来解决。
但如果要求每一项系数,仍然是$O(n^3)$
$\;$
特殊情况
$\;$
若给定的$x_i$是连续的数,即:$x_i=i$,我们来看这个东西有什么更好的性质
$fi(k)$可以变为$\prod_{i\neq j} \frac{k-j}{i-j}$
我们把整个柿子抄一遍
$$f(k)=\sum_{i=0}^n y_i \prod_{i\neq j} \frac{k-j}{i-j}$$
设$h(i)=\prod_{j=0}^i (k-j),r(i)=\prod_{j=i}^n (k-j), fac(i)=i!$,那么:
$$f(k)=\sum_{i=0}^n y_i \frac{h(i-1)r(i+1)}{fac(i)fac(n-i)(-1)^{n-i}}$$
那么我们预处理好$h,r,fac$,$f(k)$就可以$O(n)$的算出来了
但是若要求表达式仍是O(n^3)的
$\;$
优化
$\;$
其实也不算是优化,是为了解决另一种更繁琐的情况,若有时候要减少一个或加入一个插值点,即:$(x_i,y_i)$
按原来的式子还必须重新算一遍,如何优化呢?
观察上面的式子:
$f(k)=\sum_{i=0}^n y_i \prod_{i\neq j} \frac{k-x_j}{x_i-x_j}$
我们发现$k-x_j$这里是与$i$无关的,提到前面。
设$g(k)=\prod_{i=0}^n (k-x_i)$
于是原式就变成了:
$f(k)=g(k) \sum_{i=0}^n \frac{y_i}{k-x_i} \prod_{i\neq j} \frac{1}{x_i-x_j}$
设$t(i)= \prod_{i\neq j} \frac{1}{x_i-x_j}$
$f(k)=g(k) \sum_{i=0}^n \frac{y_it_i}{k-x_i}$
我们发现,$t(1),t(2),\cdots,t(n)$是可以$O(n^2)$预处理的,其余用$O(n)$时间即可解决
那么如果我们要加入一个插值点$(x_{n+1},y_{n+1})$,显然只需要把所有的$t(i)$除以$x_i-x_{n+1}$
这样修改的复杂度是$O(n)$的,然后我们再用$O(n)$的时间求值即可
Code
$\;$
求$f(k)$的值。
代码用的是那个支持加入插值点方法(其实第一种也可以做)
#include <bits/stdc++.h>
const int N = 2010, mod = 998244353;
int n, k, g = 1, t[N], x[N], y[N];
int power(int a, int b) {
int ans = 1;
while(b) {
if(b & 1) ans = 1ll * ans * a % mod;
a = 1ll * a * a % mod;
b >>= 1;
}
return ans;
}
int main() {
scanf("%d%d", &n, &k);
for(int i=0;i<n;i++) scanf("%d%d", &x[i], &y[i]);
for(int i=0;i<n;i++) g = 1ll * g * (k - x[i] + mod) % mod;
for(int i=0;i<n;i++) {
int now = 1;
for(int j=0;j<n;j++) {
if(j == i) continue;
now = 1ll * now * (x[i] - x[j] + mod) % mod;
}
t[i] = power(now, mod - 2);
}
int ans = 0;
for(int i=0;i<n;i++) {
int tmp = 1ll * y[i] * t[i] % mod * power(k - x[i] + mod, mod - 2) % mod;
ans = (ans + tmp) % mod;
}
printf("%d", 1ll * ans * g % mod);
return 0;
}