class DLX(object):
def __init__(self, n, MAXN):# n: 列数, MAXN: 1的最大数目
self.idx = n + 1
self.l, self.r, self.u, self.d, self.row, self.col = \
[list(range(MAXN + n + 1)) for _ in range(6)]
for i in range(n + 1):
self.l[i], self.r[i], self.row[i] = i - 1, i + 1, 0
self.u[i] = self.d[i] = self.col[i] = i
self.l[0], self.r[n] = n, 0
self.size = [0] * (n + 1); self.size[0] = n + 2# 用于剪枝记录每列1的个数
def insert(self, i, j):
l, r, u, d, size, col, row = \
self.l, self.r, self.u, self.d, self.size, self.col, self.row
cur = self.idx
if row[cur - 1] == i:
fr, la = r[cur - 1], cur - 1
l[cur], r[cur], l[fr], r[la] = la, fr, cur, cur
fr, la = j, u[j]
u[cur], d[cur], u[fr], d[la] = la, fr, cur, cur
row[cur], col[cur], size[j], self.idx = i, j, size[j] + 1, self.idx + 1
def remove(self, c):
l, r, u, d, size, col, row = \
self.l, self.r, self.u, self.d, self.size, self.col, self.row
r[l[c]], l[r[c]] = r[c], l[c]
i = d[c]; j = r[i]
while i != c:
while j != i:
d[u[j]] = d[j]; u[d[j]] = u[j]; size[col[j]] -= 1; j = r[j]
i = d[i]; j = r[i]
def recover(self, c):
l, r, u, d, size, col, row = \
self.l, self.r, self.u, self.d, self.size, self.col, self.row
i = u[c]; j = l[i]
while i != c:
while j != i:
d[u[j]] = j; u[d[j]] = j; size[col[j]] += 1; j = l[j]
i = u[i]; j = l[i]
r[l[c]] = l[r[c]] = c
def __dancing(self, ans):
l, r, u, d, size, col, row = \
self.l, self.r, self.u, self.d, self.size, self.col, self.row
if not r[0]: return 1
c, i = r[0], r[0]
while i != 0: #寻找1最少的列
if size[i] < size[c]: c = i
i = r[i]
self.remove(c)
i = d[c]
while i != c:
#删除
j = r[i]
while j != i: self.remove(col[j]); j = r[j]
ans.append(row[i])
if self.__dancing(ans): return 1
#恢复
j = l[i]
while j != i: self.recover(col[j]); j = l[j]
ans.pop()
i = d[i]
self.recover(c)
return 0
def dancing(self):
ans = []
if self.__dancing(ans): return ans
else: return []