凡所有相,皆是虚妄①。
注释:①虚妄:犹言虚树是狂妄的算法。
定义
给定一个大小为 \(n\) 的树和树上 \(k\) 个关键点。取出这 \(k\) 个关键点和它们任意两个间的 LCA 作为虚树的点集,按照原树上的祖孙关系连边得到虚树。
求虚树
按照 DFN 排序,获取任意相邻两点 LCA,即可生成虚树点集。将点集按 DFN 排序后连边,复杂度为 \(O(k\log k)\)。
至少两个实点对应一个虚点,故而虚树大小为 \(O(k)\)。
用途
注意到无论是构建还是遍历虚树复杂度都与 \(n\) 无关。因而适用于对 \(\sum k\) 有限制的题目。
B - Leaf Color
https://atcoder.jp/contests/abc340/tasks/abc340_g
枚举所有颜色,每次对该颜色对应的所有点建立虚树,发现不能选虚树外的其他点,虚树上 DP 即可。
注意根有可能是叶子。需要特判一下。
#include <bits/stdc++.h>
const int mod = 998244353;
int main() {
#ifdef ONLINE_JUDGE
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
std::freopen(".in", "r", stdin);
std::freopen(".out", "w", stdout);
#endif
int n;
std::cin >> n;
std::vector<int> a(n + 1);
std::vector<std::vector<int> > g1(n + 1), t(n + 1), g(n + 1);
for (int i = 1; i <= n; ++i)
std::cin >> a[i], t[a[i]].push_back(i);
for (int i = 1, x, y; i < n; ++i) {
std::cin >> x >> y;
g1[x].push_back(y), g1[y].push_back(x);
}
std::vector<int> siz(n + 1), son(n + 1), dep(n + 1), fa(n + 1);
std::function<void(int, int)> DFS = [&](int x, int faa) {
siz[x] = 1;
for (auto i : g1[x])
if (i != faa) {
dep[i] = dep[x] + 1;
fa[i] = x;
DFS(i, x);
if (siz[i] > siz[son[x]])
son[x] = i;
}
return;
};
DFS(1, -1);
std::vector<int> dfn(n + 1), rfn(n + 1), top(n + 1);
DFS = [&](int x, int fa) {
static int now = 0;
dfn[x] = ++now;
if (son[x])
top[son[x]] = top[x], DFS(son[x], x);
for (auto i : g1[x])
if (i != fa && i != son[x])
top[i] = i, DFS(i, x);
rfn[x] = now;
return;
};
top[1] = 1, DFS(1, -1);
auto getLCA = [&](int x, int y) {
for (; top[x] != top[y]; x = fa[top[x]])
if (dep[top[x]] < dep[top[y]])
std::swap(x, y);
return dep[x] < dep[y] ? x : y;
};
std::vector<int> tag(n + 1);
std::vector<long long> f(n + 1);
auto res(0ll);
for (int k = 1; k <= n; ++k)
if (!t[k].empty()) {
std::sort(t[k].begin(), t[k].end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
std::vector<int> p;
for (auto i : t[k])
tag[i] = 1, p.push_back(i);
for (int i = 1; i < (int)t[k].size(); ++i) {
int fa = getLCA(t[k][i - 1], t[k][i]);
if (!tag[fa])
tag[fa] = 1, p.push_back(fa);
}
std::sort(p.begin(), p.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
std::vector<int> tmp;
for (auto i : p) {
if (!tmp.empty()) {
for (; rfn[tmp.back()] < dfn[i]; tmp.pop_back());
g[tmp.back()].push_back(i);
}
tmp.push_back(i);
}
std::function<void(int)> DFS = [&](int x) {
f[x] = 1ll;
auto s(1ll);
for (auto i : g[x]) {
DFS(i), (s += f[i]) %= mod;
(f[x] *= f[i] + 1) %= mod;
}
if (a[x] != k)
(f[x] += mod - s) %= mod;
// printf("color = %d, res += f[%d](%lld)\n", k, x, f[x]);
(res += f[x]) %= mod;
if (a[x] != k)
--s, (f[x] += s) %= mod;
return;
};
DFS(p.front());
for (auto i : p) {
tag[i] = 0, f[i] = 0ll;
std::vector<int>().swap(g[i]);
}
}
std::cout << res << '\n';
return 0;
}
C - Watching Cowflix P
https://www.luogu.com.cn/problem/P9132
会想到钦定 \(k\) 再来做。发现任意情况下都有:假如两个连通块距离 \(\le k\),那么合并起来不劣。所以把距离 \(\le k\) 的所有点都合并起来发现只剩下 \(O(\frac nk)\) 个点了,想到用虚树。
然后虚树上枚举点选不选,DP 一下就完了。
但是实现起来好史啊。合并需要用并查集维护父亲(而非本身),特别打脑壳。
我的天哪滔天巨史。
#include <bits/stdc++.h>
const int inf = 0x3f3f3f3f;
int main() {
#ifdef ONLINE_JUDGE
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
std::freopen(".in", "r", stdin);
std::freopen(".out", "w", stdout);
#endif
int n, rt;
std::cin >> n;
std::vector<int> p, tag(n + 1);
for (int i = 1; i <= n; ++i) {
char t;
std::cin >> t;
if (t == '1')
p.push_back(i), tag[i] = 1;
}
std::vector<std::vector<int> > g(n + 1), g1(n + 1);
for (int i = 1, x, y; i < n; ++i) {
std::cin >> x >> y;
g1[x].push_back(y), g1[y].push_back(x);
}
std::set<int> st;
std::vector<int> to(n + 1), len(n + 1), cnt(n + 1), dfn(n + 1);
{
std::vector<int> siz(n + 1), son(n + 1), fa(n + 1), dep(n + 1);
std::function<void(int, int)> DFS = [&](int x, int faa) {
siz[x] = 1;
for (auto i : g1[x])
if (i != faa) {
dep[i] = dep[x] + 1;
fa[i] = x, DFS(i, x);
siz[x] += siz[i];
if (siz[i] > siz[son[x]])
son[x] = i;
}
return;
};
DFS(1, -1);
std::vector<int> rfn(n + 1), top(n + 1);
DFS = [&](int x, int fa) {
static int now = 0;
dfn[x] = ++now;
if (son[x])
top[son[x]] = top[x], DFS(son[x], x);
for (auto i : g1[x])
if (i != son[x] && i != fa)
top[i] = i, DFS(i, x);
rfn[x] = now;
return;
};
DFS(1, -1);
auto getLCA = [&](int x, int y) {
for (; top[x] != top[y]; x = fa[top[x]])
if (dep[top[x]] < dep[top[y]])
std::swap(x, y);
return dep[x] < dep[y] ? x : y;
};
std::sort(p.begin(), p.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
std::vector<int> vis(tag), t(p);
for (int i = 1; i < (int)p.size(); ++i) {
int fa = getLCA(p[i - 1], p[i]);
if (!vis[fa])
vis[fa] = 1, t.push_back(fa);
}
std::sort(t.begin(), t.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
rt = t.front();
std::vector<int> stk;
for (auto i : t) {
if (!stk.empty()) {
for (; rfn[stk.back()] < dfn[i]; stk.pop_back());
to[i] = stk.back(), len[i] = dep[i] - dep[to[i]] - 1;
}
st.insert(i), stk.push_back(i);
}
}
std::vector<int> fa(n + 1), siz(n + 1, 1);
std::iota(fa.begin() + 1, fa.end(), 1);
std::function<int(int)> find = [&](int x) {
return x == fa[x] ? x : fa[x] = find(fa[x]);
};
std::vector<std::array<int, 2> > f(n + 1);
std::function<void(int, int)> DFS = [&](int x, int k) {
if (tag[x])
f[x][0] = inf;
f[x][1] = siz[x] + k;
for (auto i : g[x]) {
DFS(i, k);
if (!tag[x])
f[x][0] += std::min(f[i][0], f[i][1]);
f[x][1] += std::min({ f[i][0], f[i][1], f[i][1] + len[i] - k });
}
return;
};
std::function<void(int, int, int)> DFS1 = [&](int x, bool flag, int k) {
for (auto i : g[x])
if (flag) {
if (f[i][0] <= std::min(f[i][1], f[i][1] + len[i] - k))
DFS1(i, 0, k);
else {
DFS1(i, 1, k);
if (f[i][1] + len[i] - k < f[i][1]) {
tag[x] |= tag[i];
siz[x] += siz[i] + len[i];
st.erase(i), fa[i] = x;
}
}
}
else {
if (f[i][0] <= f[i][1])
DFS1(i, 0, k);
else
DFS1(i, 1, k);
}
return;
};
for (int k = 1; k <= n; ++k) {
std::vector<int> p;
for (auto i : st) {
p.push_back(i);
if (to[i])
g[find(to[i])].push_back(i);
}
DFS(rt, k);
std::cout << std::min(f[rt][0], f[rt][1]) << '\n';
DFS1(rt, f[rt][1] <= f[rt][0], k);
for (auto i : st)
f[i][0] = f[i][1] = 0, std::vector<int>().swap(g[i]);
}
return 0;
}
D - Smuggling Marbles
https://atcoder.jp/contests/arc086/tasks/arc086_c
容易想到从贡献角度思考问题;那么每个点只与同深度的所有点存在竞争关系。
把每个深度的点拿出来建虚树,在虚树上跑 DP 即可。
#include <bits/stdc++.h>
const int mod = 1e9 + 7;
int main() {
#ifdef ONLINE_JUDGE
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
std::freopen(".in", "r", stdin);
std::freopen(".out", "w", stdout);
#endif
int n, m = 0;
std::cin >> n, ++n;
std::vector<std::vector<int> > g(n + 1), g1(n + 1);
for (int i = 2, x; i <= n; ++i)
std::cin >> x, g1[x + 1].push_back(i);
std::vector<int> top(n + 1), dep(n + 1), dfn(n + 1), rfn(n + 1), fa(n + 1);
{
std::vector<int> siz(n + 1), son(n + 1);
std::function<void(int)> DFS = [&](int x) {
siz[x] = 1;
m = std::max(m, dep[x]);
for (auto i : g1[x]) {
dep[i] = dep[x] + 1;
fa[i] = x, DFS(i);
siz[x] += siz[i];
if (siz[i] > siz[son[x]])
son[x] = i;
}
return;
};
dep[1] = 1, DFS(1);
DFS = [&](int x) {
static int now = 0;
dfn[x] = ++now;
if (son[x])
top[son[x]] = top[x], DFS(son[x]);
for (auto i : g1[x])
if (i != son[x])
top[i] = i, DFS(i);
rfn[x] = now;
return;
};
top[1] = 1, DFS(1);
}
auto getLCA = [&](int x, int y) {
for (; top[x] != top[y]; x = fa[top[x]])
if (dep[top[x]] < dep[top[y]])
std::swap(x, y);
return dep[x] < dep[y] ? x : y;
};
std::vector<std::vector<int> > _p(m + 1);
for (int i = 1; i <= n; ++i)
_p[dep[i]].push_back(i);
std::vector<int> tag(n + 1), flag(n + 1);
std::vector<std::array<long long, 2> > f(n + 1);
std::function<void(int)> DFS = [&](int x) {
f[x][0] = 1ll + flag[x], f[x][1] = flag[x];
auto s(1ll);
for (auto i : g[x]) {
DFS(i);
f[x][1] = (f[x][1] * f[i][0] + s * f[i][1]) % mod;
(f[x][0] *= f[i][0] + f[i][1]) %= mod;
(s *= f[i][0]) %= mod;
}
(f[x][0] += mod - f[x][1]) %= mod;
// printf("f[%d][0] = %lld, f[%d][1] = %lld\n", x, f[x][0], x, f[x][1]);
return;
};
auto qkp = [&](long long x, int y) {
auto res(1ll);
for (; y; (x *= x) %= mod, y >>= 1)
if (y & 1)
(res *= x) %= mod;
return res;
};
auto res(0ll);
for (int k = 1; k <= m; ++k) {
auto &p = _p[k], t(p);
for (auto i : p)
flag[i] = 1;
std::sort(p.begin(), p.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
for (int i = 1; i < (int)p.size(); ++i) {
int fa = getLCA(p[i - 1], p[i]);
if (!tag[fa])
tag[fa] = 1, t.push_back(fa);
}
std::sort(t.begin(), t.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
int rt = t.front();
std::vector<int> st;
for (auto i : t) {
if (!st.empty()) {
for (; rfn[st.back()] < dfn[i]; st.pop_back());
g[st.back()].push_back(i);
}
st.push_back(i);
}
DFS(rt);
(res += f[rt][1] * qkp(2ll, n - (int)p.size())) %= mod;
// std::cout << f[rt][1] * qkp(2ll, n - (int)p.size()) % mod << '\n';
for (auto i : t) {
tag[i] = flag[i] = 0;
std::vector<int>().swap(g[i]);
}
}
std::cout << res << '\n';
return 0;
}
E - 世界树
https://www.luogu.com.cn/problem/P3233
会想到在虚树上两次 DFS 找到离任意点最近的实点。具体地,第一次找下方,第二次尝试用上方更新。
接着发现对于虚树上的实点是好做的;对于实点的不在树上的儿子是好做的;接下来是虚点及其不在树上的儿子。
就要用到刚刚求的信息了。显然二分一下就可以了。说起来很简单,然而实际上写起来很苦恼
#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
std::freopen(".in", "r", stdin);
std::freopen(".out", "w", stdout);
#endif
int n, q;
std::cin >> n;
std::vector<std::vector<int> > g(n + 1), g1(n + 1);
for (int i = 1, x, y; i < n; ++i) {
std::cin >> x >> y;
g1[x].push_back(y), g1[y].push_back(x);
}
std::vector<std::array<int, 21> > fa(n + 1);
std::vector<int> siz(n + 1), top(n + 1), dep(n + 1), dfn(n + 1), rfn(n + 1);
{
std::vector<int> son(n + 1);
std::function<void(int, int)> DFS = [&](int x, int faa) {
siz[x] = 1;
for (auto i : g1[x])
if (i != faa) {
dep[i] = dep[x] + 1;
fa[i][0] = x;
for (int j = 1; j <= 20; ++j)
fa[i][j] = fa[fa[i][j - 1]][j - 1];
DFS(i, x);
siz[x] += siz[i];
if (siz[i] > siz[son[x]])
son[x] = i;
}
return;
};
DFS(1, -1);
int now = 0;
DFS = [&](int x, int fa) {
dfn[x] = ++now;
if (son[x])
top[son[x]] = top[x], DFS(son[x], x);
for (auto i : g1[x])
if (i != fa && i != son[x])
top[i] = i, DFS(i, x);
rfn[x] = now;
return;
};
top[1] = 1, DFS(1, -1);
}
auto getLCA = [&](int x, int y) {
for (; top[x] != top[y]; x = fa[top[x]][0])
if (dep[top[x]] < dep[top[y]])
std::swap(x, y);
return dep[x] < dep[y] ? x : y;
};
auto getfa = [&](int x, int p) {
for (int i = 20; ~i; --i)
if (p >= (1 << i))
x = fa[x][i], p -= (1 << i);
return x;
};
std::vector<int> to(n + 1);
std::vector<int> tag(n + 1), flag(n + 1), res(n + 1);
std::cin >> q;
auto dis = [&](int x, int y) {
return std::make_pair(dep[x] + dep[y] - 2 * dep[getLCA(x, y)], y);
};
std::function<void(int)> DFS1 = [&](int x) {
to[x] = (flag[x] ? x : -1);
for (auto i : g[x]) {
DFS1(i);
if (~to[i] && (to[x] == -1 || dis(x, to[i]) < dis(x, to[x])))
to[x] = to[i];
}
// printf("to[%d] = %d\n", x, to[x]);
return;
}, DFS2 = [&](int x) {
// printf("to[%d] = %d\n", x, to[x]);
for (auto i : g[x]) {
if (to[i] == -1 || dis(i, to[x]) < dis(i, to[i]))
to[i] = to[x];
DFS2(i);
}
return;
}, DFS3 = [&](int x) {
res[to[x]] += siz[x];
for (auto i : g[x]) {
res[to[x]] -= siz[getfa(i, dep[i] - dep[x] - 1)];
DFS3(i);
}
// printf("res[%d] = %d\n", x, res[x]);
return;
}, DFS4 = [&](int x) {
for (auto i : g[x]) {
if (to[x] == to[i])
res[to[x]] += siz[getfa(i, dep[i] - dep[x] - 1)] - siz[i];
else {
auto dx(dis(x, to[x])), di(dis(i, to[i]));
int at = -1, len = dep[i] - dep[x] - 1;
for (int l = 0, r = len, mid; l <= r; ) {
mid = (l + r) >> 1;
if ([&](auto dx, auto dy) {
dx.first += mid, dy.first += len - mid + 1;
return dx < dy;
} (dx, di))
at = mid, l = mid + 1;
else
r = mid - 1;
}
int fa = getfa(i, len - at);
res[to[x]] += siz[getfa(i, len)] - siz[fa];
res[to[i]] += siz[fa] - siz[i];
}
DFS4(i);
}
return;
};
for (int k; q--; ) {
std::cin >> k;
std::vector<int> p(k);
for (int i = 0; i < k; ++i)
std::cin >> p[i], tag[p[i]] = flag[p[i]] = 1;
auto org(p);
std::sort(p.begin(), p.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
std::vector<int> t(p);
for (int i = 1; i < (int)p.size(); ++i) {
int fa = getLCA(p[i - 1], p[i]);
if (!tag[fa])
tag[fa] = 1, t.push_back(fa);
}
if (!tag[1])
t.push_back(1), tag[1] = 1;
std::sort(t.begin(), t.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
std::vector<int> st;
for (auto i : t) {
if (!st.empty()) {
for (; rfn[st.back()] < dfn[i]; st.pop_back());
g[st.back()].push_back(i);
}
st.push_back(i);
}
DFS1(1), DFS2(1), DFS3(1), DFS4(1);
for (auto i : org)
std::cout << res[i] << ' ';
std::cout << '\n';
for (auto i : t) {
res[i] = 0;
tag[i] = flag[i] = 0;
std::vector<int>().swap(g[i]);
}
}
return 0;
}
F - 大工程
https://www.luogu.com.cn/problem/P4103
虚树上 DP 统计相关信息即可。
#include <bits/stdc++.h>
const long long inf = 0x3f3f3f3f;
int main() {
#ifdef ONLINE_JUDGE
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
std::freopen(".in", "r", stdin);
std::freopen(".out", "w", stdout);
#endif
int n, q;
std::cin >> n;
std::vector<std::vector<int> > g(n + 1), g1(n + 1);
for (int i = 1, x, y; i < n; ++i) {
std::cin >> x >> y;
g1[x].push_back(y), g1[y].push_back(x);
}
std::vector<int> top(n + 1), dep(n + 1), dfn(n + 1), rfn(n + 1), fa(n + 1);
{
std::vector<int> siz(n + 1), son(n + 1);
std::function<void(int, int)> DFS = [&](int x, int faa) {
siz[x] = 1;
for (auto i : g1[x])
if (i != faa) {
dep[i] = dep[x] + 1;
fa[i] = x, DFS(i, x);
siz[x] += siz[i];
if (siz[i] > siz[son[x]])
son[x] = i;
}
return;
};
DFS(1, -1);
int now = 0;
DFS = [&](int x, int fa) {
dfn[x] = ++now;
if (son[x])
top[son[x]] = top[x], DFS(son[x], x);
for (auto i : g1[x])
if (i != fa && i != son[x])
top[i] = i, DFS(i, x);
rfn[x] = now;
return;
};
top[1] = 1, DFS(1, -1);
}
auto getLCA = [&](int x, int y) {
for (; top[x] != top[y]; x = fa[top[x]])
if (dep[top[x]] < dep[top[y]])
std::swap(x, y);
return dep[x] < dep[y] ? x : y;
};
std::vector<long long> s(n + 1);
std::vector<int> mx(n + 1), mn(n + 1);
std::vector<int> tag(n + 1), siz(n + 1), flag(n + 1);
int rmx, rmn;
long long rs;
std::function<void(int)> DFS = [&](int x) {
if (flag[x]) {
siz[x] = 1;
mx[x] = mn[x] = s[x] = 0;
}
else {
siz[x] = s[x] = 0;
mn[x] = inf, mx[x] = -inf;
}
for (auto i : g[x]) {
DFS(i);
int len = dep[i] - dep[x];
rmx = std::max(rmx, mx[x] + mx[i] + len);
mx[x] = std::max(mx[x], mx[i] + len);
rmn = std::min(rmn, mn[x] + mn[i] + len);
mn[x] = std::min(mn[x], mn[i] + len);
rs += siz[x] * (s[i] + (long long)siz[i] * len) + siz[i] * s[x];
s[x] += s[i] + (long long)siz[i] * len;
siz[x] += siz[i];
// printf("%d -> %d, mx = %d, mn = %d, s = %lld\n", x, i, rmx, rmn, rs);
}
return;
};
std::cin >> q;
for (int k; q--; ) {
std::cin >> k;
std::vector<int> p(k);
for (int i = 0; i < k; ++i)
std::cin >> p[i], tag[p[i]] = flag[p[i]] = 1;
std::sort(p.begin(), p.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
std::vector<int> t(p);
for (int i = 1; i < (int)p.size(); ++i) {
int fa = getLCA(p[i - 1], p[i]);
if (!tag[fa])
tag[fa] = 1, t.push_back(fa);
}
std::sort(t.begin(), t.end(), [&](int x, int y) { return dfn[x] < dfn[y]; });
int rt = t.front();
std::vector<int> st;
for (auto i : t) {
if (!st.empty()) {
for (; rfn[st.back()] < dfn[i]; st.pop_back());
g[st.back()].push_back(i);
}
st.push_back(i);
}
rs = 0ll, rmx = -inf, rmn = inf;
DFS(rt);
std::cout << rs << ' ' << rmn << ' ' << rmx << '\n';
for (auto i : t) {
tag[i] = flag[i] = 0;
std::vector<int>().swap(g[i]);
}
}
return 0;
}