数位DP

发布时间 2023-04-15 12:01:15作者: ShadowAA

数位DP

数位是指把一个数字按照个、十、百、千等等一位一位地拆开,关注它每一位上的数字。如果拆的是十进制数,那么每一位数字都是 0~9,其他进制可类比十进制。

数位 DP:用来解决一类特定问题,这种问题比较好辨认,一般具有这几个特征:

  1. 要求统计满足一定条件的数的数量(即,最终目的为计数);
  2. 这些条件经过转化后可以使用「数位」的思想去理解和判断;
  3. 输入会提供一个数字区间(有时也只提供上界)来作为统计的限制;
  4. 上界很大(比如 10^{18}),暴力枚举验证会超时。

例题

不要62

https://loj.ac/p/10167

#include<bits/stdc++.h>
using namespace std;
int n,m,a[100],f[100][3];
int sc(int len,int t,int ff)
{
	int end,i;
	if (len==0)
		return 1;
	if ((ff)and(f[len][t]!=-1))
		return f[len][t];
	ff==false?end=a[len]:end=9;
	int s=0;
	for (i=0;i<=end;i++)
	{
		if (i==4) continue;
		if (i==6)
			s=s+sc(len-1,1,ff or i<a[len]);
		else
			if (!((i==2)and(t==1)))
				s=s+sc(len-1,0,ff or i<a[len]);
	}
	if (ff)
		f[len][t]=s;
	return s;
}
int solve(int x)
{
	int len=0;
	while (x>0)
	{
		a[++len]=x%10;
		x=x/10;
	}
	memset(f,-1,sizeof(f));
	return sc(len,0,0);
}
int main()
{
	while (cin>>n>>m)
	{
		if ((n==0)and(m==0))
			break;
		cout<<solve(m)-solve(n-1)<<endl;
	}
}

数字计数

统计n到m中所有数位出现的次数

https://loj.ac/p/10169

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll n,m,i,g[20],c[20],f[100][2][20],a[100];
ll ff[20];
void build()
{
	ll x=1;
	for (int i=0;i<=12;i++)
	{
		ff[i]=x;
		x=x*10;
	}
}
void sc(int len,int t,int w,ll x)
{
	if (len==0)
		return;
	if (t and(f[len][w][0]!=-1))
	{
		for (int i=0;i<=9;i++)
			g[i]=g[i]+f[len][w][i];
		return;
	}
	int end,i;ll s=0;
	t==0?end=a[len]:end=9;
	ll b[10];
	for (i=0;i<=9;i++)
		b[i]=g[i];
	for (i=0;i<=end;i++)
	{
		sc(len-1,t or i<a[len],w or i>0,x);
		if (!((w==0)and(i==0))or(len==1))
		{
			if (t or i<a[len])
				g[i]=g[i]+ff[len-1];
			else g[i]=g[i]+x%ff[len-1]+1;
		}
	}
	if (t)
		for (i=0;i<=9;i++)
			f[len][w][i]=g[i]-b[i];
}
void solve(ll x)
{
	ll len=0,y=x;
	while (x>0)
	{
		a[++len]=x%10;
		x=x/10;
	}
	memset(g,0,sizeof(g));
	if (len==0) g[0]=1;
	else sc(len,0,0,y);
}
int main()
{
	cin>>n>>m;
	build();
	memset(f,-1,sizeof(f));
	solve(m);
	for (i=0;i<=9;i++)
		c[i]=g[i];
	solve(n-1);
	for (i=0;i<=8;i++)
		cout<<c[i]-g[i]<<' ';
	cout<<c[9]-g[9]<<endl;
}

恨7不是妻

https://vjudge.net/problem/LibreOJ-10168

题解:https://www.cnblogs.com/graytido/p/12202754.html

#include <bits/stdc++.h>
using namespace std;
/*    freopen("k.in", "r", stdin);
    freopen("k.out", "w", stdout); */
//clock_t c1 = clock();
//std::cerr << "Time:" << clock() - c1 <<"ms" << std::endl;
//#pragma comment(linker, "/STACK:1024000000,1024000000")
#define de(a) cout << #a << " = " << a << endl
#define rep(i, a, n) for (int i = a; i <= n; i++)
#define per(i, a, n) for (int i = n; i >= a; i--)
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef pair<double, double> PDD;
typedef vector<int, int> VII;
#define inf 0x3f3f3f3f
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll MAXN = 1e6 + 7;
const ll MAXM = 1e6 + 7;
const ll MOD = 1e9 + 7;
const double eps = 1e-6;
const double pi = acos(-1.0);
int a[105];
/*   1、整数中某一位是7;
  2、整数的每一位加起来的和是7的整数倍;
  3、这个整数是7的整数倍; */
struct node
{
    ll sum;   //与7无关的数的个数
    ll qsum;  //与7无关的数和
    ll sqsum; //ans
    node(ll _sum = -1, ll _qsum = 0, ll _sqsum = 0) { sum = _sum, qsum = _qsum, sqsum = _sqsum; }
};
node dp[30][15][15];
ll c[20];
node dfs(int pos, int sta1, int sta2, bool lim) //sta1各位数和%7 sta2前面%7
{
    if (pos < 0)
        return node(sta1 && sta2, 0, 0);
    if (!lim && dp[pos][sta1][sta2].sum != -1)
        return dp[pos][sta1][sta2];
    int up = lim ? a[pos] : 9;
    node ret = node(0, 0, 0);
    for (int i = 0; i <= up; i++)
    {
        if (i != 7)
        {
            node t = dfs(pos - 1, (sta1 + i) % 7, (sta2 * 10 + i) % 7, lim && i == a[pos]);
            ret.sum += t.sum;
            ret.sum %= MOD;
            ret.qsum += (((c[pos] * i % MOD) * t.sum % MOD) + t.qsum) % MOD;
            ret.qsum %= MOD;
            ret.sqsum += t.sqsum % MOD;
            ret.sqsum %= MOD;
            ret.sqsum += ((i * i * c[pos] % MOD) * c[pos] % MOD) * t.sum % MOD;
            ret.sqsum %= MOD;
            ret.sqsum += ((i * 2 * c[pos] % MOD) * t.qsum) % MOD;
            ret.sqsum %= MOD;
        }
    }
    if (!lim)
        dp[pos][sta1][sta2] = ret;
    return ret;
}
ll solve(ll x)
{
    int pos = -1;
    while (x)
    {
        a[++pos] = x % 10;
        x /= 10;
    }
    return dfs(pos, 0, 0, true).sqsum;
}
void init()
{
    c[0] = 1;
    for (int i = 1; i < 20; i++)
        c[i] = (c[i - 1] * 10) % MOD;
}

int main()
{
    int t;
    init();
    scanf("%d", &t);
    while (t--)
    {
        ll L, R;
        scanf("%lld%lld", &L, &R);
        printf("%lld\n", ((solve(R) - solve(L - 1)) % MOD + MOD) % MOD);
    }
    return 0;
}