题目描述
难度分:$1400$
输入$n(1 \leq n \leq 2 \times 10^5)$和长为$n$的字符串数组 $a$,每个字符串的长度均在$[1,5]$内,且只包含数字字符1
~9
。
定义【好字符串】满足:偶数长度,且左半的数字之和等于右半的数字之和。
输出有多少对$(i,j)$满足$a[i]+a[j]$是好字符串。
注意$i$和$j$的大小没有限制,$i \lt j$,$i=j$,$i \gt j$都可以。
输入样例$1$
10
5 93746 59 3746 593 746 5937 46 59374 6
输出样例$1$
20
输入样例$2$
5
2 22 222 2222 22222
输出样例$2$
13
输入样例$3$
3
1 1 1
输出样例$3$
9
算法
枚举
首先可以想到,由于好串的长度为偶数,所以必然是偶数长度的串相互组合,奇数长度的串相互组合,我们可以按长度分组计算答案,两个组的答案加起来就是最终答案。而每个数字字符串的长度很短,因此比较容易想到以它为突破口,但是在枚举的时候需要分为以下三种情况:
-
枚举好串的长度为$2,4,6,8,10$,对于每个长度$len$,枚举前一个数字$nums[i],i \in [0,n)$,由于$nums[i].size() \gt \frac{len}{2}$时我们可以确定好串前一半的数位和,更加方便推测后面那个串需要满足的条件,因此可以假设前一个串的长度是更大的。对于$nums[i]$,计算出前$\frac{len}{2}$的数位和$s_1$,那么后一个数字的长度就应该是$olen=len-nums[i].size()$,它的数位和应该是$s_2=s_1 \times 2-digitSum(nums[i])$,其中$digitSum(x)$表示$x$的数位和。知道有多少个$nums[j]$满足长度为$olen$,数位和为$s_2$即可,这可以预处理出来,存储到一个哈希表$counter$中,$counter[len][sum]$表示长度为$len$,数位和为$sum$的数字串个数(累加到答案上即可)。而为了快速得到某个数字串$nums[i]$前$x$位的数位和,我们还可以预处理出一个前缀和数组$s$,$s[i][j]$表示$nums[i]$前$j+1$位的数位和,$j$从$0$开始取值。
-
同理可以计算出后一个数的长度$\gt \frac{len}{2}$的情况,只不过$s$数组需要变成后缀和数组。
-
最后还需要考虑前后两个数字长度相等的情况(此时还可以自己和自己匹配),我们遍历哈希表$counter$,对于每个长度$len$,再遍历$counter[len]$的键值对$(sum,cnt)$,把数位和相同的组合起来,即把$cnt^2$累加到答案上。
复杂度分析
时间复杂度
遍历的长度为$10$以内的正偶数,可以看成常数$O(L)$。而对于每个长度,都需要遍历原始的字符串数组,时间复杂度为$O(n)$。预处理出$counter$和$s$需要遍历原始字符串数组中每个字符串的所有字符,时间复杂度为$O(nL)$。因此,整个算法的时间复杂度就是$O(nL)$。
空间复杂度
将字符串按照长度的奇偶分组需要额外$O(n)$的空间,二维前缀和数组$s$的空间为$O(nL)$,哈希表$counter$在极限情况下,数位和可以达到$90$,数位长度为$[1,10]$,相比$s$数组不值一提。因此,整个算法的额外空间复杂度为$O(nL)$。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
LL get(vector<string>& nums) {
int n = nums.size();
vector<vector<int>> s(n, vector<int>());
vector<unordered_map<int, int>> counter(11, unordered_map<int, int>());
for(int i = 0; i < n; i++) {
int bl = nums[i].size();
s[i].resize(bl);
for(int j = 0; j < bl; j++) {
s[i][j] = (j? s[i][j - 1]: 0) + (nums[i][j] - '0');
}
int tot = s[i].back();
counter[bl][tot]++;
}
LL ans = 0;
for(int len = 1; len <= 5; len++) {
for(auto&[sum, cnt]: counter[len]) {
ans += 1LL*cnt*cnt;
}
}
// 枚举前一半
for(int len = 2; len <= 10; len += 2) {
int half = len>>1;
for(int i = 0; i < n; i++) {
int bl = nums[i].size();
if(bl > len) continue;
if(half < bl) {
int olen = len - bl; // 另一个数字的长度
int s1 = s[i][half - 1], s2 = s1*2 - s[i].back();
// 要找一个长度为olen,且数位和为s2的数字和i拼接
ans += counter[olen][s2];
}
}
}
for(int i = 1; i <= 10; i++) {
counter[i].clear();
}
for(int i = 0; i < n; i++) {
int bl = nums[i].size();
for(int j = bl - 1; j >= 0; j--) {
s[i][j] = (j + 1 < bl? s[i][j + 1]: 0) + (nums[i][j] - '0');
}
counter[bl][s[i][0]]++;
}
// 枚举后一半
for(int len = 2; len <= 10; len += 2) {
int half = len>>1;
for(int i = 0; i < n; i++) {
int bl = nums[i].size();
if(bl > len) continue;
if(half < bl) {
int olen = len - bl; // 另一个数字的长度
int s1 = s[i][bl - half], s2 = s1*2 - s[i][0];
// 要找一个长度为olen,且数位和为s2的数字和i拼接
ans += counter[olen][s2];
}
}
}
return ans;
}
int main() {
int n;
cin >> n;
vector<string> odd, even;
for(int i = 0; i < n; i++) {
string num;
cin >> num;
if(num.size()&1) {
odd.push_back(num);
}else {
even.push_back(num);
}
}
printf("%lld\n", get(odd) + get(even));
return 0;
}