CF113B Petr# 题解

发布时间 2023-06-12 21:41:28作者: 霜木_Atomic

最近在做字符串的题,正好就给我随机了一道这个(

题意

给你一个字符串 \(s\) 以及一个开头串 \(s_{begin}\) 和结尾串 \(s_{end}\),问该字符串中有多少个不同的子串,满足以 \(s_{begin}\) 开头,以 \(s_{end}\) 结尾。两个子串不同,当且仅当两个子串长度不同,或长度相同但至少有一个位置上的字符不同,而与位置无关。

分析

首先,我们注意到这是一个找字符串中某一个模式串出现位置的题,考虑 KMP(事实上,这道题的数据范围甚至都可以暴力匹配位置);然后,发现我们只关注字符串的组成元素不同,考虑哈希。到这里了这道题实际上做完了。

然鹅,这道题个细节。比如给定字符串 abcdefghi,然后 \(s_{begin} =\) abcdef,$s_{end} = $ cd,这时候就需要注意,一定要保证你的子串是以 \(s_{end}\) 结尾的。

代码:

这里给出两种哈希思路(第二种是从题解中找到的qwq,太弱了导致我只会map暴力搞。)

第一种:用 unordered_map

因为如果直接用 map 会 T

#include<bits/stdc++.h>
#define ull unsigned long long
using namespace std;
const int N = 2050, pp = 33;//这里23会被卡,33就可以了。哈希真的玄学……

ull pnn[N];
char s[N], ls[N], rs[N];
int n, la, lb;
int nxt1[N], nxt2[N];
int pa[N], pb[N], tota, totb;
ull  hashb[N];

void init(){
    pnn[0] = 1;
    for(int i = 1; i<=n; ++i){
        hashb[i] = (hashb[i-1]*pp+s[i]);
        pnn[i] = pnn[i-1]*pp;
    }
}
unordered_map<ull, int> mb;
inline int ull gethashb(int l, int r){
    return (hashb[r]-hashb[l-1]*pnn[r-l+1]);
}
void KMP(){
    int j = 0;
    for(int i = 2; i<=la; i++){
        while(j && ls[j+1] != ls[i]) j = nxt1[j];
        if(ls[j+1] == ls[i]) j++;
        nxt1[i] = j;
    }
    j = 0;
    for(int i = 2; i<=lb; i++){
        while(j && rs[j+1] != rs[i]) j = nxt2[j];
        if(rs[j+1] == rs[i]) j++;
        nxt2[i] = j;
    }
    j = 0;
    for(int i = 1; i<=n; i++){
        while(j && ls[j+1] != s[i]) j = nxt1[j];
        if(ls[j+1] == s[i]) j++;
        if(j == la){
            pa[++tota] = i-la+1;
        }
    }
    j = 0;
    for(int i = 1; i<=n; i++){
        while(j && rs[j+1] != s[i]) j = nxt2[j];
        if(rs[j+1] == s[i]) j++;
        if(j == lb){
            pb[++totb] = i-lb+1;
        }
    }
    //前面就是匹配+找位置
    int tl, tr;
    int ans = 0;
    for(int i = 1; i<=tota; i++){
        for(int j = 1; j<=totb; j++){
            if(pb[j]>=pa[i]&&((pa[i]+la-1)<=(pb[j]+lb-1))){
                tl = pa[i], tr = pb[j]+lb-1;
                ull tmpb = gethashb(tl, tr);
                if(mb[tmpb]){
                    continue;
                } else{
                    mb[tmpb]++;
                    ans++;
                }
            }
        }
    }
    printf("%d\n", ans);
}
int main(){
    scanf("%s%s%s", s+1, ls+1, rs+1);
    n = strlen(s+1);
    la = strlen(ls+1);
    lb = strlen(rs+1);
    init();
    KMP();
    return 0;
}

第二种:用 unique 去重

这个也是最快的。

#include<bits/stdc++.h>
#define ull unsigned long long
#define ll long long
using namespace std;
const int N = 2050,  pp = 33;

ull pnn[N];
char s[N], ls[N], rs[N];
int n, la, lb;
int nxt1[N], nxt2[N];
int pa[N], pb[N], tota, totb;
ull hashb[N];
ull ans[N*N];
void init(){
    pnn[0] = 1;
    for(int i = 1; i<=n; ++i){
        hashb[i] = (hashb[i-1]*pp+s[i]);
        pnn[i] = pnn[i-1]*pp;
    }
}

inline int ull gethashb(int l, int r){
    return (hashb[r]-hashb[l-1]*pnn[r-l+1]);
}
void KMP(){
    int j = 0;
    for(int i = 2; i<=la; i++){
        while(j && ls[j+1] != ls[i]) j = nxt1[j];
        if(ls[j+1] == ls[i]) j++;
        nxt1[i] = j;
    }
    j = 0;
    for(int i = 2; i<=lb; i++){
        while(j && rs[j+1] != rs[i]) j = nxt2[j];
        if(rs[j+1] == rs[i]) j++;
        nxt2[i] = j;
    }
    j = 0;
    for(int i = 1; i<=n; i++){
        while(j && ls[j+1] != s[i]) j = nxt1[j];
        if(ls[j+1] == s[i]) j++;
        if(j == la){
            pa[++tota] = i-la+1;
        }
    }
    j = 0;
    for(int i = 1; i<=n; i++){
        while(j && rs[j+1] != s[i]) j = nxt2[j];
        if(rs[j+1] == s[i]) j++;
        if(j == lb){
            pb[++totb] = i-lb+1;
        }
    }
    int tl, tr;
    int tans = 0;
    for(int i = 1; i<=tota; i++){
        for(int j = 1; j<=totb; j++){
            if(pb[j]>=pa[i]&&((pa[i]+la-1)<=(pb[j]+lb-1))){
                tl = pa[i], tr = pb[j]+lb-1;
                ull tmpb = gethashb(tl, tr);
                ans[++tans] = tmpb;
            }
        }
    }
    sort(ans+1, ans+tans+1);
    tans = unique(ans+1, ans+tans+1)-ans-1;
    printf("%d\n", tans);
}
int main(){
    scanf("%s%s%s", s+1, ls+1, rs+1);
    n = strlen(s+1);
    la = strlen(ls+1);
    lb = strlen(rs+1);
    init();
    KMP();
    return 0;
}