AtCoder Beginner Contest 317 F - Nim

发布时间 2023-08-28 18:12:02作者: ccz9729

数位 DP

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

int dp[64][10][10][10][2][2][2][2][2][2];

int main() {
  ll n; int b1, b2, b3; 
  cin >> n >> b1 >> b2 >> b3;
  memset(dp, -1, sizeof dp);
  string s; while (n) s.push_back(char('0' + (n & 1))), n /= 2;
  reverse(s.begin(), s.end());
  int m = s.size();
  function<int(int, int, int, int, bool, bool, bool, bool, bool, bool)> dfs = [&] (int pos, int a1, int a2, int a3, bool lim1, bool lim2, bool lim3, bool z1, bool z2, bool z3) -> int {
    if (pos == m) {
      return a1 == 0 and a2 == 0 and a3 == 0 and z1 and z2 and z3;
    }
    int &dv = dp[pos][a1][a2][a3][lim1][lim2][lim3][z1][z2][z3];
    //dv 表示枚举到 n 的第 pos 位, 此时 x1 % b1 = a1, x2 % b2 = a2, x3 % b3 = a3 并且 lim1, lim2, lim3 , 并且 x_i == 0 的真值为 z_i
    if (~dv) return dv; dv = 0;
    int up1 = (lim1 ? s[pos] - '0' : 1);
    int up2 = (lim2 ? s[pos] - '0' : 1);
    int up3 = (lim3 ? s[pos] - '0' : 1);
    for (int n1 = 0; n1 <= up1; n1 ++ ) {
      for (int n2 = 0; n2 <= up2; n2 ++ ) {
        int n3 = n1 ^ n2;
        if (n3 > up3) continue;
        dv += dfs(pos + 1, ((a1 << 1) + n1) % b1, ((a2 << 1) + n2) % b2, ((a3 << 1) + n3) % b3, lim1 and n1 == up1, lim2 and n2 == up2, lim3 and n3 == up3, z1 | n1, z2 | n2, z3 | n3);
        dv %= 998244353;
      }
    }
    return dv;
  };
  cout << dfs(0, 0, 0, 0, 1, 1, 1, 0, 0, 0);
	return 0;
}