C++ 版本
#include <stdio.h>
#include <vector>
#include <algorithm>
#include <string.h>
using namespace std;
struct point {
int ii; // 行号, 1开始
int jj; // 列号, 1开始
bool operator < (const point& other) const {
if (this->ii < other.ii) {
return true;
} else if (this->ii > other.ii) {
return false;
}
return this->jj < other.jj;
}
};
const int MAX_NODE_NUM = 10010; // 这个节点个数包含矩阵的1节点个数和DLX中col_num+1个哨兵的个数
const int MAX_COL_NUM = 150; // 最大列的个数
int uu[MAX_NODE_NUM]; // 节点的上指针
int dd[MAX_NODE_NUM]; // 节点的下指针
int ll[MAX_NODE_NUM]; // 节点左指针
int rr[MAX_NODE_NUM]; // 节点右指针
int row[MAX_NODE_NUM]; // 节点的行id, 从1开始
int col[MAX_NODE_NUM]; // 节点的列id, 从1开始
int one_num[MAX_COL_NUM]; // 某一列的1个数
int path_ans[MAX_COL_NUM]; // dfs时候使用的路径栈
int overlap[MAX_COL_NUM]; // 某一列是否被覆盖
point input[MAX_NODE_NUM]; // 输入数据的缓存
class DlxRepeatOverlap {
// c节点所在的列上除了c节点以外,其他的节点都在水平方向断开
bool remove_col(int c) {
for (int node = dd[c]; node != c; node = dd[node]) {
ll[rr[node]] = ll[node];
rr[ll[node]] = rr[node];
}
return true;
}
// c节点所在的列上除了c节点以外,其他的节点都在水平方向接回来
void resume_col(int c) {
for (int node = uu[c]; node != c; node = uu[node]) {
ll[rr[node]] = node;
rr[ll[node]] = node;
}
}
// 估价函数,预估最少还需要选择多少行才能完全覆盖所有列
int score() {
int cnt = 0;
memset(overlap, 0, sizeof(overlap));
for (int i = rr[0]; i; i = rr[i]) {
if (overlap[col[i]])
continue;
cnt++;
overlap[col[i]] = true;
for (int j = dd[i]; j != i; j = dd[j])
for (int k = rr[j]; k != j; k = rr[k])
overlap[col[k]] = true;
}
return cnt;
}
// DFS找最小的可行解
bool dfs(int step, int max_step) {
int score_val = score();
if (step + score_val > max_step) {
return false;
}
if (rr[0] == 0) {
return true;
}
// 找1最少的列处理
int c = rr[0];
int min_col = -1;
int min_one_num = 0x7fffffff;
while (c != 0) {
if (one_num[c] < min_one_num) {
min_one_num = one_num[c];
min_col = c;
}
c = rr[c];
}
// 枚举要选择的行
for (int node = dd[min_col]; node != min_col; node = dd[node]) {
path_ans[step] = row[node];
remove_col(node);
for (int nn = rr[node]; nn != node; nn = rr[nn]) {
remove_col(nn);
}
if (dfs(step + 1, max_step)) {
return true;
}
for (int nn = ll[node]; nn != node; nn = ll[nn]) {
resume_col(nn);
}
resume_col(node);
}
return false;
}
public:
// 十字链表初始化
void dlx_init(point *one_points, int points_num, int col_num) {
for (int j = 0; j <= col_num; j++) {
ll[j] = j - 1, rr[j] = j + 1;
uu[j] = j, dd[j] = j, col[j] = j;
one_num[j] = 0;
}
ll[0] = col_num, rr[col_num] = 0;
int pool_pos = col_num + 1;
int n = points_num;
sort(one_points, one_points + points_num);
int i = 0;
int ii, jj;
while (i < n) {
int j = i;
int h_node = -1; // 一行的第一个节点的id
while (j < n && one_points[j].ii == one_points[i].ii) {
int pos = pool_pos;
ii = one_points[j].ii, jj = one_points[j].jj;
row[pos] = ii, col[pos] = jj, one_num[jj]++;
uu[pos] = jj, dd[pos] = dd[jj], uu[dd[jj]] = pos, dd[jj] = pos;
if (h_node == -1) {
h_node = pos, ll[h_node] = rr[h_node] = h_node;
} else {
ll[pos] = h_node, rr[pos] = rr[h_node], ll[rr[h_node]] = pos, rr[h_node] = pos;
}
j += 1;
pool_pos++;
}
i = j;
}
}
// 迭代加深方式获取一个行数最少的可行解
vector<int> get_one_minimal_solution(int col_num) {
vector<int> ans;
for (int bound = 1; bound <= col_num; bound++) {
if (dfs(0, bound)) {
for (int i = 0; i < bound; i++) {
ans.push_back(path_ans[i]);
}
break;
}
}
return ans;
}
};
int main() {
int m, n, val;
scanf("%d %d", &m, &n);
int pos = 0;
for (int i = 1; i <= m; i++) {
for (int j = 1; j <= n; j++) {
scanf("%d", &val);
if (val == 1) {
input[pos].ii = i, input[pos].jj = j;
pos++;
}
}
}
DlxRepeatOverlap algo;
algo.dlx_init(input, pos, n);
auto ans = algo.get_one_minimal_solution(m);
printf("%d\n", ans.size());
for (auto val : ans) {
printf("%d ", val);
}
printf("\n");
}
Python3 版本
from functools import lru_cache
from typing import List
class DlxRepeatOvelap:
# max_node_num是节点池的大小,取1节点大小即可, row_num和col_num是矩阵的行数和列数
# one_points是1节点的坐标的二元组(ii, jj)的列表,行坐标和列坐标都是从1开始(注意不是从0开始)
def __init__(self, max_node_num, row_num, col_num, one_points: List):
max_node_num += col_num+1 # 加上哨兵的节点数
self.__row_num = row_num
self.__col_num = col_num
self.__pos = 0 # 当前空闲的末尾位置
# 上下左右后继的节点id
self.uu, self.dd, self.ll, self.rr = [0]*max_node_num, [0]*max_node_num, [0]*max_node_num, [0]*max_node_num
self.row, self.col = [0] * max_node_num, [0] * max_node_num # 节点的行和列
self.one_num = [0] * (self.__col_num + 1) # 列上的1计数值
self.one_points = one_points
self.__boot()
# 十字链表状态初始化
def __boot(self):
one_points = self.one_points
one_points.sort()
# 最上面一行0号节点到col_num号节点都是哨兵, 都是链表头,不是实际1节点,0号节点下面不会接其他节点,只作为是横向的哨兵
uu, dd, ll, rr = self.uu, self.dd, self.ll, self.rr
col_num, col = self.__col_num, self.col
for jj in range(col_num + 1):
ll[jj] = jj - 1
rr[jj] = jj + 1
uu[jj] = jj
dd[jj] = jj
col[jj] = jj
ll[0], rr[col_num] = col_num, 0
self.__pos = col_num + 1
# 逐行添加节点
i = 0
n = len(one_points)
row, col, one_num = self.row, self.col, self.one_num
while i < n:
j = i
h_node = None # 一行的第一个节点
while j < n and one_points[j][0] == one_points[i][0]:
ii, jj = one_points[j]
pos = self.__pos
row[pos], col[pos] = ii, jj
one_num[jj] += 1
uu[pos], dd[pos] = jj, dd[jj]
uu[dd[jj]] = pos
dd[jj] = pos
# 新节点总是放在头节点的右边
if h_node is None:
h_node = pos
ll[h_node], rr[h_node] = h_node, h_node
else:
ll[pos], rr[pos] = h_node, rr[h_node]
ll[rr[h_node]] = pos
rr[h_node] = pos
j += 1
self.__pos += 1
i = j
# c节点所在的列上除了c节点以外,其他的节点都在水平方向断开
def remove_col(self, c):
uu, dd, ll, rr = self.uu, self.dd, self.ll, self.rr
node = dd[c]
while node != c:
ll[rr[node]] = ll[node]
rr[ll[node]] = rr[node]
node = dd[node]
# c节点所在的列上除了c节点以外,其他的节点都在水平方向接回来
def resume_col(self, c):
uu, dd, ll, rr = self.uu, self.dd, self.ll, self.rr
node = uu[c]
while node != c:
ll[rr[node]] = node
rr[ll[node]] = node
node = uu[node]
# 估价函数,预估最少还需要选择多少行才能完全覆盖所有列
def score(self):
uu, dd, ll, rr = self.uu, self.dd, self.ll, self.rr
overlap = [0] * (self.__col_num + 1)
cnt = 0
node = rr[0]
while node != 0:
if overlap[node] == 1:
node = rr[node]
continue
cnt += 1
overlap[node] = 1
# 该列的所有行全部都选中
mm = dd[node]
while mm != node:
nn = rr[mm]
while nn != mm:
overlap[self.col[nn]] = 1
nn = rr[nn]
mm = dd[mm]
node = rr[node]
return cnt
# DFS找最小的可行解
def dfs(self, step, max_step, path: List):
uu, dd, ll, rr = self.uu, self.dd, self.ll, self.rr
if step + self.score() > max_step:
return False
if rr[0] == 0:
return True
# 选1最少的行
c = rr[0]
min_one_num = 0x7fffffff
min_col = None
while c != 0:
if self.one_num[c] < min_one_num:
min_one_num = self.one_num[c]
min_col = c
c = rr[c]
# 针对min_col列,枚举能够选择的行
node = dd[min_col]
while node != min_col:
nn = rr[node]
while nn != node:
self.remove_col(nn)
nn = rr[nn]
self.remove_col(node)
path.append(self.row[node])
if self.dfs(step+1, max_step, path):
return True
path.pop(-1)
self.resume_col(node)
nn = ll[node]
while nn != node:
self.resume_col(nn)
nn = ll[nn]
node = dd[node]
return False
# 迭代加深获取一个最小的可行解
def get_one_minimal_solution(self):
path = []
for bound in range(1, self.__col_num + 1):
if self.dfs(0, bound, path):
return path
return None
one_points = []
m, n = map(int, input().split())
for i in range(1, m + 1):
line = list(map(int, input().split()))
for j in range(n):
if line[j] == 1:
one_points.append((i, (j + 1)))
algo = DlxRepeatOvelap(550 * 550, m, n, one_points)
ret = algo.get_one_minimal_solution()
if ret is not None:
print(len(ret))
for val in ret:
print(val, end=' ')
print()
else:
print("No Solution!")
orz
# Orz