虚树
2025-06-03

凡所有相,皆是虚妄①。

注释:①虚妄:犹言虚树是狂妄的算法。


定义

给定一个大小为 \(n\) 的树和树上 \(k\) 个关键点。取出这 \(k\) 个关键点和它们任意两个间的 LCA 作为虚树的点集,按照原树上的祖孙关系连边得到虚树。


求虚树

按照 DFN 排序,获取任意相邻两点 LCA,即可生成虚树点集。将点集按 DFN 排序后连边,复杂度为 \(O(k\log k)\)

至少两个实点对应一个虚点,故而虚树大小为 \(O(k)\)


用途

注意到无论是构建还是遍历虚树复杂度都与 \(n\) 无关。因而适用于对 \(\sum k\) 有限制的题目。


B - Leaf Color

https://atcoder.jp/contests/abc340/tasks/abc340_g

枚举所有颜色,每次对该颜色对应的所有点建立虚树,发现不能选虚树外的其他点,虚树上 DP 即可。

注意根有可能是叶子。需要特判一下。

#include <bits/stdc++.h>
const int mod = 998244353;
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n;
    std::cin >> n;
    std::vector<int> a(n + 1);
    std::vector<std::vector<int> > g1(n + 1), t(n + 1), g(n + 1);
    for (int i = 1; i <= n; ++i)
        std::cin >> a[i], t[a[i]].push_back(i);
    for (int i = 1, x, y; i < n; ++i) {
        std::cin >> x >> y;
        g1[x].push_back(y), g1[y].push_back(x);
    }
    std::vector<int> siz(n + 1), son(n + 1), dep(n + 1), fa(n + 1);
    std::function<void(int, int)> DFS = [&](int x, int faa) {
        siz[x] = 1;
        for (auto i : g1[x])
            if (i != faa) {
                dep[i] = dep[x] + 1;
                fa[i] = x;
                DFS(i, x);
                if (siz[i] > siz[son[x]])
                    son[x] = i;
            }
        return;
    };
    DFS(1, -1);
    std::vector<int> dfn(n + 1), rfn(n + 1), top(n + 1);
    DFS = [&](int x, int fa) {
        static int now = 0;
        dfn[x] = ++now;
        if (son[x])
            top[son[x]] = top[x], DFS(son[x], x);
        for (auto i : g1[x])
            if (i != fa && i != son[x])
                top[i] = i, DFS(i, x);
        rfn[x] = now;
        return;
    };
    top[1] = 1, DFS(1, -1);
    auto getLCA = [&](int x, int y) {
        for (; top[x] != top[y]; x = fa[top[x]])
            if (dep[top[x]] < dep[top[y]])
                std::swap(x, y);
        return dep[x] < dep[y] ? x : y;
    };
    std::vector<int> tag(n + 1);
    std::vector<long long> f(n + 1);
    auto res(0ll);
    for (int k = 1; k <= n; ++k)
        if (!t[k].empty()) {
            std::sort(t[k].begin(), t[k].end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
            std::vector<int> p;
            for (auto i : t[k])
                tag[i] = 1, p.push_back(i);
            for (int i = 1; i < (int)t[k].size(); ++i) {
                int fa = getLCA(t[k][i - 1], t[k][i]);
                if (!tag[fa])
                    tag[fa] = 1, p.push_back(fa);
            }
            std::sort(p.begin(), p.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
            std::vector<int> tmp;
            for (auto i : p) {
                if (!tmp.empty()) {
                    for (; rfn[tmp.back()] < dfn[i]; tmp.pop_back());
                    g[tmp.back()].push_back(i);
                }
                tmp.push_back(i);
            }
            std::function<void(int)> DFS = [&](int x) {
                f[x] = 1ll;
                auto s(1ll);
                for (auto i : g[x]) {
                    DFS(i), (s += f[i]) %= mod;
                    (f[x] *= f[i] + 1) %= mod;
                }
                if (a[x] != k)
                    (f[x] += mod - s) %= mod;
                // printf("color = %d, res += f[%d](%lld)\n", k, x, f[x]);
                (res += f[x]) %= mod;
                if (a[x] != k)
                    --s, (f[x] += s) %= mod;
                return;
            };
            DFS(p.front());
            for (auto i : p) {
                tag[i] = 0, f[i] = 0ll;
                std::vector<int>().swap(g[i]);
            }
        }
    std::cout << res << '\n';
    return 0;
}

C - Watching Cowflix P

https://www.luogu.com.cn/problem/P9132

会想到钦定 \(k\) 再来做。发现任意情况下都有:假如两个连通块距离 \(\le k\),那么合并起来不劣。所以把距离 \(\le k\) 的所有点都合并起来发现只剩下 \(O(\frac nk)\) 个点了,想到用虚树。

然后虚树上枚举点选不选,DP 一下就完了。

但是实现起来好史啊。合并需要用并查集维护父亲(而非本身),特别打脑壳。

我的天哪滔天巨史。

#include <bits/stdc++.h>
const int inf = 0x3f3f3f3f;
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, rt;
    std::cin >> n;
    std::vector<int> p, tag(n + 1);
    for (int i = 1; i <= n; ++i) {
        char t;
        std::cin >> t;
        if (t == '1')
            p.push_back(i), tag[i] = 1;
    }
    std::vector<std::vector<int> > g(n + 1), g1(n + 1);
    for (int i = 1, x, y; i < n; ++i) {
        std::cin >> x >> y;
        g1[x].push_back(y), g1[y].push_back(x);
    }
    std::set<int> st;
    std::vector<int> to(n + 1), len(n + 1), cnt(n + 1), dfn(n + 1);
    {
        std::vector<int> siz(n + 1), son(n + 1), fa(n + 1), dep(n + 1);
        std::function<void(int, int)> DFS = [&](int x, int faa) {
            siz[x] = 1;
            for (auto i : g1[x])
                if (i != faa) {
                    dep[i] = dep[x] + 1;
                    fa[i] = x, DFS(i, x);
                    siz[x] += siz[i];
                    if (siz[i] > siz[son[x]])
                        son[x] = i;
                }
            return;
        };
        DFS(1, -1);
        std::vector<int> rfn(n + 1), top(n + 1);
        DFS = [&](int x, int fa) {
            static int now = 0;
            dfn[x] = ++now;
            if (son[x])
                top[son[x]] = top[x], DFS(son[x], x);
            for (auto i : g1[x])
                if (i != son[x] && i != fa)
                    top[i] = i, DFS(i, x);
            rfn[x] = now;
            return;
        };
        DFS(1, -1);
        auto getLCA = [&](int x, int y) {
            for (; top[x] != top[y]; x = fa[top[x]])
                if (dep[top[x]] < dep[top[y]])
                    std::swap(x, y);
            return dep[x] < dep[y] ? x : y;
        };
        std::sort(p.begin(), p.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
        std::vector<int> vis(tag), t(p);
        for (int i = 1; i < (int)p.size(); ++i) {
            int fa = getLCA(p[i - 1], p[i]);
            if (!vis[fa])
                vis[fa] = 1, t.push_back(fa);
        }
        std::sort(t.begin(), t.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
        rt = t.front();
        std::vector<int> stk;
        for (auto i : t) {
            if (!stk.empty()) {
                for (; rfn[stk.back()] < dfn[i]; stk.pop_back());
                to[i] = stk.back(), len[i] = dep[i] - dep[to[i]] - 1;
            }
            st.insert(i), stk.push_back(i);
        }
    }
    std::vector<int> fa(n + 1), siz(n + 1, 1);
    std::iota(fa.begin() + 1, fa.end(), 1);
    std::function<int(int)> find = [&](int x) {
        return x == fa[x] ? x : fa[x] = find(fa[x]);
    };
    std::vector<std::array<int, 2> > f(n + 1);
    std::function<void(int, int)> DFS = [&](int x, int k) {
        if (tag[x])
            f[x][0] = inf;
        f[x][1] = siz[x] + k;
        for (auto i : g[x]) {
            DFS(i, k);
            if (!tag[x])
                f[x][0] += std::min(f[i][0], f[i][1]);
            f[x][1] += std::min({ f[i][0], f[i][1], f[i][1] + len[i] - k });
        }
        return;
    };
    std::function<void(int, int, int)> DFS1 = [&](int x, bool flag, int k) {
        for (auto i : g[x])
            if (flag) {
                if (f[i][0] <= std::min(f[i][1], f[i][1] + len[i] - k))
                    DFS1(i, 0, k);
                else {
                    DFS1(i, 1, k);
                    if (f[i][1] + len[i] - k < f[i][1]) {
                        tag[x] |= tag[i];
                        siz[x] += siz[i] + len[i];
                        st.erase(i), fa[i] = x;
                    }
                }
            }
            else {
                if (f[i][0] <= f[i][1])
                    DFS1(i, 0, k);
                else
                    DFS1(i, 1, k);
            }
        return;
    };
    for (int k = 1; k <= n; ++k) {
        std::vector<int> p;
        for (auto i : st) {
            p.push_back(i);
            if (to[i])
                g[find(to[i])].push_back(i);
        }
        DFS(rt, k);
        std::cout << std::min(f[rt][0], f[rt][1]) << '\n';
        DFS1(rt, f[rt][1] <= f[rt][0], k);
        for (auto i : st)
            f[i][0] = f[i][1] = 0, std::vector<int>().swap(g[i]);
    }
    return 0;
}

D - Smuggling Marbles

https://atcoder.jp/contests/arc086/tasks/arc086_c

容易想到从贡献角度思考问题;那么每个点只与同深度的所有点存在竞争关系。

把每个深度的点拿出来建虚树,在虚树上跑 DP 即可。

#include <bits/stdc++.h>
const int mod = 1e9 + 7;
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, m = 0;
    std::cin >> n, ++n;
    std::vector<std::vector<int> > g(n + 1), g1(n + 1);
    for (int i = 2, x; i <= n; ++i)
        std::cin >> x, g1[x + 1].push_back(i);
    std::vector<int> top(n + 1), dep(n + 1), dfn(n + 1), rfn(n + 1), fa(n + 1);
    {
        std::vector<int> siz(n + 1), son(n + 1);
        std::function<void(int)> DFS = [&](int x) {
            siz[x] = 1;
            m = std::max(m, dep[x]);
            for (auto i : g1[x]) {
                dep[i] = dep[x] + 1;
                fa[i] = x, DFS(i);
                siz[x] += siz[i];
                if (siz[i] > siz[son[x]])
                    son[x] = i;
            }
            return;
        };
        dep[1] = 1, DFS(1);
        DFS = [&](int x) {
            static int now = 0;
            dfn[x] = ++now;
            if (son[x])
                top[son[x]] = top[x], DFS(son[x]);
            for (auto i : g1[x])
                if (i != son[x])
                    top[i] = i, DFS(i);
            rfn[x] = now;
            return;
        };
        top[1] = 1, DFS(1);
    }
    auto getLCA = [&](int x, int y) {
        for (; top[x] != top[y]; x = fa[top[x]])
            if (dep[top[x]] < dep[top[y]])
                std::swap(x, y);
        return dep[x] < dep[y] ? x : y;
    };
    std::vector<std::vector<int> > _p(m + 1);
    for (int i = 1; i <= n; ++i)
        _p[dep[i]].push_back(i);
    std::vector<int> tag(n + 1), flag(n + 1);
    std::vector<std::array<long long, 2> > f(n + 1);
    std::function<void(int)> DFS = [&](int x) {
        f[x][0] = 1ll + flag[x], f[x][1] = flag[x];
        auto s(1ll);
        for (auto i : g[x]) {
            DFS(i);
            f[x][1] = (f[x][1] * f[i][0] + s * f[i][1]) % mod;
            (f[x][0] *= f[i][0] + f[i][1]) %= mod;
            (s *= f[i][0]) %= mod;
        }
        (f[x][0] += mod - f[x][1]) %= mod;
        // printf("f[%d][0] = %lld, f[%d][1] = %lld\n", x, f[x][0], x, f[x][1]);
        return;
    };
    auto qkp = [&](long long x, int y) {
        auto res(1ll);
        for (; y; (x *= x) %= mod, y >>= 1)
            if (y & 1)
                (res *= x) %= mod;
        return res;
    };
    auto res(0ll);
    for (int k = 1; k <= m; ++k) {
        auto &p = _p[k], t(p);
        for (auto i : p)
            flag[i] = 1;
        std::sort(p.begin(), p.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
        for (int i = 1; i < (int)p.size(); ++i) {
            int fa = getLCA(p[i - 1], p[i]);
            if (!tag[fa])
                tag[fa] = 1, t.push_back(fa);
        }
        std::sort(t.begin(), t.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
        int rt = t.front();
        std::vector<int> st;
        for (auto i : t) {
            if (!st.empty()) {
                for (; rfn[st.back()] < dfn[i]; st.pop_back());
                g[st.back()].push_back(i);
            }
            st.push_back(i);
        }
        DFS(rt);
        (res += f[rt][1] * qkp(2ll, n - (int)p.size())) %= mod;
        // std::cout << f[rt][1] * qkp(2ll, n - (int)p.size()) % mod << '\n';
        for (auto i : t) {
            tag[i] = flag[i] = 0;
            std::vector<int>().swap(g[i]);
        }
    }
    std::cout << res << '\n';
    return 0;
}

E - 世界树

https://www.luogu.com.cn/problem/P3233

会想到在虚树上两次 DFS 找到离任意点最近的实点。具体地,第一次找下方,第二次尝试用上方更新。

接着发现对于虚树上的实点是好做的;对于实点的不在树上的儿子是好做的;接下来是虚点及其不在树上的儿子。

就要用到刚刚求的信息了。显然二分一下就可以了。说起来很简单,然而实际上写起来很苦恼

#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, q;
    std::cin >> n;
    std::vector<std::vector<int> > g(n + 1), g1(n + 1);
    for (int i = 1, x, y; i < n; ++i) {
        std::cin >> x >> y;
        g1[x].push_back(y), g1[y].push_back(x);
    }
    std::vector<std::array<int, 21> > fa(n + 1);
    std::vector<int> siz(n + 1), top(n + 1), dep(n + 1), dfn(n + 1), rfn(n + 1);
    {
        std::vector<int> son(n + 1);
        std::function<void(int, int)> DFS = [&](int x, int faa) {
            siz[x] = 1;
            for (auto i : g1[x])
                if (i != faa) {
                    dep[i] = dep[x] + 1;
                    fa[i][0] = x;
                    for (int j = 1; j <= 20; ++j)
                        fa[i][j] = fa[fa[i][j - 1]][j - 1];
                    DFS(i, x);
                    siz[x] += siz[i];
                    if (siz[i] > siz[son[x]])
                        son[x] = i;
                }
            return;
        };
        DFS(1, -1);
        int now = 0;
        DFS = [&](int x, int fa) {
            dfn[x] = ++now;
            if (son[x])
                top[son[x]] = top[x], DFS(son[x], x);
            for (auto i : g1[x])
                if (i != fa && i != son[x])
                    top[i] = i, DFS(i, x);
            rfn[x] = now;
            return;
        };
        top[1] = 1, DFS(1, -1);
    }
    auto getLCA = [&](int x, int y) {
        for (; top[x] != top[y]; x = fa[top[x]][0])
            if (dep[top[x]] < dep[top[y]])
                std::swap(x, y);
        return dep[x] < dep[y] ? x : y;
    };
    auto getfa = [&](int x, int p) {
        for (int i = 20; ~i; --i)
            if (p >= (1 << i))
                x = fa[x][i], p -= (1 << i);
        return x;
    };
    std::vector<int> to(n + 1);
    std::vector<int> tag(n + 1), flag(n + 1), res(n + 1);
    std::cin >> q;
    auto dis = [&](int x, int y) {
        return std::make_pair(dep[x] + dep[y] - 2 * dep[getLCA(x, y)], y);
    };
    std::function<void(int)> DFS1 = [&](int x) {
        to[x] = (flag[x] ? x : -1);
        for (auto i : g[x]) {
            DFS1(i);
            if (~to[i] && (to[x] == -1 || dis(x, to[i]) < dis(x, to[x])))
                to[x] = to[i];
        }
        // printf("to[%d] = %d\n", x, to[x]);
        return;
    }, DFS2 = [&](int x) {
        // printf("to[%d] = %d\n", x, to[x]);
        for (auto i : g[x]) {
            if (to[i] == -1 || dis(i, to[x]) < dis(i, to[i]))
                to[i] = to[x];
            DFS2(i);
        }
        return;
    }, DFS3 = [&](int x) {
        res[to[x]] += siz[x];
        for (auto i : g[x]) {
            res[to[x]] -= siz[getfa(i, dep[i] - dep[x] - 1)];
            DFS3(i);
        }
        // printf("res[%d] = %d\n", x, res[x]);
        return;
    }, DFS4 = [&](int x) {
        for (auto i : g[x]) {
            if (to[x] == to[i])
                res[to[x]] += siz[getfa(i, dep[i] - dep[x] - 1)] - siz[i];
            else {
                auto dx(dis(x, to[x])), di(dis(i, to[i]));
                int at = -1, len = dep[i] - dep[x] - 1;
                for (int l = 0, r = len, mid; l <= r; ) {
                    mid = (l + r) >> 1;
                    if ([&](auto dx, auto dy) {
                        dx.first += mid, dy.first += len - mid + 1;
                        return dx < dy;
                    } (dx, di))
                        at = mid, l = mid + 1;
                    else
                        r = mid - 1;
                }
                int fa = getfa(i, len - at);
                res[to[x]] += siz[getfa(i, len)] - siz[fa];
                res[to[i]] += siz[fa] - siz[i];
            }
            DFS4(i);
        }
        return;
    };
    for (int k; q--; ) {
        std::cin >> k;
        std::vector<int> p(k);
        for (int i = 0; i < k; ++i)
            std::cin >> p[i], tag[p[i]] = flag[p[i]] = 1;
        auto org(p);
        std::sort(p.begin(), p.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
        std::vector<int> t(p);
        for (int i = 1; i < (int)p.size(); ++i) {
            int fa = getLCA(p[i - 1], p[i]);
            if (!tag[fa])
                tag[fa] = 1, t.push_back(fa);
        }
        if (!tag[1])
            t.push_back(1), tag[1] = 1;
        std::sort(t.begin(), t.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
        std::vector<int> st;
        for (auto i : t) {
            if (!st.empty()) {
                for (; rfn[st.back()] < dfn[i]; st.pop_back());
                g[st.back()].push_back(i);
            }
            st.push_back(i);
        }
        DFS1(1), DFS2(1), DFS3(1), DFS4(1);
        for (auto i : org)
            std::cout << res[i] << ' ';
        std::cout << '\n';
        for (auto i : t) {
            res[i] = 0;
            tag[i] = flag[i] = 0;
            std::vector<int>().swap(g[i]);
        }
    }
    return 0;
}

F - 大工程

https://www.luogu.com.cn/problem/P4103

虚树上 DP 统计相关信息即可。

#include <bits/stdc++.h>
const long long inf = 0x3f3f3f3f;
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, q;
    std::cin >> n;
    std::vector<std::vector<int> > g(n + 1), g1(n + 1);
    for (int i = 1, x, y; i < n; ++i) {
        std::cin >> x >> y;
        g1[x].push_back(y), g1[y].push_back(x);
    }
    std::vector<int> top(n + 1), dep(n + 1), dfn(n + 1), rfn(n + 1), fa(n + 1);
    {
        std::vector<int> siz(n + 1), son(n + 1);
        std::function<void(int, int)> DFS = [&](int x, int faa) {
            siz[x] = 1;
            for (auto i : g1[x])
                if (i != faa) {
                    dep[i] = dep[x] + 1;
                    fa[i] = x, DFS(i, x);
                    siz[x] += siz[i];
                    if (siz[i] > siz[son[x]])
                        son[x] = i;
                }
            return;
        };
        DFS(1, -1);
        int now = 0;
        DFS = [&](int x, int fa) {
            dfn[x] = ++now;
            if (son[x])
                top[son[x]] = top[x], DFS(son[x], x);
            for (auto i : g1[x])
                if (i != fa && i != son[x])
                    top[i] = i, DFS(i, x);
            rfn[x] = now;
            return;
        };
        top[1] = 1, DFS(1, -1);
    }
    auto getLCA = [&](int x, int y) {
        for (; top[x] != top[y]; x = fa[top[x]])
            if (dep[top[x]] < dep[top[y]])
                std::swap(x, y);
        return dep[x] < dep[y] ? x : y;
    };
    std::vector<long long> s(n + 1);
    std::vector<int> mx(n + 1), mn(n + 1);
    std::vector<int> tag(n + 1), siz(n + 1), flag(n + 1);
    int rmx, rmn;
    long long rs;
    std::function<void(int)> DFS = [&](int x) {
        if (flag[x]) {
            siz[x] = 1;
            mx[x] = mn[x] = s[x] = 0;
        }
        else {
            siz[x] = s[x] = 0;
            mn[x] = inf, mx[x] = -inf;
        }
        for (auto i : g[x]) {
            DFS(i);
            int len = dep[i] - dep[x];
            rmx = std::max(rmx, mx[x] + mx[i] + len);
            mx[x] = std::max(mx[x], mx[i] + len);
            rmn = std::min(rmn, mn[x] + mn[i] + len);
            mn[x] = std::min(mn[x], mn[i] + len);
            rs += siz[x] * (s[i] + (long long)siz[i] * len) + siz[i] * s[x]; 
            s[x] += s[i] + (long long)siz[i] * len;
            siz[x] += siz[i];
            // printf("%d -> %d, mx = %d, mn = %d, s = %lld\n", x, i, rmx, rmn, rs);
        }
        return;
    };
    std::cin >> q;
    for (int k; q--; ) {
        std::cin >> k;
        std::vector<int> p(k);
        for (int i = 0; i < k; ++i)
            std::cin >> p[i], tag[p[i]] = flag[p[i]] = 1;
        std::sort(p.begin(), p.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
        std::vector<int> t(p);
        for (int i = 1; i < (int)p.size(); ++i) {
            int fa = getLCA(p[i - 1], p[i]);
            if (!tag[fa])
                tag[fa] = 1, t.push_back(fa);
        }
        std::sort(t.begin(), t.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
        int rt = t.front();
        std::vector<int> st;
        for (auto i : t) {
            if (!st.empty()) {
                for (; rfn[st.back()] < dfn[i]; st.pop_back());
                g[st.back()].push_back(i);
            }
            st.push_back(i);
        }
        rs = 0ll, rmx = -inf, rmn = inf;
        DFS(rt);
        std::cout << rs << ' ' << rmn << ' ' << rmx << '\n';
        for (auto i : t) {
            tag[i] = flag[i] = 0;
            std::vector<int>().swap(g[i]);
        }
    }
    return 0;
}

言论