class Solution:
def maxSumAfterPartitioning(self, a: List[int], k: int) -> int:
n = len(a)
dp = [0 for _ in range(n+1)]
for i in range(1,n+1):
curM = a[i-1]
j = i - 1
while j >= 0 and i-j <= k:
curM = max(curM, a[j])
dp[i] = max(curM * (i-j) + dp[j], dp[i])
j -= 1
return dp[n]