数位DP

发布时间 2023-11-10 01:59:45作者: Qiansui

数位DP

数位DP用于数字的数位统计,如果题目和数位统计有关,那么可以用数位DP思想,把低位的统计结果记录下来,在高位计算的时候直接使用低位的结果

感觉很抽象啊,而且还有点难啊啊啊。下面给出的几个例题再理解理解!

第一个例题 数位和

请问 0 到 9 这些数字在 \([l, r](1 \le l \le r \le 10 ^ {16})\) 的数位中分别出现了多少次。

显然,\(ans_{l, r} = ans_{1,r} - ans_{1, l - 1}\),则问题可转化为统计 0 到 9 在区间 [1, n] 中分别出现了多少次;

考虑将 [1, n] 中的数字分类,我们把小于 \(n\) 的数字 \(s\) 按照以下方法进行分类:如果 \(s\) 从左到右数第一个小于 \(n\) 对应位置的数位是第 \(i\) 位,那么 \(s\) 就会被分到第 i 类。

\(n = 12345\) 时,小于 \(n\) 的数字会被分成 5 类:

  • 第一类(第一位不同):[1, 9999]
  • 第二类(第二位不同):[10000,11999]
  • 第三类(第三位不同):[12000, 12299]
  • 第四类(第四位不同):[12300,12399]
  • 第五类(第五位不同):[12340,12344]

对于第 \(i\) 类数,数字 s 的前 \(i - 1\) 位与原数相同。问题可以转变为区间内每个位置等于 x 的数字的个数的和。

那么对于 \(n = a_1 a_2 \dots a_m)\) 时( \(a_i\) 表示 n 从左到右第 i 个数位上的值,其中 m 表示 n 是一个 m 位数),考虑第 \(i\) 类数,我们可以得知每个位置 j 的情况:

  • 对于 \(j < i\) 的情况,有 \(a_i \times 10^{m - i}\) (前面的数字均相等,后面第 i 位仅有 \(a_i\) 中选择,剩下的均有 0 ~ 9 每个位置一共十种选择)个数字的第 j 位是 \(a_j\)
  • 对于 \(j = i\) 的情况,\(a_j\) 可以是 0 到 \(a_i - 1\) 中选择,每种选择共有 \(10^{m - i}\) (仅考虑后面 m - i 个位置,每个位置 0 ~ 9) 种数字;
  • 对于 \(j > i\) 的情况,对于数字 x,一共有 m - i 个位置可以出现,选定一个位置后一共有 \(a_i \times 10^{m - i - 1}\) 个数字,所以每种数字 x 一共出现 \(a_i \times 10^{m - i - 1} \times (m - i)\) 个。

但是对于第一类数,这时会出现问题,因为会把前导 0 也算进去。那么第一类中 0 的出现次数得单独统计,我们可以枚举第一个非 0 位置,该位置后面的 0 才需要考虑进去。有两种情况:

  • 假如第一个非零位是第 1 位,满足 \(j > 1\) 的位置 j 是 0 的数字有 \((a_1 - 1) \times 10^{m - 2}\),又后面还有 m - 2 个位置,所以此时 0 的出现次数为 \((a_1 - 1) \times 10^{m - 2} \times (m - 1)\)

  • 假如第一个非零位是第 \(i (i \ne 1)\) 位,满足 \(j > i\) 的位置 j 是 0 的数字有 \(9 \times 10^{m - i - 1}\),又 i 后一共有 \(m - i\) 个位置,所以 0 的出现次数为 \(9 \times 10^{m - i - 1} \times (m - i)\)

至此,所有情况均考虑完毕。

int a[21];
ll l, r, ans[10], f[17];// f[i] 表示 10 ^ i

void calc(ll n, int xs){
	// 将数按数位放入数组中
	int m = 0;
	for(; n; n /= 10){
		a[++ m] = n % 10;
	}
	for(int i = 1, j = m; i < j; ++ i, -- j)// 因为从前往后考虑,所以需要翻转数组
		swap(a[i], a[j]);

	for(int i = 1; i <= m; ++ i){
		// 考虑 j < i 的情况
		for(int j = 1; j < i; ++ j)
			ans[a[j]] += xs * a[i] * f[m - i];
		// 考虑 j = i 的情况
		for(int j = 1; j < a[i]; ++ j)
			ans[j] += xs * f[m - i];
		if(i != 1 && a[i])
			ans[0] += xs * f[m - i];
		// 考虑 j > i 的情况
		if(m != i){
			for(int j = 1; j < 10; ++ j)
				ans[j] += xs * f[m - i - 1] * (m - i) * a[i];
			if(i != 1)
				ans[0] += xs * f[m - i - 1] * (m - i) * a[i];
		}
		// 单独考虑第一类对于 0 的贡献,找第一个非零位
		if(i == 1){
			// 第一个为非零位
			if(m >= 2)
				ans[0] += xs * (a[1] - 1) * (m - 1) * f[m - 2];
			// 第一个非零位为第 j 位
			for(int j = 2; j < m; ++ j)
				ans[0] += xs * 9 * (m - j) * f[m - j - 1];
		}
	}
	// 把 n 考虑进去
	for(int i = 1; i <= m; ++ i)
		ans[a[i]] += xs;
	return ;
}

void solve(){
	// 预处理 10 的幂次方
	f[0] = 1;
	for(int i = 1; i < 17; ++ i) f[i] = f[i - 1] * 10;
	cin >> l >> r;
	// ans[l, r] = ans[1, r] - ans[1, l - 1]
	calc(r, 1);
	calc(l - 1, -1);
	for(int i = 0; i < 10; ++ i)
		cout << ans[i] << ' ';
	return ;
}

第二个例题 数数

询问区间 [l, r] 中有多少个数字满足数字中任意两个相邻数位的差的绝对值不超过2?

依然,\(ans_{l, r} = ans_{1, r} - ans_{1, l - 1}\)。依旧先考虑 \([1, n)\) 的数字,再考虑 n 的贡献。

\(n = a_1a_2\dots a_n\),考虑前 i - 1位和 n 一样,前 i - 1 位是固定的。那么第 i 位的值 \(\in [0, a_i - 1]\),后面可以放的数任意。

\(f[i][j][0/1]\) 表示考虑了前 i 位,第 i 位的数字为 j,前 i 位是否全是 0(0表示全是,1表示不全是)的情况下有多少满足条件的数字。

  • 如果最后一维是 1,那么我们可以枚举第 i + 1 位的值 k,如果满足 \(|j - k| \le 2\),即可进行转移: \(f[i + 1][k][1] += f[i][j][1]\)
  • 如果最后一维是 0,那么有两种选择:第 i + 1 位是 0,\(f[i + 1][0][0] += f[i][0][0]\);第 i + 1 位不是 0,那么可以 1~9 枚举 k,\(f[i + 1][k][1] += f[i][0][0]\)

最后答案等于 \(\displaystyle \sum_{x = 0}^{9} f[m][x][1]\)

//未验证,仅摘抄
int a[21];
ll l, r, f[21][10][2];

ll calc(ll n){
	// n = 0 时满足条件的个数也等于 0
	if(n == 0) return 0;
	int m = 0;
	while(n){
		a[++ m] = n % 10;
		n /= 10;
	}
	for(int i = 1, j = m; i < j; ++ i, -- j) swap(a[i], a[j]);
	ll res = 0;
	bool ok = true;
	for(int i = 1; i <= m && ok; ++ i){// 第 i 类数字
		// 枚举第 i 位数字 j
		for(int j = 0; j < a[i]; ++ j){
			if(i != 1 && abs(j - a[i - 1]) > 2) continue;
			// 设置动态规划的初始状态
			mem(f, 0);
			if(i == 1 && j == 0) f[i][j][0] = 1;
			else f[i][j][1] = 1;
			// 枚举后面位置上的值
			for(int k = i + 1; k <= m; ++ k){// 利用前 k - 1 个位置推前 k 个位置的状态
				for(int l = 0; l < 10; ++ l){// 第 k - 1 位的值
					for(int r = 0; r < 2; ++ r){
						if(f[k - 1][l][r]){
							for(int x = 0; x < 10; ++ x){// 枚举第 k 个位置的值
								// r != 0 时后面放的数需要和前面的差的绝对值小于等于2
								if(r && abs(l - x) <= 2)
									f[k][x][r] += f[k - 1][l][r];
								if(!r){// r = 0 后面随意放数即可
									// x = 0 时延续前面全 0
									if(!x) f[k][0][0] += f[k - 1][0][0];
									else f[k][x][1] += f[k - 1][0][0];
								}
							}
						}
					}
				}
			}
			for(int j = 0; j < 10; ++ j) res += f[m][j][1];
		}
		if(i != 1 && abs(a[i] - a[i - 1]) > 2) ok = false;
	}
	if(ok) ++ res;
	return res;
}

void solve(){
	cin >> l >> r;
	cout << calc(r) - calc(l - 1) << '\n';
	return ;
}

上面的代码写法是 \(O(m ^ 2)\) 的,时间复杂度较高

可以在DP的时候多开一个维度解决这个问题。定义 \(f[i][j][0/1][0/1]\) 表示考虑了前 i 位,第 i 位的数字是 j ,前 i 位的数字是否全是 0(0表示全零,1表示不全零),这个数字的前 i 位和 n 的前 i 位是否相等(0 表示不相等,1 表示相等)的情况下有多少满足条件的数字。

  • 如果最后一维是 0,后面想放几放几;
  • 否则,有两种选择:放 \([0, a_{i + 1} - 1]\),最后一维变成 0;放 \(a_{i + 1}\),最后一维仍然是 1;

最后答案等于 \(\displaystyle \sum_{x = 0}^{9} (f[m][x][1][0] + f[m][x][1][1])\)

int a[21];
long long l, r, f[21][10][2][2];

long long calc(long long n) {
    if (!n)
        return 0;
    int m = 0;
    for (; n; n /= 10)
        a[++m] = n % 10;
    for (int i = 1, j = m; i < j; i++, j--)
        swap(a[i], a[j]);
    long long res = 0;
    memset(f, 0, sizeof(f));
	// 初始状态,考虑了 0 个数,相当于前面都是前导 0
    f[0][0][0][1] = 1;
    for (int i = 1; i <= m; i++)// 考虑每一类数
        for (int j = 0; j < 10; j++)
            for (int k = 0; k < 2; k++)
                for (int l = 0; l < 2; l++)
                    if (f[i - 1][j][k][l])
                        for (int x = 0; x < 10; x++) {
							// l = 1 时前 i 位均相等,此时新位必须严格小于
                            if (l && x > a[i])
                                continue;
                            if (l) {// l = 1 时前 i 位均相等
                                if (x < a[i])// 此时最后一维为 0
                                    if (!k) {
                                        if (!x)
                                            f[i][0][0][0] += f[i - 1][j][k][l];
                                        else
                                            f[i][x][1][0] += f[i - 1][j][k][l];
                                    } else {
                                        if (abs(j - x) <= 2)
                                            f[i][x][1][0] += f[i - 1][j][k][l];
                                    }
                                else// 此时最后一维为 1
                                    if (!k)
                                        f[i][x][1][1] += f[i - 1][j][k][l];
                                    else
                                        if (abs(j - x) <= 2)
                                            f[i][x][1][1] += f[i - 1][j][k][l];
                            } else {// l = 0 时前 i 位不相等,后面可以任意放
                                if (!k) {// k = 0 时,前面全是 0
                                    if (!x)// x = 0 则延续 0
                                        f[i][0][0][0] += f[i - 1][j][k][l];
                                    else
                                        f[i][x][1][0] += f[i - 1][j][k][l];
                                } else
									// k = 1 则需要放入绝对值小于等于2的数
                                    if (abs(j - x) <= 2)
                                        f[i][x][1][0] += f[i - 1][j][k][l];
                            }
                        }
    for (int i = 0; i < 10; i++)
        res += f[m][i][1][0] + f[m][i][1][1];
    return res;
}

void solve(){
	cin >> l >> r;
	cout << calc(r) - calc(l - 1) << '\n';
	return ;
}

相关资料

例题

综合运用

模板题

const int N = 21 + 5, inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f, mod = 998244353;
int num[N], now;// now 表示当前统计的是哪一个数字
ll dp[N][N];// dp[i][j] 表示数字后面i个数字里面有j个now时,共有多少个数字now

// pos表示当前处理到 pos 位,sum表示前面共有的now的数字的个数
// lead 为 true 时表示有前导零,limit为 true 时表示数字范围有限制
ll dfs(int pos, int sum, bool lead, bool limit){
	if(pos == 0) return sum;// 所有数位均处理完了
	if(!lead && !limit && dp[pos][sum] != -1) return dp[pos][sum];
	int up = (limit ? num[pos] : 9);
	ll ans = 0;
	for(int i = 0; i <= up; ++ i){
		// 延续前导 0
		if(i == 0 && lead) ans += dfs(pos - 1, sum, true, limit && i == up);
		// 考虑当前位多一个now
		else if(i == now) ans += dfs(pos - 1, sum + 1, false, limit && i == up);
		// 考虑剩下的数字
		else if(i != now) ans += dfs(pos - 1, sum, false, limit && i == up);
	}
	if(!lead && !limit) dp[pos][sum] = ans;// 记忆化搜索
	return ans;
}

ll calc(ll n){
	int len = 0;
	while(n){
		num[++ len] = n % 10;
		n /= 10;
	}
	mem(dp, -1);
	return dfs(len, 0, true, true);// 从len开始即从高位向低位进行处理
}

void solve(){
	ll l, r;
	cin >> l >> r;
	for(int i = 0; i < 10; ++ i){
		now = i;
		cout << calc(r) - calc(l - 1) << ' ';
	}
	return ;
}
const int N = 20 + 5, inf = 0x3f3f3f3f;
int a[N], dp[N][N], now;
// dp[i][j] 表示数字长度为 i,前一位是 j 的情况下无数位限制的 Windy数总数

int dfs(int pos, int last, int lead, int limit){
	if(pos == 0) return 1;
	if(!lead && !limit && dp[pos][last] != -1) return dp[pos][last];
	int ans = 0;
	int up = limit ? a[pos] : 9;
	for(int i = 0; i <= up; ++ i){
		if(abs(i - last) < 2) continue;// 不符合条件
		// 延续前导0
		if(lead && i == 0) ans += dfs(pos - 1, -2, true, limit && i == up);
		// 遍历剩下的情况
		else ans += dfs(pos - 1, i, false, limit && i == up);
	}
	if(!lead && !limit) dp[pos][last] = ans;
	return ans;
}

int calc(int n){
	int len = 0;
	while(n){
		a[++ len] = n % 10;
		n /= 10;
	}
	mem(dp, -1);
	return dfs(len, -2, true, true);
}

void solve(){
	int a, b;
	cin >> a >> b;
	cout << calc(b) - calc(a - 1) << '\n';
	return ;
}
const int N = 20 + 5, inf = 0x3f3f3f3f;
int num[21];
// dp[pos][u][v][state][n8][n4]表示数字长度为pos,前一位数字为u,前前位数字为v
// state表示是否出现了3个连续相同数字,n8 n4 表示是否出现数字8 数字4
ll dp[21][10][10][2][2][2];

ll dfs(int pos, int u, int v, bool state, bool n8, bool n4, bool limit){
	if(n8 && n4) return 0;
	if(pos == 0) return state;
	if(!limit && dp[pos][u][v][state][n8][n4] != -1)
		return dp[pos][u][v][state][n8][n4];
	int up = limit ? num[pos] : 9;
	ll ans = 0;
	for(int i = 0; i <= up; ++ i)
		ans += dfs(pos - 1, i, u, state || (i == u && i == v), n8 || (i ==8), n4 || i == 4, limit && i == up);
	if(!limit) dp[pos][u][v][state][n8][n4] = ans;
	return ans;
}

ll calc(ll n){
	int len = 0;
	while(n){
		num[++ len] = n % 10;
		n /= 10;
	}
	if(len != 11) return 0;// 手机号必须是11位数字
	mem(dp, -1);
	ll ans = 0;
	for(int i = 1; i <= num[len]; ++ i)// 枚举首位,避开前导0
		ans += dfs(len - 1, i, 0, 0, i == 8, i == 4, i == num[len]);
	return ans;
}

void solve(){
	ll l, r;
	cin >> l >> r;
	cout << calc(r) - calc(l - 1) << '\n';
	return ;
}