多维 DP
2025-05-26

上楼听到有人在哼 Climbing on the Trees,上来听到隔壁班在放 Burn the House Down,果然春天是 AJR 的季节()


主要是多维 DP 特有的优化!

分步转移

如果状态的若干维之间只存在很弱的联系,可以分步转移每一维。

假设有二维状态转移 \((a,b)\to (a',b')\),『很弱的联系』举例:

  1. DP 值中包含 \(w(a',b')\),但不包含诸如 \(w(a',b),w(a,b')\) 之类。也即该值的求解不依赖于上一个状态、不依赖于转移。
  2. 对于 \((a',b')\) 的取值有限制,比如 \((114,514)\) 不能取到之类。显然这也不依赖于上一个状态、不依赖于转移。
  3. 符号限制(如转移间为 \(+\),状态间为 \(\times\) 之类),在计数中常出现。
  4. 一个变动的时候另一个必须也一起动。

注意!有的时候一个状态也可以拆成两个状态,然后分步转移。


例题:经典题

给定 \(w_{N\times M}\),求 \((A,B)_{1\sim K}\),最大化 \(\sum\limits_{i=1}^K w_{A_i,B_i}+w_{A_{i-1},A_i}+w_{B_{i-1},B_i}\) 的值。

发现如果设 \(f_{i,a,b}\) 表示 \((A,B)_i=(a,b)\) 的话,转移就会达到可观的 \(O(N\times M)\);但发现里面的 \(A\)\(B\) 其实没什么有机联系;唯一的联系 \(w_{A_i,B_i}\)(上面『很弱的联系』中的第一、四种情况)。因此分布转移:

\[ f'_{a',b}\gets \max\limits_a f_{a,b}+w_{a',a}\\ f_{a',b'}\gets \max\limits_b f'_{a', b} + w_{b,b'}+w_{a',b'} \]

由此便优化转移复杂度到 \(O(N+M)\)


例题:彩灯晚会

goto link

Tips:发现 \(l_1,l_2\) 之间也没啥有机联系(上面『很弱的联系』中的第三、四种情况),故分步转移。


例题:序列妙妙值

https://uoj.ac/problem/549

朴素地,设 \(f_{i,j}\) 表示在第 \(i\) 个处分第 \(j\) 段的最大价值,显然有 \(O(k\times n^2)\),且并没有优化的空间。

考虑利用 \(a_i\) 很小这一条件,发现上一条无法优化是因为要获取 \(sum_{i}\) 的值;那么把 \(sum_i\) 塞到状态里。设 \(f_{s',j}\) 表示 \(sum=s'\) 时,分了 \(j\) 段的最大价值。同样可以 \(O(k\cdot n^2)\)

接下来有两个理解这个优化的角度:

  1. 从平衡角度,原问题等价于 \(O(1)\) 更新 \((s,j)\) 处的最大值,\(O(v)\) 查找 \(j-1\) 处的最大值;把 \(s\) 拆成前 \(8\) 位、后 \(8\) 位两个 part,当更新

    对于修改:相当于固定前 \(8\)\(s\),枚举后 \(8\)\(x\),并更新 \((s,x)\)

    对于查询:相当于固定后 \(8\)\(x\),枚举前 \(8\)\(s\),并查询 \((s,x)\)
  2. 从分步转移角度,由于『现时刻』的贡献是按位的,二进制状态的前 \(8\) 位和后 \(8\) 位没啥有机联系,故直接拆开,先转移前 \(8\) 位,再转移后 \(8\) 位。

这样就被优化为 \(O(k\cdot n\times \sqrt v)\)

#include <bits/stdc++.h>
const int maxv = 1 << 8;
const int inf = 0x3f3f3f3f;
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, k;
    std::cin >> n >> k;
    using arr = std::vector<int>;
    using brr = std::vector<arr>;
    using crr = std::vector<brr>;
    arr s(n + 1), tag0(maxv), d0;
    crr f(maxv, brr(maxv, arr(k + 1, inf))), g(maxv, brr(maxv, arr(k + 1, inf)));
    f[0][0][0] = 0;
    for (int j = 0; j < maxv; ++j)
        g[0][j][1] = j;
    tag0[0] = 1, d0.push_back(0);
    for (int i = 1; i <= n; ++i) {
        std::cin >> s[i], s[i] ^= s[i - 1];
        int p0 = s[i] >> 8, p1 = s[i] & ((1 << 8) - 1);
        for (int j = k; j; --j) {
            for (auto a : d0)
                f[p0][p1][j] = std::min(f[p0][p1][j], g[a][p1][j] + ((a ^ p0) << 8));
            // printf("f[%d][%d][%d] = %d\n", p0, p1, j, f[p0][p1][j]);
            if (j != k)
                for (int a = 0; a < maxv; ++a)
                    g[p0][a][j + 1] = std::min(g[p0][a][j + 1], f[p0][p1][j] + (a ^ p1));
        }
        if (i >= k)
            std::cout << f[p0][p1][k] << ' ';
        if (!tag0[p0])
            tag0[p0] = 1, d0.push_back(p0);
    }
    std::cout << '\n';
    return 0;
}

割裂

如果状态的若干维之间连无机联系都没了,可以直接把 DP 数组拆开,各自转移各自的。

最后答案的拼接,可能是枚举、直接拼、用一个转移另一个之类。


例题:经典题

给定 \(w_{N\times M}\),求 \((A,B)_{1\sim K}\),最大化 \(\sum\limits_{i=1}^K w_{A_{i-1},A_i}+w_{B_{i-1},B_i}\) 的值。

先 DP 出来 \(A\),再 DP 出来 \(B\),直接相加即可。


例题:MNOGOMET

https://www.luogu.com.cn/problem/P7648

想到设 \(f_{t,i,a,b}\) 表示『过去了 \(t\) 秒,且球在球员 \(i\) 手上,且两个队伍得分分别为 \(a,b\)』的概率。那么枚举球上一次在谁手上,有美观的 \(O(T\cdot N^2R^2)\)

发现比分变动时(上半个时刻完成射门并传球)球一定在某个队的 \(1\) 号手上;\(i\) 这一维和 \(a,b\) 也没有相互的干扰;故可以将 \(i\) 提取出来(作为对比,射门和夺球、传球都会影响 \(t\),所以分裂出来的状态中必须包含 \(t\))。具体地:

  1. \(g_{t,0/1,i}\) 表示『一开始球在哪个球队的 \(1\) 号,花费 \(t\) 秒,且没人射门,且球在球员 \(i\) 手上』的概率。发现可以 \(O(T\times N^2)\) 计算。

    当然这里就是条件概率了。其中『一开始球在哪个球队的 \(1\) 号』就是条件。

    再预处理可以得到 \(G_{t,0/1,0/1,0/1}\) 表示『一开始球在哪个球队的 \(1\) 号,花费 \(t\) 秒,哪个球队射门,(没)射进』的概率。
  2. \(f_{t,0/1,a,b}\) 表示『过去了 \(t\) 秒,球在哪个队的 \(1\) 号,且得分为 \(a,b\)』的概率。枚举距离上一次射门的时间,可以 \(O(T^2\times R^2)\) 完成转移。

#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, r, T;
    std::cin >> n >> r >> T;
    using arr = std::vector<double>;
    using brr = std::vector<arr>;
    using crr = std::vector<brr>;
    brr p(2 * n + 1, arr(2)), pp(2 * n + 1, arr(2 * n + 1));
    for (int i = 1; i <= n; ++i) {
        std::cin >> p[i][1], p[i][0] = 1. - p[i][1];
        int ss, sd;
        std::cin >> ss >> sd;
        double P = 1. / (ss + sd + 1);
        p[i][0] *= P, p[i][1] *= P;
        for (int x; ss--; pp[i][x] = P)
            std::cin >> x;
        for (int x; sd--; pp[i][x + n] = P)
            std::cin >> x;
    }
    for (int i = n + 1; i <= 2 * n; ++i) {
        std::cin >> p[i][1], p[i][0] = 1. - p[i][1];
        int ss, sd;
        std::cin >> ss >> sd;
        double P = 1. / (ss + sd + 1);
        p[i][0] *= P, p[i][1] *= P;
        for (int x; ss--; pp[i][x + n] = P)
            std::cin >> x;
        for (int x; sd--; pp[i][x] = P)
            std::cin >> x;
    }
    brr s(T + 1, arr(2));
    crr g(T + 1, brr(2, arr(2 * n + 1)));
    std::vector<crr> G(T + 1, crr(2, brr(2, arr(2)))), f(T + 1, crr(2, brr(r + 1, arr(r + 1))));
    g[0][0][1] = g[0][1][n + 1] = 1.;
    for (int t = 1; t <= T; ++t) {
        for (int f1 = 0; f1 <= 1; ++f1)
            for (int i = 1; i <= 2 * n; ++i) {
                G[t][f1][i > n][0] += g[t - 1][f1][i] * p[i][0];
                G[t][f1][i > n][1] += g[t - 1][f1][i] * p[i][1];
                for (int j = 1; j <= 2 * n; ++j)
                    if (j != i)
                        g[t][f1][i] += pp[j][i] * g[t - 1][f1][j];
                // printf("spend %ds, start from %d, now at %d: %.10lf\n", t, 1 + f1 * n, i, g[t][f1][i]);
                s[t][f1] += g[t][f1][i];
            }
    }
    // puts("------------------------------------");
    // for (int t = 1; t <= T; ++t) {
    //     for (int f1 = 0; f1 <= 1; ++f1)
    //         for (int f2 = 0; f2 <= 1; ++f2)
    //             for (int f3 = 0; f3 <= 1; ++f3)
    //                 printf("spend %ds, start from %d, %d shoots, STATUS: %d, P = %.10lf\n", t, 1 + f1 * n, 1 + f2, f3, G[t][f1][f2][f3]);
    // }
    // puts("------------------------------------");
    f[0][0][0][0] = 1.;
    for (int t = 0; t < T; ++t)
        for (int f1 = 0; f1 <= 1; ++f1)
            for (int a = 0; a < r; ++a)
                for (int b = 0; b < r; ++b) {
                    for (int t1 = 1; t + t1 <= T; ++t1) {
                        f[t + t1][0][a][b] += f[t][f1][a][b] * G[t1][f1][1][0];
                        f[t + t1][0][a][b + 1] += f[t][f1][a][b] * G[t1][f1][1][1];
                        f[t + t1][1][a][b] += f[t][f1][a][b] * G[t1][f1][0][0];
                        f[t + t1][1][a + 1][b] += f[t][f1][a][b] * G[t1][f1][0][1];
                        if (t + t1 == T)
                            f[T][0][a][b] += f[t][f1][a][b] * s[t1][f1];
                    }
                }
    // for (int t = 0; t <= T; ++t)
    //     for (int f1 = 0; f1 <= 1; ++f1)
    //         for (int a = 0; a <= r; ++a)
    //             for (int b = 0; b <= r; ++b)
    //                 if (f[t][f1][a][b] > 1e-10)
    //                     printf("%ds later, %d shoots, %d : %d, P = %.10lf\n", t, 1 + !f1 * n, a, b, f[t][f1][a][b]);
    std::cout << std::fixed << std::setprecision(10);
    for (int a = 0; a <= r; ++a) {
        for (int b = 0; b <= r; ++b) {
            if (a == r && b == r)
                continue;
            double res(0.);
            if (a == r)
                for (int t = r; t <= T; ++t)
                    res += f[t][1][r][b];
            else if (b == r)
                for (int t = r; t <= T; ++t)
                    res += f[t][0][a][r];
            else
                res = f[T][0][a][b] + f[T][1][a][b];
            std::cout << res << '\n';
        }
    }
    return 0;
}

言论