中国石油大学(北京)第三届“骏码杯”程序设计竞赛题解

发布时间 2023-03-25 16:13:59作者: propane

中国石油大学(北京)第三届“骏码杯”程序设计竞赛题解

感谢大家的参与,我是本次比赛所有$10$道题目的出题人,在接下来的题解中,所有C++Python的标程均由我本人编写,因为我本人完全不懂Java,所以Java标程是由ChatGPT根据我其他语言的标程改编的,过程中由我提出改进建议,最终ChatGPT圆满完成任务,写出的代码可以通过所有的$10$道题.

另外提一嘴,赛前我把所有题目都喂给了ChatGPT让它做,遗憾的是ChatGPT在我的提示下只完成了AD两道题,题意只要不直接他直接就懵了.尤其是做M题的时候,他给了我一大堆公式和证明,看得我一愣一愣的,结果一运行样例都过不了...

A. 云影密码

题解 按照题目的要求模拟即可.
C++ code by me
#include<iostream>
#include<cstring>
using namespace std;
using LL = long long;
 
int main(){
 
	cin.tie(0);
	cout.tie(0);
	ios::sync_with_stdio(0);
 
	int T;
	cin >> T;
	while(T--){
		string s;
		cin >> s;
		string ans;
		for(int i = 0; i < s.size(); i++){
			int j = i;
			int sum = s[i] - '0';
			while(j + 1 < s.size() && s[j + 1] != '0')
				sum += s[++j] - '0';
			ans += char('a' + sum - 1);
			i = j + 1;
		}
		cout << ans << '\n';
	}
 
}
Java code by ChatGPT
import java.util.*;
 
public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
 
        int t = Integer.parseInt(sc.nextLine().trim());
 
        for (int i = 0; i < t; i++) {
            String[] nums = sc.nextLine().trim().split("0");
            StringBuilder sb = new StringBuilder();
 
            for (String num : nums) {
                int sum = 0;
 
                for (int j = 0; j < num.length(); j++) {
                    sum += Integer.parseInt(String.valueOf(num.charAt(j)));
                }
 
                char c = (char) (sum + 'a' - 1);
                sb.append(c);
            }
 
            System.out.println(sb.toString());
        }
    }
}
Python code by me
import sys
input = sys.stdin.readline
 
for _ in range(int(input())):
    print("".join(chr(ord("a") - 1 + sum(int(num) for num in s)) for s in input().strip().split("0")))

B.上分

题解

注意到$n$很小,所以我们可以直接枚举做题顺序的全排列,计算出每种做题顺序能获得的分值然后求最大值即是答案.

根据实现的方式不同,算法的时间复杂度为$O(\sum{n\cdot n!)}$或者$O(\sum n!)$.

Bonus: 当$n \le 20$时如何解决?

C++ code by me
#include<iostream>
#include<cstring>
#include<vector>
#include<array>
#include<algorithm>
using namespace std;
using LL = long long;
 
int main(){
 
	cin.tie(0);
	cout.tie(0);
	ios::sync_with_stdio(0);
 
	int T;
	cin >> T;
	while(T--){
		int n, m;
		cin >> n >> m;
		vector<array<int, 3> > p(n);
		for(int i = 0; i < n; i++){
			int a, b, c;
			cin >> a >> b >> c;
			p[i] = {a, b, c};
		}
		vector<int> id(n);
		for(int i = 0; i < n; i++) id[i] = i;
		int ans = 0;
		do{
			int sum = 0, t = 0;
			for(int i = 0; i < n; i++){
				auto [a, b, c] = p[id[i]];
				t += b;
				if (t <= m){
					int delta = a / 250;
					sum += max(a - t * delta - 50 * c, 3 * a / 10);
				}
			}
			ans = max(ans, sum);
		}
		while(next_permutation(id.begin(), id.end()));
		cout << ans << '\n';
	}
 
}
Java code by ChatGPT
import java.util.*;
 
public class Main {
    public static void main(String[] args) {
 
        Scanner in = new Scanner(System.in);
        int T = in.nextInt();
        while (T-- > 0) {
            int n = in.nextInt(), m = in.nextInt();
            ArrayList<int[]> p = new ArrayList<int[]>();
            for (int i = 0; i < n; i++) {
                int a = in.nextInt(), b = in.nextInt(), c = in.nextInt();
                p.add(new int[]{a, b, c});
            }
            int[] id = new int[n];
            for (int i = 0; i < n; i++) id[i] = i;
            int ans = 0;
            do {
                int sum = 0, t = 0;
                for (int i = 0; i < n; i++) {
                    int[] item = p.get(id[i]);
                    int a = item[0], b = item[1], c = item[2];
                    t += b;
                    if (t <= m) {
                        int delta = a / 250;
                        sum += Math.max(a - t * delta - 50 * c, 3 * a / 10);
                    }
                }
                ans = Math.max(ans, sum);
            } while (nextPermutation(id));
            System.out.println(ans);
        }
 
    }
 
    public static boolean nextPermutation(int[] nums) {
        int n = nums.length, i = n - 2, j = n - 1;
        while (i >= 0 && nums[i] >= nums[i + 1]) i--;
        if (i < 0) return false;
        while (j >= 0 && nums[j] <= nums[i]) j--;
        swap(nums, i, j);
        reverse(nums, i + 1, n - 1);
        return true;
    }
 
    public static void swap(int[] nums, int i, int j) {
        int temp = nums[i];
        nums[i] = nums[j];
        nums[j] = temp;
    }
 
    public static void reverse(int[] nums, int i, int j) {
        while (i < j) {
            swap(nums, i, j);
            i++;
            j--;
        }
    }
}
Python code by me
from itertools import permutations
import sys
input = sys.stdin.readline
 
for _ in range(int(input())):
    n, m = map(int, input().split())
    p = []
    for i in range(n):
        p.append(tuple(map(int, input().split())))
    ans = 0
    for id in permutations(range(n)):
        sum, t = 0, 0
        for i in range(n):
            a, b, c = p[id[i]]
            t += b
            if t <= m:
                sum += max(a - t * a // 250 - 50 * c, 3 * a // 10)
        ans = max(ans, sum)
    print(ans)

C.小菲爱数数

题解

对$a_{1\sim n}$分解质因子,考虑计算每个质因子对答案的贡献.

不妨设质因子$p$出现的位置有$b_1,b_2,b_3,\cdots ,b_k$,对于任意一个区间只要包含任意的$b_i$就会对答案产生$1$的贡献.

正难则反,我们可以计算有多少个区间不包含任何$b_i$,然后用总区间数减去不包含任何$b_i$的区间数得到质因子对答案的贡献.

显然只有处于任意两个相邻$b$之间的区间是不包含任何$b_i$的.

对于每个质因子我们先给答案加上$\frac{n(n+1)}{2}$,然后对于所有的相邻$b$之间的区间,假设长度为$len$,我们让答案减去$\frac{len(len + 1)}{2}$.

当然,你可能需要通过加一些哨兵或者特判来解决边界问题.

根据分解质因子的方法不同,时间复杂度为$O(n\sqrt{A})$或者$O(n\log{A})$.

因为含有多组测试样例,所以你需要用类似mapdict之类的关联容器来存储每个质因子出现的位置,以保证你枚举质因子的时间复杂度是$O(n\log{A})$.

C++ code by me
#include<iostream>
#include<cstring>
#include<map>
#include<vector>
using namespace std;
using LL = long long;
 
int main(){
 
    cin.tie(0);
    cout.tie(0);
    ios::sync_with_stdio(0);
 
    auto g = [](int x){
        return 1LL * x * (x + 1) / 2;
    };
 
    int T;
    cin >> T;
    while(T--){
        int n;
        cin >> n;
        map<int, vector<int> > mp;
        for(int i = 1; i <= n; i++){
            int x;
            cin >> x;
            for(int j = 2; j * j <= x; j++){
                if (x % j == 0){
                    mp[j].push_back(i);
                    while(x % j == 0) x /= j;
                }
            }
            if (x > 1) mp[x].push_back(i);
        }
        LL ans = 0;
        for(auto &[x, v] : mp){
            ans += g(n);
            v.push_back(n + 1);
            int last = 0;
            for(auto u : v){
                ans -= g(u - last - 1);
                last = u;
            }
        }
        cout << ans << '\n';
    }
 
}

Java code by ChatGPT
import java.util.*;
 
public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
 
        int T = sc.nextInt();
        while(T-- > 0){
            int n = sc.nextInt();
            Map<Integer, List<Integer>> mp = new HashMap<>();
            for(int i = 1; i <= n; i++){
                int x = sc.nextInt();
                for(int j = 2; j * j <= x; j++){
                    if (x % j == 0){
                        List<Integer> list = mp.getOrDefault(j, new ArrayList<>());
                        list.add(i);
                        mp.put(j, list);
                        while(x % j == 0) x /= j;
                    }
                }
                if (x > 1){
                    List<Integer> list = mp.getOrDefault(x, new ArrayList<>());
                    list.add(i);
                    mp.put(x, list);
                }
            }
            long ans = 0;
            for(Map.Entry<Integer, List<Integer>> entry : mp.entrySet()){
                int x = entry.getKey();
                List<Integer> v = entry.getValue();
                ans += g(n);
                v.add(n + 1);
                int last = 0;
                for(int u : v){
                    ans -= g(u - last - 1);
                    last = u;
                }
            }
            System.out.println(ans);
        }
        sc.close();
    }
 
    static long g(int x){
        return 1L * x * (x + 1) / 2;
    }
}
Python code by me
import sys
from collections import defaultdict
input = sys.stdin.readline
 
primes = [[] for i in range(1000001)]
for i in range(2, 1000001):
    if len(primes[i]) == 0:
        j = i
        while j <= 1000000:
            primes[j].append(i)
            j += i
 
for _ in range(int(input())):
    n = int(input())
    a = tuple(map(int, input().split()))
    mp = defaultdict(list)
    for i in range(n):
        for p in primes[a[i]]:
            mp[p].append(i)
    ans = 0
    for key, lst in mp.items():
        ans += n * (n + 1) // 2
        lst.append(n)
        last = -1
        for pos in lst:
            ans -= (pos - last - 1) * (pos - last) // 2
            last = pos
    print(ans)

D.引流

题解 格式化输出即可.
C++ code by me
#include<iostream>
#include<cstring>
using namespace std;
using LL = long long;
 
int main(){
 
	char s[40];
	scanf("%s", s);
	printf("guan zhu %s miao, guan zhu %s xie xie miao!", s, s);
 
}
Java code by ChatGPT
import java.util.Scanner;
 
public class Main {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        String s = scanner.next();
        System.out.printf("guan zhu %s miao, guan zhu %s xie xie miao!", s, s);
    }
}
Python code by me
s = input()
print(f"guan zhu {s} miao, guan zhu {s} xie xie miao!")

E.Construction Complete!

题解

首先以所有的x为起点,用广度优先搜索(BFS)求取多源最短路$dist$.

对于任意一个矩形,首先要满足矩形内所有点都是空地,其次要满足矩形内存在一个$(x, y)$点满足$dist_{x,y} \le d$.

对于第一个条件,我们令所有.的位置为$1$,其他位置为$0$,然后处理出二维前缀和.查询时只需判断矩形和是否等于矩形面积即可.

对于第二个条件,我们令所有$dist_{x,y} \le d$的点$(x, y)$为$1$,其他位置为$0$,同样处理出二维前缀和,查询时只需判断矩形和是否大于$0$即可.

然后就可以直接枚举矩形左上角的位置利用处理出的二维前缀和$O(1)$判断是否合法.时间复杂度和空间复杂度均为$O(\sum{nm})$.

C++ code by me
#include<iostream>
#include<cstring>
#include<vector>
#include<queue>
using namespace std;
using LL = long long;
 
int main(){
 
	cin.tie(0);
	cout.tie(0);
	ios::sync_with_stdio(0);
 
	const int dx[4] = {-1, 0, 1, 0};
	const int dy[4] = {0, -1, 0, 1};
 
	int T;
	cin >> T;
	while(T--){
		int n, m, r, c, dist;
		cin >> n >> m >> r >> c >> dist;
		vector<string> g(n + 1);
		for(int i = 1; i <= n; i++){
			g[i].resize(m + 1);
			for(int j = 1; j <= m; j++)
				cin >> g[i][j];
		}
		vector<vector<int> > d(n + 1, vector<int>(m + 1, n + m));
		queue<pair<int, int> > q;
		for(int i = 1; i <= n; i++)
			for(int j = 1; j <= m; j++)
				if (g[i][j] == 'x'){
					d[i][j] = 0;
					q.push({i, j});
				}
		while(!q.empty()){
			auto [x, y] = q.front(); q.pop();
			for(int u = 0; u < 4; u++){
				int nx = x + dx[u], ny = y + dy[u];
				if (nx < 1 || nx > n || ny < 1 || ny > m) continue;
				if (d[nx][ny] < n + m) continue;
				d[nx][ny] = d[x][y] + 1;
				q.push({nx, ny});
			}
		}
		vector<vector<int> > s1(n + 1, vector<int>(m + 1, 0)), s2(n + 1, vector<int>(m + 1, 0));
		for(int i = 1; i <= n; i++)
			for(int j = 1; j <= m; j++){
				s1[i][j] = (d[i][j] <= dist) + s1[i - 1][j] + s1[i][j - 1] - s1[i - 1][j - 1];
				s2[i][j] = (g[i][j] == '.') + s2[i - 1][j] + s2[i][j - 1] - s2[i - 1][j - 1];
			}
 
		auto query = [&](const vector<vector<int> > &s, int x1, int y1, int x2, int y2){
			return s[x2][y2] - s[x2][y1 - 1] - s[x1 - 1][y2] + s[x1 - 1][y1 - 1];
		};
 
		int ans = 0;
		for(int i = 1; i + r - 1 <= n; i++)
			for(int j = 1; j + c - 1 <= m; j++)
				if (query(s1, i, j, i + r - 1, j + c - 1) > 0 && query(s2, i, j, i + r - 1, j + c - 1) == r * c)
					ans++;
		cout << ans << '\n';
	}
 
}
Java code by ChatGPT
import java.util.*;
import java.io.*;
 
public class Main {
    static int n, m, r, c, dist;
    
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int t = Integer.parseInt(br.readLine());
 
        while (t-- > 0) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            n = Integer.parseInt(st.nextToken());
            m = Integer.parseInt(st.nextToken());
            r = Integer.parseInt(st.nextToken());
            c = Integer.parseInt(st.nextToken());
            dist = Integer.parseInt(st.nextToken());
 
            int[][] d = new int[n + 1][m + 1];
            int[][] s1 = new int[n + 1][m + 1];
            int[][] s2 = new int[n + 1][m + 1];
 
            for (int i = 0; i <= n; i++) {
                Arrays.fill(d[i], n + m);
            }
 
            char[][] g = new char[n + 1][];
            Queue<int[]> q = new LinkedList<>();
 
            for (int i = 1; i <= n; i++) {
                g[i] = (" " + br.readLine()).toCharArray();
                for (int j = 1; j <= m; j++) {
                    if (g[i][j] == 'x') {
                        q.add(new int[]{i, j});
                        d[i][j] = 0;
                    }
                }
            }
 
            int[][] directions = {{-1, 0}, {0, -1}, {1, 0}, {0, 1}};
 
            while (!q.isEmpty()) {
                int[] curr = q.poll();
                int x = curr[0], y = curr[1];
                for (int[] dir : directions) {
                    int nx = x + dir[0], ny = y + dir[1];
                    if (nx < 1 || nx > n || ny < 1 || ny > m || d[nx][ny] < n + m) continue;
                    d[nx][ny] = d[x][y] + 1;
                    q.add(new int[]{nx, ny});
                }
            }
            for (int i = 1; i <= n; i++) {
                for (int j = 1; j <= m; j++) {
                    s1[i][j] = (d[i][j] <= dist ? 1 : 0) + s1[i - 1][j] + s1[i][j - 1] - s1[i - 1][j - 1];
                    s2[i][j] = (g[i][j] == '.' ? 1 : 0) + s2[i - 1][j] + s2[i][j - 1] - s2[i - 1][j - 1];
                }
            }
 
            int ans = 0;
            for (int i = 1; i <= n - r + 1; i++) {
                for (int j =1; j <= m - c + 1; j++) {
                    if (query(s1, i, j, i + r - 1, j + c - 1) > 0 && query(s2, i, j, i + r - 1, j + c - 1) == r * c) {
                        ans += 1;
                    }
                }
            }
            System.out.println(ans);
        }
    }
    static int query(int[][] s, int x1, int y1, int x2, int y2) {
        return s[x2][y2] - s[x2][y1 - 1] - s[x1 - 1][y2] + s[x1 - 1][y1 - 1];
    }   
}
Python code by me
import sys
from collections import deque
input = sys.stdin.readline
 
def query(s, x1, y1, x2, y2):
    return s[x2][y2] - s[x2][y1 - 1] - s[x1 - 1][y2] + s[x1 - 1][y1 - 1]
 
for _ in range(int(input())):
    n, m, r, c, dist = map(int, input().split())
 
    d = [[n + m for j in range(m + 1)] for i in range(n + 1)]
    s1 = [[0 for j in range(m + 1)] for i in range(n + 1)]
    s2 = [[0 for j in range(m + 1)] for i in range(n + 1)]
    q = deque()
 
    g = [[]]
    for i in range(1, n + 1):
        g.append(" " + input())
        for j in range(1, m + 1):
            if g[i][j] == 'x':
                q.append((i, j))
                d[i][j] = 0
    while q:
        x, y = q.popleft()
        for dx, dy in ((-1, 0), (0, -1), (1, 0), (0, 1)):
            nx, ny = x + dx, y + dy
            if nx < 1 or nx > n or ny < 1 or ny > m or d[nx][ny] < n + m: continue
            d[nx][ny] = d[x][y] + 1
            q.append((nx, ny))
 
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            s1[i][j] = (d[i][j] <= dist) + s1[i - 1][j] + s1[i][j - 1] - s1[i - 1][j - 1]
            s2[i][j] = (g[i][j] == '.') + s2[i - 1][j] + s2[i][j - 1] - s2[i - 1][j - 1]
    
    ans = 0
    for i in range(1, n - r + 2):
        for j in range(1, m - c + 2):
            if query(s1, i, j, i + r - 1, j + c - 1) > 0 and query(s2, i, j, i + r - 1, j + c - 1) == r * c:
                ans += 1
    print(ans)

F.最小异或对

题解1

首先有一个结论,最小异或对一定是集合排序后相邻的数.

我们只需要证明对于任意的$a < b < c$,都满足最小值一定是$a\oplus b$或者$b \oplus c$,而不可能是$a \oplus c$.

满足这个条件后,如果我们选取的数在排序后的集合中不相邻,通过选相邻的数一定能使答案更小.

下面我们证明这个结论.

对于$a\oplus c$,找到其二进制表示中最高的为$1$的位为第$x$位,则$a$和$c$高于$x$的位一定相同,第$x$位一定不同,因为$a < c$,一定是$a$第$x$位为$0$,$c$第$x$位为$1$.

因为$a < b < c$,高于$x$的位$b$一定也和$a,c$都相同,$b$的第$x$位可能为$0$或者为$1$.

所以$a\oplus b$和$b \oplus c$中一定有一个数第$x$位为$0$,而$a\oplus c$第$x$位为$1$,即一定有一个数是小于$a \oplus c$的,由此结论得证.

所以我们需要维护原集合的有序序列和有序序列中所有相邻数字的异或值.

每次加入一个数$x$时,找到有序集合中$x$的前驱$pre$和后继$nxt$,在维护异或值的数据结构中删除$nxt\oplus pre$,加入$x \oplus pre$和$x \oplus nxt$.

删除一个数$x$时类似,删除$x$后同样找到$x$的前驱$pre$和后继$nxt$,加入$nxt\oplus pre$,删除$x \oplus pre$和$x \oplus nxt$即可.

每次查询输出所有相邻数字的异或值的最小值即可.

这要求我们需要快速找到有序序列中一个数字的前驱和后继,可以使用平衡树来维护.

用平衡树数维护原序列,然后使用一个支持增删和查询最小值的数据结构维护所有所有相邻数的异或值,同样可以使用平衡树.

当然你需要通过特判处理一些前驱或者后继不存在的情况.

增删查的时间复杂度均为$O(\log{n})$总时间复杂度为$O(n\log{n})$,空间复杂度为$O(n)$.

题解2

如果你不知道以上的结论,有另一种01Trie的做法.

从高位到低位建立01Trie,每个结点都维护$3$个信息:

  1. 当前子树内能凑出的最小异或对的答案
  2. 子树内的元素个数
  3. 如果子树内只有一个元素,记录这个元素的值.

首先对于叶子结点,如果包含的元素数量不大于$1$,那么当前答案为无穷大,否则答案为$0$.

01Trie上,除了叶子结点外每个结点都有两个儿子.那么对于当前结点选取两个元素有$3$种可能,在左儿子中选取两个元素,在右儿子中选取两个元素,或者左右儿子各选一个元素.

如果能在同一个儿子里选择两个元素答案一定好于在两个儿子中各选一个元素,所以可以用儿子的答案更新当前结点.

如果两个儿子结点包含的元素数量都不大于$1$,那么我们只能在两个儿子中各选一个元素然后更新答案.

当然如果当前子树中包含的元素数量不大于$1$,答案为无穷大.

添加和删除时使用类似线段树的方式从下往上更新答案,每次查询时只需输出根节点的答案即可.

可以发现这里的01Trie完全等价于动态开点的权值线段树.

每次插入和删除的复杂度为01Trie的高度,查询的复杂度为$O(1)$,总的时间复杂度为$O(n\log{A})$,空间复杂度也为$O(n\log{A})$.

C++ code with multiset by me
#include<iostream>
#include<cstring>
#include<set>
using namespace std;
using LL = long long;
 
int main(){
 
    cin.tie(0);
    cout.tie(0);
    ios::sync_with_stdio(0);
 
    multiset<int> s, v;
    int n;
    cin >> n;
    while(n--){
        string op;
        cin >> op;
        if (op[0] == 'A'){
            int x;
            cin >> x;
            auto it = s.lower_bound(x);
            if (it != s.end())
                v.insert(*it ^ x);
            if (it != s.begin()){
                v.insert(*prev(it) ^ x);
            }
            if (it != s.begin() && it != s.end()){
                v.erase(v.lower_bound(*it ^ *prev(it)));
            }
            s.insert(x);
        }
        else if (op[0] == 'D'){
            int x;
            cin >> x;
            s.erase(s.lower_bound(x));
            auto it = s.lower_bound(x);
            if (it != s.end())
                v.erase(v.lower_bound(*it ^ x));
            if (it != s.begin()){
                v.erase(v.lower_bound(*prev(it) ^ x));
            }
            if (it != s.begin() && it != s.end()){
                v.insert(*it ^ *prev(it));
            }
        }
        else cout << *v.begin() << '\n';
    }
 
}
C++ code with 01Trie by me
#include<iostream>
#include<cstring>
#include<vector>
#include<numeric>
#include<functional>
using namespace std;
using LL = long long;
const int maxn = 2e5 + 5, INF = 1 << 30;
 
struct Info {
    int dp, cnt, val;
    Info() : dp(INF), cnt(0), val(-1) {} 
    Info(int x) : dp(INF), cnt(1), val(x) {} 
    Info(int dp, int cnt, int val) : dp(dp), cnt(cnt), val(val) {}
}tr[maxn * 32];
int son[maxn * 32][2];
int root, idx;
 
Info operator+(const Info &a, const Info &b){
    Info ret = Info();
    ret.cnt = a.cnt + b.cnt;
    ret.dp = min(a.dp, b.dp);
    if (a.cnt == 1 && b.cnt == 1) ret.dp = a.val ^ b.val;
    if (ret.cnt == 1) ret.val = a.cnt ? a.val : b.val;
    return ret;
}
 
void insert(int &u, int x, int dep){
    if (!u) u = ++idx;
    if (dep == -1){
        ++tr[u].cnt;
        if (tr[u].cnt == 1) tr[u] = Info(x);
        else if (tr[u].cnt == 2) tr[u].dp = 0;
        return;
    }
    int bit = (x >> dep & 1);
    insert(son[u][bit], x, dep - 1);
    tr[u] = tr[son[u][0]] + tr[son[u][1]];
}
 
void del(int &u, int x, int dep){
    if (dep == -1){
        --tr[u].cnt;
        if (tr[u].cnt == 0) tr[u] = Info();
        else if (tr[u].cnt == 1) tr[u] = Info(x);
        return;
    }
    int bit = (x >> dep & 1);
    del(son[u][bit], x, dep - 1);
    tr[u] = tr[son[u][0]] + tr[son[u][1]];
}
 
int main(){
    
    cin.tie(0);
    cout.tie(0);
    ios::sync_with_stdio(0);
 
    int n;
    cin >> n;
    int root = ++idx;
    while(n--){
        string op;
        cin >> op;
        if (op[0] == 'A'){
            int x;
            cin >> x;
            insert(root, x, 29);
        }
        else if (op[0] == 'D'){
            int x;
            cin >> x;
            del(root, x, 29);
        }
        else cout << tr[root].dp << '\n';
    }
 
}
Java code by ChatGPT
import java.util.*;
 
public class Main {
    public static void main(String[] args) {
 
        Scanner in = new Scanner(System.in);
        int n = in.nextInt();
        in.nextLine();
        TreeMap<Integer, Integer> s = new TreeMap<>();
        TreeMap<Integer, Integer> v = new TreeMap<>();
        int cnt = 0;  // 记录s中数量大于等于2的元素的个数
        s.put(Integer.MAX_VALUE, 0);
        while (n-- > 0) {
            String op = in.next();
            if (op.charAt(0) == 'A') {
                int x = in.nextInt();
                if (!s.containsKey(x)) {
                    Integer it = s.higherKey(x);
                    if (it != null) {
                        int xor = it ^ x;
                        v.put(xor, v.getOrDefault(xor, 0) + 1);
                    }
                    if (it != null && s.lowerKey(it) != null) {
                        int xor = s.lowerKey(it) ^ it;
                        v.put(xor, v.getOrDefault(xor, 0) - 1);
                        if (v.get(xor) == 0) {
                            v.remove(xor);
                        }
                    }
                    if (s.lowerKey(it) != null) {
                        int xor = s.lowerKey(it) ^ x;
                        v.put(xor, v.getOrDefault(xor, 0) + 1);
                    }
                    s.put(x, 1);
                    if (s.get(x) > 1) {
                        cnt++;
                    }
                } else {
                    s.put(x, s.get(x) + 1);
                    if (s.get(x) == 2) {
                        cnt++;
                    }
                }
            } else if (op.charAt(0) == 'D') {
                int x = in.nextInt();
                if (s.containsKey(x)) {
                    int nums = s.get(x);
                    if (nums == 1) {
                        s.remove(x);
                        Integer it = s.higherKey(x);
                        if (it != null && s.lowerKey(it) != null) {
                            int xor = s.lowerKey(it) ^ it;
                            v.put(xor, v.getOrDefault(xor, 0) + 1);
                        }
                        if (s.lowerKey(it) != null && it != null) {
                            int xor = s.lowerKey(it) ^ x;
                            v.put(xor, v.getOrDefault(xor, 0) - 1);
                            if (v.get(xor) == 0) {
                                v.remove(xor);
                            }
                        }
                        if (it != null) {
                            int xor = it ^ x;
                            v.put(xor, v.getOrDefault(xor, 0) - 1);
                            if (v.get(xor) == 0) {
                                v.remove(xor);
                            }
                        }
                    } else {
                        s.put(x, nums - 1);
                        if (nums == 2) {
                            cnt--;
                        }
                    }
                }
            } else {
                if (cnt > 0) {
                    System.out.println(0);
                } else {
                    System.out.println(v.isEmpty() ? 0 : v.firstKey());
                }
            }
        }
    }
}
Python code by me
import sys
input = sys.stdin.readline   
import random
 
class TreapMultiSet(object):
    root = 0
    size = 0
 
    def __init__(self, data=None):
        if data:
            data = sorted(data)
            self.root = treap_builder(data)
            self.size = len(data)
 
    def add(self, key):
        self.root = treap_insert(self.root, key)
        self.size += 1
 
    def remove(self, key):
        self.root = treap_erase(self.root, key)
        self.size -= 1
 
    def discard(self, key):
        try:
            self.remove(key)
        except KeyError:
            pass
 
    def ceiling(self, key):
        x = treap_ceiling(self.root, key)
        return treap_keys[x] if x else None
 
    def higher(self, key):
        x = treap_higher(self.root, key)
        return treap_keys[x] if x else None
 
    def floor(self, key):
        x = treap_floor(self.root, key)
        return treap_keys[x] if x else None
 
    def lower(self, key):
        x = treap_lower(self.root, key)
        return treap_keys[x] if x else None
 
    def max(self):
        return treap_keys[treap_max(self.root)]
 
    def min(self):
        return treap_keys[treap_min(self.root)]
 
    def __len__(self):
        return self.size
 
    def __nonzero__(self):
        return bool(self.root)
 
    __bool__ = __nonzero__
 
    def __contains__(self, key):
        return self.floor(key) == key
 
    def __repr__(self):
        return "TreapMultiSet({})".format(list(self))
 
    def __iter__(self):
        if not self.root:
            return iter([])
        out = []
        stack = [self.root]
        while stack:
            node = stack.pop()
            if node > 0:
                if right_child[node]:
                    stack.append(right_child[node])
                stack.append(~node)
                if left_child[node]:
                    stack.append(left_child[node])
            else:
                out.append(treap_keys[~node])
        return iter(out)
 
 
class TreapSet(TreapMultiSet):
    def add(self, key):
        self.root, duplicate = treap_insert_unique(self.root, key)
        if not duplicate:
            self.size += 1
 
    def __repr__(self):
        return "TreapSet({})".format(list(self))
 
 
class TreapHashSet(TreapMultiSet):
    def __init__(self, data=None):
        if data:
            self.keys = set(data)
            super(TreapHashSet, self).__init__(self.keys)
        else:
            self.keys = set()
 
    def add(self, key):
        if key not in self.keys:
            self.keys.add(key)
            super(TreapHashSet, self).add(key)
 
    def remove(self, key):
        self.keys.remove(key)
        super(TreapHashSet, self).remove(key)
 
    def discard(self, key):
        if key in self.keys:
            self.remove(key)
 
    def __contains__(self, key):
        return key in self.keys
 
    def __repr__(self):
        return "TreapHashSet({})".format(list(self))
 
 
class TreapHashMap(TreapMultiSet):
    def __init__(self, data=None):
        if data:
            self.map = dict(data)
            super(TreapHashMap, self).__init__(self.map.keys())
        else:
            self.map = {}
 
    def __setitem__(self, key, value):
        if key not in self.map:
            super(TreapHashMap, self).add(key)
        self.map[key] = value
 
    def __getitem__(self, key):
        return self.map[key]
 
    def add(self, key):
        raise TypeError("add on TreapHashMap")
 
    def get(self, key, default=None):
        return self.map.get(key, default=default)
 
    def remove(self, key):
        self.map.pop(key)
        super(TreapHashMap, self).remove(key)
 
    def discard(self, key):
        if key in self.map:
            self.remove(key)
 
    def __contains__(self, key):
        return key in self.map
 
    def __repr__(self):
        return "TreapHashMap({})".format(list(self))
 
 
left_child = [0]
right_child = [0]
treap_keys = [0]
treap_prior = [0.0]
 
 
def treap_builder(sorted_data):
    """Build a treap in O(n) time using sorted data"""
    def build(begin, end):
        if begin == end:
            return 0
        mid = (begin + end) // 2
        root = treap_create_node(sorted_data[mid])
        left_child[root] = build(begin, mid)
        right_child[root] = build(mid + 1, end)
 
        # sift down the priorities
        ind = root
        while True:
            lc = left_child[ind]
            rc = right_child[ind]
 
            if lc and treap_prior[lc] > treap_prior[ind]:
                if rc and treap_prior[rc] > treap_prior[lc]:
                    treap_prior[ind], treap_prior[rc] = treap_prior[rc], treap_prior[ind]
                    ind = rc
                else:
                    treap_prior[ind], treap_prior[lc] = treap_prior[lc], treap_prior[ind]
                    ind = lc
            elif rc and treap_prior[rc] > treap_prior[ind]:
                treap_prior[ind], treap_prior[rc] = treap_prior[rc], treap_prior[ind]
                ind = rc
            else:
                break
        return root
 
    return build(0, len(sorted_data))
 
 
def treap_create_node(key):
    treap_keys.append(key)
    treap_prior.append(random.random())
    left_child.append(0)
    right_child.append(0)
    return len(treap_keys) - 1
 
 
def treap_split(root, key):
    left_pos = right_pos = 0
    while root:
        if key < treap_keys[root]:
            left_child[right_pos] = right_pos = root
            root = left_child[root]
        else:
            right_child[left_pos] = left_pos = root
            root = right_child[root]
    left, right = right_child[0], left_child[0]
    right_child[left_pos] = left_child[right_pos] = right_child[0] = left_child[0] = 0
    return left, right
 
 
def treap_merge(left, right):
    where, pos = left_child, 0
    while left and right:
        if treap_prior[left] > treap_prior[right]:
            where[pos] = pos = left
            where = right_child
            left = right_child[left]
        else:
            where[pos] = pos = right
            where = left_child
            right = left_child[right]
    where[pos] = left or right
    node = left_child[0]
    left_child[0] = 0
    return node
 
 
def treap_insert(root, key):
    if not root:
        return treap_create_node(key)
    left, right = treap_split(root, key)
    return treap_merge(treap_merge(left, treap_create_node(key)), right)
 
 
def treap_insert_unique(root, key):
    if not root:
        return treap_create_node(key), False
    left, right = treap_split(root, key)
    if left and treap_keys[left] == key:
        return treap_merge(left, right), True
    return treap_merge(treap_merge(left, treap_create_node(key)), right), False
 
 
def treap_erase(root, key):
    if not root:
        raise KeyError(key)
    if treap_keys[root] == key:
        return treap_merge(left_child[root], right_child[root])
    node = root
    while root and treap_keys[root] != key:
        parent = root
        root = left_child[root] if key < treap_keys[root] else right_child[root]
    if not root:
        raise KeyError(key)
    if root == left_child[parent]:
        left_child[parent] = treap_merge(left_child[root], right_child[root])
    else:
        right_child[parent] = treap_merge(left_child[root], right_child[root])
 
    return node
 
 
def treap_ceiling(root, key):
    while root and treap_keys[root] < key:
        root = right_child[root]
    if not root:
        return 0
    min_node = root
    min_key = treap_keys[root]
    while root:
        if treap_keys[root] < key:
            root = right_child[root]
        else:
            if treap_keys[root] < min_key:
                min_key = treap_keys[root]
                min_node = root
            root = left_child[root]
    return min_node
 
 
def treap_higher(root, key):
    while root and treap_keys[root] <= key:
        root = right_child[root]
    if not root:
        return 0
    min_node = root
    min_key = treap_keys[root]
    while root:
        if treap_keys[root] <= key:
            root = right_child[root]
        else:
            if treap_keys[root] < min_key:
                min_key = treap_keys[root]
                min_node = root
            root = left_child[root]
    return min_node
 
 
def treap_floor(root, key):
    while root and treap_keys[root] > key:
        root = left_child[root]
    if not root:
        return 0
    max_node = root
    max_key = treap_keys[root]
    while root:
        if treap_keys[root] > key:
            root = left_child[root]
        else:
            if treap_keys[root] > max_key:
                max_key = treap_keys[root]
                max_node = root
            root = right_child[root]
    return max_node
 
 
def treap_lower(root, key):
    while root and treap_keys[root] >= key:
        root = left_child[root]
    if not root:
        return 0
    max_node = root
    max_key = treap_keys[root]
    while root:
        if treap_keys[root] >= key:
            root = left_child[root]
        else:
            if treap_keys[root] > max_key:
                max_key = treap_keys[root]
                max_node = root
            root = right_child[root]
    return max_node
 
 
def treap_min(root):
    if not root:
        raise ValueError("min on empty treap")
    while left_child[root]:
        root = left_child[root]
    return root
 
 
def treap_max(root):
    if not root:
        raise ValueError("max on empty treap")
    while right_child[root]:
        root = right_child[root]
    return root
 
s, v = TreapMultiSet(), TreapMultiSet()
 
for _ in range(int(input())):
    op = input()
    if op[0] == 'Q':
        print(v.min())
    else:
        x = int(op[4:].strip())
        if op[0] == 'A':
            nxt = s.higher(x)
            pre = s.floor(x)
            if nxt is not None and pre is not None: 
                v.remove(nxt ^ pre)
            if pre is not None:
                v.add(x ^ pre)
            if nxt is not None:
                v.add(nxt ^ x)
            s.add(x)
        else:
            s.remove(x)
            nxt = s.higher(x)
            pre = s.floor(x)
            if nxt is not None and pre is not None:
                v.add(nxt ^ pre)
            if pre is not None:
                v.remove(x ^ pre)
            if nxt is not None:
                v.remove(nxt ^ x)

G.魔术卡片

题解

对于$1 \sim n$我们都用一个$8$位的二进制数表示其状态,第$i$位取$0/1$表示第$i$张卡片上是否有这个数字,然后我们考虑这些数之间需要具有怎样的性质.

对于任意数字$x$,假设其没有出现在第$c_1, c_2, c_3,\cdots$张卡片上,那么除了$x$以外的所有数字都至少在第$c_1,c_2,c_3,\cdots$上出现过一次.即第$c_1,c_2,c_3,\cdots$上至少有一位为$1$.所以所有除了$x$以外的数字的二进制表示一定不是$x$的子集.

推广到所有数字上,一定没有一个数字的二进制表示是另一个数字二进制表示的子集.所以我们只要找出一些二进制数,满足没有一个数是其他任何数的子集即可.

一种简单的构造方法是在所有$8$位的二进制数字里选取二进制表示中$1$的数量为$4$的数来构造,这些数显然不会有一个数是另一个数的子集.并且这样的数有$C_8^4 = 70$个,而$n \le 60$.

C++ code by me
#include<iostream>
#include<cstring>
#include<vector>
using namespace std;
using LL = long long;
 
int main(){
 
    cin.tie(0);
    cout.tie(0);
    ios::sync_with_stdio(0);
 
    int T;
    cin >> T;
    while(T--){
        const int N = 8;
        int n;
        cin >> n;
        int cnt = 0;
        vector<vector<int> > ans(N);
        for(int i = 0; i < 1 << N; i++){
            if (__builtin_popcount(i) != 4) continue;
            if (++cnt > n) break;
            for(int j = 0; j < N; j++)
                if (i >> j & 1) ans[j].push_back(cnt);
        }
        cout << N << '\n';
        for(int i = 0; i < N; i++){
            cout << ans[i].size();
            for(auto x : ans[i]) cout << ' ' << x;
            cout << '\n';
        }
    }
 
}
Java code by ChatGPT
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
 
public class Main {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int T = scanner.nextInt();
 
        while (T-- > 0) {
            final int N = 8;
            int n = scanner.nextInt();
            int cnt = 0;
 
            List<List<Integer>> ans = new ArrayList<>();
            for (int i = 0; i < N; i++) {
                ans.add(new ArrayList<>());
            }
 
            for (int i = 0; i < (1 << N); i++) {
                if (Integer.bitCount(i) != 4) continue;
                if (++cnt > n) break;
                for (int j = 0; j < N; j++) {
                    if ((i >> j & 1) == 1) {
                        ans.get(j).add(cnt);
                    }
                }
            }
 
            System.out.println(N);
            for (int i = 0; i < N; i++) {
                System.out.print(ans.get(i).size());
                for (int x : ans.get(i)) {
                    System.out.print(" " + x);
                }
                System.out.println();
            }
        }
    }
}
Python code by me
cand = []
 
for i in range(1 << 8):
    cnt = 0
    for j in range(8):
        cnt += (i >> j & 1)
    if cnt == 4:
        cand.append(i)
 
for _ in range(int(input())):
    n = int(input())
    ans = [[] for i in range(8)]
    for i in range(n):
        for j in range(8):
            if cand[i] >> j & 1: 
                ans[j].append(i + 1)
    print(8)
    for i in range(8):
        print(len(ans[i]), *ans[i])

H.推队

题解

我们用$dp_{i, a, b, c}$表示击退对手$i$只精灵,攻击等级为$a$,防御等级为$b$,速度等级为$c$时能剩余的最大体力,如果无法达到该状态则为$-1$.

考虑状态转移,因为$k$很小,所以我们可以枚举每种属性当前的等级和该回合准备给每个属性提升的等级进行转移.

关键在于计算某种策略下击败对手第$i$只精灵最少会损失多少体力.

虽然提升等级可能会导致属性值降低,但是我们可以发现在最优策略中,击败第$i$只精灵后每项属性一定不会低于对战第$i$只精灵之前的属性.

因为强化某项属性值但是使得这项属性降低对我们当前的对战是没有好处的,我们完全可以选择该回合强化到更高的属性值或者这回合不强化全部留到与对手下一只精灵对战时强化.这两种策略中一定至少有一种不劣于使某项属性值降低的策略.

所以当我们确定对战第$i$只精灵时,最优策略的操作顺序一定是先强化防御,然后强化攻击和速度,然后再使用攻击技能.

因为$k$很小,对于强化过程中受到的伤害我们可以直接暴力模拟,时间复杂度为$O(k)$,强化完之后属性值固定,可以$O(1)$计算.

总时间复杂度为$O(\sum{nk^7})$.

C++ code by me
#include<iostream>
#include<cstring>
using namespace std;
using LL = long long;
const int maxn = 1e4 + 5;
int f[maxn][4][4][4];
 
int main(){
 
    cin.tie(0);
    cout.tie(0);
    ios::sync_with_stdio(0);
 
    int T;
    cin >> T;
    while(T--){
        int n, k;
        cin >> n >> k;
        int H, A[4], D[4], S[4];
        cin >> H;
        for(int i = 0; i <= k; i++) cin >> A[i];
        for(int i = 0; i <= k; i++) cin >> D[i];
        for(int i = 0; i <= k; i++) cin >> S[i];
 
        for(int i = 0; i <= n; i++){
            for(int a = 0; a <= k; a++)
            for(int b = 0; b <= k; b++)
            for(int c = 0; c <= k; c++)
                f[i][a][b][c] = -1;
        }
 
        f[0][0][0][0] = H;
        for(int i = 0; i < n; i++){
            int h, a, d, s;
            cin >> h >> a >> d >> s;
 
            auto update = [&](int l1, int l2, int l3){
				int cur = f[i][l1][l2][l3];
 
                auto get = [&](int c1, int c2, int c3){
					int remain = cur;
                    for(int i = 1; i <= c2; i++){
						if (S[l3] >= s){
							remain -= max(0, a - D[l2 + i]);
							if (remain <= 0) return -1;
						}
						else{
							remain -= max(0, a - D[l2 + i - 1]);
							if (remain <= 0) return -1;
						}
					}
 
					if (A[l1 + c1] <= d) return -1;
					int dec1 = max(0, a - D[l2 + c2]);
					int dec2 = A[l1 + c1] - d;
 
					if (1LL * (c1 + c3) * dec1 >= remain) return -1;
					if (dec1 == 0) return remain;
 
					remain -= (c1 + c3) * dec1;
 
					int cnt = (h + dec2 - 1) / dec2 - (S[l3 + c3] >= s);
 
					if (1LL * cnt * dec1 >= remain) return -1;
					remain -= cnt * dec1;
					
					return remain;
                };
 
                for(int c1 = 0; c1 <= k - l1; c1++)
                for(int c2 = 0; c2 <= k - l2; c2++)
                for(int c3 = 0; c3 <= k - l3; c3++)
                    f[i + 1][l1 + c1][l2 + c2][l3 + c3] = max(f[i + 1][l1 + c1][l2 + c2][l3 + c3], get(c1, c2, c3));
            };
 
            for(int a = 0; a <= k; a++)
            for(int b = 0; b <= k; b++)
            for(int c = 0; c <= k; c++)
                if (f[i][a][b][c] > 0)
                    update(a, b, c);
        }
 
        int ans = -1;
        for(int a = 0; a <= k; a++)
        for(int b = 0; b <= k; b++)
        for(int c = 0; c <= k; c++)
            ans = max(ans, f[n][a][b][c]);
        cout << ans << '\n';
    }
 
}
Java code by ChatGPT
import java.util.*;
 
public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
 
        int T = Integer.parseInt(sc.nextLine().trim());
        while (T-- > 0) {
            int n = sc.nextInt();
            int k = sc.nextInt();
            int H = sc.nextInt();
            int[] A = new int[k + 1];
            int[] D = new int[k + 1];
            int[] S = new int[k + 1];
            for (int i = 0; i <= k; i++) {
                A[i] = sc.nextInt();
            }
            for (int i = 0; i <= k; i++) {
                D[i] = sc.nextInt();
            }
            for (int i = 0; i <= k; i++) {
                S[i] = sc.nextInt();
            }
            int[][][][] f = new int[n + 1][k + 1][k + 1][k + 1];
            for (int i = 0; i <= n; i++) {
                for (int l1 = 0; l1 <= k; l1++) {
                    for (int l2 = 0; l2 <= k; l2++) {
                        for (int l3 = 0; l3 <= k; l3++) {
                            f[i][l1][l2][l3] = -1;
                        }
                    }
                }
            }
            f[0][0][0][0] = H;
            for (int i = 0; i < n; i++) {
                int h = sc.nextInt();
                int a = sc.nextInt();
                int d = sc.nextInt();
                int s = sc.nextInt();
                for (int l1 = 0; l1 <= k; l1++) {
                    for (int l2 = 0; l2 <= k; l2++) {
                        for (int l3 = 0; l3 <= k; l3++) {
                            if (f[i][l1][l2][l3] <= 0) continue;
                            for (int c1 = 0; c1 <= k - l1; c1++) {
                                for (int c2 = 0; c2 <= k - l2; c2++) {
                                    for (int c3 = 0; c3 <= k - l3; c3++) {
                                        int remain = f[i][l1][l2][l3];
                                        for (int t = 1; t <= c2; t++) {
                                            if (S[l3] >= s) {
                                                remain -= Math.max(0, a - D[l2 + t]);
                                            } else {
                                                remain -= Math.max(0, a - D[l2 + t - 1]);
                                            }
                                            if (remain <= 0) break;
                                        }
                                        if (remain <= 0 || A[l1 + c1] <= d) continue;
                                        int dec1 = Math.max(0, a - D[l2 + c2]);
                                        int dec2 = A[l1 + c1] - d;
                                        if (1L * (c1 + c3) * dec1 >= remain) continue;
                                        remain -= (c1 + c3) * dec1;
                                        int cnt = (h + dec2 - 1) / dec2 - (S[l3 + c3] >= s ? 1 : 0);
                                        if (1L * cnt * dec1 >= remain) continue;
                                        remain -= cnt * dec1;
                                        f[i + 1][l1 + c1][l2 + c2][l3 + c3] = Math.max(f[i + 1][l1 + c1][l2 + c2][l3 + c3], remain);
                                    }
                                }
                            }
                        }
                    }
                }
            }
            int ans = -1;
            for (int a = 0; a <= k; a++) {
                for (int b = 0; b <= k; b++) {
                    for (int c = 0; c <= k; c++) {
                        ans = Math.max(ans, f[n][a][b][c]);
                    }
                }
            }
            System.out.println(ans);
        }
        sc.close();
    }
}
Python code by me
import sys
input = sys.stdin.readline    
 
for _ in range(int(input())):
    n, k = map(int, input().split())
    H = int(input())
    A = tuple(map(int, input().split()))
    D = tuple(map(int, input().split()))
    S = tuple(map(int, input().split()))
    f = [[[[-1 for a in range(k + 1)] for b in range(k + 1)] for c in range(k + 1)] for d in range(n + 1)]
    f[0][0][0][0] = H
    for i in range(n):
        h, a, d, s = map(int, input().split())
    
        for l1 in range(k + 1):
            for l2 in range(k + 1):
                for l3 in range(k + 1):
                    if f[i][l1][l2][l3] <= 0: continue
                    for c1 in range(k - l1 + 1):
                        for c2 in range(k - l2 + 1):
                            for c3 in range(k - l3 + 1):
                                remain = f[i][l1][l2][l3]
                                for t in range(1, c2 + 1):
                                    if S[l3] >= s:
                                        remain -= max(0, a - D[l2 + t])
                                    else:
                                        remain -= max(0, a - D[l2 + t - 1])
                                    if remain <= 0: 
                                        break
                                if remain <= 0 or A[l1 + c1] <= d: continue
                                dec1, dec2 = max(0, a - D[l2 + c2]), A[l1 + c1] - d
                                if (c1 + c3) * dec1 >= remain: continue
                                remain -= (c1 + c3) * dec1
                                cnt = (h + dec2 - 1) // dec2 - (S[l3 + c3] >= s)
                                if cnt * dec1 >= remain: continue
                                remain -= cnt * dec1
                                f[i + 1][l1 + c1][l2 + c2][l3 + c3] = max(f[i + 1][l1 + c1][l2 + c2][l3 + c3], remain)
    print(max(f[n][a][b][c] for a in range(k + 1) for b in range(k + 1) for c in range(k + 1)))

I.最大公约数求和

题解1

$f(x) = x^k$与$g(x) = \gcd(x, k)$都是积性函数,可以在线性筛的过程中计算.

当$x$为质数,直接计算$f(x)$与$g(x)$,计算这两个函数的复杂度都为$O(\log_2{k})$.

当$x$为合数,在线性筛的过程中,$x$会被其最小质因子筛掉,设其最小质因子为$p$.

$f(x) = x^k$是完全积性函数,因此有$f(x) = f(p) \cdot f(\frac{x}{p})$.

对于$g(x) = \gcd(x, k)$,只需判断是否有$\frac{k}{g(\frac{x}{p})} \equiv 0 (\bmod \ p)$,如果是则$g(x) = p \cdot g(\frac{x}{p})$,否则$g(x) = g(\frac{x}{p})$.

质数密度$\pi(n) \approx \frac{n}{\ln{n}}$,所以总时间复杂度为$O(\frac{\log_2{k}}{\ln{n}} \cdot n + n) \approx O(n)$.

题解2

$f(x) = x^k$很容易发现是积性函数,但是观察到$g(x) = \gcd(x, k)$是积性函数还是有一定难度的.

$f(x) = x^k$通过解法$1$中的方式计算.

而对于$g(x) = \gcd(x, k)$,我们发现$g(x)$的取值是非常有限的,只可能是$k$小于等于$n$的因子.因此我们可以枚举所有因子的倍数更新$g(x)$的值.

枚举$1 \sim n$所有倍数的时间复杂度是$O(n\log{n})$,但因为$k$的因子在$1 \sim n$中的分布是非常稀疏的,所以实际复杂度远小于这个上界,时间复杂度不明,但足以通过本题.

C++ code by me
#include<iostream>
#include<algorithm>
using namespace std;
const int maxn = 2e7 + 5, mod = 1e9 + 7;
typedef long long LL;
int primes[maxn], cnt;
int x[maxn], g[maxn];
bool isPrime[maxn];
 
int qpow(int a, int b, int mod){
    int res = 1;
    while (b){
        if (b & 1) res = 1LL * res * a % mod;
        a = 1LL * a * a % mod;
        b >>= 1;
    }
    return res;
}
 
int f(int n, LL k){
    int res = 1;
    int up = k % (mod - 1);
    for(int i = 2; i <= n; i++){
        if (!isPrime[i]){
            primes[cnt++] = i;
            x[i] = qpow(i, up, mod);
            g[i] = __gcd(1LL * i, k);
        }
        for(int j = 0; i * primes[j] <= n; j++){
            isPrime[i * primes[j]] = 1;
            x[i * primes[j]] = 1LL * x[i] * x[primes[j]] % mod;
            if (i % primes[j] == 0){
                g[i * primes[j]] = g[i];
                if ((k / g[i]) % primes[j] == 0) g[i * primes[j]] *= primes[j];
                break;
            }
            g[i * primes[j]] = g[i] * g[primes[j]];
        }
        res = (res + 1LL * g[i] * x[i]) % mod;
    }
    return res;
}
 
int main(){
 
    int n; LL k;
    cin >> n >> k;
    cout << f(n, k) << '\n';
}
Java code by ChatGPT
import java.util.*;
 
public class Main {
    static final int maxn = 20000005;
    static final int mod = 1000000007;
    static int[] primes = new int[maxn];
    static int[] x = new int[maxn];
    static int[] g = new int[maxn];
    static boolean[] isPrime = new boolean[maxn];
    static int cnt;
 
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
 
        int n = sc.nextInt();
        long k = sc.nextLong();
 
        int ans = f(n, k);
        System.out.println(ans);
    }
 
    public static int qpow(int a, int b, int mod){
        int res = 1;
        while (b > 0){
            if ((b & 1) != 0) {
                res = (int) ((1L * res * a) % mod);
            }
            a = (int) ((1L * a * a) % mod);
            b >>= 1;
        }
        return res;
    }
 
    public static int f(int n, long k){
        int res = 1;
        int up = (int) (k % (mod - 1));
        for(int i = 2; i <= n; i++){
            if (!isPrime[i]){
                primes[cnt++] = i;
                x[i] = qpow(i, up, mod);
                g[i] = gcd(1L * i, k);
            }
            for(int j = 0; i * primes[j] <= n; j++){
                isPrime[i * primes[j]] = true;
                x[i * primes[j]] = (int) ((1L * x[i] * x[primes[j]]) % mod);
                if (i % primes[j] == 0){
                    g[i * primes[j]] = g[i];
                    if ((k / g[i]) % primes[j] == 0) g[i * primes[j]] *= primes[j];
                    break;
                }
                g[i * primes[j]] = g[i] * g[primes[j]];
            }
            res = (int) ((res + 1L * g[i] * x[i]) % mod);
        }
        return res;
    }
 
    public static int gcd(long a, long b) {
        return b == 0 ? (int) a : gcd(b, a % b);
    }
}
Python code by me
from math import gcd
from array import array
 
MOD = 1000000007
 
n, k = map(int, input().split())
 
primes = array('i')
cnt = 0
x = array('i', [0] * (n + 1))
g = array('i', [0] * (n + 1))
is_prime = array('b', [False] * (n + 1))
 
res = 1
up = k % (MOD - 1)
for i in range(2, n + 1):
    if not is_prime[i]:
        primes.append(i)
        cnt += 1
        x[i] = pow(i, up, MOD)
        g[i] = gcd(i, k)
 
    for p in primes:
        if i * p > n: break
        is_prime[i * p] = True
        x[i * p] = x[i] * x[p] % MOD
        if i % p == 0:
            g[i * p] = g[i]
            if (k // g[i]) % p == 0:
                g[i * p] *= p
            break
        g[i * p] = g[i] * g[p]
    res = (res + g[i] * x[i]) % MOD
print(res)

J.三门问题

题解

在最初每扇门后有轿车的概率都是$\frac{1}{n}$,选手选了一扇门,因此轿车不在选手选的门后的概率为$\frac{n-1}{n}$.

主持人打开了一扇门,此时相当于剩下的$n-2$个选手没有选的门均分了$\frac{n-1}{n}$的概率,所以应该换门,新的开走轿车的概率为$\frac{n-1}{n(n-2)}$.

C++ code by me
#include<iostream>
#include<cstring>
#include<cassert>
#include<algorithm>
using namespace std;
using LL = long long;
 
int main(){
 
    cin.tie(0);
    cout.tie(0);
    ios::sync_with_stdio(0);
 
    int T;
    cin >> T;
    while(T--){
        int n;
        cin >> n;
        assert(__gcd(n - 1, n * (n - 2)) == 1);
        cout << n - 1 << ' ' << n * (n - 2) << '\n';
    }
 
}
Java code by ChatGPT
import java.util.*;
import java.lang.*;
 
public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
 
        int T = sc.nextInt();
        while(T-- > 0){
            int n = sc.nextInt();
            assert gcd(n - 1, n * (n - 2)) == 1;
            System.out.println((n - 1) + " " + n * (n - 2));
        }
        sc.close();
    }
 
    static int gcd(int a, int b){
        if (b == 0) return a;
        return gcd(b, a % b);
    }
}
Python code by me
for _ in range(int(input())):
    n = int(input())
    print(n - 1, n * (n - 2))