问题来自 poj 1160 Post Office
这是一道动态规划题。(下面的分析都是village start from 1)
先从简单的分析:假设有n个village,只要建一个post office,选在哪里建能得到shortest total distance?
继续简化:
- 当n等于1的时候,建在village 1里面
- 当n等于2的时候,建在village 1或village 2里面
- 当n等于3的时候,建在village 2里面
- 当n等于4的时候,建在village 2或village 3里面
- 当n等于5的时候,建在village 3里面
- 当n等于6的时候,建在village 3或village 4里面
- 当n等于7的时候,建在village 4里面
- 当n等于8的时候,建在village 4或village 5里面
所以,只建一个post office的时候,我们只需要把该post office建在“中点”即可。
然后继续分析我们原本的问题:定义数组dp[v][p]表示建p个post office给village 1到village v的shortest total distance,那么:
dp[v][p] = min{dp[i][p-1] + one[i+1][v]: p-1<=i<v}
(上面式子的one[i+1][v]表示的是shortest total distance for village i+1 to village v if we only build one post office for [i+1,…,v])
从上面的递推公式我们知道:需要先求出one这个二维数组,下面的代码用O(n2)的时间复杂度算出了这个二维数组。
#include <iostream> #include <stdio.h> #include <string.h> using namespace std; const int V = 310; const int P = 35; int pos[V]; int v, p; //shortestDis[i][j] means if we only consider village i to village j, //and build only one post office for them([i, j]), the shortest //total distance (i <= j). village start from index 1 int shortestDis[V][V]; void init() { scanf("%d%d", &v, &p); for(int i = 1; i <= v; ++i) scanf("%d", &pos[i]); memset(shortestDis, 0, sizeof(shortestDis)); } void calculateOne() { for(int mid = 1; mid <= v; ++mid) { int i = mid, j = mid; int total = 0; while(i > 0 && j <= v) { total += pos[mid]-pos[i]; total += pos[j]-pos[mid]; shortestDis[i][j] = total; --i; ++j; } } for(int mid1 = 1; mid1 <= v-1; ++mid1) { int i = mid1, j = mid1+1; int total = 0; while(i > 0 && j <= v) { total += pos[mid1]-pos[i]; total += pos[j]-pos[mid1]; shortestDis[i][j] = total; --i; ++j; } } } int solve() { calculateOne(); //dp[i][j]: when only build "j" post offices for villages [1,...,i] //the shorest total distance. so our goal is get dp[v][p]; int dp[V][P]; memset(dp, 0, sizeof(dp)); for(int r = 1; r <= v; ++r) dp[r][1] = shortestDis[1][r]; for(int c = 2; c <= p; ++c) for(int r = c+1; r <= v; ++r) { dp[r][c] = dp[r-1][c-1]; for(int i = r-1; i >= c; --i) dp[r][c] = min(dp[r][c], dp[i-1][c-1]+shortestDis[i][r]); } return dp[v][p]; } int main() { init(); printf("%d\n", solve()); return 0; }