最短路计数

发布时间 2023-12-24 09:15:50作者: CQWYB

前置知识

  • 最短路的一个很好的性质:从\(s\)\(t\)的最短路上的一个节点\(k\),都满足\(s\)\(k\)的路径是关于\(s\)单源最短路的最短路

证明:

反证法,假设\(s\)\(k\)的路径不为最短路,但\(s \to k \to t\)为到\(t\)的最短路,那么\(s \to k \to t\)的路径一定不会比\(s\)\(k\)的最短路再加上\(k \to t\)的路径优,所以\(s \to t\)就不为最短路,与上文假设条件矛盾。

故得证。

  • \(s\)\(t\)的最短路径上一定不存在环

证明:

考虑反证法,因为这个图的边权非负,所以环的权值为非负数,所以一定不会比不经过这个环更优。

故得证。

注: 上述结论是在图的边权非负时才成立的,如果图的边权为负数上述结论就不一定成立了


例题[HAOI2012] 道路

直接队每一条边进行枚举肯定会\(TLE\),但如果只计算边的贡献就好像可以优化一点时间复杂度。

\(cnt1_v\)表示从\(s\)到达\(v\)的最短路径条数,\(cnt2_v\)表示在最短路图上以\(v\)作为起点的最短路径条数。

很显然,对于一条边\(u \to v\),它的贡献就为以\(1 \sim n\)每个点作为起点\(cnt1_u \times cnt2_v\)的和。

\(cnt1\)可以在\(dijkstra\)后进行判断哪些边在最短路图上再进行\(topo\)求得,显然\(cnt1_s=1\)

\(cnt2\)可以在拓扑序的逆序上求得,对于一个点\(u\)\(cnt2_u=1+cnt2_v\),其中\(v\)\(u\)的相邻节点。

代码如下

#include<bits/stdc++.h>
using namespace std;
const int N=1550,M=5050;
const int mod=1e9+7;
int n,m,head[N],cnt,to[M],nxt[M],w[M],dis[N],fro[M],in[N],cnt1[N],cnt2[N],s[N],ans[M];
bool vis[N],mark[M];
void add(int u,int v,int f)
{
	to[++cnt]=v;
	nxt[cnt]=head[u];
	head[u]=cnt;
	w[cnt]=f;
	fro[cnt]=u;
}
struct node{
	int val,pos;
	bool operator >(const node &x)const{
		return val>x.val;
	}
};
void dij(int s)
{
	priority_queue<node,vector<node>,greater<node> >q;
	memset(dis,0x3f,sizeof(dis));
	memset(vis,0,sizeof(vis));
	memset(mark,0,sizeof(mark));
	dis[s]=0;
	q.push(node{dis[s],s});
	while(!q.empty())
	{
		node t=q.top();q.pop();
		if(vis[t.pos])continue;
		vis[t.pos]=1;
		for(int i=head[t.pos];i;i=nxt[i])
		{
			int v=to[i];
			if(dis[v]>dis[t.pos]+w[i])
			{
				dis[v]=dis[t.pos]+w[i];
				q.push(node{dis[v],v});
			}
		}
	}
	for(int i=1;i<=m;i++)
		if(dis[to[i]]==dis[fro[i]]+w[i])mark[i]=true;
	return;
}
void topo(int fs)
{
	memset(cnt1,0,sizeof(cnt1));
	memset(cnt2,0,sizeof(cnt2));
	memset(in,0,sizeof(in));
	queue<int>q;
	for(int i=1;i<=m;i++)
	{
		if(mark[i]==false)continue;
		in[to[i]]++;
	}
	q.push(fs);
	cnt1[fs]=1;
	int tag=0;
	while(!q.empty())
	{
		int x=q.front();q.pop();
		s[++tag]=x;
		for(int i=head[x];i;i=nxt[i])
		{
			if(mark[i]==false)continue;
			in[to[i]]--;
			if(in[to[i]]==0)q.push(to[i]);
			cnt1[to[i]]+=cnt1[x];
			cnt1[to[i]]%=mod;
		}
	}
	for(int i=tag;i;i--)
	{
		int x=s[i];
		cnt2[x]++;
		for(int j=head[x];j;j=nxt[j])
		{
			if(mark[j]==false)continue;
			cnt2[x]+=cnt2[to[j]];
			cnt2[x]%=mod;
		}
	}
}
int main()
{
	scanf("%d %d",&n,&m);
	for(int i=1,u,v,w;i<=m;i++)
	{
		scanf("%d %d %d",&u,&v,&w);
		add(u,v,w);
	}
	for(int i=1;i<=n;i++)
	{
		dij(i);topo(i);
		for(int i=1;i<=m;i++)
		{
//		cout<<cnt1[fro[i]]<<" "<<cnt2[to[i]]<<endl;
			if(mark[i])
			{
				ans[i]+=(cnt1[fro[i]]%mod*cnt2[to[i]]%mod)%mod;
				ans[i]%=mod;
			}
		}
	}
	for(int i=1;i<=m;i++)printf("%d\n",ans[i]);
	return 0;
}

时间复杂度\(O(nmlogm)\)