学习笔记:后缀数组
2025-02-02

非常后知后觉地意识到 SA(Suffix Array) 和 SAM(Suffix Automaton) 的 A 不是同一个 A


定义

显而易见一个长度为 \(n\) 的字符串中有 \(n\) 个长度分别为 \(1\sim n\) 的后缀,如果我们对其按字典序排序,分别存储下排名 \(i\) 的后缀 \(sa_i\) 和每个后缀 \(i\) 的排名 \(rk_i\)。虽然看着挺没头没尾的,但是很有用。

求解

哈希 + 排序

直接把所有后缀拿来排序的话,字符串比较是 \(O(n)\) 的。如果我们用哈希 + 二分优化比较过程,就可以把整个排序优化到 \(O(n\log^2 n)\)

倍增

先对所有后缀按 第一个字符 排序,记排序后排名序列为 \(a\)

那么怎么按 前两个字符 排序呢?对于第 \(i\) 组字符,我们用 \((a_i,a_{i+1})\) 双关键字排序即可。记此时排名序列为 \(b\),那么如果需要按照前四个字符排序,用 \((b_i,b_{i+2})\) 进行双关键字排序即可。总共需要进行 \(\log n\) 次排序。复杂度为 \(O(n\log^2n)\)

此时我们注意到排名数组的值域为 \(n\),那么我们用桶排就能少一个 \(\log\)

实现

哈希很好实现,这里就按下不表,主要讲解倍增法的实现。

描述起来很简单,实现起来很要命。OI wiki 上的实现算是相对好理解的:

首先了解双关键字桶排的方法,首先用单关键字桶排完成对 第二关键字 的排序;对于第一关键字,令桶 \(i\) 记录前 \(i\) 个元素的数量;遍历排序后的第二关键字数组,将元素放到桶中记录数值对应的下标中,并将桶中数值 \(-1\)。实际上桶 \(c\) 充当计算下标范围的作用,\((c_{i-1},c_i]\) 即为 \(i\) 分布的范围。

显然,当且仅当排名种类为 \(n\),即没有并列排名时,排序完成。设本轮区间长度为 \(w\),对于一轮操作:

  1. 计算每个区间按后半段 \(\frac w2\) 长度字符排序的结果:\((n-w,n]\) 开头的区间后半段均为空,直接放在序列首端;接着按照上一轮 \(sa\) 结果,把能够作为后半段的元素依次放入。
  2. 依照上一轮的 \(rk\) 作为前半段排名,进行双关键字桶排。
  3. 依照 \(sa\) 和第二关键字(处理并列),求出 \(rk\)
std::vector<int> la(n + 2);
std::copy(s.begin(), s.end(), rk.begin());
int m = 128;
{
    std::vector<int> c(m + 1);
    for (int i = 1; i <= n; ++i)
        ++c[rk[i]];
    std::partial_sum(c.begin(), c.end(), c.begin());
    for (int i = n; i; --i)
        sa[c[rk[i]]--] = i;
} 
for (int w = 1, p; ; w <<= 1, m = p) {
    std::vector<int> id(1);
    for (int i = n - w + 1; i <= n; ++i)
        id.push_back(i);
    for (int i = 1; i <= n; ++i)
        if (sa[i] > w)
            id.push_back(sa[i] - w);
    std::vector<int> c(m + 1);
    for (int i = 1; i <= n; ++i)
        ++c[rk[i]];
    std::partial_sum(c.begin(), c.end(), c.begin());
    for (int i = n; i; --i)
        sa[c[rk[id[i]]]--] = id[i];
    p = 0;
    std::copy(rk.begin(), rk.end(), la.begin());
    for (int i = 1; i <= n; ++i)
        if (la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
            rk[sa[i]] = p;
        else
            rk[sa[i]] = ++p;
    if (p == n)
        break;
}

纯 SA 的应用

最小表示法

模板:https://www.luogu.com.cn/problem/P1368

对于循环位移相关要求,首先考虑将字符串重复一遍。

\(ss\) 中找到排名第一个 \(sa_i\le n\) 即为答案。

#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n;
    std::cin >> n;
    std::vector<int> s(2 * n + 1);
    for (int i = 1; i <= n; ++i)
        std::cin >> s[i], s[n + i] = s[i];
    std::vector<int> sa(2 * n + 1), rk(s);
    {
        int m = 29;
        {
            std::vector<int> c(m + 1);
            for (int i = 1; i <= 2 * n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = 2 * n; i; --i)
                sa[c[rk[i]]--] = i;
        }
        for (int w = 1, p; ; w <<= 1, m = p) {
            std::vector<int> id(1);
            for (int i = 2 * n - w + 1; i <= 2 * n; ++i)
                id.push_back(i);
            for (int i = 1; i <= 2 * n; ++i)
                if (sa[i] > w)
                    id.push_back(sa[i] - w);
            std::vector<int> c(m + 1);
            for (int i = 1; i <= 2 * n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = 2 * n; i; --i)
                sa[c[rk[id[i]]]--] = id[i];
            auto la(rk);
            p = 0;
            for (int i = 1; i <= 2 * n; ++i)
                if (i != 1 && la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
                    rk[sa[i]] = p;
                else
                    rk[sa[i]] = ++p;
            if (p == 2 * n)
                break;
        }
    }
    for (int i = 1; i <= 2 * n; ++i)
        if (sa[i] <= n) {
            for (int j = sa[i]; j < n + sa[i]; ++j)
                std::cout << s[j] << ' ';
            std::cout << '\n';
            break;
        }
    return 0;
}

字符串匹配

二分,复杂度 \(O(|S|\log |T|)\)。求出现次数则二分左右边界。

太麻烦了且没有实际应用价值,代码略。


height 数组

定义 \(h_i=\text {lcp}(sa_i, sa_i-1)\),特别地,\(h_1=0\)

有引理:\(h_{rk_i}\ge h_{rk_{i-1}}-1\)

假设已经求出 \(h_{rk_{i-1}}\),那么可以从 \(h_{rk_{i-1}}-1\) 出发暴力看下一个字符是否相等得到答案。那么我们会发现从前往后 \(h\) 值每次最多 \(-1\),所以复杂度摊下来是 \(O(n)\) 的。

记住记住一定是 \(rk_{i-1}\) 而不是下意识的 \(rk_i-1\)!!!所以为了保证求解顺序循环枚举的一定是下标而非排名。但是注意定义却是和 \(rk_i-1\) 的 lcp!!!所以求 height 的写法是相对固定的,不能觉得好像是对的就随便乱改。


height 数组的应用

相当于背板子,因为应用太多且形式大多固定。

求任意两个后缀的 lcp

易得 \(\text{lcp}(sa_i, sa_j)=\min\{h_{i+1}, \cdots, h_j\}\)故应将一些复杂 lcp 问题的解决方式和 RMQ 联系起来


子串大小关系

即比较 \(S_{l_1, r_1}\)\(S_{l_2, r_2}\) 的大小关系。比较导致 lcp 不能继续延伸的元素大小即可。


本质不同子串数量

子串等价于「后缀的前缀」。按顺序枚举每个后缀,减去和已枚举的所有后缀的 lcp 即可。鉴于 \(\min\{h_{j+1},\cdots,h_i\}\) 单调不减,直接减去 \(h_i\) 即可。

最后答案即为 \(\frac {n(n-1)}2 - \sum\limits_{i=2}^nh_i\)


至少出现 \(k\) 次子串的最大长度

模板:https://www.luogu.com.cn/problem/P2852

出现 \(k\)\(\iff\) 在后缀数组中连续出现 \(k\)\(\iff\) 是任意连续 \(k-1\)\(h\) 的最小值,需要最大化该最小值,考虑滑动窗口。

#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen("P2852_7.in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, k;
    std::cin >> n >> k, --k;
    std::vector<int> s(n + 1);
    for (int i = 1; i <= n; ++i)
        std::cin >> s[i];
    std::vector<int> sa(n + 1), rk(s), h(n + 1);
    {
        int m = 1000001;
        {
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n; i; --i)
                sa[c[rk[i]]--] = i;
        }
        for (int w = 1, p; ; w <<= 1, m = p) {
            std::vector<int> id(1);
            for (int i = n - w + 1; i <= n; ++i)
                id.push_back(i);
            for (int i = 1; i <= n; ++i)
                if (sa[i] > w)
                    id.push_back(sa[i] - w);
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n; i; --i)
                sa[c[rk[id[i]]]--] = id[i];
            auto la(rk);
            p = 0;
            for (int i = 1; i <= n; ++i)
                if (i != 1 && la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
                    rk[sa[i]] = p;
                else
                    rk[sa[i]] = ++p;
            if (p == n)
                break;
        }
        for (int i = 1, to = 0; i <= n; ++i)
            if (rk[i]) {
                for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
                h[rk[i]] = to;
            }
    }
    std::vector<int> q(n + 1);
    int res = 0;
    for (int i = 1, l = 1, r = 0; i <= n; ++i) {
        // printf("%d\n", h[i]);
        for (; l <= r && i - q[l] >= k; ++l);
        for (; l <= r && h[i] <= h[q[r]]; --r);
        q[++r] = i;
        if (i >= k)
            res = std::max(res, h[q[l]]);
    }
    std::cout << res << '\n';
    return 0;
}

最长不重叠多次出现子串

bb:定式太多太杂以至于让人怀疑某些定式是否存在应用场景

发现满足单调性,二分子串长度 \(len\),那么显然 \(\text {lcp}\ge len\);将 \(h\) 划分为连续 \(\ge len\) 的段,在每段内找到下标极差与 \(len\) 比较即可。

也可以用于判定是否存在不重叠多次出现子串。

甚至可以考虑限制至少出现次数为 \(k\),那大概多个 \(\log\),看看一段里有没有 \(\ge k\) 个相互相差 \(\ge len\) 的。排序贪心求解。

那么上面的至少出现 \(k\) 次子串也可以用这个方法来解,但是多个 \(\log\) 没必要。

也可以限制多次出现但长度至少为 \(len\),那甚至少了二分的 \(\log\),直接跑一遍 check 即可。

???到底为什么会有这么多奇怪的定式,是因为真的有题这么出吗???


最长公共子串问题

\(S\)\(T\) 的最长公共子串(注意不是 LCS)。设 \(S\) 长为 \(n\)\(T\) 长为 \(m\),那么将 \(S\)\(T\) 拼接,答案就是 \(\max \{\text{lcp}(i,j)\},i\le n<j\)

但这里不直接枚举 \(i\)\(j\),还是照例先从 \(h\) 下手再卡条件,若 \(sa_{i-1}\le n<sa_{i}\)(或者反过来),就可以用 \(h_i\) 更新答案。容易证明这样总可以找到最大值。

eg1. 找相同字符

https://www.luogu.com.cn/problem/P3181

要求方案数,那么答案为 \(\text{lcp}(i,j),i\le n<j\)。(我已经帮你们试过了容斥比直接做更麻烦),考虑用单调栈维护左 / 右侧区间 lcp 求解右 / 左侧答案。关于单调栈的描述可见 本页后部内容

#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, n1;
    std::string s, t;
    std::cin >> s >> t;
    n = (int)s.length(), n1 = n + (int)t.length() + 1;
    s = "#" + s + "$" + t;
    std::vector<int> sa(n1 + 1), rk(n1 + 1), h(n1 + 1);
    {
        std::copy(s.begin() + 1, s.end(), rk.begin() + 1);
        int m = 128;
        {
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n1; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n1; i; --i)
                sa[c[rk[i]]--] = i;
        }
        for (int w = 1, p; ; w <<= 1, m = p) {
            std::vector<int> id(w + 1);
            std::iota(id.begin() + 1, id.end(), n1 - w + 1);
            for (int i = 1; i <= n1; ++i)
                if (sa[i] > w)
                    id.push_back(sa[i] - w);
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n1; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n1; i; --i)
                sa[c[rk[id[i]]]--] = id[i];
            p = 0;
            auto la(rk);
            for (int i = 1; i <= n1; ++i)
                if (i != 1 && la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
                    rk[sa[i]] = p;
                else
                    rk[sa[i]] = ++p;
            if (p == n1)
                break;
        }
        for (int i = 1, to = 0; i <= n1; ++i) {
            for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
            h[rk[i]] = to;
        }
    }
    std::vector<std::pair<int, long long> > q1, q2;
    std::vector<int> tot1(n1 + 1), tot2(n1 + 1);
    for (int i = 1; i <= n1; ++i) {
        tot1[i] = tot1[i - 1] + (sa[i] <= n);
        tot2[i] = tot2[i - 1] + (sa[i] > n + 1);
    }
    long long res = 0ll;
    q1.emplace_back(1, 0ll), q2.emplace_back(1, 0ll);
    for (int i = 1; i <= n1; ++i) {
        for (; !q1.empty() && h[i] < h[q1.back().first]; q1.pop_back());
        q1.emplace_back(i, (tot1[i - 1] - tot1[q1.back().first - 1]) * h[i] + q1.back().second);
        if (sa[i] > n + 1)
            res += q1.back().second;
        for (; !q2.empty() && h[i] < h[q2.back().first]; q2.pop_back());
        q2.emplace_back(i, (tot2[i - 1] - tot2[q2.back().first - 1]) * h[i] + q2.back().second);
        if (sa[i] <= n)
            res += q2.back().second;
    }
    std::cout << res << '\n';
    return 0;
}

eg2. 公共串

https://www.luogu.com.cn/problem/P5546

要求多串最长公共子串,仍然考虑将多个串拼在一起。仿照前面二分的方式处理,问题转化为找到最长的 \(len\),使得存在一段最小值 \(\ge len\) 的区间,其覆盖了 \(n\) 段串。

#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, l = 0, r = 0;
    std::cin >> n;
    std::string s;
    std::vector<std::pair<int, int> > lim(n + 1);
    for (int i = 1; i <= n; ++i) {
        std::string t;
        std::cin >> t;
        lim[i] = { (int)s.length() + 1, s.length() + t.length() };
        s += "#" + t;
        r = std::max(r, (int)t.length());
        // printf("[%d, %d]\n", lim[i].first, lim[i].second);
    }
    int n1 = lim.back().second;
    std::vector<int> sa(n1 + 1), rk(n1 + 1), h(n1 + 1);
    {
        std::copy(s.begin() + 1, s.end(), rk.begin() + 1);
        int m = 128;
        {
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n1; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n1; i; --i)
                sa[c[rk[i]]--] = i;
        }
        for (int w = 1, p; ; w <<= 1, m = p) {
            std::vector<int> id(w + 1);
            std::iota(id.begin() + 1, id.end(), n1 - w + 1);
            for (int i = 1; i <= n1; ++i)
                if (sa[i] > w)
                    id.push_back(sa[i] - w);
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n1; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n1; i; --i)
                sa[c[rk[id[i]]]--] = id[i];
            p = 0;
            auto la(rk);
            for (int i = 1; i <= n1; ++i)
                if (i != 1 && la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
                    rk[sa[i]] = p;
                else
                    rk[sa[i]] = ++p;
            if (p == n1)
                break;
        }
        for (int i = 1, to = 0; i <= n1; ++i) {
            for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
            h[rk[i]] = to;
        }
    }
    // for (int i = 1; i <= n1; ++i)
    //     printf("h[%d] = %d\n", sa[i], h[i]);
    int res = 0;
    auto check = [&](int len) {
        // printf("check %d: \n", len);
        std::vector<int> cnt(n + 1);
        for (int i = 1; i <= n1; ++i) {
            if (h[i] < len) {
                if (*std::min_element(cnt.begin() + 1, cnt.end()))
                    return 1;
                cnt.assign(n + 1, 0);
            }
            else
                for (int j = 1; j <= n; ++j) {
                    if (lim[j].first <= sa[i - 1] && sa[i - 1] <= lim[j].second)
                        cnt[j] = 1;
                    if (lim[j].first <= sa[i] && sa[i] <= lim[j].second)
                        cnt[j] = 1;
                }
        }
        // printf("\n%d\n", *std::min_element(cnt.begin() + 1, cnt.end()));
        return *std::min_element(cnt.begin() + 1, cnt.end());
    };
    for (int mid; l <= r; ) {
        mid = (l + r) >> 1;
        if (check(mid))
            l = mid + 1, res = mid;
        else
            r = mid - 1;
    }
    std::cout << res << '\n';
    return 0;
}

但是看了题解发现居然还有线性做法(当然不看建 SA 的 \(\log\)),对于覆盖全部 \(n\) 段串找区间最小值,发现需要最小化区间,考虑双指针。

区间最小值用单调队列求解,细想可能会觉得不太对劲,但是容易证明答案不大于队首且不小于最大队首,所以最大队首就是答案。

#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, l = 0, r = 0;
    std::cin >> n;
    std::string s;
    std::vector<std::pair<int, int> > lim(n + 1);
    for (int i = 1; i <= n; ++i) {
        std::string t;
        std::cin >> t;
        lim[i] = { (int)s.length() + 1, s.length() + t.length() };
        s += (char)('A' + i - 1) + t;
        r = std::max(r, (int)t.length());
    }
    int n1 = lim.back().second;
    std::vector<int> sa(n1 + 1), rk(n1 + 1), h(n1 + 1);
    {
        std::copy(s.begin() + 1, s.end(), rk.begin() + 1);
        int m = 128;
        {
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n1; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n1; i; --i)
                sa[c[rk[i]]--] = i;
        }
        for (int w = 1, p; ; w <<= 1, m = p) {
            std::vector<int> id(w + 1);
            std::iota(id.begin() + 1, id.end(), n1 - w + 1);
            for (int i = 1; i <= n1; ++i)
                if (sa[i] > w)
                    id.push_back(sa[i] - w);
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n1; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n1; i; --i)
                sa[c[rk[id[i]]]--] = id[i];
            p = 0;
            auto la(rk);
            for (int i = 1; i <= n1; ++i)
                if (i != 1 && la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
                    rk[sa[i]] = p;
                else
                    rk[sa[i]] = ++p;
            if (p == n1)
                break;
        }
        for (int i = 1, to = 0; i <= n1; ++i) {
            for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
            h[rk[i]] = to;
        }
    }
    int res = 0;
    std::vector<int> q(n1 + 1), cnt(n + 1);
    // for (int i = 1; i <= n1; ++i)
    //     printf("%d: %d\n", sa[i], h[i]);
    for (int l = 1, r = 0, ql = 1, qr = 0; l <= n1; ++l) {
        for (; r < n1 && !*std::min_element(cnt.begin() + 1, cnt.end()); ) {
            ++r;
            for (int i = 1; i <= n; ++i)
                if (lim[i].first <= sa[r] && sa[r] <= lim[i].second) {
                    ++cnt[i];
                    break;
                }
            for (; ql <= qr && h[r] <= h[q[qr]]; --qr);
            q[++qr] = r;
        }
        if (*std::min_element(cnt.begin() + 1, cnt.end())) {
            // printf("[%d, %d]: %d\n", l, r, h[q[ql]]);
            res = std::max(res, h[q[ql]]);
        }
        for (; ql <= qr && q[ql] <= l; ++ql);
        if (l != 1) {
            for (int i = 1; i <= n; ++i)
                if (lim[i].first <= sa[l - 1] && sa[l - 1] <= lim[i].second) {
                    --cnt[i];
                    break;
                }
        }
    }
    std::cout << res << '\n';
    return 0;
}

\(\texttt {AA}\) 式子串处理

即对于连续相同子串问题的处理,有一个定的思路,由例题分析。

eg1. 优秀的拆分

https://www.luogu.com.cn/problem/P1117

还是从中间分开,按前后分别处理。这里有个 trick,我们枚举 \(\texttt B\) 的长度 \(len\),在 \(S\) 中每隔 \(len\) 打一个标记。那么显然,任意一个长度为 \(2\times len\) 的子串都会经过恰好两个标记(充分的),这样就可以筛选出所有可能的串。

我们枚举所有连续两个标记(总复杂度为调和级数),求它们对应后缀的 lcp 和对应前缀的 lcs(翻转求 SA 即可),如果二者加起来 \(\ge len\) 就说明存在这样的 \(\texttt {AA}\)。在 \(\text {lcs}+\text {lcp}\) 中任取 \(len\) 长度即为一对 \(\texttt {AA}\)。用差分给可能的起点和终点区间加即可。

小细节:lcp 和 lcs 均需要对 \(len\)\(\min\),否则取到的串可能不会经过当前选中的两个标记。

#include <bits/stdc++.h>
class SA {
public:
    std::vector<int> sa, rk, h;
    std::vector<std::vector<int>  > st;
    SA(int n, std::string s): sa(n + 1), rk(n + 2), h(n + 1), st(20, std::vector<int> (n + 1)) {
        std::vector<int> la(n + 2);
        std::copy(s.begin(), s.end(), rk.begin());
        int m = 128;
        {
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n; i; --i)
                sa[c[rk[i]]--] = i;
        }
        for (int w = 1, p; ; w <<= 1, m = p) {
            std::vector<int> id(1);
            for (int i = n - w + 1; i <= n; ++i)
                id.push_back(i);
            for (int i = 1; i <= n; ++i)
                if (sa[i] > w)
                    id.push_back(sa[i] - w);
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n; i; --i)
                sa[c[rk[id[i]]]--] = id[i];
            p = 0;
            std::copy(rk.begin(), rk.end(), la.begin());
            for (int i = 1; i <= n; ++i)
                if (la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
                    rk[sa[i]] = p;
                else
                    rk[sa[i]] = ++p;
            if (p == n)
                break;
        }
        for (int i = 1, to = 0; i <= n; ++i)
            if (rk[i]) {
                for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
                h[rk[i]] = to;
            }
        for (int i = 1; i <= n; ++i)
            st[0][i] = h[i];
        for (int j = 1; (1 << j) <= n; ++j)
            for (int i = 1; i + (1 << j) - 1 <= n; ++i)
                st[j][i] = std::min(st[j - 1][i], st[j - 1][i + (1 << (j - 1))]);
        rk.emplace_back();
        return;
    }
private:
    int ask(int l, int r) {
        // fprintf(stderr, "l = %d, r = %d\n", l, r);
        int k = std::__lg(r - l + 1);
        return std::min(st[k][l], st[k][r - (1 << k) + 1]);
    }
public:
    int lcp(int l, int r) {
        return ask(std::min(rk[l], rk[r]) + 1, std::max(rk[l], rk[r]));
    }
};
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int T;
    for (std::cin >> T; T--; ) {
        std::string s;
        std::cin >> s;
        int n = (int)s.length();
        s = "#" + s;
        SA p(n, s);
        std::reverse(s.begin() + 1, s.end());
        SA q(n, s);
        std::vector<int> f(n + 2), g(n + 2);
        for (int len = 1; len <= n / 2; ++len)
            for (int i = len; i + len <= n; i += len) {
                int l = i, r = i + len, lcp = std::min(len, p.lcp(l, r)), lcs = std::min(len - 1, q.lcp(n - l + 2, n - r + 2));
                if (lcp + lcs >= len) {
                    int t = lcp + lcs - len + 1;
                    // fprintf(stderr, "(%d, %d), %d, %d\n", l, r, lcp, lcs);
                    ++g[l - lcs], --g[l - lcs + t], ++f[r + lcp - t], --f[r + lcp];
                }
            }
        std::partial_sum(f.begin(), f.end(), f.begin());
        std::partial_sum(g.begin(), g.end(), g.begin());
        long long res = 0ll;
        for (int i = 1; i < n; ++i)
            res += (long long)f[i] * g[i + 1];
        std::cout << res << '\n';
    }
    return 0;
}

eg2. tandem

https://www.codechef.com/problems/TANDEM

注意到多了一个限制,前一个好处理,找到经过 \(3\) 个标记的串即可。对于后一个限制,画图可以发现对于 interesting ones,每次只会出现最多一个;当且仅当 \(\text {lcp}>len\) 时不存在。

对于 uninteresting ones,用每次能提供的总数减去 interesting ones 的数量即可。

#include <bits/stdc++.h>
class SA {
public:
    std::vector<int> sa, rk, h;
    std::vector<std::vector<int>  > st;
    SA(int n, std::string s): sa(n + 1), rk(n + 2), h(n + 1), st(20, std::vector<int> (n + 1)) {
        std::vector<int> la(n + 2);
        std::copy(s.begin(), s.end(), rk.begin());
        int m = 128;
        {
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n; i; --i)
                sa[c[rk[i]]--] = i;
        }
        for (int w = 1, p; ; w <<= 1, m = p) {
            std::vector<int> id(1);
            for (int i = n - w + 1; i <= n; ++i)
                id.push_back(i);
            for (int i = 1; i <= n; ++i)
                if (sa[i] > w)
                    id.push_back(sa[i] - w);
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n; i; --i)
                sa[c[rk[id[i]]]--] = id[i];
            p = 0;
            std::copy(rk.begin(), rk.end(), la.begin());
            for (int i = 1; i <= n; ++i)
                if (la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
                    rk[sa[i]] = p;
                else
                    rk[sa[i]] = ++p;
            if (p == n)
                break;
        }
        for (int i = 1, to = 0; i <= n; ++i)
            if (rk[i]) {
                for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
                h[rk[i]] = to;
            }
        for (int i = 1; i <= n; ++i)
            st[0][i] = h[i];
        for (int j = 1; (1 << j) <= n; ++j)
            for (int i = 1; i + (1 << j) - 1 <= n; ++i)
                st[j][i] = std::min(st[j - 1][i], st[j - 1][i + (1 << (j - 1))]);
        rk.emplace_back();
        return;
    }
private:
    int ask(int l, int r) {
        // fprintf(stderr, "l = %d, r = %d\n", l, r);
        int k = std::__lg(r - l + 1);
        return std::min(st[k][l], st[k][r - (1 << k) + 1]);
    }
public:
    int lcp(int l, int r) {
        return ask(std::min(rk[l], rk[r]) + 1, std::max(rk[l], rk[r]));
    }
};
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    std::string s;
    std::cin >> s;
    int n = (int)s.length();
    s = "#" + s;
    SA p(n, s);
    std::reverse(s.begin() + 1, s.end());
    SA q(n, s);
    std::reverse(s.begin() + 1, s.end());
    long long res1 = 0ll, res2 = 0ll;
    for (int len = 1; len <= n; ++len)
        for (int i = len, j = 2 * len, k = 3 * len; k <= n; i += len, j += len, k += len) {
            int lcp = std::min(p.lcp(i, j), p.lcp(j, k)), lcs = std::min({ len - 1, q.lcp(n - i + 2, n - j + 2), q.lcp(n - j + 2, n - k + 2) });
            if (std::min(len, lcp) + lcs >= len) {
                // printf("(%d, %d, %d), %d, %d, %d\n", i, j, k, lcs, lcp, len);
                int t = (lcp <= len);
                res1 += t, res2 += std::min(len, lcp) + lcs - len + 1 - t;
            }
            // else
            //     printf("# (%d, %d, %d), %d, %d, %d\n", i, j, k, lcs, lcp, len);
        }
    std::cout << res1 << ' ' << res2 << '\n';
    return 0;
}

eg3. repeats

https://www.spoj.com/problems/REPEATS/

重复次数最多,只需经过标记点最多。显然经过标记点的数量就是该字符串长除以 \(len\) 向下取整就可以得到重复次数减 \(1\) 的值。

选择两个连续标记点,对于 lcp 和 lcs(显然此时不需要对 \(len\)\(\min\)),计算 \(\dfrac {\text{lcp}+\text{lcs}}{len}+1\) 取最大即可。

#include <bits/stdc++.h>
class SA {
public:
    std::vector<int> sa, rk, h;
    std::vector<std::vector<int>  > st;
    SA(int n, std::string s): sa(n + 1), rk(n + 2), h(n + 1), st(20, std::vector<int> (n + 1)) {
        std::vector<int> la(n + 2);
        std::copy(s.begin(), s.end(), rk.begin());
        int m = 128;
        {
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n; i; --i)
                sa[c[rk[i]]--] = i;
        }
        for (int w = 1, p; ; w <<= 1, m = p) {
            std::vector<int> id(1);
            for (int i = n - w + 1; i <= n; ++i)
                id.push_back(i);
            for (int i = 1; i <= n; ++i)
                if (sa[i] > w)
                    id.push_back(sa[i] - w);
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n; i; --i)
                sa[c[rk[id[i]]]--] = id[i];
            p = 0;
            std::copy(rk.begin(), rk.end(), la.begin());
            for (int i = 1; i <= n; ++i)
                if (la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
                    rk[sa[i]] = p;
                else
                    rk[sa[i]] = ++p;
            if (p == n)
                break;
        }
        for (int i = 1, to = 0; i <= n; ++i)
            if (rk[i]) {
                for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
                h[rk[i]] = to;
            }
        for (int i = 1; i <= n; ++i)
            st[0][i] = h[i];
        for (int j = 1; (1 << j) <= n; ++j)
            for (int i = 1; i + (1 << j) - 1 <= n; ++i)
                st[j][i] = std::min(st[j - 1][i], st[j - 1][i + (1 << (j - 1))]);
        rk.emplace_back();
        return;
    }
private:
    int ask(int l, int r) {
        // fprintf(stderr, "l = %d, r = %d\n", l, r);
        int k = std::__lg(r - l + 1);
        return std::min(st[k][l], st[k][r - (1 << k) + 1]);
    }
public:
    int lcp(int l, int r) {
        return ask(std::min(rk[l], rk[r]) + 1, std::max(rk[l], rk[r]));
    }
};
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int T;
    for (std::cin >> T; T--; ) {
        int n;
        std::cin >> n;
        std::string s = "#";
        for (int i = 1; i <= n; ++i) {
            char t;
            std::cin >> t;
            s.push_back(t);
        }
        SA p(n, s);
        std::reverse(s.begin() + 1, s.end());
        SA q(n, s);
        int res = 0;
        for (int len = 1; len <= n; ++len)
            for (int i = len, j = 2 * len; j <= n; i += len, j += len) {
                int lcp = p.lcp(i, j), lcs = q.lcp(n - i + 2, n - j + 2);
                if (lcp + lcs >= len)
                    res = std::max(res, (lcp + lcs) / len + 1);
            }
        std::cout << res << '\n';
    }
    return 0;
}

结合并查集


结合单调栈


一言 - Hitokoto