AcWing 241. 楼兰图腾
原题链接
简单
作者:
皓首不倦
,
2020-08-27 15:48:20
,
所有人可见
,
阅读 384
'''
本质上是要求任意一个位置k,其左边比该位置小的数的数量和右边比该位置小的数值的数量
可以先从左向右枚举每一个数值val,经历过的数值就在序列A的对应位置加1,A[1: val+1]
总和就是val左边比其大的数字个数,同理A[val+1: n+1]总和就是val左边比其大的数字个数
每次枚举会导致A序列中一个数值更改,然后又需要在更改后快速求区间和,正好是树状数组
的特性,所以使用树状数组分两次分别正向和逆向扫描一遍数组,就可以得出每一个位置的
数值其左边和右边小于该数值的数的个数,然后用乘法原则累加结果即可,对于另外一个问题
大于和小于关系对调用一样的方法求解即可
'''
import sys
class FenwickTree:
def __init__(self, data):
if len(data) == 0:
raise ValueError('data length is zero!')
self.n = len(data)
prefix_sum = [0] * (self.n + 1)
self.sum_seg = [0] * (self.n + 1) # 分段存储的区间和 sum_seg[x]表示的是x位置结尾,长度为x&(-x)的区间中的数值的和
# 实际存储时候错一位存储,外部数据的下标从0开始,但是树状数组的下标是从1开始的
i = 1
for val in data:
prefix_sum[i] = prefix_sum[i-1] + val
self.sum_seg[i] = prefix_sum[i] - prefix_sum[i & (i-1)]
i += 1
# 获取x位置的原始数值
def getOrigValue(self, x):
x += 1
if x < 1 or x > self.n:
raise IndexError(f'error idx = {x}')
ans = self.sum_seg[x]
lca = x & (x-1) # 最近公共祖先
x -= 1
while x != lca:
ans -= self.sum_seg[x]
x &= (x-1)
return ans
# 获取区间[0, end]的数值和
def __getSumWithEnd(self, end): # 这里end是内部坐标,不是外部坐标
ans = 0
while end:
ans += self.sum_seg[end]
end &= (end - 1)
return ans
# 获取区间的数据和
def getSum(self, start, end):
start, end = start+1, end+1
if not (end >= start and start >= 1 and end <= self.n):
raise IndexError(f'bad range {(start-1, end-1)}')
if start == 1:
return self.__getSumWithEnd(end)
return self.__getSumWithEnd(end) - self.__getSumWithEnd(start-1)
# 更新x位置数值
def updateValue(self, x, val):
orig_val = self.getOrigValue(x)
x += 1
delta = val - orig_val
while x <= self.n:
self.sum_seg[x] += delta
x += x & (-x)
n = int(input())
arr = list( map(int, input().split()) )
left_min_cnt = [0] * (n+1) # 左边小于等于x的数值数量
right_min_cnt = [0] * (n+1) # 右边小于等于x的数值数量
left_max_cnt = [0] * (n+1) # 左边大于等于等于x的数值数量
right_max_cnt = [0] * (n+1) # 右边大于等于x的数值数量
tree = FenwickTree([0]*(n+1))
for val in arr:
tree.updateValue(val, 1)
left_min_cnt[val] = tree.getSum(1, val)
left_max_cnt[val] = tree.getSum(val, n)
tree = FenwickTree([0]*(n+1))
for val in reversed(arr):
tree.updateValue(val, 1)
right_min_cnt[val] = tree.getSum(1, val)
right_max_cnt[val] = tree.getSum(val, n)
ans1, ans2 = 0, 0
for i in range(1, n+1):
ans1 += (left_max_cnt[i] - 1) * (right_max_cnt[i] - 1)
ans2 += (left_min_cnt[i] - 1) * (right_min_cnt[i] - 1)
print(ans1, ans2)