算法分析
状态压缩dp
预处理过程
- 1、抛物线$y = ax^2 + bx + c$经过原点,且开口向下,因此用两个点的位置描述该抛物线的状态,例如
path[i][j]
表示i
号点和j
号点对应的抛物线的状态 - 2、先预处理出所有抛物线对应有哪些点的状态,例如
path[i][j] = 10110
,则表示该抛物线有1
号点,2
号点,4
号点(从0
号点开始),再算该抛物线有哪些点时,需要先算出该抛物线的a
和b
值,a = (y1 / x1 - y2 / x2) / (x1 - x2),b = y1 / x1 - a * x1
,再枚举所有点,判断该点是否在该抛物线上
注意:由于存在精度的问题,判断两个值是否相等,只需要判断Math.abs(a - b) < esp
是否成立即可
算法1:记忆化搜索
f[state] = dfs(state)
表示从该状态覆盖所有的点需要多少条抛物线
int dfs(int state) // state表示当前状态哪些点被覆盖
{
if(state 已经覆盖所有的点) return 0
if(state 状态已经搜索过) retrun f[state]的值
任选出没有被覆盖的点x
枚举所有与x号点相关的抛物线
当该抛物线不能覆盖x号点 即path[x][i] == 0 continue;
f[state] = Math.min(f[state],dfs(state | path[x][i]) + 1);
返回 f[state];
}
算法2:线性dp
- 切入点:
i | path[x][j]
是大于等于i
,i
才能 从小到大枚举,从i
转移到i | path[x][j]
f[state]
: 表示从该状态覆盖所有的点需要多少条抛物线- 从
0
枚举到(1 << n) - 1
,找到任意一个未覆盖的点,枚举所有与x
号点相关的抛物线,最终返回f[(1 << n) - 1]
时间复杂度$O(T(n^3 + n * 2^n))$
参考文献
算法提高课
Java 代码(记忆化搜索)
import java.util.Arrays;
import java.util.Scanner;
public class Main {
static int n,m;
static int N = 18,M = 1 << 18;
static int[][] path = new int[N][N];
static Pair[] pair = new Pair[N];
static double esp = 1e-8;
static int[] f = new int[M];//表示该状态需要抛物线的数量
static int cmp(double a,double b)
{
if(Math.abs(a - b) < esp) return 0;
if(a > b) return 1;
return -1;
}
static int dfs(int state)
{
if(state == (1 << n) - 1) return 0;
if(f[state] != -1) return f[state];
//任选出没有被覆盖的点x
int x = 0;
for(int i = 0;i < n;i ++)
{
if((state >> i & 1) == 0)
{
x = i;
break;
}
}
int res = 0x3f3f3f3f;
//枚举所有与x号点相关的抛物线
for(int i = 0;i < n;i ++)
{
if (path[x][i] == 0) continue;//当该抛物线不能覆盖x号点
res = Math.min(res,dfs(state | path[x][i]) + 1);
}
return f[state] = res;
}
public static void main(String[] args) {
Scanner scan = new Scanner(System.in);
int T = scan.nextInt();
while(T -- > 0)
{
n = scan.nextInt();
m = scan.nextInt();
for(int i = 0;i < n;i ++)
{
double x = scan.nextDouble();
double y = scan.nextDouble();
pair[i] = new Pair(x,y);
}
for(int i = 0;i < n;i ++) Arrays.fill(path[i], 0);
//找出所有抛物线对应点的状态
for(int i = 0;i < n;i ++)
{
path[i][i] = 1 << i;
for(int j = 0;j < n;j ++)
{
double x1 = pair[i].x, y1 = pair[i].y;
double x2 = pair[j].x, y2 = pair[j].y;
if(cmp(x1,x2) == 0) continue;
double a = (y1 / x1 - y2 / x2) / (x1 - x2);
double b = y1 / x1 - a * x1;
if(cmp(a,0) >= 0) continue;
int state = 0;
//判断哪些点在该抛物线上
for(int k = 0;k < n;k ++)
{
double x = pair[k].x;
double y = pair[k].y;
if(cmp(a * x * x + b * x,y) == 0) state += 1 << k;
}
path[i][j] = state;
}
}
Arrays.fill(f, -1);
System.out.println(dfs(0));
}
}
}
class Pair
{
double x,y;
Pair(double x,double y)
{
this.x = x;
this.y = y;
}
}
Java代码(线性dp)
import java.util.Arrays;
import java.util.Scanner;
public class Main {
static int n,m;
static int N = 18,M = 1 << 18;
static int[][] path = new int[N][N];
static Pair[] pair = new Pair[N];
static double esp = 1e-8;
static int[] f = new int[M];//表示从该状态覆盖所有的点需要多少条抛物线
static int cmp(double a,double b)
{
if(Math.abs(a - b) < esp) return 0;
if(a > b) return 1;
return -1;
}
public static void main(String[] args) {
Scanner scan = new Scanner(System.in);
int T = scan.nextInt();
while(T -- > 0)
{
n = scan.nextInt();
m = scan.nextInt();
for(int i = 0;i < n;i ++)
{
double x = scan.nextDouble();
double y = scan.nextDouble();
pair[i] = new Pair(x,y);
}
for(int i = 0;i < n;i ++) Arrays.fill(path[i], 0);
//找出所有抛物线对应点的状态
for(int i = 0;i < n;i ++)
{
path[i][i] = 1 << i;
for(int j = 0;j < n;j ++)
{
double x1 = pair[i].x, y1 = pair[i].y;
double x2 = pair[j].x, y2 = pair[j].y;
if(cmp(x1,x2) == 0) continue;
double a = (y1 / x1 - y2 / x2) / (x1 - x2);
double b = y1 / x1 - a * x1;
if(cmp(a,0) >= 0) continue;
int state = 0;
//判断哪些点在该抛物线上
for(int k = 0;k < n;k ++)
{
double x = pair[k].x;
double y = pair[k].y;
if(cmp(a * x * x + b * x,y) == 0) state += 1 << k;
}
path[i][j] = state;
}
}
Arrays.fill(f, 0x3f3f3f3f);
f[0] = 0;
//注意:i | path[x][j] 是大于等于i,i才能 从小到大枚举,从i转移到i | path[x][j]
for(int i = 0;i < 1 << n;i ++)
{
//找到任意一个未覆盖的点
int x = 0;
for(int j = 0;j < n;j ++)
{
if((i >> j & 1) == 0)
{
x = j;
break;
}
}
//枚举所有与x号点相关的抛物线
for(int j = 0;j < n;j ++)
f[i | path[x][j]] = Math.min(f[i | path[x][j]], f[i] + 1);
}
System.out.println(f[(1 << n) - 1]);
}
}
}
class Pair
{
double x,y;
Pair(double x,double y)
{
this.x = x;
this.y = y;
}
}
线性dp方法哪里的?
太难了 没懂 QAQ
加油