最长公共上升子序列

发布时间 2023-03-22 21:09:20作者: onlyblues

最长公共上升子序列

熊大妈的奶牛在小沐沐的熏陶下开始研究信息题目。

小沐沐先让奶牛研究了最长上升子序列,再让他们研究了最长公共子序列,现在又让他们研究最长公共上升子序列了。

小沐沐说,对于两个数列 $A$ 和 $B$,如果它们都包含一段位置不一定连续的数,且数值是严格递增的,那么称这一段数是两个数列的公共上升子序列,而所有的公共上升子序列中最长的就是最长公共上升子序列了。

奶牛半懂不懂,小沐沐要你来告诉奶牛什么是最长公共上升子序列。

不过,只要告诉奶牛它的长度就可以了。

数列 $A$ 和 $B$ 的长度均不超过 $3000$。

输入格式

第一行包含一个整数 $N$,表示数列 $A$,$B$ 的长度。

第二行包含 $N$ 个整数,表示数列 $A$。

第三行包含 $N$ 个整数,表示数列 $B$。

输出格式

输出一个整数,表示最长公共上升子序列的长度。

数据范围

$1 \leq N \leq 3000$,序列中的数字均不超过 $2^{31}−1$。

输入样例:

4
2 2 1 3
2 1 2 3

输出样例:

2

 

解题思路

  有点像最长上升子序列与最长公共子序列问题的结合。

  首先容易想到状态定义$f(i,j)$表示所有由$a$序列的前$i$个数且以$a_i$结尾,和$b$序列的前$j$个数且以$b_j$结尾的公共上升子序列所构成的集合,属性就是公共上升子序列的最大长度。现在公共上升子序列的最后一个数已经确定了(即$a_i$和$b_j$),那么根据前一个数的选择来进行状态划分。状态转移方程就是$$f(i,j) = \max\limits_{\begin{align*} 1 \leq u < i &, \  1 \leq v < j \\ a_u = b_v &, \ a_u < a_i \end{align*}} \ \left\{ {f(u,v)} \right\} + 1$$

  当然还有个前提条件就是$a_i = b_j$,上面的状态转移方程成立。

  很明显上面的dp做法的时间复杂度为$O(n^4)$,同时发现状态转移的部分不好进行优化,因此我们需要改变状态的定义。

  可以发现在上面的状态定义中,只有当$a_i = b_j$时$f(i,j)$才是一个合法的状态,即如果确定以$b_j$为结尾,那么对应的$a_i$也就确定了。因此我们尝试只确定以$b_j$为结尾,第一维的$i$只考虑$a$序列的前$i$个数是否可行(反过来考虑也可以,即确定以$a_i$为结尾,第二维的$j$只考虑$b$序列的前$j$个数)。

  因此定义状态$f(i,j)$表示所有由$a$序列的前$i$个数,和$b$序列的前$j$个数且以$b_j$结尾的公共上升子序列所构成的集合。此时根据公共上升子序列是否包含$a_i$来进行状态划分。如果不包含$a_i$,那么对应的状态集合就是$f(i-1,j)$。如果包含$a_i$,由于是公共上升子序列,此时应该保证$a_i = b_j$,然后在公共上升子序列的$b$序列中前一个数的选择$b_k$继续划分:

  因此考虑所有满足条件的$b_k$所对应的状态集合就是$\bigcup\limits_{\begin{array}{center} k = 1 \\ b_k < b_j \end{array}}^{j-1} {f(i-1,k)}$。

  因此状态转移方程就是$$f(i,j) = \max \left\{ {f(i-1,j), \ \max\limits_{\begin{array}{center} 1 \leq k < j \\ b_k < b_j \end{array}} \left\{ {f(i-1,k)} \right\} } \right\}$$

  然而上面做法的时间复杂度为$O(n^3)$,还是会超时。但发现可以对状态转移的部分进行优化。在枚举$k$的时候本质就是找到满足$b_k < b_j$的$f(i-1,k)$,因此对于固定的$i$,可以开个权值树状数组,在$b_j$处维护$f(i-1,j)$的最大值,每次枚举到$f(i,j)$,那么就在树状数组中所有小于$b_j$的$b_k$中查询得到最大的$f(i-1,k)$。时间复杂度就降到了$O(n^2 \cdot \log{n})$,但由于时间限制为$1s$,有可能会超时。事实上还是会超时,不过还是把代码贴出来:

  TLE代码如下,时间复杂度为$O(n^2 \cdot \log{n})$:

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 
 4 const int N = 3010;
 5 
 6 int a[N], b[N];
 7 int f[N][N];
 8 int xs[N], sz;
 9 int tr[N];
10 
11 int lowbit(int x) {
12     return x & -x;
13 }
14 
15 void add(int x, int c) {
16     for (int i = x; i <= sz; i += lowbit(i)) {
17         tr[i] = max(tr[i], c);
18     }
19 }
20 
21 int query(int x) {
22     int ret = 0;
23     for (int i = x; i; i -= lowbit(i)) {
24         ret = max(ret, tr[i]);
25     }
26     return ret;
27 }
28 
29 int find(int x) {
30     int l = 1, r = sz;
31     while (l < r) {
32         int mid = l + r >> 1;
33         if (xs[mid] >= x) r = mid;
34         else l = mid + 1;
35     }
36     return l;
37 }
38 
39 int main() {
40     int n;
41     scanf("%d", &n);
42     for (int i = 1; i <= n; i++) {
43         scanf("%d", a + i);
44     }
45     for (int i = 1; i <= n; i++) {
46         scanf("%d", b + i);
47         xs[++sz] = b[i];
48     }
49     sort(xs + 1, xs + sz + 1);
50     sz = unique(xs + 1, xs + sz + 1) - xs - 1;
51     for (int i = 1; i <= n; i++) {
52         memset(tr, 0, sizeof(tr));
53         for (int j = 1; j <= n; j++) {
54             f[i][j] = f[i - 1][j];
55             if (a[i] == b[j]) f[i][j] = max(f[i][j], query(find(b[j]) - 1) + 1);
56             add(find(b[j]), f[i - 1][j]);
57         }
58     }
59     int ret = 0;
60     for (int i = 1; i <= n; i++) {
61         ret = max(ret, f[n][i]);
62     }
63     printf("%d", ret);
64     
65     return 0;
66 }

  有个trick就是可以发现每次枚举$j$时,只有$b_j = a_i$时才考虑包含$a_i$的那个集合并进行状态转移,然后$b_k < b_j$就等价于$b_k < a_i$,因此我们只需维护一个前缀最大值$\text{maxf} = \max\limits_{\begin{array}{center} 1 \leq k < j \\ b_k < a_i \end{array}} \{ f(i-1,k) \}$就可以了,当枚举到$j$且$b_j = a_i$,那么就会有$f(i,j) = \text{maxf} + 1$。而如果有$b_j < a_i$则更新$\text{maxf} = \max \{ \text{maxf}, \ f(i-1,j) \}$。直接把状态转移的时间复杂度降到$O(1)$。

  AC代码如下,时间复杂度为$O(n^2)$:

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 
 4 const int N = 3010;
 5 
 6 int a[N], b[N];
 7 int f[N][N];
 8 
 9 int main() {
10     int n;
11     scanf("%d", &n);
12     for (int i = 1; i <= n; i++) {
13         scanf("%d", a + i);
14     }
15     for (int i = 1; i <= n; i++) {
16         scanf("%d", b + i);
17     }
18     for (int i = 1; i <= n; i++) {
19         int maxf = 0;
20         for (int j = 1; j <= n; j++) {
21             f[i][j] = f[i - 1][j];
22             if (b[j] == a[i]) f[i][j] = max(f[i][j], maxf + 1);
23             else if (b[j] < a[i]) maxf = max(maxf, f[i - 1][j]);
24         }
25     }
26     int ret = 0;
27     for (int i = 1; i <= n; i++) {
28         ret = max(ret, f[n][i]);
29     }
30     printf("%d", ret);
31     
32     return 0;
33 }

 

参考资料

  AcWing 272. 最长公共上升子序列(算法提高课):https://www.acwing.com/video/364/