DDP 与全局平衡二叉树
2025-08-12

其实动态 DP 不是动态树 DP 的同义词


动态 DP

对于转移只依赖前几项的 DP,可以放进矩阵。

默认转移范围是 \([1,n]\);如果我们需要改变求解范围,发现如果能获取转移矩阵的乘积就很快了。考虑存下来。

如果需要单点修改,等价于修改其所在矩阵,则想要尽可能少地修改存储的答案,考虑使用线段树优化该过程。

例:海报

https://www.luogu.com.cn/problem/P9790

容易列出暴力 DP 式:令 \(f_{i,j}\) 表示枚举到 \(i\) 时,包含 \(i\) 在内已经有 \(j\) 个连续的人举起海报,易得:

\[ f_{i,0}=\max(f_{i-1,0},f_{i-1,1},f_{i-1,2},f_{i-1,3})\\ f_{i,j}=f_{i-1,j-1}+a_i \forall 1\le j\le 3 \]

发现满足 + / max 矩阵乘法 的形式;想到用线段树保存每段区间对应矩阵(对应性质:结合律),每次修改 / 查询就能在 \(O(\log n)\) 之内完成。

原问题是环形的,可以再加一维 \(k\) 表示钦定选了前 \(k\) 个且不选第 \(k+1\) 个时的答案。

#include <bits/stdc++.h>
const int maxn = 4e4 + 5;
const long long inf = 1e18;
struct mat {
    int n, m;
    long long a[4][4];
    mat() {}
    mat(int n1, int m1): n(n1), m(m1) {
        for (int i = 0; i < n; ++i)
            for (int j = 0; j < m; ++j)
                a[i][j] = -inf;
        return;
    }
    long long* operator[] (const int q) {
        return a[q];
    }
    mat operator* (mat &q) const {
        mat res(n, q.m);
        for (int i = 0; i < n; ++i)
            for (int k = 0; k < q.m; ++k)
                for (int j = 0; j < m; ++j)
                    res[i][k] = std::max(res[i][k], a[i][j] + q[j][k]);
        return res;
    }
    mat& operator*= (mat &q) {
        return *this = *this * q;
    }
};
struct { int l, r; mat u[4]; } t[maxn << 2];
int a[maxn];
#define lt (p << 1)
#define rt (lt | 1)
#define c t[p].u[i]
void bld(int p, int l, int r) {
    t[p].l = l, t[p].r = r;
    if (l == r) {
        for (int i = 0; i <= 3; ++i) {
            c = mat(4, 4);
            if (l > i + 1) {
                c[0][0] = c[1][0] = c[2][0] = c[3][0] = 0;
                c[0][1] = c[1][2] = c[2][3] = a[l];
            }
            else if (l == i + 1)
                c[0][0] = c[1][0] = c[2][0] = c[3][0] = 0;
            else
                c[0][1] = c[1][2] = c[2][3] = a[l];
        }
        return;
    }
    int mid = (l + r) >> 1;
    bld(lt, l, mid), bld(rt, mid + 1, r);
    for (int i = 0; i <= 3; ++i)
        t[p].u[i] = t[lt].u[i] * t[rt].u[i];
    return;
}
void add(int p, int x, int v) {
    if (t[p].l == t[p].r) {
        for (int i = 0; i <= 3; ++i)
            if (t[p].l != i + 1) 
                c[0][1] = c[1][2] = c[2][3] = v;
        return;
    }
    int mid = (t[p].l + t[p].r) >> 1;
    if (x <= mid)
        add(lt, x, v);
    else
        add(rt, x, v);
    for (int i = 0; i <= 3; ++i)
        t[p].u[i] = t[lt].u[i] * t[rt].u[i];
    return;
}
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
    const auto stime = std::chrono::steady_clock::now();
#endif
    int n;
    std::cin >> n;
    for (int i = 1; i <= n; ++i)
        std::cin >> a[i];
    bld(1, 1, n);
    auto calc = [&](void) {
        mat f(1, 4);
        f[0][0] = 0;
        auto res = -inf;
        for (int i = 0; i <= 3; ++i) {
            auto r = (f * t[1].u[i]);
            for (int j = 0; j <= 3; ++j)
                if (i + j <= 3) {
                    res = std::max(res, r[0][j]);
                    // printf("f[%d][%d] = %lld\n", i, j, f[i][j]);
                }
        }
        return res;
    };
    std::cout << calc() << '\n';
    int q;
    std::cin >> q;
    for (int x, v; q--; ) {
        std::cin >> x >> v;
        add(1, x, v);
        std::cout << calc() << '\n';
    }
#ifndef ONLINE_JUDGE
    std::cerr << std::fixed << std::setprecision(6) << std::chrono::duration<double> (std::chrono::steady_clock::now() - stime).count() << "s\n";
#endif
    return 0;
}

发现矩阵本身和 \(k\) 无关,还可以可以共用一个线段树上的信息,就可以只开一个线段树了。

我写这一版本的原因是 maxn 开大了导致 MLE,实际上四个线段树是没有任何空间压力的 😅

#include <bits/stdc++.h>
const int maxn = 4e5 + 5;
const long long inf = 1e18;
struct mat {
    int n, m;
    long long a[4][4];
    mat() {}
    mat(int n1, int m1): n(n1), m(m1) {
        for (int i = 0; i < n; ++i)
            for (int j = 0; j < m; ++j)
                a[i][j] = -inf;
        return;
    }
    long long* operator[] (const int q) {
        return a[q];
    }
    mat operator* (mat q) const {
        mat res(n, q.m);
        for (int i = 0; i < n; ++i)
            for (int k = 0; k < q.m; ++k)
                for (int j = 0; j < m; ++j)
                    res[i][k] = std::max(res[i][k], a[i][j] + q[j][k]);
        return res;
    }
    mat& operator*= (mat &q) {
        return *this = *this * q;
    }
};
struct { int l, r; mat u; } t[maxn << 2];
int a[maxn];
#define lt (p << 1)
#define rt (lt | 1)
void bld(int p, int l, int r) {
    t[p].l = l, t[p].r = r;
    if (l == r) {
        t[p].u = mat(4, 4);
        t[p].u[0][0] = t[p].u[1][0] = t[p].u[2][0] = t[p].u[3][0] = 0;
        t[p].u[0][1] = t[p].u[1][2] = t[p].u[2][3] = a[l];
        return;
    }
    int mid = (l + r) >> 1;
    bld(lt, l, mid), bld(rt, mid + 1, r);
    t[p].u = t[lt].u * t[rt].u;
    return;
}
void add(int p, int x, int v) {
    if (t[p].l == t[p].r) {
        t[p].u[0][1] = t[p].u[1][2] = t[p].u[2][3] = v;
        return;
    }
    int mid = (t[p].l + t[p].r) >> 1;
    if (x <= mid)
        add(lt, x, v);
    else
        add(rt, x, v);
    t[p].u = t[lt].u * t[rt].u;
    return;
}
mat ask(int p, int l, int r) {
    if (l <= t[p].l && t[p].r <= r)
        return t[p].u;
    int mid = (t[p].l + t[p].r) >> 1;
    if (r <= mid)
        return ask(lt, l, r);
    if (l > mid)
        return ask(rt, l, r);
    return ask(lt, l, r) * ask(rt, l, r);
}
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
    const auto stime = std::chrono::steady_clock::now();
#endif
    int n;
    std::cin >> n;
    for (int i = 1; i <= n; ++i)
        std::cin >> a[i];
    bld(1, 1, n);
    auto calc = [&](void) {
        mat f(1, 4);
        f[0][0] = 0, f[0][1] = f[0][2] = f[0][3] = -inf;
        auto r(f * ask(1, 2, n));
        auto res(*std::max_element(r[0], r[0] + 4));
        mat z(4, 4);
        z[0][0] = z[1][0] = z[2][0] = z[3][0] = 0ll;
        for (int i = 1; i <= 3; ++i) {
            mat op(4, 4);
            op[0][1] = op[1][2] = op[2][3] = a[i];
            f *= op;
            if (i + 2 <= n)
                r = f * z * ask(1, i + 2, n);
            else
                r = f * z;
            res = std::max(res, *std::max_element(r[0], r[0] + 4 - i));
        }
        return res;
    };
    std::cout << calc() << '\n';
    int q;
    std::cin >> q;
    for (int x, v; q--; ) {
        std::cin >> x >> v, a[x] = v;
        add(1, x, v);
        std::cout << calc() << '\n';
    }
#ifndef ONLINE_JUDGE
    std::cerr << std::fixed << std::setprecision(6) << std::chrono::duration<double> (std::chrono::steady_clock::now() - stime).count() << "s\n";
#endif
    return 0;
}

经典题:GSS3

https://www.luogu.com.cn/problem/SP1716

怎么是子段和 😓

同样列出能够矩阵乘法的 DP 式,发现限制在于至少要选一个数。设 \(f_{i}\) 表示选了 \(i\) 的最大值,\(g_i\) 表示历史最大值,则:

\[ f_i=\max(f_{i-1}+a_i,0+a_i)\\ g_i=\max(g_{i-1}+0,f_{i-1}+a_i,0+a_i) \]

容易发现是一个 + / max 矩乘,线段树维护即可。

#include <bits/stdc++.h>
const int maxn = 4e5 + 5;
const long long inf = 1e18;
struct mat {
    int n, m;
    long long a[3][3];
    mat() {}
    mat(int n1, int m1): n(n1), m(m1) {
        for (int i = 0; i < n; ++i)
            for (int j = 0; j < m; ++j)
                a[i][j] = -inf;
        return;
    }
    long long* operator[] (const int q) {
        return a[q];
    }
    mat operator* (mat q) const {
        mat res(n, q.m);
        for (int i = 0; i < n; ++i)
            for (int k = 0; k < q.m; ++k)
                for (int j = 0; j < m; ++j)
                    res[i][k] = std::max(res[i][k], a[i][j] + q[j][k]);
        return res;
    }
    mat& operator*= (mat &q) {
        return *this = *this * q;
    }
};
struct { int l, r; mat u; } t[maxn << 2]; 
int a[maxn];
#define lt (p << 1)
#define rt (lt | 1)
void bld(int p, int l, int r) {
    t[p].l = l, t[p].r = r;
    if (l == r) {
        t[p].u = mat(3, 3);
        t[p].u[0][0] = t[p].u[2][0] = a[l];
        t[p].u[0][1] = a[l], t[p].u[1][1] = 0ll, t[p].u[2][1] = a[l];
        t[p].u[2][2] = 0ll;
        return;
    }
    int mid = (l + r) >> 1;
    bld(lt, l, mid), bld(rt, mid + 1, r);
    t[p].u = t[lt].u * t[rt].u;
    return;
}
void add(int p, int x, int v) {
    if (t[p].l == t[p].r) {
        t[p].u[0][0] = t[p].u[2][0] = t[p].u[0][1] = t[p].u[2][1] = v;
        return;
    }
    int mid = (t[p].l + t[p].r) >> 1;
    if (x <= mid)
        add(lt, x, v);
    else
        add(rt, x, v);
    t[p].u = t[lt].u * t[rt].u;
    return;
}
mat ask(int p, int l, int r) {
    if (l <= t[p].l && t[p].r <= r)
        return t[p].u;
    int mid = (t[p].l + t[p].r) >> 1;
    if (r <= mid)
        return ask(lt, l, r);
    if (l > mid)
        return ask(rt, l, r);
    return ask(lt, l, r) * ask(rt, l, r);
}
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
    const auto stime = std::chrono::steady_clock::now();
#endif
    int n;
    std::cin >> n;
    for (int i = 1; i <= n; ++i)
        std::cin >> a[i];
    bld(1, 1, n);
    auto calc = [&](int l, int r) {
        mat f(1, 3);
        f[0][2] = 0ll;
        auto res(f * ask(1, l, r));
        return res[0][1];
    };
    int q;
    std::cin >> q;
    for (int op; q--; ) {
        std::cin >> op;
        if (op == 1) {
            int l, r;
            std::cin >> l >> r;
            std::cout << calc(l, r) << '\n';
        }
        else {
            int x, v;
            std::cin >> x >> v;
            a[x] = v, add(1, x, v);
        }
    }
#ifndef ONLINE_JUDGE
    std::cerr << std::fixed << std::setprecision(6) << std::chrono::duration<double> (std::chrono::steady_clock::now() - stime).count() << "s\n";
#endif
    return 0;
}

动态树 DP

把上述过程放到树上,很容易想到树剖 + 线段树。

由于认为线段树上的矩乘只能进行从重儿子到父亲的转移,轻儿子的转移会被合并为一个新函数(同时是矩阵的系数),在跳重链的时候被单独更新。

由于两个函数相互依赖,需要思考清楚更新的先后顺序。

【模板】动态 DP

https://www.luogu.com.cn/problem/P4719

\(f_{u,0/1}\) 表示在 \(u\) 上,选 / 不选 \(u\) 的最大价值。容易得出转移:

\[ f_{u,0}=\sum\max(f_{v,0},f_{v,1})\\ f_{u,1}=a_i+\sum f_{v,0} \]

把转移矩阵放到树剖上后,考虑更新,发现求和这一步很困难。解决方案是直接将求和用另一个函数代替。定义 \(g_{u,0}\) 表示取 \(u\)、不取 \(u\) 的所有轻儿子的答案,\(g_{u,1}\) 不取 \(u\),轻儿子可选可不选的答案。

\(g\) 是可求的,且只需要在跳重链的时候更新 \(g\)。具体更新起来非常绞,因为 \(g\)\(f\) 是相互依赖的,需要分清楚先后关系。

首先 \(g_u\)\(f_u\) 都会被更新;接着,重链上其它的 \(g\) 不会被影响,而链顶的 \(f\) 需要被新的 \(g_u\) 更新;由此链顶父亲的 \(g\) 被更新;依次类推。注意到对于一个链顶,其 \(f\) 值是整条重链的乘积,故需要记录链底。

复杂度 \(O(q\log ^2n)\)。注意很重要的一点是线段树内乘法应从右往左。

#include <bits/stdc++.h>
const int maxn = 1e5 + 5;
const int inf = 0x3f3f3f3f;
struct mat {
    int n, m, a[2][2];
    mat() {}
    mat(int n1, int m1): n(n1), m(m1) {
        for (int i = 0; i < n; ++i)
            for (int j = 0; j < m; ++j)
                a[i][j] = -inf;
        return;
    }
    int* operator[] (const int q) {
        return a[q];
    }
    mat operator* (mat q) const {
        mat res(n, q.m);
        for (int i = 0; i < n; ++i)
            for (int k = 0; k < q.m; ++k)
                for (int j = 0; j < m; ++j)
                    res[i][k] = std::max(res[i][k], a[i][j] + q[j][k]);
        return res;
    }
    mat& operator*= (mat &q) {
        return *this = *this * q;
    }
};
struct { int l, r; mat u; } t[maxn << 2]; 
int g[maxn][2], tab[maxn];
#define lt (p << 1)
#define rt (lt | 1)
void bld(int p, int l, int r) {
    t[p].l = l, t[p].r = r;
    if (l == r) {
        int u = tab[l];
        t[p].u = mat(2, 2);
        t[p].u[0][0] = t[p].u[1][0] = g[u][1];
        t[p].u[0][1] = g[u][0];
        return;
    }
    int mid = (l + r) >> 1;
    bld(lt, l, mid), bld(rt, mid + 1, r);
    t[p].u = t[rt].u * t[lt].u;
    return;
}
void add(int p, int x) {
    if (t[p].l == t[p].r) {
        int u = tab[x];
        t[p].u[0][0] = t[p].u[1][0] = g[u][1];
        t[p].u[0][1] = g[u][0];
        return;
    }
    int mid = (t[p].l + t[p].r) >> 1;
    if (x <= mid)
        add(lt, x);
    else
        add(rt, x);
    t[p].u = t[rt].u * t[lt].u;
    return;
}
mat ask(int p, int l, int r) {
    if (l <= t[p].l && t[p].r <= r)
        return t[p].u;
    int mid = (t[p].l + t[p].r) >> 1;
    if (r <= mid)
        return ask(lt, l, r);
    if (l > mid)
        return ask(rt, l, r);
    return ask(rt, l, r) * ask(lt, l, r);
}
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
    const auto stime = std::chrono::steady_clock::now();
#endif
    int n, q;
    std::cin >> n >> q;
    std::vector<int> a(n + 1);
    for (int i = 1; i <= n; ++i)
        std::cin >> a[i];
    std::vector<std::vector<int> > g1(n + 1);
    for (int i = 1, x, y; i < n; ++i) {
        std::cin >> x >> y;
        g1[x].push_back(y), g1[y].push_back(x);
    }
    std::vector<int> siz(n + 1), son(n + 1), fa(n + 1);
    std::function<void(int)> DFS = [&](int x) {
        siz[x] = 1;
        for (auto i : g1[x])
            if (i != fa[x]) {
                fa[i] = x, DFS(i);
                siz[x] += siz[i];
                if (siz[i] > siz[son[x]])
                    son[x] = i;
            }
        return;
    };
    DFS(1);
    std::vector<std::array<int, 2> > f(n + 1);
    std::vector<int> top(n + 1), bot(n + 1), dfn(n + 1);
    DFS = [&](int x) {
        static int now = 0;
        dfn[x] = ++now, tab[now] = x;
        if (son[x]) {
            top[son[x]] = top[x], DFS(son[x]);
            bot[x] = bot[son[x]];
            g[x][0] = a[x];
            for (auto i : g1[x])
                if (i != son[x] && i != fa[x]) {
                    top[i] = i, DFS(i);
                    g[x][0] += f[i][0];
                    g[x][1] += std::max(f[i][1], f[i][0]);
                }
            f[x][0] = g[x][1] + std::max(f[son[x]][0], f[son[x]][1]);
            f[x][1] = g[x][0] + f[son[x]][0];
        }
        else
            f[x][1] = g[x][0] = a[x], bot[x] = x;
        return;
    };
    top[1] = 1, DFS(1);
    bld(1, 1, n);
    for (int x, v; q--; ) {
        std::cin >> x >> v;
        g[x][0] -= a[x], g[x][0] += v, a[x] = v;
        for (; top[x] != 1; ) {
            auto r = ask(1, dfn[top[x]], dfn[bot[x]]);
            f[top[x]][0] = r[0][0], f[top[x]][1] = r[0][1];
            g[fa[top[x]]][0] -= f[top[x]][0];
            g[fa[top[x]]][1] -= std::max(f[top[x]][0], f[top[x]][1]);
            add(1, dfn[x]);
            r = ask(1, dfn[top[x]], dfn[bot[x]]);
            f[top[x]][0] = r[0][0], f[top[x]][1] = r[0][1];
            g[fa[top[x]]][0] += f[top[x]][0];
            g[fa[top[x]]][1] += std::max(f[top[x]][0], f[top[x]][1]);
            x = fa[top[x]];
        }
        add(1, dfn[x]);
        auto r = ask(1, dfn[1], dfn[bot[1]]);
        f[1][0] = r[0][0], f[1][1] = r[0][1];
        std::cout << std::max(f[1][0], f[1][1]) << '\n';
    }
#ifndef ONLINE_JUDGE
    std::cerr << std::fixed << std::setprecision(6) << std::chrono::duration<double> (std::chrono::steady_clock::now() - stime).count() << "s\n";
#endif
    return 0;
}

全局平衡二叉树

回顾树剖(重剖),功能在于解决路径问题,单次操作能够在 \(O(\log^2 n)\) 的时间内完成。这个功能可以被全局平衡二叉树(GBT)上位替代。GBT 能够在单次 \(O(\log n)\) 的复杂度内完成链操作、子树操作。Yang Zhe - SPOJ375 QTREE 解法的一些研究 中更为详细严谨地对 GBT 进行了说明,我传了份文件上来。

考虑树剖能被卡的原因:每次线段树询问都会卡满 \(O(\log n)\),找一条卡得满 \(O(\log n)\) 次跳重链次数的路径一直薅,就可以卡到 \(O(\log^2 n)\)

在实现线段树时发现,对于路径操作单点操作,树剖只需要维护同一条重链的信息,建一个大线段树会产生许多重链间的无效维护。故一种经典的树剖卡常技巧是对于每一条重链建出线段树。

在本文中定义全局二叉树:将单个线段树按照在原树上重链顶的相对祖孙关系连边得到的模型。这里为了和全局平衡二叉树形成照应而命名,实际上模型并不是二叉树。容易发现修改某个点花费的操作次数和其在全局二叉树中的深度相同

考虑本方法理论上仍可卡到 \(O(\log^2 n)\) 单次操作的原因,虽然单个线段树平衡,但全局二叉树并不平衡;能够构造数据使得树高达到 \(\log^2 n\)。考虑使得全局二叉树平衡,即调整线段树结构使得任何一个点在全局二叉树上的左右儿子大小最接近。发现是易做的,只需在建线段树时移动左右儿子分割点使得两边子树大小均为全树的一半即可。

法一:求出每个点的轻子树大小 \(ls_u=1+\sum siz_v\),作为加权在线段树上找 mid 就能满足全局平衡;证明可以见上面的论文。这里用线段树代替了 BST,常数很大。

法二:用一个 BST 实现上述功能,需要满足:任意子树的根为子树的带权 mid;BST 的中序遍历为原重链。显然有:树高为 log 级别。这就决定了所有问题都可以通过暴力爬山解决。

GBT 能够快速维护普通树剖操作DDP 信息


维护 DDP:【模板】动态 DP(加强版)

https://www.luogu.com.cn/problem/P4751

和未加强版类似,把所有线段树操作替换为 BST 即可。如果写得丑可能需要一些额外的卡常技巧。

#include <bits/stdc++.h>
const int maxn = 1e6 + 5;
const int inf = 0x3f3f3f3f;
const int LEN = (1 << 20);
int nec(void) {
    static char buf[LEN], *p = buf, *e = buf;
    if (p == e) {
        e = buf + fread(buf, 1, LEN, stdin);
        if (e == buf) return EOF;
        p = buf;
    }
    return *p++;
}
bool read(int &x) {
    x = 0;
    bool f = 0;
    char ch = nec();
    while (ch < '0' || ch > '9') {
        if (ch == EOF) return 0;
        if (ch == '-') f = 1;
        ch = nec();
    }
    while (ch >= '0' && ch <= '9') {
        x = x * 10 + ch - '0';
        ch = nec();
    }
    if (f) x = -x;
    return 1;
}
void print(int x) {
    if (x < 0)
        putchar('-'), x = -x;
    if (x >= 10) print(x / 10);
    putchar(x % 10 + '0');
    return;
}
void print(int x, char ch) {
    print(x), putchar(ch);
    return;
}
struct mat {
    int a[2][2];
    int* operator[] (const int q) { 
        return a[q];
    }
    mat operator* (mat &q) const {
        mat res;
        res[0][0] = res[0][1] = res[1][0] = res[1][1] = -inf;
        for (int i = 0; i < 2; ++i)
            for (int k = 0; k < 2; ++k)
                for (int j = 0; j < 2; ++j)
                    res[i][k] = std::max(res[i][k], a[i][j] + q[j][k]);
        return res;
    }
    mat& operator*= (mat &q) {
        return *this = *this * q;
    }
};
struct { int lc, rc, fa; mat u; } t[maxn]; 
int g[maxn][2], tab[maxn], ls[maxn];
mat p[maxn];
void pushup(int x) {
    t[x].u = t[t[x].rc].u * p[x] * t[t[x].lc].u;
    return;
}
void bld(int &x, int l, int r) {
    if (l > r)
        return;
    int s = 0, k = 0;
    for (int i = l; i <= r; ++i)
        s += ls[tab[i]];
    for (int i = l; i <= r; ++i, k += ls[tab[i]])
        if ((k + ls[tab[i]]) * 2 > s) {
            x = tab[i];
            bld(t[x].lc, l, i - 1), t[t[x].lc].fa = x;
            bld(t[x].rc, i + 1, r), t[t[x].rc].fa = x;
            pushup(x);
            break;
        }
    return;
}
#undef lt
#undef rt
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
    const auto stime = std::chrono::steady_clock::now();
#endif
    t[0].u[0][0] = t[0].u[1][1] = 0, t[0].u[0][1] = t[0].u[1][0] = -inf;
    int n, q;
    read(n), read(q);
    std::vector<int> a(n + 1), rt(n + 1);
    for (int i = 1; i <= n; ++i)
        read(a[i]);
    std::vector<std::vector<int> > g1(n + 1);
    for (int i = 1, x, y; i < n; ++i) {
        read(x), read(y);
        g1[x].push_back(y), g1[y].push_back(x);
    }
    std::vector<int> son(n + 1), fa(n + 1), siz(n + 1);
    std::function<void(int)> DFS = [&](int x) {
        siz[x] = 1;
        for (auto i : g1[x])
            if (i != fa[x]) {
                fa[i] = x, DFS(i);
                siz[x] += siz[i];
                if (siz[i] > siz[son[x]])
                    son[x] = i;
            }
        return;
    };
    DFS(1);
    std::vector<std::array<int, 2> > f(n + 1);
    std::vector<int> top(n + 1), bot(n + 1), dfn(n + 1);
    DFS = [&](int x) {
        static int now = 0;
        dfn[x] = ++now, tab[now] = x, ls[x] = 1;
        if (son[x]) {
            top[son[x]] = top[x], DFS(son[x]);
            bot[x] = bot[son[x]];
            g[x][0] = a[x];
            for (auto i : g1[x])
                if (i != son[x] && i != fa[x]) {
                    top[i] = i, DFS(i), t[rt[i]].fa = x;
                    ls[x] += siz[i];
                    g[x][0] += f[i][0];
                    g[x][1] += std::max(f[i][1], f[i][0]);
                }
            f[x][0] = g[x][1] + std::max(f[son[x]][0], f[son[x]][1]);
            f[x][1] = g[x][0] + f[son[x]][0];
        }
        else
            f[x][1] = g[x][0] = a[x], bot[x] = x;
        p[x][0][0] = p[x][1][0] = g[x][1];
        p[x][0][1] = g[x][0], p[x][1][1] = -inf;
        if (x == top[x])
            bld(rt[x], dfn[x], dfn[bot[x]]);
        return;
    };
    top[1] = 1, DFS(1);
    for (int x, v, la = 0; q--; ) {
        read(x), read(v), x ^= la;
        p[x][0][1] += v - a[x], a[x] = v;
        for (; x; ) {
            int fa = t[x].fa;
            if (fa && x != t[fa].lc && x != t[fa].rc) {
                int f0 = t[x].u[0][0], f1 = std::max(f0, t[x].u[0][1]);
                pushup(x);
                int F0 = t[x].u[0][0], F1 = std::max(F0, t[x].u[0][1]);
                p[fa][0][0] += F1 - f1, p[fa][1][0] += F1 - f1;
                p[fa][0][1] += F0 - f0;
            }
            else
                pushup(x);
            x = fa;
        }
        print(la = std::max(t[rt[1]].u[0][0], t[rt[1]].u[0][1]), '\n');
    }
#ifndef ONLINE_JUDGE
    std::cerr << std::fixed << std::setprecision(6) << std::chrono::duration<double> (std::chrono::steady_clock::now() - stime).count() << "s\n";
#endif
    return 0;
}

也给出一个线段树版本的

#include <bits/stdc++.h>
const int maxn = 1e6 + 5;
const int inf = 0x3f3f3f3f;
const int LEN = (1 << 20);
int nec(void) {
    static char buf[LEN], *p = buf, *e = buf;
    if (p == e) {
        e = buf + fread(buf, 1, LEN, stdin);
        if (e == buf) return EOF;
        p = buf;
    }
    return *p++;
}
bool read(int &x) {
    x = 0;
    bool f = 0;
    char ch = nec();
    while (ch < '0' || ch > '9') {
        if (ch == EOF) return 0;
        if (ch == '-') f = 1;
        ch = nec();
    }
    while (ch >= '0' && ch <= '9') {
        x = x * 10 + ch - '0';
        ch = nec();
    }
    if (f) x = -x;
    return 1;
}
void print(int x) {
    if (x < 0)
        putchar('-'), x = -x;
    if (x >= 10) print(x / 10);
    putchar(x % 10 + '0');
    return;
}
void print(int x, char ch) {
    print(x), putchar(ch);
    return;
}
struct mat {
    int a[2][2];
    int* operator[] (const int q) {
        return a[q];
    }
    mat operator* (mat q) const {
        mat res;
        res[0][0] = res[0][1] = res[1][0] = res[1][1] = -inf;
        for (int i = 0; i < 2; ++i)
            for (int k = 0; k < 2; ++k)
                for (int j = 0; j < 2; ++j)
                    res[i][k] = std::max(res[i][k], a[i][j] + q[j][k]);
        return res;
    }
    mat& operator*= (mat &q) {
        return *this = *this * q;
    }
};
struct { int lc, rc, l, r, mid; mat u; } t[maxn << 2]; 
int g[maxn][2], tab[maxn], ls[maxn];
#define lt t[p].lc
#define rt t[p].rc
void bld(int &p, int l, int r) {
    static int tot = 0;
    p = ++tot, t[p].l = l, t[p].r = r;
    if (l == r) {
        int u = tab[l];
        t[p].u[0][0] = t[p].u[1][0] = g[u][1];
        t[p].u[0][1] = g[u][0], t[p].u[1][1] = -inf;
        return;
    }
    int s = 0, u = ls[tab[l]];
    for (int i = l; i <= r; ++i)
        s += ls[tab[i]];
    s >>= 1;
    t[p].mid = r - 1;
    for (int i = l + 1; i < r; ++i) {
        u += ls[tab[i]];
        if (u + ls[tab[i + 1]] > s) {
            t[p].mid = i;
            break;
        }
    }
    bld(lt, l, t[p].mid), bld(rt, t[p].mid + 1, r);
    t[p].u = t[rt].u * t[lt].u;
    return;
}
void add(int p, int x) {
    if (t[p].l == t[p].r) {
        int u = tab[x];
        t[p].u[0][0] = t[p].u[1][0] = g[u][1];
        t[p].u[0][1] = g[u][0];
        return;
    }
    if (x <= t[p].mid)
        add(lt, x);
    else
        add(rt, x);
    t[p].u = t[rt].u * t[lt].u;
    return;
}
mat ask(int p, int l, int r) {
    if (l <= t[p].l && t[p].r <= r)
        return t[p].u;
    if (r <= t[p].mid)
        return ask(lt, l, r);
    if (l > t[p].mid)
        return ask(rt, l, r);
    return ask(rt, l, r) * ask(lt, l, r);
}
#undef lt
#undef rt
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
    const auto stime = std::chrono::steady_clock::now();
#endif
    int n, q;
    read(n), read(q);
    std::vector<int> a(n + 1), rt(n + 1);
    for (int i = 1; i <= n; ++i)
        read(a[i]);
    std::vector<std::vector<int> > g1(n + 1);
    for (int i = 1, x, y; i < n; ++i) {
        read(x), read(y);
        g1[x].push_back(y), g1[y].push_back(x);
    }
    std::vector<int> siz(n + 1), son(n + 1), fa(n + 1);
    std::function<void(int)> DFS = [&](int x) {
        siz[x] = 1;
        for (auto i : g1[x])
            if (i != fa[x]) {
                fa[i] = x, DFS(i);
                siz[x] += siz[i];
                if (siz[i] > siz[son[x]])
                    son[x] = i;
            }
        return;
    };
    DFS(1);
    std::vector<std::array<int, 2> > f(n + 1);
    std::vector<int> top(n + 1), bot(n + 1), dfn(n + 1);
    DFS = [&](int x) {
        static int now = 0;
        dfn[x] = ++now, tab[now] = x;
        ls[x] = 1;
        if (son[x]) {
            top[son[x]] = top[x], DFS(son[x]);
            bot[x] = bot[son[x]];
            g[x][0] = a[x];
            for (auto i : g1[x])
                if (i != son[x] && i != fa[x]) {
                    top[i] = i, DFS(i);
                    ls[x] += siz[i];
                    g[x][0] += f[i][0];
                    g[x][1] += std::max(f[i][1], f[i][0]);
                }
            f[x][0] = g[x][1] + std::max(f[son[x]][0], f[son[x]][1]);
            f[x][1] = g[x][0] + f[son[x]][0];
        }
        else
            f[x][1] = g[x][0] = a[x], bot[x] = x;
        if (x == top[x])
            bld(rt[x], dfn[x], dfn[bot[x]]);
        return;
    };
    top[1] = 1, DFS(1);
    for (int x, v, la = 0; q--; ) {
        read(x), read(v), x ^= la;
        g[x][0] -= a[x], g[x][0] += v, a[x] = v;
        for (; top[x] != 1; ) {
            auto r = t[rt[top[x]]].u;
            f[top[x]][0] = r[0][0], f[top[x]][1] = r[0][1];
            g[fa[top[x]]][0] -= f[top[x]][0];
            g[fa[top[x]]][1] -= std::max(f[top[x]][0], f[top[x]][1]);
            add(rt[top[x]], dfn[x]);
            r = t[rt[top[x]]].u;
            f[top[x]][0] = r[0][0], f[top[x]][1] = r[0][1];
            g[fa[top[x]]][0] += f[top[x]][0];
            g[fa[top[x]]][1] += std::max(f[top[x]][0], f[top[x]][1]);
            x = fa[top[x]];
        }
        add(rt[1], dfn[x]);
        auto r =  t[rt[1]].u;
        f[1][0] = r[0][0], f[1][1] = r[0][1];
        std::cout << (la = std::max(f[1][0], f[1][1])) << '\n';
    }
#ifndef ONLINE_JUDGE
    std::cerr << std::fixed << std::setprecision(6) << std::chrono::duration<double> (std::chrono::steady_clock::now() - stime).count() << "s\n";
#endif
    return 0;
}

树上路径问题:染色

https://www.luogu.com.cn/problem/P2486

GBT 上的路径问题

BST:首先暴力爬山到 LCA,统计一路上的答案;然后跳到当前 BST 的根。由于本题有标记,所以需要在跳的同时想办法把标记问题解决一下。

线段树:把树剖的线段树略改一下就过了。从上一行模棱两可的描述就可以看出来 BST 实现起来不太轻松;还是线段树更轮椅啊!

跑得没纯树剖快,因为数据没有刻意构造导致树高很低,再加上常数的影响吧。

#include <bits/stdc++.h>
const int maxn = 1e5 + 5;
const int LEN = (1 << 20);
int nec(void) {
    static char buf[LEN], *p = buf, *e = buf;
    if (p == e) {
        e = buf + fread(buf, 1, LEN, stdin);
        if (e == buf) return EOF;
        p = buf;
    }
    return *p++;
}
bool read(int &x) {
    x = 0;
    bool f = 0;
    char ch = nec();
    while (ch < '0' || ch > '9') {
        if (ch == EOF) return 0;
        if (ch == '-') f = 1;
        ch = nec();
    }
    while (ch >= '0' && ch <= '9') {
        x = x * 10 + ch - '0';
        ch = nec();
    }
    if (f) x = -x;
    return 1;
}
void read(char &x) {
    for (x = nec(); x != 'C' && x != 'Q'; x = nec());
    return;
}
void print(int x) {
    if (x < 0)
        putchar('-'), x = -x;
    if (x >= 10) print(x / 10);
    putchar(x % 10 + '0');
    return;
}
void print(int x, char ch) {
    print(x), putchar(ch);
    return;
}
class node {
private:
    int lt, rt;
public:
    int l, r, mid, lc, rc, u, d;
    node(): u(-1) {}
    inline int& ls(void) { return lt; }
    inline int& rs(void) { return rt; }
    node& operator= (const node &q) {
        lc = q.lc, rc = q.rc, u = q.u;
        return *this;
    }
    node operator+ (const node &q) const {
        if (u == -1)
            return q;
        if (q.u == -1)
            return *this;
        node res;
        res.lc = lc, res.rc = q.rc, res.u = u + q.u - (rc == q.lc);
        return res;
    }
    inline void swap(void) {
        if (~u)
            lc ^= rc ^= lc ^= rc;
        return;
    }
} t[maxn << 2]; 
int tab[maxn], ls[maxn], a[maxn];
#define lt t[p].ls()
#define rt t[p].rs()
void bld(int &p, int l, int r) {
    static int tot = 0;
    p = ++tot, t[p].l = l, t[p].r = r;
    if (l == r) {
        t[p].u = 1;
        t[p].lc = t[p].rc = a[tab[l]];
        return;
    }
    int s = 0, u = ls[tab[l]];
    for (int i = l; i <= r; ++i)
        s += ls[tab[i]];
    s >>= 1;
    t[p].mid = r - 1;
    for (int i = l + 1; i < r; ++i) {
        u += ls[tab[i]];
        if (u + ls[tab[i + 1]] > s) {
            t[p].mid = i;
            break;
        }
    }
    bld(lt, l, t[p].mid), bld(rt, t[p].mid + 1, r);
    t[p] = t[lt] + t[rt];
    return;
}
void pushdown(int p) {
    if (t[p].d) {
        t[lt].d = t[lt].lc = t[lt].rc = t[rt].d = t[rt].lc = t[rt].rc = t[p].d;
        t[lt].u = t[rt].u = 1;
        t[p].d = 0;
    }
    return;
}
void add(int p, int l, int r, int v) {
    if (l <= t[p].l && t[p].r <= r) {
        t[p].u = 1;
        t[p].lc = t[p].rc = t[p].d = v;
        return;
    }
    pushdown(p);
    if (l <= t[p].mid)
        add(lt, l, r, v);
    if (r > t[p].mid)
        add(rt, l, r, v);
    t[p] = t[lt] + t[rt];
    return;
}
node ask(int p, int l, int r) {
    if (l <= t[p].l && t[p].r <= r)
        return t[p];
    pushdown(p);
    if (r <= t[p].mid)
        return ask(lt, l, r);
    if (l > t[p].mid)
        return ask(rt, l, r);
    return ask(lt, l, r) + ask(rt, l, r);
}
#undef lt
#undef rt
int main() {
#ifdef ONLINE_JUDGE
#else
    std::freopen("paint17.in", "r", stdin);
    std::freopen(".out", "w", stdout);
    const auto stime = std::chrono::steady_clock::now();
#endif
    int n, q;
    read(n), read(q);
    std::vector<int> rt(n + 1);
    for (int i = 1; i <= n; ++i)
        read(a[i]);
    std::vector<std::vector<int> > g(n + 1);
    for (int i = 1, x, y; i < n; ++i) {
        read(x), read(y);
        g[x].push_back(y), g[y].push_back(x);
    }
    std::vector<int> siz(n + 1), son(n + 1), fa(n + 1), dep(n + 1);
    std::function<void(int)> DFS = [&](int x) {
        siz[x] = 1;
        for (auto i : g[x])
            if (i != fa[x]) {
                dep[i] = dep[x] + 1;
                fa[i] = x, DFS(i);
                siz[x] += siz[i];
                if (siz[i] > siz[son[x]])
                    son[x] = i;
            }
        return;
    };
    DFS(1);
    std::vector<std::array<int, 2> > f(n + 1);
    std::vector<int> top(n + 1), bot(n + 1), dfn(n + 1);
    DFS = [&](int x) {
        static int now = 0;
        dfn[x] = ++now, tab[now] = x, ls[x] = 1;
        if (son[x]) {
            top[son[x]] = top[x], DFS(son[x]);
            bot[x] = bot[son[x]];
        }
        else
            bot[x] = x;
        for (auto i : g[x])
            if (i != son[x] && i != fa[x]) {
                top[i] = i;
                DFS(i);
                ls[x] += siz[i];
            }
        if (x == top[x])
            bld(rt[x], dfn[x], dfn[bot[x]]);
        return;
    };
    top[1] = 1, DFS(1);
    for (char op; q--; ) {
        read(op);
        if (op == 'C') {
            int x, y, c;
            read(x), read(y), read(c);
            for (; top[x] != top[y]; x = fa[top[x]]) {
                if (dep[top[x]] < dep[top[y]])
                    std::swap(x, y);
                add(rt[top[x]], dfn[top[x]], dfn[x], c);
            }
            if (dep[x] > dep[y])
                std::swap(x, y);
            add(rt[top[x]], dfn[x], dfn[y], c);
        }
        else {
            int x, y;
            read(x), read(y);
            node res1, res2;
            for (; top[x] != top[y]; )
                if (dep[top[x]] < dep[top[y]]) {
                    res2 = ask(rt[top[y]], dfn[top[y]], dfn[y]) + res2;
                    y = fa[top[y]];
                }
                else {
                    res1 = ask(rt[top[x]], dfn[top[x]], dfn[x]) + res1;
                    x = fa[top[x]];
                }
            if (dep[x] > dep[y])
                res1 = ask(rt[top[y]], dfn[y], dfn[x]) + res1;
            else
                res2 = ask(rt[top[x]], dfn[x], dfn[y]) + res2;
            res1.swap();
            print((res1 + res2).u, '\n');
        }
    }
#ifndef ONLINE_JUDGE
    std::cerr << std::fixed << std::setprecision(6) << std::chrono::duration<double> (std::chrono::steady_clock::now() - stime).count() << "s\n";
#endif
    return 0;
}

树上路径查询?

做题的时候可以感受到,路径询问的存在很诡异,因为只需要一条链的操作,更偏向链上 DDP 而非树上的;轻儿子的信息是不用维护的,形式上更像树剖(DS);当成链上的 DDP 就可以解决了。

可以用倍增维护 DDP(不用考虑轻重儿子,只用维护父子关系,进一步向序列 DDP 靠近),就不用打 GBT 了,常数也会小一些。


习题

GBT 就统一用线段树了。DDP 也就可以顺带用 GBT 优化了。


E - 猫或狗 / Cats or Dogs

https://www.luogu.com.cn/problem/P9597

\(f_{u,0/1}\) 表示从根上颜色为 \(0/1\) 时的最小断边数,显然无色可以视作任选一个颜色。那么有:

\[ f_{u,a}=\sum\limits_v\min\limits_{b\in\{0,1\}}\{f_{v,b}+[a\ne b]\} \]

直接把轻儿子的项提出来,记 \(g_{u,0/1}\) 表示 \(f_{u,0/1}\) 对应的轻儿子贡献即可。

#include <bits/stdc++.h>
const int maxn = 1e5 + 5;
const int inf = 0x3f3f3f3f;
struct mat {
    int a[2][2];
    int* operator[] (const int q) { 
        return a[q];
    }
    mat operator* (mat &q) const {
        mat res;
        res[0][0] = res[0][1] = res[1][0] = res[1][1] = inf;
        for (int i = 0; i < 2; ++i)
            for (int k = 0; k < 2; ++k)
                for (int j = 0; j < 2; ++j)
                    res[i][k] = std::min(res[i][k], a[i][j] + q[j][k]);
        return res;
    }
    mat& operator*= (mat &q) {
        return *this = *this * q;
    }
} p[maxn];
struct { int lc, rc, l, r, mid; mat u; } t[maxn << 2]; 
int ls[maxn], tab[maxn], g[maxn][2], rt[maxn], top[maxn], dfn[maxn], fa[maxn];
std::vector<int> a;
#define lt t[p].lc
#define rt t[p].rc
void bld(int &p, int l, int r) {
    static int tot = 0;
    p = ++tot, t[p].l = l, t[p].r = r;
    if (l == r) {
        t[p].u[0][1] = t[p].u[1][0] = 1;
        return;
    }
    int s = 0, u = ls[tab[l]];
    for (int i = l; i <= r; ++i)
        s += ls[tab[i]];
    s >>= 1;
    t[p].mid = r - 1;
    for (int i = l + 1; i < r; ++i) {
        u += ls[tab[i]];
        if (u + ls[tab[i + 1]] > s) {
            t[p].mid = i;
            break;
        }
    }
    bld(lt, l, t[p].mid), bld(rt, t[p].mid + 1, r);
    t[p].u = t[rt].u * t[lt].u;
    return;
}
void add(int p, int x) {
    if (t[p].l == t[p].r) {
        int u = tab[x], g0 = (a[u] == 1 ? inf : g[u][0]), g1 = (a[u] == 0 ? inf : g[u][1]);
        t[p].u[0][0] = g0, t[p].u[0][1] = g1 + 1;
        t[p].u[1][0] = g0 + 1, t[p].u[1][1] = g1;
        return;
    }
    if (x <= t[p].mid)
        add(lt, x);
    else
        add(rt, x);
    t[p].u = t[rt].u * t[lt].u;
    return;
}
#undef lt
#undef rt
void initialize(int n, std::vector<int> A, std::vector<int> B) {
    a.assign(n + 1, -1);
    std::vector<std::vector<int> > g1(n + 1);
    for (int i = 0; i < n - 1; ++i)
        g1[A[i]].push_back(B[i]), g1[B[i]].push_back(A[i]);
    std::vector<int> siz(n + 1), son(n + 1);
    std::function<void(int)> DFS = [&](int x) {
        siz[x] = 1;
        for (auto i : g1[x])
            if (i != fa[x]) {
                fa[i] = x, DFS(i);
                siz[x] += siz[i];
                if (siz[i] > siz[son[x]])
                    son[x] = i;
            }
        return;
    };
    DFS(1);
    std::vector<int> bot(n + 1);
    DFS = [&](int x) {
        static int now = 0;
        dfn[x] = ++now, tab[now] = x, ls[x] = 1;
        if (son[x]) {
            top[son[x]] = top[x], DFS(son[x]);
            bot[x] = bot[son[x]];
            for (auto i : g1[x])
                if (i != son[x] && i != fa[x]) {
                    top[i] = i, DFS(i);
                    ls[x] += siz[i];
                }
        }
        else
            bot[x] = x;
        if (x == top[x])
            bld(rt[x], dfn[x], dfn[bot[x]]);
        return;
    };
    top[1] = 1, DFS(1);
    return;
}
int upd(int x) {
    for (; top[x] != 1; ) {
        int faa = fa[top[x]];
        auto &id = t[rt[top[x]]].u;
        int f0 = std::min(id[0][0], id[1][0]), f1 = std::min(id[0][1], id[1][1]);
        g[faa][0] -= std::min(f0, f1 + 1);
        g[faa][1] -= std::min(f0 + 1, f1);
        add(rt[top[x]], dfn[x]);
        f0 = std::min(id[0][0], id[1][0]), f1 = std::min(id[0][1], id[1][1]);
        // printf("%d: f0 = %d, f1 = %d\n", top[x], f0, f1);
        g[faa][0] += std::min(f0, f1 + 1);
        g[faa][1] += std::min(f0 + 1, f1);
        x = faa;
    }
    add(rt[1], dfn[x]);
    auto &id = t[rt[1]].u;
    int f0 = std::min(id[0][0], id[1][0]), f1 = std::min(id[0][1], id[1][1]);
    // printf("%d: f0 = %d, f1 = %d\n", 1, f0, f1);
    return std::min(f0, f1);
}
int cat(int x) {
    a[x] = 0;
    return upd(x);
}
int dog(int x) {
    a[x] = 1;
    return upd(x);
}
int neighbor(int x) {
    a[x] = -1;
    return upd(x);
}
#ifndef ONLINE_JUDGE
int main() {
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
    const auto stime = std::chrono::steady_clock::now();
    int n;
    std::cin >> n;
    std::vector<int> A(n - 1), B(n - 1);
    for (int i = 0; i < n - 1; ++i)
        std::cin >> A[i] >> B[i];
    initialize(n, A, B);
    int q;
    std::cin >> q;
    for (int op, x; q--; ) {
        std::cin >> op >> x;
        std::cout << (op == 1 ? cat(x) : (op == 2 ? dog(x) : neighbor(x))) << '\n';
    }
    std::cerr << std::fixed << std::setprecision(6) << std::chrono::duration<double> (std::chrono::steady_clock::now() - stime).count() << "s\n";
    return 0;
}
#endif

F - Hash on Tree

https://www.luogu.com.cn/problem/AT_abc351_g

\(g_x\) 为轻儿子的哈希值之积,则 \(f_u=A_u+f_{son}\cdot g_u\)

然后开一个常数维即可。快速幂的 log 省不掉,所以是双 log 的。

初值和修改都有可能为 \(0\),需要维护实际哈希值和去掉 \(0\) 的哈希值。

#include <bits/stdc++.h>
const int maxn = 1e6 + 5;
const int LEN = (1 << 20);
const int mod = 998244353;
int nec(void) {
    static char buf[LEN], *p = buf, *e = buf;
    if (p == e) {
        e = buf + fread(buf, 1, LEN, stdin);
        if (e == buf) return EOF;
        p = buf;
    }
    return *p++;
}
bool read(int &x) {
    x = 0;
    bool f = 0;
    char ch = nec();
    while (ch < '0' || ch > '9') {
        if (ch == EOF) return 0;
        if (ch == '-') f = 1;
        ch = nec();
    }
    while (ch >= '0' && ch <= '9') {
        x = x * 10 + ch - '0';
        ch = nec();
    }
    if (f) x = -x;
    return 1;
}
void print(int x) {
    if (x < 0)
        putchar('-'), x = -x;
    if (x >= 10) print(x / 10);
    putchar(x % 10 + '0');
    return;
}
void print(int x, char ch) {
    print(x), putchar(ch);
    return;
}
struct mat {
    int n, m;
    long long a[2][2];
    long long* operator[] (const int q) {
        return a[q];
    }
    mat operator* (mat &q) const {
        mat res;
        res.n = n, res.m = q.m;
        res[0][0] = res[0][1] = res[1][0] = res[1][1] = 0ll;
        for (int i = 0; i < n; ++i)
            for (int k = 0; k < q.m; ++k) {
                for (int j = 0; j < m; ++j)
                    res[i][k] += a[i][j] * q[j][k];
                res[i][k] %= mod;
            }
        return res;
    }
    mat& operator*= (mat &q) {
        return *this = *this * q;
    }
};
struct { int lc, rc, l, r, mid; mat u; } t[maxn << 2]; 
long long g[maxn];
int tab[maxn], ls[maxn], a[maxn];
#define lt t[p].lc
#define rt t[p].rc
void bld(int &p, int l, int r) {
    static int tot = 0;
    p = ++tot, t[p].l = l, t[p].r = r;
    if (l == r) {
        int u = tab[l];
        t[p].u.n = t[p].u.m = 2;
        t[p].u[0][0] = g[u];
        t[p].u[0][1] = 0ll;
        t[p].u[1][0] = a[u];
        t[p].u[1][1] = 1ll;
        return;
    }
    int s = 0, u = ls[tab[l]];
    for (int i = l; i <= r; ++i)
        s += ls[tab[i]];
    s >>= 1;
    t[p].mid = r - 1;
    for (int i = l + 1; i < r; ++i) {
        u += ls[tab[i]];
        if (u + ls[tab[i + 1]] > s) {
            t[p].mid = i;
            break;
        }
    }
    bld(lt, l, t[p].mid), bld(rt, t[p].mid + 1, r);
    t[p].u = t[rt].u * t[lt].u;
    return;
}
void add(int p, int x) {
    if (t[p].l == t[p].r) {
        int u = tab[x];
        t[p].u[0][0] = g[u];
        t[p].u[1][0] = a[u];
        return;
    }
    if (x <= t[p].mid)
        add(lt, x);
    else
        add(rt, x);
    t[p].u = t[rt].u * t[lt].u;
    return;
}
#undef lt
#undef rt
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
    const auto stime = std::chrono::steady_clock::now();
#endif
    int n, q;
    read(n), read(q);
    std::vector<int> rt(n + 1);
    std::vector<std::vector<int> > g1(n + 1);
    for (int i = 2, x; i <= n; ++i) {
        read(x);
        g1[i].push_back(x), g1[x].push_back(i);
    }
    for (int i = 1; i <= n; ++i)
        read(a[i]);
    std::vector<int> siz(n + 1), son(n + 1), fa(n + 1);
    std::function<void(int)> DFS = [&](int x) {
        siz[x] = 1;
        for (auto i : g1[x])
            if (i != fa[x]) {
                fa[i] = x, DFS(i);
                siz[x] += siz[i];
                if (siz[i] > siz[son[x]])
                    son[x] = i;
            }
        return;
    };
    DFS(1);
    std::vector<long long> f(n + 1);
    std::vector<int> top(n + 1), bot(n + 1), dfn(n + 1), la(n + 1), cnt(n + 1);
    DFS = [&](int x) {
        static int now = 0;
        dfn[x] = ++now, tab[now] = x, ls[x] = 1;
        if (son[x]) {
            top[son[x]] = top[x], DFS(son[x]);
            bot[x] = bot[son[x]];
            g[x] = 1ll;
            for (auto i : g1[x])
                if (i != son[x] && i != fa[x]) {
                    top[i] = i, DFS(i);
                    ls[x] += siz[i];
                    if (f[i] == 0ll)
                        ++cnt[x];
                    else
                        (la[x] *= f[i]) %= mod;
                    (g[x] *= f[i]) %= mod;
                }
            f[x] = (a[x] + g[x] * f[son[x]]) % mod;
        }
        else
            f[x] = a[x], bot[x] = x;
        // printf("f[%d] = %lld\n", x, f[x]);
        if (x == top[x])
            bld(rt[x], dfn[x], dfn[bot[x]]);
        return;
    };
    top[1] = 1, DFS(1);
    mat init;
    init.n = 1, init.m = 2;
    init[0][0] = 0ll, init[0][1] = 1ll;
    auto qkp = [&](long long x, int y) {
        auto res(1ll);
        for (; y; (x *= x) %= mod, y >>= 1)
            if (y & 1)
                (res *= x) %= mod;
        return res;
    };
    auto inv = [&](int x) {
        return qkp(x, mod - 2);
    };
    for (int x, v; q--; ) {
        read(x), read(v);
        a[x] = v;
        for (; top[x] != 1; ) {
            int faa = fa[top[x]];
            auto &id = t[rt[top[x]]].u;
            int f = (init * id)[0][0];
            if (f == 0) {
                if (--cnt[faa] == 0)
                    g[faa] = la[faa];
            }
            else {
                f = inv(f);
                (g[faa] *= f) %= mod;
                (la[faa] *= f) %= mod;
            }
            add(rt[top[x]], dfn[x]);
            f = (init * id)[0][0];
            if (f == 0) {
                if (cnt[faa]++ == 0)
                    la[faa] = g[faa];
                g[faa] = 0ll;
            }
            else {
                (g[faa] *= f) %= mod;
                (la[faa] *= f) %= mod;
            }
            x = faa;
        }
        add(rt[1], dfn[x]);
        auto &id = t[rt[1]].u;
        int f = (init * id)[0][0];
        std::cout << f << '\n';
    }
#ifndef ONLINE_JUDGE
    std::cerr << std::fixed << std::setprecision(6) << std::chrono::duration<double> (std::chrono::steady_clock::now() - stime).count() << "s\n";
#endif
    return 0;
}

G - 考试 2

https://www.luogu.com.cn/problem/P10626

运算均是线性的,容易想到『按位』维护,即只维护某个特定点值。如果将询问离线下来排序,每个函数的点值只会变化 \(O(1)\) 次。

建出符号二叉树,进行 DDP 即可。具体地,令 \(f_u\) 表示在运算 \(u\) 处的答案,计算 \(g_u\) 为轻儿子的答案,按照 \(u\) 处的符号写矩阵即可。叶子不是一次运算,应该直接填入点值(注意线段树上的叶子不一定是原树上的叶子)。

建树和矩阵更新有点史,适当封装一下感觉会好一点

#include <bits/stdc++.h>
const int maxn = 1e6 + 5;
const char mp[] = "x!&|^";
const int LEN = (1 << 20);
#define nec getchar
inline bool read(int &x) {
    x = 0;
    bool f = 0;
    char ch = nec();
    while (ch < '0' || ch > '9') {
        if (ch == EOF) return 0;
        if (ch == '-') f = 1;
        ch = nec();
    }
    while (ch >= '0' && ch <= '9') {
        x = x * 10 + ch - '0';
        ch = nec();
    }
    if (f) x = -x;
    return 1;
}
void print(int x) {
    if (x < 0)
        putchar('-'), x = -x;
    if (x >= 10) print(x / 10);
    putchar(x % 10 + '0');
    return;
}
void print(int x, char ch) {
    print(x), putchar(ch);
    return;
}
struct mat {
    int n, m;
    int a[2][2];
    int* operator[] (const int q) {
        return a[q];
    }
    inline mat operator* (mat &q) const {
        mat res;
        res.n = n, res.m = q.m;
        res[0][0] = res[0][1] = res[1][0] = res[1][1] = 0ll;
        for (int i = 0; i < n; ++i)
            for (int k = 0; k < q.m; ++k)
                for (int j = 0; j < m; ++j)
                    res[i][k] += a[i][j] * q[j][k];
        return res;
    }
    mat& operator*= (mat &q) {
        return *this = *this * q;
    }
};
struct { int lc, rc, l, r, mid; mat u; } t[maxn << 2]; 
int g1[maxn][2];
int g[maxn], ty[maxn];
int tab[maxn], ls[maxn], lim[maxn];
#define lt t[p].lc
#define rt t[p].rc
void fillmat(mat &a, int op, int g) {
    if (op == 0) {
        a.n = 1, a.m = 2;
        a[0][!g] = 0, a[0][g] = 1;
        return;
    }
    a.n = 2, a.m = 2;
    switch (op) {
    case 1: // !
        a[0][0] = 0, a[0][1] = 1;
        a[1][0] = 1, a[1][1] = 0;
        break;
    case 2: // &
        a[0][0] = 1, a[0][1] = 0;
        a[1][0] = !g, a[1][1] = g;
        break;
    case 3: // |
        a[0][0] = !g, a[0][1] = g;
        a[1][0] = 0, a[1][1] = 1;
        break;
    case 4: // ^
        a[0][0] = !g, a[0][1] = g;
        a[1][0] = g, a[1][1] = !g;
        break;
    default:
        assert(0);
    }
    return;
}
void bld(int &p, int l, int r) {
    static int tot = 0;
    p = ++tot, t[p].l = l, t[p].r = r;
    if (l == r) {
        fillmat(t[p].u, ty[tab[l]], g[tab[l]]);
        return;
    }
    int s = 0, u = ls[tab[l]];
    for (int i = l; i <= r; ++i)
        s += ls[tab[i]];
    s >>= 1;
    t[p].mid = r - 1;
    for (int i = l + 1; i < r; ++i) {
        u += ls[tab[i]];
        if (u + ls[tab[i + 1]] > s) {
            t[p].mid = i;
            break;
        }
    }
    bld(lt, l, t[p].mid), bld(rt, t[p].mid + 1, r);
    t[p].u = t[rt].u * t[lt].u;
    return;
}
void add(int p, int x) {
    if (t[p].l == t[p].r) {
        fillmat(t[p].u, ty[tab[t[p].l]], g[tab[t[p].l]]);
        return;
    }
    if (x <= t[p].mid)
        add(lt, x);
    else
        add(rt, x);
    t[p].u = t[rt].u * t[lt].u;
    return;
}
#undef lt
#undef rt
int main() {
#ifndef ONLINE_JUDGE
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
    const auto stime = std::chrono::steady_clock::now();
#endif
    int n = 0, q, p;
    read(q), read(q);
    {
        std::string s;
        std::cin >> s;
        std::stringstream in(s);
        std::stack<char> op;
        std::stack<int> id;
        auto trans = [&](char t) {
            switch (t) {
            case '!':
                return 1;
            case '&':
                return 2;
            case '|':
                return 3;
            case '^':
                return 4;
            }
            assert(0);
            return -1;
        };
        auto calcNot = [&](void) {
            for (; !op.empty() && op.top() == '!'; ) {
                ty[++n] = trans(op.top());
                g1[n][0] = id.top(), id.pop(), id.push(n);
                op.pop();
            }
            return;
        };
        auto opt = [&](void) {
            ty[++n] = trans(op.top());
            g1[n][0] = id.top(), id.pop();
            g1[n][1] = id.top(), id.pop();
            id.push(n), op.pop();
            return;
        };
        for (char t; in >> t; )
            if (t == '[') {
                int x;
                in >> x, in >> t;
                lim[++n] = x - 1;
                id.push(n), calcNot();
            }
            else if (t == '(')
                op.push(t);
            else if (t == ')') {
                for (; op.top() != '('; opt());
                op.pop(), calcNot();
            }
            else if (t == '&')
                op.push(t);
            else if (t == '^') {
                for (; !op.empty() && op.top() == '&'; opt());
                op.push(t);
            }
            else if (t == '|') {
                for (; !op.empty() && (op.top() == '&' || op.top() == '^'); opt());
                op.push(t);
            }
            else {
                assert(t == '!');
                op.push(t);
            }
        for (; !op.empty(); opt());
        p = id.top();
    }
    std::vector<int> rt(n + 1);
    std::vector<int> siz(n + 1), son(n + 1), fa(n + 1);
    std::function<void(int)> DFS = [&](int x) {
        siz[x] = 1;
        for (auto i : g1[x])
            if (i != 0) {
                fa[i] = x, DFS(i);
                siz[x] += siz[i];
                if (siz[i] > siz[son[x]])
                    son[x] = i;
            }
        return;
    };
    DFS(p);
    std::vector<long long> f(n + 1);
    std::vector<int> top(n + 1), bot(n + 1), dfn(n + 1);
    DFS = [&](int x) {
        static int now = 0;
        dfn[x] = ++now, tab[now] = x, ls[x] = 1;
        if (son[x]) {
            top[son[x]] = top[x], DFS(son[x]);
            bot[x] = bot[son[x]];
            for (auto i : g1[x])
                if (i != son[x] && i != 0) {
                    top[i] = i, DFS(i);
                    ls[x] += siz[i], g[x] = f[i];
                }
            switch (ty[x]) {
            case 1:
                f[x] = !f[son[x]];
                break;
            case 2:
                f[x] = g[x] & f[son[x]];
                break;
            case 3:
                f[x] = g[x] | f[son[x]];
                break;
            case 4:
                f[x] = g[x] ^ f[son[x]];
                break;
            default:
                assert(0);
                break;
            }
        }
        else
            assert(!ty[x]), f[x] = g[x] = 0, bot[x] = x;
        if (x == top[x])
            bld(rt[x], dfn[x], dfn[bot[x]]);
        return;
    };
    top[p] = p, DFS(p);
    std::vector<int> a(q + 1), res(q + 1);
    for (int i = 1; i <= q; ++i)
        read(a[i]);
    std::vector<int> qid(q), nid;
    for (int i = 1; i <= n; ++i)
        if (ty[i] == 0)
            nid.push_back(i);
    std::iota(qid.begin(), qid.end(), 1);
    std::sort(qid.begin(), qid.end(), [&](int x, int y) { return a[x] < a[y]; });
    std::sort(nid.begin(), nid.end(), [&](int x, int y) { return lim[x] > lim[y]; });
    for (auto i : qid) {
        for (; !nid.empty() && lim[nid.back()] < a[i]; ) {
            int x = nid.back();
            nid.pop_back();
            g[x] = 1;
            for (; top[x] != p; ) {
                add(rt[top[x]], dfn[x]);
                g[fa[top[x]]] = t[rt[top[x]]].u[0][1];
                x = fa[top[x]];
            }
            add(rt[p], dfn[x]);
        }
        res[i] = t[rt[p]].u[0][1];
    }
    for (int i = 1; i <= q; ++i)
        if (res[i])
            std::cout << "True\n";
        else
            std::cout << "False\n";
#ifndef ONLINE_JUDGE
    std::cerr << std::fixed << std::setprecision(6) << std::chrono::duration<double> (std::chrono::steady_clock::now() - stime).count() << "s\n";
#endif
    return 0;
}

还口胡了另一个做法:最后的数组由若干连续段组成;段数不超过操作次数。故离散化后线段树维护分段函数即可。


言论