题目描述
难度分:$2100$
输入$n(1 \leq n \leq 3 \times 10^5)$、$m(1 \leq m \leq 3 \times 10^5)$和长为$n$的数组$a(1 \leq a[i] \leq m)$。
定义振荡序列为形如$[3,1,3]$,$[3,2,3,2]$,$[1]$这种最多由两种元素组成的,相邻元素不同的序列。
定义$f(x,y)$为$a$中最长振荡子序列的长度,其中子序列的每个数要么是$x$,要么是$y$。
注:子序列不一定连续。
输出所有$f(x,y)$的和,其中$1 \leq x \lt y \leq m$。
输入样例$1$
5 4
3 2 1 3 2
输出样例$1$
13
输入样例$2$
3 3
1 1 1
输出样例$2$
2
算法
贡献法
这种题目看着就很贡献法,因为不可能枚举所有的数对$(x,y)$,但是本题求贡献的方式还挺抽象的。对于一个$a[i]$,我们计算它对答案的贡献分为以下两种情况:
- $a[i]$是第一次出现:它和一个不在$a$数组中的数贡献就是$1$,设总共有$c$个$[1,m]$内的数不在$a$数组中,总的贡献就是$c$。它和在$a$中的数的贡献又分为两种情况:$(1)$ 如果这个数第一次出现在$a[i]$的左边,那么$a[i]$直接接在后面形成更长的振荡序列即可,为它们形成的最长振荡序列贡献$1$。$(2)$ 如果这个数第一次出现在$a[i]$的右边,那么$a[i]$就在它的前面,也为它们形成的最长振荡序列贡献$1$。所以一共有$m-1-c$个这样的数,就能有$m-1-c$的贡献。这两种情况的总贡献加起来就是$m-1-c+c=m-1$。
- $a[i]$不是第一次出现,假设它上一次出现的位置是$pre[a[i]]$(可以用一个哈希表来维护),那么$(pre[a[i]],i)$中间有多少个不同的数就能贡献多少,假设一共有$k$个不同的数。这$k$个数都可以接在上一个$a[i]$的后面形成振荡序列,为它们的长度都增加$1$,$k$个序列就增加$k$的贡献。
至于如何统计区间内不同数的个数?可以用主席树,也可以用树状数组维护每个数的最新下标,即每个数值$a[i]$最后一次出现的时候在树状数组的$i$位置加$1$,之前出现数值$a[i]$加$1$的下标$j$上加$-1$撤销掉。查询区间内下标个数,这样对应的就是区间中有多少个不同的数,详见代码。
复杂度分析
时间复杂度
遍历$i \in [1,n]$时计算答案,而每个下标$i$都有一次对树状数组的操作,时间复杂度为$O(log_2n)$。因此,整个算法的时间复杂度为$O(nlog_2n)$。
空间复杂度
空间消耗主要是两部分:哈希表$pre$,在最差情况下需要存$O(n)$个值的最早出现位置,因此空间复杂度为$O(n)$;树状数组$tr$,其长度也是$O(n)$。因此,整个算法的额外空间复杂度为$O(n)$。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
int n, m;
class Fenwick {
public:
explicit Fenwick(int n): sums_(n + 1) {}
int lowbit(int x) {
return x&-x;
}
void add(int idx, int val) {
for(; idx < sums_.size(); idx += lowbit(idx)) {
sums_[idx] += val;
}
}
int query(int idx) {
int ans = 0;
for(; idx > 0; idx -= lowbit(idx)) {
ans += sums_[idx];
}
return ans;
}
int query(int left, int right) {
if(left > right) return 0;
return query(right) - query(left - 1);
}
private:
vector<int> sums_;
};
int main() {
scanf("%d%d", &n, &m);
LL ans = 0;
unordered_map<int, int> pre;
Fenwick tr(n + 1);
for(int i = 1; i <= n; i++) {
int a;
scanf("%d", &a);
if(pre.count(a)) {
int j = pre[a];
ans += tr.query(j + 1, i - 1);
tr.add(j, -1);
}else {
// 首次出现
ans += m - 1;
}
tr.add(i, 1);
pre[a] = i;
}
printf("%lld\n", ans);
return 0;
}