原文地址:
https://www.cnblogs.com/fallingdust/p/15153898.html
安利我的博客:
https://www.cnblogs.com/fallingdust
虫食算
题意
给出三个数$a,b,c$,每个数的位数为n(可能有前导$0$),这三个数是$n$进制数,求出唯一解可以满足$a+b=c$。(即求出这$1到n$个位置的唯一量,保证$0-(n-1)$有且仅出现一次)
分析
$n<=26$
- 我看到这个n之后认真想了一下为什么是26,一开始我算了一下26的阶乘:403,291,461,126,605,635,584,000,000 。发现这完全不是用来卡全排列的意思,后来我恍然大悟,这完全是因为数据规模大小差不多合适,而且只有26个字母......
- 然后开始分析题目。正如上一段话中26!的大小,很显然朴素算法爆了(不看也知道<_<),所以针对于搜索,我们思考可不可以进行优化。
搜索
- 先思考怎么搜索:
-
- 枚举每一个空闲字母,从剩下的0到$n-1$中枚举一个没用的给它;
- 直到n个全部枚举结束
- 进行判断,输出。
- 很显然,我们需要进行剪枝,怎么想?
-
- 首先这是加法,不然你是几进制,一次一位加法,你一定最多往前补1,原因:对于每一位,最大为$n-1$,两个相加:$2n-2$,他只能向前+1,而$2n-2+1$仍旧$<2n$,这很显然。
- 所以我们可以判断:$number[a[i]]+numbei[b[i]]==number[c[i]]$或$number[a[i]]+numbei[b[i]]+1==number[c[i]]$是不是都不成立,如果都不成立,就可以跳过
- 但是又能发现:我们从A搜索到A+n-1,还是会超时,例如:A+n-1是第一位,依次向后,第二行随即构成,然后第三行找一个可以满足的,尽量让A+n-1大,这样我们算法修正次数会阶乘级别递增,会挂(luogu有这种数据),那么怎么优化?
- 整体式子,从后向前按出现次序来枚举,例如:ABCED+BDACE=EBBAA,我们从后向前,依次把DE,C,A,B进行枚举,这样即使我们的枚举出现错误,我们需要退回最后,带来的代价也仅限到当前枚举位数的阶乘,而不是从头到尾再一次重新枚举。同时,我们从后向前,按位来枚举,每一次一列枚举完可以立刻进行$check$,而从前往后,可能真的必须把n个全部填满才可以$check$。样例就是这种情况。
代码实现
第一步,预处理
我们把A,B,C,D…A+n-1用tag[N]给他编号,这样枚举的时候可以直接取,同时编号大小也是按照出现顺序排序的。原因也很简单,我们dfs习惯写法是从0到n-1来顺次写,那么tag按照从后向前顺序进行编号,我们每次按照dfs(Now)的Now位置的tag赋值,可以有效从后向前按出现次序赋值,进行优化。
inline int id(char c) {
return c-'A';
}
void add(int x) {
if(vis[x]==0) {
vis[x]=1;
tag[cnt++]=x;
}
return;
}
scanf("%d",&n);
cin>>s1>>s2>>s3;
for(int i=0;i<n;i++) {
a[i]=id(s1[i]);
b[i]=id(s2[i]);
c[i]=id(s3[i]);
}
for(int i=n-1;i>=0;i--) {
add(a[i]);
add(b[i]);
add(c[i]);
}
memset(num,-1,sizeof num);
\\注意,由于我们数字赋值是从0到n-1,所以num初值为-1
memset(vis,0,sizeof vis);
剪枝&边界到达后的check
详情见上文
bool check() {
int k=0;
for(int i=n-1;i>=0;i--) {
if((num[a[i]]+num[b[i]]+k)%n!=num[c[i]]) return 0;
k=(num[a[i]]+num[b[i]]+k)/n;
//进位
}
return 1;
}
bool check2() {
if(num[a[0]]+num[b[0]]>=n)
return 0;
//首位有进位则退出
for(int i=n-1;i>=0;i--) {
if(num[a[i]]==-1||num[b[i]]==-1||num[c[i]]==-1)
continue;
if((num[a[i]]+num[b[i]])%n!=num[c[i]]&&(num[a[i]]+num[b[i]]+1)%n!=num[c[i]])
return 0;
//每一位进行演算,注意,这个是宽泛的,而非唯一的,所有这个check2只能剪掉一部分,原因是无法判断前一列到底会不会进行进位
}
return 1;
}
总代码
#include <bits/stdc++.h>
using namespace std;
int n,cnt;
int a[27],b[27],c[27];
int num[27],tag[27];
string s1,s2,s3;
bool vis[27];
inline int id(char c) {
return c-'A';
}
void add(int x) {
if(vis[x]==0) {
vis[x]=1;
tag[cnt++]=x;
}
return;
}
bool check() {
int k=0;
for(int i=n-1;i>=0;i--) {
if((num[a[i]]+num[b[i]]+k)%n!=num[c[i]]) return 0;
k=(num[a[i]]+num[b[i]]+k)/n;
}
return 1;
}
bool check2() {
if(num[a[0]]+num[b[0]]>=n)
return 0;
for(int i=n-1;i>=0;i--) {
if(num[a[i]]==-1||num[b[i]]==-1||num[c[i]]==-1)
continue;
if((num[a[i]]+num[b[i]])%n!=num[c[i]]&&(num[a[i]]+num[b[i]]+1)%n!=num[c[i]])
return 0;
}
return 1;
}
void print() {
for(int i=0;i<n-1;i++)
printf("%d ",num[i]);
printf ("%d",num[n-1]);
}
void dfs(int u) {
if(!check2())
return;
if(u==n)
if(check()){
print();
exit(0);
}
for(int i=n-1;i>=0;i--)
if(vis[i]==0) {
num[tag[u]]=i;
vis[i]=1;
dfs(u+1);
num[tag[u]]=-1;
vis[i]=0;
}
return;
}
int main() {
scanf("%d",&n);
cin>>s1>>s2>>s3;
for(int i=0;i<n;i++) {
a[i]=id(s1[i]);
b[i]=id(s2[i]);
c[i]=id(s3[i]);
}
for(int i=n-1;i>=0;i--) {
add(a[i]);
add(b[i]);
add(c[i]);
}
memset(num,-1,sizeof num);
memset(vis,0,sizeof vis);
dfs(0);
return 0;
}