class Solution:
def minFallingPathSum(self, a: List[List[int]]) -> int:
m, n = len(a), len(a[0])
dp = [a[0][i] for i in range(n)]
for i in range(1,m):
r = heapq.nsmallest(2, dp)
for j in range(n):
dp[j] = a[i][j] + (r[0] if dp[j] != r[0] else r[1])
return min(dp)