【数学】LGV 引理

发布时间 2023-11-22 19:55:33作者: The_Last_Candy

题目描述

这是一道模板题。

有一个 \(n\times n\) 的棋盘,左下角为 \((1,1)\),右上角为 \((n,n)\),若一个棋子在点 \((x,y)\),那么走一步只能走到 \((x+1,y)\)\((x,y+1)\)

现在有 \(m\) 个棋子,第 \(i\) 个棋子一开始放在 \((a_i,1)\),最终要走到 \((b_i,n)\)。问有多少种方案,使得每个棋子都能从起点走到终点,且对于所有棋子,走过路径上的点互不相交。输出方案数 \(\bmod\ 998244353\) 的值。

两种方案不同当且仅当存在至少一个棋子所经过的点不同。

\(1 \leq n \leq 10^6,1 \leq m \leq 100\)

算法描述

LGV 引理,用于解决 DAG 上一个起点集合到一个终点集合不交路径的统计问题。

下面给出定义:

首先有一个排列 \(p\)\(p_i\) 表示起点 \(a_i\) 最后走到终点 \(b_{p_i}\)

\(w(i)\) :表示从 \(a_i\)\(b_{p_i}\) 经过边权的乘积。

\(\sigma(p)\) :排列 \(p\) 的逆序对数。

\(P\) : 一个路径集合,代表对于当前排列 \(p\) ,每一个 \(a_i\) 都找到了一条 \(a_i \to b_{p_i}\) 的路径。

定义 \(P\) 的权值为 \(\sum_i w(i)\)

\(e(a_i,b_j)\) 代表从 \(a_i\)\(b_j\) 所有路径 \(w\) 的和。

引理:

\[det\left | \begin{matrix} e(a_1,b_1) & e(a_1,b_2) & \dots & e(a_1,b_n)\\ \dots\\ e(a_n,b_1) & e(a_n,b_2) & \dots & e(a_n,b_n) \end{matrix} \right | = \sum_{P : A \to B,p}(-1)^{\sigma (p)} \prod_{i = 1}^n w(i) \]

也就是这个矩阵的行列式等于所有 \(A \to B\) 不交路径集权值的带权和,权值为 \(-1\) 的逆序对数次方。

证明不会。

也就是说,右边这个所有路径集合的抽象统计问题被转化成了左边这个东西的行列式,而左边的东西只需要求出一对点之间所有路径的权值和即可,有些时候这是可求的。

行列式怎么求自行搜索。

但是由于右边这个排列的逆序对数次方太过抽象,通常我们只用弱化版的结论,比如说这道模板题里面只有 \(\sigma (p) = 0\) 的时候才统计入答案,因为只要 \(a_i\) 不对应 \(b_i\) ,路线一定会相交。

对于本题,也就是计数问题,我们发现一条路径贡献 \(1\) 的答案,边权要乘起来,直接赋为 \(1\) 即可。

至于 \(a_i\)\(b_j\) 路径条数,没有了限制,直接组合数算就行了,是 \([a_i \leq b_j]\binom {n - 1 + b_j - a_i}{b_j - a_i}\)

注意起点数量要和终点数量相等。

套 LGV 模板即可。

#include<bits/stdc++.h>
using namespace std;
const int N = 2e6 + 5,M = 105,MOD = 998244353;
typedef long long ll;
ll frac[N],inv[N];
int n,m,a[M],b[M],c[M][M];
inline ll C(int y,int x)
{
	if(y < 0 || x < 0 || y < x) return 0;
	return frac[y] * inv[x] % MOD * inv[y - x] % MOD;
}
inline ll ksm(ll base,int pts)
{
	ll ret = 1;
	for(;pts > 0;pts >>= 1,base = base * base % MOD)
		if(pts & 1)
			ret = ret * base % MOD;
	return ret;
}
inline ll getdet()
{
	ll ret = 1;
	for(int i = 1;i <= m;i++)
	{
		int tmp = i;
		while(tmp <= m && c[tmp][i] == 0) ++tmp;
		if(tmp == m + 1) return 0ll;
		swap(c[tmp],c[i]);
        if(tmp ^ i) ret = MOD - ret;
		ret = ret * c[i][i] % MOD;
		for(int j = i + 1;j <= m;j++)
		{
			ll rate = c[j][i] * ksm(c[i][i],MOD - 2) % MOD;
			for(int k = 1;k <= m;k++) c[j][k] = (c[j][k] - c[i][k] * rate % MOD + MOD) % MOD;
		}
	}
	return ret;
}
int main()
{
	frac[0] = inv[0] = 1;
	for(int i = 1;i < N;i++) frac[i] = frac[i - 1] * i % MOD;
	inv[N - 1] = ksm(frac[N - 1],MOD - 2);
	for(int i = N - 2;i >= 1;i--) inv[i] = inv[i + 1] * (i + 1) % MOD;
	int T;
	cin>>T;
	while(T--)
	{
		cin>>n>>m;
		for(int i = 1;i <= m;i++) cin>>a[i]>>b[i];
		for(int i = 1;i <= m;i++)
			for(int j = 1;j <= m;j++)
			{
				if(a[i] > b[j]) c[i][j] = 0;
				else c[i][j] = C(n - 1 + b[j] - a[i],b[j] - a[i]);
			}
		ll ans = getdet();
		cout<<ans<<endl;
	}
	return 0;
}

本人自认为下面这道题更好地利用了 LGV 引理的性质:

[NOI2021] 路径交点

题面太长,不在这里挂。

考虑要求从第 \(1\) 层走到第 \(k\) 层的路径集合的交点数量,并且点不重复。

没法做诶,起终点关系又没法直接和中间的交点数量对应...

再读一边题,发现求的是奇偶性个数统计。

观察起终点位置关系对奇偶性的影响,发现一旦 \(a_i,a_j\) 是起点,对应 \(b_{p_i},b_{p_j}\)\(p_i,p_j\) 之间构成了逆序对的话,一定有奇数个交点。

否则一定是偶数个。(自己画图研究可以发现)

所以 LGV 引理中 "逆序对数量奇偶性" 就派上了作用,容易发现逆序对数量奇偶性,就是交点数量奇偶性。

所以 LGV 引理即可,至于 \(a_i\)\(b_j\) 的路径数量,可以 bfs 解决,时间复杂度 \(\Theta(n^3)\)

#include<bits/stdc++.h>
using namespace std;
const int N = 205,MOD = 998244353;
typedef long long ll;
ll c[N][N];
int k,n[N],m[N],w[N][N][N],num[N][N];
inline ll ksm(ll base,int pts)
{
	ll ret = 1;
	for(;pts > 0;pts >>= 1,base = base * base % MOD)
		if(pts & 1)
			ret = ret * base % MOD;
	return ret;
}
inline void bfs(int S)
{
	queue <pair<int,int> > q;
	q.push(make_pair(1,S));
	num[1][S] = 1;
	while(!q.empty())
	{
		int x = q.front().second,d = q.front().first; q.pop();
		if(d == k) continue;
		for(int i = 1;i <= n[d + 1];i++) 
			if(w[d][x][i])
			{
				if(!num[d + 1][i]) q.push(make_pair(d + 1,i));
				num[d + 1][i] = (num[d + 1][i] + num[d][x]) % MOD;
			}
	}
}
inline ll getdet()
{
	ll ret = 1;
	for(int i = 1;i <= n[1];i++)
	{
		int tmp = i;
		while(tmp <= n[1] && c[tmp][i] == 0) ++tmp;
		if(tmp == n[1] + 1) return 0ll;
		for(int j = 1;j <= n[1];j++) swap(c[tmp][j],c[i][j]);
		if(tmp ^ i) ret = MOD - ret;
		ret = ret * c[i][i] % MOD;
		for(int j = i + 1;j <= n[1];j++)
		{
			ll rate = c[j][i] * ksm(c[i][i],MOD - 2) % MOD;
			for(int k = 1;k <= n[1];k++) c[j][k] = (c[j][k] - c[i][k] * rate % MOD + MOD) % MOD;
		}
	}
	return ret;
}
int main()
{
	int T;
	cin>>T;
	while(T--)
	{
		memset(w,0,sizeof(w));
		cin>>k;
		for(int i = 1;i <= k;i++) cin>>n[i];
		for(int i = 1;i <= k - 1;i++) cin>>m[i];
		for(int i = 1;i <= k - 1;i++)
			for(int j = 1,x,y;j <= m[i];j++)
			{
				cin>>x>>y;
				w[i][x][y] = 1;
			}
		for(int i = 1;i <= n[1];i++)
		{
			memset(num,0,sizeof(num));
			bfs(i);
			for(int j = 1;j <= n[k];j++) c[i][j] = num[k][j];
		}
		ll res = getdet();
		cout<<res<<endl;
	}
	return 0;
 }