学习笔记:FWT
2025-03-29

活了哥们,复活了。


考虑以下问题:

  • 对于 \(\forall \,i\in U\),求 \(c_i=\sum\limits_{j\cup k=i}a_j\cdot b_k\)

当然这里可以把集合看成二进制状态,那么取并集就是按位或了。

咕咕咕


Emiya 家明天的饭

https://www.luogu.com.cn/problem/P10242

冷知识:Emiya 家明天的饭Emiya 家今天的饭 拥有相同的难度评级 磕头

考虑暴力的做法,如果我们先钦定必须到达的人(这是 \(O(2^n)\) 的),再花费 \(O(nm)\) 的时间依次判定每道菜是否可以被选择,就可以 解决问题。但发现时间是不能承受的,这里我们选择优化 \(O(nm)\) 的判定。

现在已知人员集合 \(s\),对于一道菜 \(i\),设它适配的人员集合为 \(T_i\),那么有 \(s\subseteq T_i\)。所求即为 \(\max\limits_s\{\sum\limits_{T_i\supseteq s}\sum\limits_{j\in s}a_{i,j}\}\)\(O(nm)\) 预处理出 \(f_j(u)=\sum\limits_{T_i=u}a_{j, i}\),那么待求即为 \(\max\limits_{s}\{\sum\limits_{j\in s}\sum\limits_{u\supseteq s}f_j(u)\}\),发现这个东西可以用 FWT 求。

#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, m, k;
    std::cin >> k >> m, n = 1 << k;
    using arr = std::vector<long long>;
    std::vector<int> t(m + 1);
    std::vector<arr> a(k + 1, arr(m + 1)), f(k + 1, arr(n));
    for (int i = 1; i <= k; ++i)
        for (int j = 1; j <= m; ++j) {
            std::cin >> a[i][j];
            if (a[i][j] >= 0)
                t[j] |= 1 << (i - 1);
        }
    for (int i = 1; i <= k; ++i)
        for (int j = 1; j <= m; ++j)
            if (a[i][j] >= 0)
                f[i][t[j]] += a[i][j];
    {
        std::vector<arr> mT(2, arr(2));
        mT[0][0] = 1ll, mT[0][1] = 1ll, mT[1][0] = 0ll, mT[1][1] = 1ll;
        auto calc = [&](arr a, arr &f, std::vector<arr> &w) {
            f = a;
            for (int len = 2; len <= n; len <<= 1) {
                for (int i = 0; i < n; i += len)
                    for (int p = i, q = i + len / 2; q < i + len; ++p, ++q)
                        std::tie(f[p], f[q]) = std::make_tuple(f[p] * w[0][0] + f[q] * w[0][1], f[p] * w[1][0] + f[q] * w[1][1]);
            }
            return;
        };
        for (int i = 1; i <= k; ++i)
            calc(f[i], f[i], mT);
    }
    long long res = 0ll;
    for (int s = 0; s < n; ++s) {
        long long sum = 0ll;
        for (int i = 1; i <= k; ++i)
            if ((s >> (i - 1)) & 1)
                sum += f[i][s];
        res = std::max(res, sum);
    }
    std::cout << res << '\n';
    return 0;
}

Nim Counting

https://atcoder.jp/contests/abc212/tasks/abc212_h

即,从 \(A_N\) 中有放回地选择 \(\le M\) 个数,问它们异或起来不为 \(0\) 的方案数。

如果令 \(f_{i, j}\) 表示选了 \(i\) 次,异或和为 \(j\) 的方案数,显然 \(f_{1,i}=\sum [a_j=i]\) 为关于 \(a\) 的桶。此时有 \(f_{i,j}=\sum\limits_{k=1}^n f_{i-1,j\oplus a_k}=\sum\limits_{k=0}^V f_{i-1,j\oplus k}\cdot f_{1,k}\),发现把 \(f_1\) 这个桶在 \(f\) 上做 \(N\) 次 xor-FWT 就可以得到 \(f_n\)

但如果直接卷 \(N\) 次是 \(O(N\cdot V\log V)\) 的,不太美好,但我们看看我们实际上需要做什么:

  1. \(f_i\) 的 FWT。
  2. 求初始桶 \(f_1\) 的 FWT。
  3. 对位相乘得到 \(f_{i+1}\) 的 FWT。
  4. 通过 FWT 求得原本的 \(f_{i+1}\)

当这个操作被放在 \(i=1\sim n\) 上依次进行时,我们发现第一步和最后一步会相互抵消,我们只需要求出 \(f_1\) 的 FWT,\(FWT_{i, j}(f)\) 即为 \(FWT_{1, j}(f)^i\)。因为我们要求的是 \(\sum\limits_{i, j}f_{i,j}\) 可以通过等比数列求和求出 \(FWT_j(s)=\sum f_{i, j}\)。由前文推导可知直接做一次逆变换求得 \(s_j\) 即可。

#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    struct mint {
        const int mod = 998244353;
        long long x;
        mint(): x(0ll) {}
        mint(long long x1): x((x1 + mod) % mod) {}
        mint& operator= (const mint q) {
            x = q.x;
            return *this;
        }
        bool operator== (const mint q) const {
            return x == q.x;
        }
        mint operator* (const mint q) const {
            return x * q.x % mod;
        }
        mint& operator*= (const mint q) {
            return *this = *this * q;
        }
        mint operator+ (const mint q) {
            return (x + q.x) % mod;
        }
        mint& operator+= (const mint q) {
            return *this = *this + q;
        }
        mint operator- (const mint q) {
            return (x + mod - q.x) % mod;
        }
        mint qkp(int y) {
            mint res(1ll), x(this->x);
            for (; y; y >>= 1, x *= x)
                if (y & 1)
                    res *= x;
            return res;
        }
        mint inv(void) {
            return qkp(mod - 2);
        }
    };
    int n, m, k = 16, l = 1 << k;
    std::cin >> m >> n;
    using arr = std::vector<mint>;
    arr a(n + 1), c(l);
    for (int i = 1; i <= n; ++i)
        std::cin >> a[i].x, c[a[i].x] += 1;
    std::vector<arr> mT(2, arr(2)), mI(2, arr(2));
    mT[0][0] = 1ll, mT[0][1] = 1ll, mT[1][0] = 1ll, mT[1][1] = -1ll;
    mI[0][0] = mI[0][1] = mI[1][0] = mint(2ll).inv(), mI[1][1] = mint(-2ll).inv();
    auto calc = [&](arr a, arr &f, std::vector<arr> &w) {
        f = a;
        for (int len = 2; len <= l; len <<= 1)
            for (int i = 0; i < l; i += len)
                for (int p = i, q = i + len / 2; q < i + len; ++p, ++q)
                    std::tie(f[p], f[q]) = std::make_tuple(f[p] * w[0][0] + f[q] * w[0][1], f[p] * w[1][0] + f[q] * w[1][1]);
        return;
    };
    calc(c, c, mT);
    arr s(l);
    for (int i = 0; i < l; ++i)
        if (c[i] == 1ll)
            s[i] = m;
        else
            s[i] = c[i] * (mint(1ll) - c[i].qkp(m)) * (mint(1ll) - c[i]).inv();
    calc(s, s, mI);
    mint res;
    for (int i = 1; i < l; ++i)
        res += s[i];
    std::cout << res.x << '\n';
    return 0;
}

Binary Table

https://codeforces.com/problemset/problem/662/C

发现这个题和 Emiya 家今天的饭 很像,都是有一个很小的维和一个相对比较大的维。

显然,我们所有的操作顺序都可以任意调换;朴素地,我们枚举 \(2^n\) 种给这 \(n\) 行反转的情况;再对于每一列,\(O(n)\) 选择应该反转还是不反转,这样总共是 \(O(2^n\cdot nm)\) 的。

发现把一列初始状态压成一个二进制数 \(a\),假设我们现在枚举的行反转状态为 \(s\),显然用 \(s\oplus a\) 可以得到当前状态。怎么利用这个去 FWT 呢?这引导我们依然枚举 \(s\),用一个和 \(s\oplus a\) 有关的量跟一个和 \(a\) 有关的量相乘得到关于 \(s\) 的答案。

容易发现令 \(f_{s\oplus a}\) 表示 \(s\oplus a\) 这个状态反转和不反转两个选项中可以获取的最少 1 的个数;再令 \(c_s\) 表示状态为 \(s\) 的列的个数,那么 \(\sum\limits_s c_s\times f_{s\oplus a}\) 就可以得到枚举的反转方案为 \(s\) 的答案,求最小即可。

#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    struct mint {
        const int mod = 998244353;
        long long x;
        mint(): x(0ll) {}
        mint(long long x1): x((x1 + mod) % mod) {}
        mint& operator= (const mint q) {
            x = q.x;
            return *this;
        }
        bool operator== (const mint q) const {
            return x == q.x;
        }
        mint operator* (const mint q) const {
            return x * q.x % mod;
        }
        mint& operator*= (const mint q) {
            return *this = *this * q;
        }
        mint operator+ (const mint q) {
            return (x + q.x) % mod;
        }
        mint& operator+= (const mint q) {
            return *this = *this + q;
        }
        mint operator- (const mint q) {
            return (x + mod - q.x) % mod;
        }
        mint qkp(int y) {
            mint res(1ll), x(this->x);
            for (; y; y >>= 1, x *= x)
                if (y & 1)
                    res *= x;
            return res;
        }
        mint inv(void) {
            return qkp(mod - 2);
        }
    };
    int n, m, l;
    std::cin >> n >> m, l = 1 << n;
    using arr = std::vector<mint>;
    arr c(l), f(l);
    std::vector<std::vector<char> > a(n + 1, std::vector<char> (m + 1));
    for (int i = 1; i <= n; ++i)
        for (int j = 1; j <= m; ++j)
            std::cin >> a[i][j];
    for (int j = 1; j <= m; ++j) {
        int s = 0;
        for (int i = 1; i <= n; ++i)
            s = (s << 1) + a[i][j] - '0';
        c[s] += 1;
    }
    for (int i = 0, t; i < l; ++i)
        t = (__builtin_popcount(i)), f[i] = std::min(t, n - t);
    std::vector<arr> mT(2, arr(2)), mI(2, arr(2));
    mT[0][0] = 1ll, mT[0][1] = 1ll, mT[1][0] = 1ll, mT[1][1] = -1ll;
    mI[0][0] = mI[0][1] = mI[1][0] = mint(2ll).inv(), mI[1][1] = mint(-2ll).inv();
    auto calc = [&](arr a, arr &f, std::vector<arr> &w) {
        f = a;
        for (int len = 2; len <= l; len <<= 1)
            for (int i = 0; i < l; i += len)
                for (int p = i, q = i + len / 2; q < i + len; ++p, ++q)
                    std::tie(f[p], f[q]) = std::make_tuple(f[p] * w[0][0] + f[q] * w[0][1], f[p] * w[1][0] + f[q] * w[1][1]);
        return;
    };
    calc(c, c, mT), calc(f, f, mT);
    arr s(l);
    for (int i = 0; i < l; ++i)
        s[i] = c[i] * f[i];
    calc(s, s, mI);
    int res = 0x3f3f3f3f;
    for (int i = 0; i < l; ++i)
        res = std::min(res, (int)s[i].x);
    std::cout << res << '\n';
    return 0;
}

Hard Nim

https://hydro.ac/p/bzoj-P4589

Nim Counting 然后缝了个筛子。

#include <bits/stdc++.h>
const int mod = 1e9 + 7;
const int inv2 = 5e8 + 4;
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    std::vector<int> p, tag(50001);
    for (int i = 2; i <= 50000; ++i)
        if (!tag[i]) {
            p.push_back(i);
            for (int j = 2 * i; j <= 50000; j += i)
                tag[j] = 1;
        }
    auto qkp = [&](long long x, int y) {
        long long res = 1ll;
        for (; y; y >>= 1, (x *= x) %= mod)
            if (y & 1)
                (res *= x) %= mod;
        return res;
    };
    for (int n, m; std::cin >> n >> m; ) {
        int k = std::__lg(m) + 1, l = 1 << k;
        using arr = std::vector<long long>;
        arr c(l);
        for (auto i : p)
            if (i <= m)
                c[i] += 1;
            else
                break;
        std::vector<arr> mT(2, arr(2)), mI(2, arr(2));
        mT[0][0] = 1ll, mT[0][1] = 1ll, mT[1][0] = 1ll, mT[1][1] = mod - 1ll;
        mI[0][0] = mI[0][1] = mI[1][0] = inv2, mI[1][1] = mod - inv2;
        auto calc = [&](arr &f, std::vector<arr> &w) {
            for (int len = 2; len <= l; len <<= 1)
                for (int i = 0; i < l; i += len)
                    for (int p = i, q = i + len / 2; q < i + len; ++p, ++q)
                        std::tie(f[p], f[q]) = std::make_tuple((f[p] * w[0][0] + f[q] * w[0][1]) % mod, (f[p] * w[1][0] + f[q] * w[1][1]) % mod);
            return;
        };
        calc(c, mT);
        arr s(l);
        for (int i = 0; i < l; ++i)
            s[i] = qkp(c[i], n);
        calc(s, mI);
        std::cout << s[0] << '\n';
    }
    return 0;
}

子集卷积

https://www.luogu.com.cn/problem/P6097

这个是在许多 DP 中都可能见到过的结构,相信大家都深有体会。

等价于求 \(c_s=\sum\limits_{i\cap j=\varnothing\land i\cup j=s}a_i\cdot b_j\),首先关注 \(i\cup j=s\),可以用 FWT 解决;对于 \(i\cap j = \varnothing\) 呢?

有一个很聪明的办法,我们发现 \(i\cap j=\varnothing\land i\cup j = s\iff \operatorname{ctz}(i)+\operatorname{ctz}(j)=\operatorname{ctz}(s)\),其中 \(\operatorname{ctz}(i)\) 表示 \(i\)\(1\) 的个数,即 popcount(i)

显然后者是个简单的加法运算,这里又有一个很聪明又很基本的办法,我们令 \(A_{\operatorname{ctz}(i),i}=a_i,B_{\operatorname{ctz}(j),j}=b_j\),那么答案就转化为 \(\sum\limits_{i}\sum\limits_{j\cup k=s}A_{i,j}\cdot B_{\operatorname{ctz}(s)-i,j}\),就可以 \(O(n\log^2n)\) 地解决问题了。

#include <bits/stdc++.h>
const int mod = 1e9 + 9;
using arr = std::vector<long long>;
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n;
    std::cin >> n;
    int l = 1 << n;
    std::vector<arr> a(n + 1, arr(l)), b(n + 1, arr(l)), c(n + 1, arr(l));
    for (int i = 0, x; i < l; ++i)
        std::cin >> x, a[__builtin_popcount(i)][i] = x;
    for (int i = 0, x; i < l; ++i)
        std::cin >> x, b[__builtin_popcount(i)][i] = x;
    std::vector<arr> mT(2, arr(2)), mI(2, arr(2));
    mT[0][0] = 1ll, mT[0][1] = 0ll, mT[1][0] = 1ll, mT[1][1] = 1ll;
    mI[0][0] = 1ll, mI[0][1] = 0ll, mI[1][0] = mod - 1ll, mI[1][1] = 1ll;
    auto calc = [&](arr &f, std::vector<arr> &w) {
        for (int len = 2; len <= l; len <<= 1)
            for (int i = 0; i < l; i += len)
                for (int p = i, q = i + len / 2; q < i + len; ++p, ++q)
                    std::tie(f[p], f[q]) = std::make_tuple((f[p] * w[0][0] + f[q] * w[0][1]) % mod, (f[p] * w[1][0] + f[q] * w[1][1]) % mod);
        return;
    };
    for (int i = 0; i <= n; ++i)
        calc(a[i], mT), calc(b[i], mT);
    for (int i = 0; i <= n; ++i) {
        for (int k = 0; k <= i; ++k)
            for (int j = 0; j < l; ++j)
                (c[i][j] += a[k][j] * b[i - k][j]) %= mod;
        calc(c[i], mI);
    }
    for (int j = 0; j < l; ++j)
        std::cout << c[__builtin_popcount(j)][j] << ' ';
    std::cout << '\n';
    return 0;
}

州区划分

https://www.luogu.com.cn/problem/P4221

人话:把 \(n\) 个点划分成若干个点集,保证每个点集的导出子图不是欧拉回路(每个点的度数为偶且图连通)。

那么枚举每一个点集判定是否合法,再令 \(f_{i}={w_i}^{-p}\times\sum\limits_{j\cup k = i\land j\cap k=\varnothing}f_{j}\times {w_k}^p\)

发现后面那个 sigma 里面是一个简单的子集卷积;但是我们发现前面有一个和 \(i\) 有关的常数,导致没办法直接卷 \(n\) 次。

但我们发现这个 DP 其实是有一定隐含条件的——都是按照 \(\text{ctz}\) 从小到大转移。考虑子集卷积的第一维,得到 \(FWT_{f_i}\) 后先 IFWT 回来,乘上 \({w_i}^{-p}\),再 FWT 回去即可。考察子集卷积的结构,发现这个是可以想办法套进去的。

复杂度 \(O(n^2\cdot 2^n)\)

#include <bits/stdc++.h>
const int mod = 998244353;
using arr = std::vector<long long>;
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, m, p;
    std::cin >> n >> m >> p;
    int l = 1 << n;
    auto qkp = [&](long long x, int y = mod - 2) {
        auto res = 1ll;
        for (; y; (x *= x) %= mod, y >>= 1)
            if (y & 1)
                (res *= x) %= mod;
        return res;
    };
    arr w(n), inv(l);
    std::vector<arr> f(n + 1, arr(l)), s(n + 1, arr(l));
    std::vector<std::vector<int> > g(n);
    for (int x, y; m--; ) {
        std::cin >> x >> y, --x, --y;
        g[x].push_back(y);
    }
    auto check = [&](int s) {
        std::vector<int> f(n), deg(n);
        std::iota(f.begin(), f.end(), 0);
        std::function<int(int)> find = [&](int x) {
            return x == f[x] ? x : f[x] = find(f[x]);
        };
        auto merge = [&](int x, int y) {
            f[find(x)] = find(y);
            return;
        };
        for (int i = 0; i < n; ++i)
            if ((s >> i) & 1)
                for (auto j : g[i])
                    if ((s >> j) & 1)
                        merge(i, j), ++deg[i], ++deg[j];
        int fa = -1;
        for (int i = 0; i < n; ++i)
            if ((s >> i) & 1) {
                if (deg[i] & 1)
                    return 1;
                if (fa == -1 || fa == find(i))
                    fa = find(i);
                else
                    return 1;
            }
        return 0;
    };
    for (int i = 0; i < n; ++i)
        std::cin >> w[i];
    for (int i = 0; i < l; ++i) {
        int ctz = __builtin_popcount(i);
        for (int j = 0; j < n; ++j)
            if ((i >> j) & 1) 
                s[ctz][i] += w[j];
        s[ctz][i] = qkp(s[ctz][i], p);
        inv[i] = qkp(s[ctz][i]);
        if (!check(i))
            s[ctz][i] = 0ll;
    }
    std::vector<arr> mT(2, arr(2)), mI(2, arr(2));
    mT[0][0] = 1ll, mT[0][1] = 0ll, mT[1][0] = 1ll, mT[1][1] = 1ll;
    mI[0][0] = 1ll, mI[0][1] = 0ll, mI[1][0] = mod - 1ll, mI[1][1] = 1ll;
    auto calc = [&](arr &f, std::vector<arr> &w) {
        for (int len = 2; len <= l; len <<= 1)
            for (int i = 0; i < l; i += len)
                for (int p = i, q = i + len / 2; q < i + len; ++p, ++q)
                    std::tie(f[p], f[q]) = std::make_tuple((f[p] * w[0][0] + f[q] * w[0][1]) % mod, (f[p] * w[1][0] + f[q] * w[1][1]) % mod);
        return;
    };
    f[0][0] = 1ll;
    for (int i = 0; i <= n; ++i)
        calc(s[i], mT);
    for (int i = 1; i <= n; ++i) {
        calc(f[i - 1], mT);
        for (int k = 0; k < i; ++k)
            for (int j = 0; j < l; ++j)
                (f[i][j] += f[k][j] * s[i - k][j]) %= mod;
        calc(f[i], mI);
        for (int j = 0; j < l; ++j)
            (f[i][j] *= inv[j]) %= mod;
    }
    std::cout << f[n][l - 1] << '\n';
    return 0;
}

B - Sum the Fibonacci

https://codeforces.com/problemset/problem/914/G

首先,对于每个元素的 \(cnt\) 是好做的——做一次子集卷积得到 \(cnt_1(a\cup b)\),做一次 xor-FWT 得到 \(cnt_2(d\oplus e)\),中间的 \(cnt(c)\) 就是原样。

值得注意的是可以在一次 and-FWT 后把 \(cnt,cnt_1,cnt_2\) 直接卷起来——由矩阵乘法结合律得(哈哈大笑了)。但怎么把 \(f\) 塞进去呢?如果你拥有小学生的数学水平,你可以很容易地想到直接在 \(cnt,cnt_1,cnt_2\) IFWT 后的结果分别对位乘上 \(f\) 即可,很可惜我并没有这样的数学能力

#include <bits/stdc++.h>
const int N = 17;
const int mod = 1e9 + 7;
using arr = std::vector<long long>;
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, l;
    std::cin >> n, l = 1 << N;
    std::vector<int> a(n + 1);
    arr f(l), cnt(l), cnt1(l), cnt2(l);
    for (int i = 1; i <= n; ++i)
        std::cin >> a[i], ++cnt[a[i]];
    f[0] = 0ll, f[1] = 1ll;
    for (int i = 2; i < l; ++i)
        f[i] = (f[i - 1] + f[i - 2]) % mod;
    auto qkp = [&](long long x, int y = mod - 2) {
        auto res = 1ll;
        for (; y; (x *= x) %= mod, y >>= 1)
            if (y & 1)
                (res *= x) %= mod;
        return res;
    };
    const int inv2 = qkp(2ll);
    std::vector<arr> mT(2, arr(2)), mI(2, arr(2));
    auto calc = [&](arr a, arr &f, std::vector<arr> &w) {
        f = a;
        for (int len = 2; len <= l; len <<= 1)
            for (int i = 0; i < l; i += len)
                for (int p = i, q = i + len / 2; q < i + len; ++p, ++q)
                    std::tie(f[p], f[q]) = std::make_tuple((f[p] * w[0][0] + f[q] * w[0][1]) % mod, (f[p] * w[1][0] + f[q] * w[1][1]) % mod);
        return;
    };
    { // orFWT
        mT[0][0] = 1ll, mT[0][1] = 0ll, mT[1][0] = 1ll, mT[1][1] = 1ll;
        mI[0][0] = 1ll, mI[0][1] = 0ll, mI[1][0] = mod - 1ll, mI[1][1] = 1ll;
        std::vector<arr> u(N + 1, arr(l)), d(N + 1, arr(l));
        for (int i = 0; i < l; ++i)
            u[__builtin_popcount(i)][i] = cnt[i];
        for (int i = 0; i <= N; ++i)
            calc(u[i], u[i], mT);
        for (int i = 0; i <= N; ++i) {
            for (int j = 0; j < l; ++j)
                for (int k = 0; k <= i; ++k)
                    (d[i][j] += u[k][j] * u[i - k][j]) %= mod;
            calc(d[i], d[i], mI);
        }
        for (int i = 0; i < l; ++i)
            cnt1[i] = d[__builtin_popcount(i)][i] * f[i] % mod;
    }
    { // xor FWT
        mT[0][0] = 1ll, mT[0][1] = 1ll, mT[1][0] = 1ll, mT[1][1] = mod - 1ll;
        mI[0][0] = mI[0][1] = mI[1][0] = inv2, mI[1][1] = mod - inv2;
        calc(cnt, cnt2, mT);
        for (int i = 0; i < l; ++i)
            (cnt2[i] *= cnt2[i]) %= mod;
        calc(cnt2, cnt2, mI);
        for (int i = 0; i < l; ++i)
            (cnt2[i] *= f[i]) %= mod;
    }
    { // andFWT
        mT[0][0] = 0ll, mT[0][1] = 1ll, mT[1][0] = 1ll, mT[1][1] = 1ll;
        mI[0][0] = mod - 1ll, mI[0][1] = 1ll, mI[1][0] = 1ll, mI[1][1] = 0ll;
        for (int i = 0; i < l; ++i) // 这一步要放在 FWT 之前,原因显然 🤗
            (cnt[i] *= f[i]) %= mod;
        calc(cnt2, cnt2, mT), calc(cnt1, cnt1, mT), calc(cnt, cnt, mT);
        for (int i = 0; i < l; ++i)
            cnt[i] = cnt[i] * cnt1[i] % mod * cnt2[i] % mod % mod;
        calc(cnt, cnt, mI);
    }
    long long res = 0ll;
    for (int i = 0; i < N; ++i)
        (res += cnt[1 << i] % mod) %= mod;
    std::cout << res << '\n';
    return 0;
}

言论