同余最短路 学习心得

发布时间 2023-10-18 17:24:15作者: DZhearMins

同余最短路用于解决几个数的和或差来统计能凑出多少数的问题,比如:

P3403 跳楼机

P2371 墨墨的等式

C231017C诡异键盘

OI Wiki讲同余最短路讲得好。接下来我们来通过一道例题分析同余最短路的做法。

P3403 跳楼机

题目大意:给定 3 个数,问你能用这三个数(可重复用可不用)凑出 [ \(1\) , \(h\) ] 区间内多少个数。

我们设给的三个数为 \(a<b<c\) ,我们把最小的 \(a\) 设为模数,然后所有 \(+a\) 的操作就可以被忽略(因为 \(+a =\) 在原地徘徊)。

我们把 \(\forall i \in [0,a-1]\) 都连上 \(i \stackrel{b}{\longrightarrow} (i+b)\% a\) 的边(代表从所有模 \(a\)\(i\) 的楼层要往上爬 \(b\) 层可以到达 \((i+b)\%a\) 楼层,没错,这样的建边方式很像分层图。

注意到第 \(i\) 号点的距离为 \(dis_i\) 的实际意义是: \(dis_i\) 层以前的模 \(a\)\(i\) 的楼层都到达不了;而 \(dis_i\) 层一直到第 \(h\) 层以内最大的模 \(a\)\(i\) 的楼层都是可以到达的。 \(dis_i\%a\) 的值一定为 \(i\) ,原因就是建边的方式决定了。

因此对于一个 \(dis_i\) ,它带来的贡献就是

\[\lfloor{\frac{h-dis_i}{a}}\rfloor+1 \]

怎么求 \(dis_i\) 呢?跑 dijkstra 就行哩。

#include<bits/stdc++.h>  
using namespace std;  
#define ll long long  
#define N 300005  
ll tot,fir[N],nxt[N*2],to[N*2],w[N*2];  
void add(ll x,ll y,ll z){  
    nxt[++tot]=fir[x];  
    fir[x]=tot;  
    to[tot]=y;  
    w[tot]=z;  
}  
ll a,b,c,H,dis[N],ans;  
bool vis[N];  
void dijkstra(){  
    priority_queue<pair<ll,ll> >q;  
    q.push(make_pair(-dis[1],1));  
    while(!q.empty()){  
        ll x=q.top().second;  
        q.pop();  
        vis[x]=1;  
        for(ll e=fir[x];e;e=nxt[e]){  
            ll u=to[e];  
            if(u>H)continue;  
            if(dis[x]+w[e]<dis[u]){  
                dis[u]=dis[x]+w[e];  
                if(vis[u]==0)q.push(make_pair(-dis[u],u));  
            }  
        }  
    }  
    return;  
}  
int main(){  
    memset(dis,0x3f,sizeof dis);  
    memset(vis,0,sizeof vis);  
    scanf("%lld %lld %lld %lld",&H,&a,&b,&c);  
    if(a>b)swap(a,b);  
    if(b>c)swap(b,c);  
    if(a>c)swap(a,c);  
    if(a>b)swap(a,b);//硬核排序  
    if(a==1){  
        printf("%lld\n",H);  
        return 0;  
    }//模数不能为1  
    for(ll i=0;i<a;++i){  
        add(i,(i+b)%a,b);  
        add(i,(i+c)%a,c);  
    }  
    dis[1]=1;//注意是从1楼开始爬的,0号点的意思是第a楼、第2a楼、……  
    dijkstra();  
    for(ll i=0;i<a;++i)  
        if(dis[i]<=H)  
            ans+=(H-dis[i])/a+1;  
    printf("%lld\n",ans);  
    return 0;  
}

P2371 墨墨的等式

看起来就跟上一道题一模一样,实际上也是一模一样。当 \(dis_i \in [l,r]\) (必须是左闭右闭区间)的时候解有 \(\lfloor{\frac{r-dis_i}{P}}\rfloor+1\) 个,而当 \(dis_i < l\) 的时候解有 \((\lfloor{\frac{r-dis_i}{P}}+1\rfloor)-(\lfloor{\frac{l-1-dis_i}{P}}\rfloor+1)\) 个(就是 \(r\) 楼及以下满足条件的数的数量减去 \(l\) 楼以下满足条件的数的数量),轻松切掉紫色水题。

这道题还有一个注意的地方,就是 \(dis\)\(l,r\) 区间关系分类时,\(dis_i=l\) 的情况必须放在 解有 \(\lfloor{\frac{r-dis_i}{P}}\rfloor+1\) 个的分类里面,不然要出锅。原因是因为 \(l-1\) 往前没有解了,但是如果放进另一个分类里的话,是默认 \(l-1\) 以前至少有一个解的。

#include<bits/stdc++.h>  
using namespace std;  
#define T 13  
#define N 500005  
#define M 6500065    //边的个数是 N * n 条  
#define ll long long  
ll n,l,r;  
ll tot,fir[N],nxt[M],to[M],w[M];  
ll a[T],P,ans;  
void add(ll x,ll y,ll z){  
	...(略)
}  
ll dis[N];  
bool vis[N];  
void dijkstra(){  
    priority_queue<pair<ll,ll> >q;  
    q.push(make_pair(0,0));  
    dis[0]=0;  
    while(!q.empty()){  
        ll x=q.top().second;  
        q.pop();  
        vis[x]=1;  
        for(ll e=fir[x];e;e=nxt[e]){  
            ll u=to[e];  
            if(dis[x]+w[e]<dis[u]){  
                dis[u]=dis[x]+w[e];  
                if(vis[u]==0){  
                    vis[u]=1;  
                    q.push(make_pair(-dis[u],u));  
                }  
            }  
        }  
    }  
    return;  
}  
int main(){  
    memset(dis,0x3f,sizeof dis);  
    memset(vis,0,sizeof vis);  
    scanf("%lld %lld %lld",&n,&l,&r);  
    for(int i=0;i<n;++i){  
        scanf("%lld",&a[i]);  
    }  
    sort(a,a+n);  
    P=a[0];  
    for(int i=1;i<n;++i){  
        for(int st=0;st<P;++st){  
            add(st,(st+a[i])%P,a[i]);  
        }  
    }  
    dijkstra();  
    for(int i=0;i<P;++i){  
        if(dis[i]>=l&&dis[i]<=r){    //这里的条件要注意  
            ans+=(r-dis[i])/P+1;  
        }else if(dis[i]<l){  
            ans+=(r-dis[i])/P-(l-1-dis[i])/P;  
        }  
    }  
    printf("%lld",ans);  
    return 0;  
}

ABC077D Small Multiple

这道题的笔记抄的 OI wiki

本题可以使用循环卷积优化完全背包在 \(O(n\log^2 n)\) 的时间内解决,但我们希望得到线性的算法。

观察到任意一个正整数都可以从 \(1\) 开始,按照某种顺序执行乘 \(10\) 加 \(1\) 的操作,最终得到 \(K\) ,而其中加 \(1\) 操作的次数就是这个数的数位和。这提示我们使用最短路。

对于所有 \(0\le k\le n-1\) ,从 \(k\) 向 \(10k\) 连边权为 \(0\) 的边;从 \(k\) 向 \(k+1\) 连边权为 \(1\) 的边。(点的编号均在模 \(K\) 意义下)

每个 \(K\) 的倍数在这个图中都对应了 \(1\) 号点到 \(0\) (其实是 \(K\),但是 \(K\%K=1\))号点的一条路径,求出 \(1\) 到 \(0\) 的最短路即可。某些路径不合法(如连续走了 \(10\) 条边权为 \(1\) 的边),但这些路径产生的答案一定不优,不影响答案。

时间复杂度 \(O(n)\)

#include<bits/stdc++.h>  
using namespace std;  
#define N 100005  
#define M 200005  
#define ll long long  
ll n;  
ll tot,fir[N],to[M],nxt[M],w[M];  
void add(ll x,ll y,ll z){  
	...(略)
}  
ll dis[N];  
bool vis[N];  
void dijkstra(){  
	...(略)
}  
int main(){  
    scanf("%lld",&n);  
    memset(dis,0x3f,sizeof dis);  
    memset(vis,0,sizeof vis);  
    for(int i=0;i<n;++i){  
        add(i,(i+1)%n,1);    //+1操作增加1的数位值  
        add(i,(i*10)%n,0);    //*10操作不增加数位值  
    }  
    dis[1]=1;                //1的数位值不是0而是1  
    dijkstra();  
    printf("%lld",dis[0]);    //到n的正整数倍的最短距离  
    return 0;  
}                            //Маленькая страна

SuperOJ - 诡异键盘 (外网链接)

这道题也是一道同于最短路的题,虽然不明显。

约定:按键上的字符串 \(s_1\)\(s_n\) 的长度为 \(len_i\)

显然我们可以将题目的添加和删除两种操作揉成一个操作:添加一段长度的前缀。这个前缀的长度从 \(1\)\(len_i\) 都可以。

添加前缀的操作等于一个字符串末尾加上一些字符串来垫背,再按下几次删除,删除掉垫背和这个字符串的一段后缀,留下这个字符串的一段指定长度前缀。

也就说,添加前缀 \(=\) 找垫背 \(+\) 找前缀


找垫背

我们发现删除相等长度的后缀可以用相同的字符串来做垫背。因此我们可以用 \(ms_i\) 来表示凑出长度为 \(i\) 的垫背字符串所需的步数。

那么要怎么算呢?没错,同余最短路。我们把长度为 \(0\) ~ \(k-1\) 的垫背算成 \(k\) 个点,\(ms_i\) 就是 \(0\) 号点到 \(i\) 号点的距离。

由于我们要将垫背的长度限制在 \(k\) 以内,所以每次添加一个字符串以后都要考虑删除,考虑向点 \(x\) 添加字符串 \(s_i\) ,相当于连边 $$x;{\xrightarrow{\huge{\lfloor{\frac{len_i+x}{k}}\rfloor+1}}};(x+len_i)%k$$
长度 \(+1\) 是因为添加一个字符串也算一个操作。

我们写一个带优先队列优化的 dijkstra ,然后就 TLE

\(\downarrow\) 低效的代码

#include<bits/stdc++.h>  
using namespace std;  
#define N 5005  
#define L 5005  
#define K 2005  
#define S 1000005  
#define P 175194560498040823ull                                               
#define pi pair<int,int>  
#define ll long long  
#define ull unsigned ll  
int T,n,kk,ms[K],len[N],dis[N],ans,f[N],pretime;  
ull hsh[N][L],hshof[L][L];  
char s[N][L],ss[L];  
bool vis[K];  
unordered_map<ull,int>mp;
void getms(){  
    memset(vis,0,sizeof vis);  
    ms[0]=0;  
    for(int i=1;i<kk;++i)ms[i]=0x2f2f2f2f;  
    priority_queue<pi >q;  
    q.push(make_pair(0,0));  
    while(!q.empty()){  
        int x=q.top().second;  
        vis[x]=1;  
        q.pop();  
        for(int i=1;i<=n;++i){  
            int newms=ms[x]+(len[i]+x)/kk+1;  
            int newi=(x+len[i])%kk;  
            if(ms[newi]>newms){  
                ms[newi]=newms;  
                if(vis[newi]==0)q.push(make_pair(-ms[newi],newi));  
            }  
        }  
    }  
    return;  
}

原来,这里的 dijkstra 复杂度是 \(O(m\log m)\) 的,而 \(m=nk\le 10^7\) ( \(k\) 个点,每个点可以通过 \(n\) 个字符串转移,也就是 \(n\) 条边) ,\(T\times 10^7 \log 10^7\) 肯定是要 TLE 的。

事实上,我们可以用未带优化的 \(O(n^2)\) 算法跑 dijkstra ,因为 \(n\le5000\) ,实在是乐死了。

\(\downarrow\) 高效的代码

void getms(){  
    memset(vis,0,sizeof vis);  
    ms[0]=0;  
    for(int i=1;i<kk;++i)ms[i]=0x2f2f2f2f;  
    ms[2003]=0x2f2f2f2f;  
    while(1){  
        int x=2003;  
        for(int i=0;i<kk;++i)  
            if(vis[i]==0&&ms[i]<ms[x])x=i;  
        if(x==2003)return;  
        vis[x]=1;  
        for(int i=1;i<=n;++i){  
            int newms=ms[x]+(len[i]+x)/kk+1;  //新的可能的ms
            int newi=(x+len[i])%kk;           //新的点的序号
            if(ms[newi]>newms)ms[newi]=newms;  
        }  
    }  
    return;  
}
void work(){  
    scanf("%d %d",&n,&kk);  
    ans=0;  
    for(int i=1;i<=n;++i){  
        scanf("%s",s[i]+1);  
        len[i]=strlen(s[i]+1);  
    }  
    scanf("%s",ss+1);  
    len[0]=strlen(ss+1);  
	...(略)
    getms();  
	...(略)
}

找前缀

现在我们可以实现得到高质量垫背了,接下来我们来实现得到高质量前缀:

我们先把 \(s\) 的前缀哈希掉,\(hsh[i][j]\) 代表 \(s[i]\) 的前 \(j\) 位的哈希值,\(rest\) 代表要删除的后缀的长度,我们用 \(mp[\,\text{hash}\,(s)]\) 来记录任一字符串 \(s\) 需要多少步才能打出来。

添加目标字符串的前 \(j\) 位,相当于先添加一个目标字符串,再添加一个长度为 \((k-(len_i-j)\%k)\%k\) 的垫背,再按下 \(\lfloor\frac{rest+k-1}{k}\rfloor\) 次删除。我们 \(O(n\sum_{1}^{n}len_i)\) 暴力枚举就行了。

为什么要取两次模呢?因为当 \((len_i-j)\%k=0\) 的时候,程序以为垫背的长度是 \((k-0)=k\) ,但其实长度是 \(0\) ,所以还要再取一次模。

void getmap(){  
    mp.clear();  
    for(int i=1;i<=n;++i){  
        for(int j=1;j<=len[i];++j){  
            int rest=len[i]-j;  
            if(mp[hsh[i][j]]==0)  
                mp[hsh[i][j]]=1+(rest+kk-1)/kk+ms[(kk-rest%kk)%kk];                        //bug #2 这里要取两次模  
            else  
                mp[hsh[i][j]]=min(mp[hsh[i][j]],1+(rest+kk-1)/kk+ms[(kk-rest%kk)%kk]);    //bug #2 这里要取两次模  
        }  
    }  
}

算答案

现在我们可以用这些前缀来拼接答案啦

我们设目标字符串为 \(ss\)\(ss\) 拼到第 \(i\) 位,且第 \(i\) 位后面不跟东西需要 \(f_i\) 步,那么就有递推式:

\[f_i=min(f_{j-1},mp[\,\text{hash}\,(ss[j : i])) \]

\(ss[j:i]\) 代表 \(ss\) 的第 \(j\)(包含) 至 \(i\)(包含) 位。

同样 \(O(len_{ss}^2)\) 暴力动归就行了

\(\downarrow\) 暴力的代码(这段代码不开 O2 慢得离谱,开了又快得离谱,我也不知道为啥 )

void getans(){  
    memset(f,0x2f,sizeof f);  
    f[0]=0;  
    for(int i=1;i<=len[0];++i)  
        for(int j=1;j<=i;++j)                //bug #1:这里可以取等  
            if(mp[hshof[j][i]]!=0)  
                f[i]=min(f[i],f[j-1]+mp[hshof[j][i]]);  
}  

因此将所有代码拼接起来就得到了正解:

展开
/*  
  bug:1.getans()中j循环的范围可以取等  
      2.getmap()中ms的下标要取两次模  
 */  
#include<bits/stdc++.h>  
using namespace std;  
#define N 5005  
#define L 5005  
#define K 2005  
#define S 1000005  
#define P 175194560498040823ull                                               
#define pi pair<int,int>  
#define ll long long  
#define ull unsigned ll  
int T,n,kk,ms[K],len[N],dis[N],ans,f[N],pretime;  
ull hsh[N][L],hshof[L][L];  
char s[N][L],ss[L];  
bool vis[K];  
unordered_map<ull,int>mp;  
void gethash(){  
    for(int i=1;i<=n;++i){  
        hsh[i][0]=0;  
        for(int j=1;j<=len[i];++j)  
            hsh[i][j]=(hsh[i][j-1]*131ull+(s[i][j]-'a'+1ull))%P;  
    }  
    for(int i=1;i<=len[0];++i){  
        ull tmp=0;  
        for(int j=i;j<=len[0];++j){  
            tmp=(tmp*131ull+ss[j]-'a'+1ull)%P;  
            hshof[i][j]=tmp;  
        }  
    }  
    return;  
}  
void getms(){  
    memset(vis,0,sizeof vis);  
    ms[0]=0;  
    for(int i=1;i<kk;++i)ms[i]=0x2f2f2f2f;  
    ms[2003]=0x2f2f2f2f;  
    while(1){  
        int x=2003;  
        for(int i=0;i<kk;++i)  
            if(vis[i]==0&&ms[i]<ms[x])x=i;  
        if(x==2003)return;  
        vis[x]=1;  
        for(int i=1;i<=n;++i){  
            int newms=ms[x]+(len[i]+x)/kk+1;  
            int newi=(x+len[i])%kk;  
            if(ms[newi]>newms)ms[newi]=newms;  
        }  
    }  
    return;  
}  
void getmap(){  
    mp.clear();  
    for(int i=1;i<=n;++i){  
        for(int j=1;j<=len[i];++j){  
            int rest=len[i]-j;  
            if(mp[hsh[i][j]]==0)  
                mp[hsh[i][j]]=1+(rest+kk-1)/kk+ms[(kk-rest%kk)%kk];                        //bug #2 这里要取两次模  
            else  
                mp[hsh[i][j]]=min(mp[hsh[i][j]],1+(rest+kk-1)/kk+ms[(kk-rest%kk)%kk]);    //bug #2 这里要取两次模  
        }  
    }  
}  
void getans(){  
    memset(f,0x2f,sizeof f);  
    f[0]=0;  
    for(int i=1;i<=len[0];++i)  
        for(int j=1;j<=i;++j)                //bug #1:这里可以取等  
            if(mp[hshof[j][i]]!=0)  
                f[i]=min(f[i],f[j-1]+mp[hshof[j][i]]);  
}  
void work(){  
    scanf("%d %d",&n,&kk);  
    ans=0;  
    for(int i=1;i<=n;++i){  
        scanf("%s",s[i]+1);  
        len[i]=strlen(s[i]+1);  
    }  
    scanf("%s",ss+1);  
    len[0]=strlen(ss+1);  
    gethash();  
    getms();  
    getmap();  
    getans();  
    if(f[len[0]]>=0x2f2f2f2f)  
        printf("-1\n");  
    else  
        printf("%d\n",f[len[0]]);  
    return;  
}  
int main(){  
    freopen("keyboard.in","r",stdin);  
    freopen("keyboard.out","w",stdout);  
    scanf("%d",&T);  
    while(T--)work();  
    return 0;  
}

以及一份 TLE 的代码:

展开
/*  
  bug:1.getans()中j循环的范围可以取等  
      2.竟然会TLE(?)  
 */  
#include<bits/stdc++.h>  
using namespace std;  
#define N 5005  
#define L 5005  
#define S 1000005  
#define K 2005  
#define P 175194560498040823ull                                             
#define pi pair<int,int>  
#define ll long long  
#define ull unsigned ll  
int T,n,kk,ms[K],len[N],dis[N],ans,f[S];  
ull hsh[N][N],hshof[N][L];  
char s[N][L],ss[S];  
bool vis[N];  
map<ull,int>mp;  
void gethash(){  
    for(int i=1;i<=n;++i){  
        hsh[i][0]=0;  
        for(int j=1;j<=len[i];++j)  
            hsh[i][j]=(hsh[i][j-1]*131ull+(s[i][j]-'a'+1ull))%P;  
    }  
    for(int i=1;i<=len[0];++i){  
        hshof[i][0]=0;  
        for(int j=i;j<=len[0];++j)  
            hshof[i][j]=(hshof[i][j-1]*131ull+ss[j]-'a'+1ull)%P;  
    }  
    return;  
}  
void getms(){  
    memset(vis,0,sizeof vis);  
    ms[0]=0;  
    for(int i=1;i<kk;++i)ms[i]=0x3f3f3f3f;  
    priority_queue<pi >q;  
    q.push(make_pair(0,0));  
    while(!q.empty()){  
        int x=q.top().second;  
        vis[x]=1;  
        q.pop();  
        for(int i=1;i<=n;++i){  
            int newms=ms[x]+(len[i]+x)/kk+1;  
            int newi=(x+len[i])%kk;  
            if(ms[newi]>newms){  
                ms[newi]=newms;  
                if(vis[newi]==0)q.push(make_pair(-ms[newi],newi));  
            }  
        }  
    }  
    return;  
}  
void getmap(){  
    mp.clear();  
    for(int i=1;i<=n;++i){  
        for(int j=1;j<=len[i];++j){  
            int rest=len[i]-j;  
            if(mp[hsh[i][j]]==0)mp[hsh[i][j]]=1+(rest+kk-1)/kk+ms[(kk-rest%kk)%kk];  
            else mp[hsh[i][j]]=min(mp[hsh[i][j]],1+(rest+kk-1)/kk+ms[(kk-rest%kk)%kk]);//这里记得要取模2次  
        }  
    }  
}  
void getans(){  
    memset(f,0x3f,sizeof f);  
    f[0]=0;  
    for(int i=1;i<=len[0];++i)  
        for(int j=1;j<=i;++j)                //bug #1:这里可以取等  
            if(mp[hshof[j][i]]!=0)  
                f[i]=min(f[i],f[j-1]+mp[hshof[j][i]]);  
}  
void work(){  
    scanf("%d %d",&n,&kk);  
    ans=0;  
    for(int i=1;i<=n;++i){  
        scanf("%s",s[i]+1);  
        len[i]=strlen(s[i]+1);  
    }  
    scanf("%s",ss+1);  
    len[0]=strlen(ss+1);  
    gethash();  
    getms();  
    getmap();  
    getans();  
    if(f[len[0]]>=0x3f3f3f3f)printf("-1\n");  
    else printf("%d\n",f[len[0]]);  
    return;  
}  
int main(){  
    freopen("keyboard.in","r",stdin);  
    freopen("keyboard.out","w",stdout);  
    scanf("%d",&T);  
    while(T--)work();  
    return 0;  
}