【题解】P4593 [TJOI2018] 教科书般的亵渎

发布时间 2023-09-02 11:24:09作者: kymru

之前整理的时候忘记写,现在补上。

思路

拉插求自然数幂和。

关于自然数幂和 \(\sum\limits_{i = 1}^n i^k\),已知是关于 \(n\)\(k + 1\) 次多项式,可以用伯努利数 \(O(k \log k)\) 求,也可以直接拉插 \(O(k)\) 求。

拉插结论:若一个 \(n\) 次多项式 \(f\) 经过 \(n + 1\) 个点,则 \(f(x) = \sum\limits_{i = 1}^{n + 1} y_i \prod\limits_{j \neq i} \frac{x - x_j}{x_i - x_j}.\)

于是只需要求出 \(f = \sum\limits_{i = 1}^n i^k\)\(n\)\(1\)\(k + 2\) 时的点值,就可以拉插计算得到答案。

回到这题,考虑到每次「亵渎」作用于全局,可以猜到最终使用「亵渎」的次数是一定的,观察得到使用次数 \(k = m + 1\).

假设没有不存在的血量,最终的分数是 \(\sum\limits_{i = 1}^n \sum\limits_{j = 1}^{n - i + 1} j^k\).

不存在的血量意味着需要在此处多使用一次亵渎,并且这里无法贡献分数,在上面的基础上容斥。方便起见,认为在 \(0\) 处有一头不存在的怪兽,编号为 \(a_0\)。最终答案是:

\(\sum\limits_{i = 0}^m \sum\limits_{j = 1}^{n - a_i} j^k - \sum\limits_{j = i + 1} (a_j - a_i)^k\).

朴素的复杂度是 \(O(m^2 \log |V| + mk)\).

代码

#include <cstdio>
#include <algorithm>
using namespace std;

#define int long long

const int maxn = 3.5e6 + 5;
const int maxm = 50 + 5;
const int mod = 1e9 + 7;

int t, n, m, k;
int a[maxm];
int pre[maxn], suf[maxn], fac[maxn];

int qpow(int base, int power)
{
	int res = 1;
	while (power)
	{
		if (power & 1) res = 1ll * res * base % mod;
		base = 1ll * base * base % mod, power >>= 1;
	}
	return res;
}

int solve(int n, int k)
{
	if (n <= 0) return 0;
	int y = 0, res = 0;
	pre[0] = fac[0] = suf[k + 3] = 1;
	for (int i = 1; i <= k + 2; i++) pre[i] = 1ll * pre[i - 1] * (n - i) % mod;
	for (int i = k + 2; i >= 1; i--) suf[i] = 1ll * suf[i + 1] * (n - i) % mod;
	for (int i = 1; i <= k + 2; i++) fac[i] = 1ll * fac[i - 1] * i % mod;
	for (int i = 1; i <= k + 2; i++)
	{
		y = (y + qpow(i, k)) % mod;
		int a = 1ll * pre[i - 1] * suf[i + 1] % mod;
		int b = fac[i - 1] * ((k - i) & 1 ? -1ll : 1ll) * fac[k + 2 - i] % mod;
		res = (res + 1ll * y * a % mod * qpow(b, mod - 2) % mod) % mod;
	}
	res = (res % mod + mod) % mod;
	return res;
}

signed main()
{
	scanf("%lld", &t);
	a[0] = 0;
	while (t--)
	{
		scanf("%lld%lld", &n, &m), k = m + 1;
		for (int i = 1; i <= m; i++) scanf("%lld", &a[i]);
		sort(a + 1, a + m + 1);
		int ans = 0;
		for (int i = 0; i <= m; i++)
		{
			ans = (ans + solve(n - a[i], k)) % mod;
			for (int j = i + 1; j <= m; j++) ans = ((ans - qpow(a[j] - a[i], k)) % mod + mod) % mod;
		}
		printf("%lld\n", ans);
	}
	return 0;
}