[算法学习笔记] O(nlogn)求最长上升子序列

发布时间 2023-08-22 19:49:08作者: SXqwq

朴素 dp 求最长上升子序列

大家应该都会朴素 dp 求最长上升子序列,简单回忆一下。

我们令 \(f_i\) 表示以 第 \(i\) 位元素为结尾的最长上升子序列长度。满足 \(\forall j < i\),则有:

\(f_i = max(f_i,f_j+1)[a_j < a_i]\)

Explanation : \(a_i\) 前面若有多个可以拼接的序列,则拼一个长度最大的。

实现方式如下:

朴素 dp 求最长上升子序列代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#define N 100010
using namespace std;
int n;
int a[N];
int f[N];
int maxn = -1;
int main()
{
	scanf("%d",&n);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]),f[i] = 1;
	for(int i=1;i<=n;i++)
	{
		for(int j=1;j<=i;j++) 
		{
			if(a[j] < a[i]) f[i] = max(f[i],f[j]+1);
		}
		maxn = max(maxn,f[i]);
	}
	cout<<maxn<<endl;
	return 0;
 } 

O(nlogn) 二分优化

上述方式求最长上升子序列的时间复杂度是 \(O(n^2)\) 的。我们考虑如何优化。

不妨令 \(f_i\) 表示最长上升子序列的长度为 \(i\) 的时候序列末尾最小的元素。这里渗透了贪心的思想,我们一定希望末尾元素最小,以便更好的拼接后面的元素。

考虑转移,设 \(len\) 表示当前的最长上升子序列长度。

  • 对于 \(a_i > f_{len}\),我们先拼接,因为在前 \(i\) 位中这就是最优解。也就是 \(f_{++len}=a_i\)

  • 对于 \(a_i \leq f_{len}\),我们显然需要在 \(f\) 中找到一个可以拼上它的地方。由于是最长上升子序列,\(f\) 数组中的内容一定是 单调递增的,因此可以二分查找。这也是把 \(O(n^2)\) 算法优化成 \(O(nlogn)\) 算法的关键。

具体地,若 \(a_i < f_l\),则 \(f_l=min(f_l,a_i)\)

最后输出 \(len\) 即可。时间复杂度 \(O(nlogn)\)

实现代码
    for(int i=1;i<=n;i++)
    {
        int l = 0,r = len;
        if(b[i] > f[len]) f[++len] = b[i];
        else
        {
            int mid;
            while(l < r) //二分查找
            {
                mid = (l+r) / 2;
                if(f[mid] > b[i]) r = mid;
                else l = mid + 1;
            }
            f[l] = min(f[l],b[i]);
        }
    }

典例:Luogu P1439 【模板】最长公共子序列

Problem

首先介绍一下朴素 \(O(n^2)\) 求最长公共子序列,然后再讲解本题。

设计状态: \(f_{i,j}\) 表示 \(a\) 数组前 \(i\) 位,\(b\) 数组前 \(j\) 位的最长公共子序列长度。

转移:\(n^2\) 枚举 \(a,b\)。对于 \(a_i,b_j\)

  • \(a_i = b_j\) ,考虑拼接。即 \(f_{i,j}=f_{i-1,j-1}+1\)

  • \(a_i \ne b_j\),无法拼接,考虑继承。即 \(f_{i,j}=max(f_{i-1,j},f_{i,j-1})\)

实现代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#define N 10001
using namespace std;
int a[N],b[N];
int n;
int f[N][N];
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++) scanf("%d",&a[i]);
    for(int i=1;i<=n;i++) scanf("%d",&b[i]);
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=n;j++) 
        {
            if(a[i] ==b[j]) f[i][j] = f[i-1][j-1] + 1;
            else f[i][j] = max(f[i-1][j],f[i][j-1]);
        }
    }
    cout<<f[n][n]<<endl;
    return 0;
}

回到本题。

对于 \(100\%\) 的数据, \(n \le 10^5\)

朴素 \(n^2\) 求最长公共子序列显然无法接受。

如何处理呢?本题保证了输入的数据是排列,即内容为 \(1,2,3...n\)(不保证顺序)的数组。
这样显然数组 \(a,b\) 的内容是完全相同的,只是位置不一定相同。

不妨设 \(num_i\) 表示 \(a_i\)\(b_i\) 的编号。

显然 \(b\) 的编号是严格递增的,若 \(num\) 的一部分严格递增,则这部分可以纳入 LCS(也就是说这部分属于公共子序列)。这样我们就只需要求 \(num\) 数组的最长公共子序列即可。

实现代码
#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 1000010;
const int INF = 0x3f3f3f3f;
int num[N];
int a[N],b[N];
int f[N];
int n;
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    cin>>n;
    for(int i=1;i<=n;i++) cin>>a[i],num[a[i]] = i;
    for(int i=1;i<=n;i++) cin>>b[i],f[i] = INF;
    int len = 0;
    f[0] = 0;
    for(int i=1;i<=n;i++)
    {
        int l = 0,r = len;
        if(num[b[i]] > f[len]) f[++len] = num[b[i]];
        else
        {
            int mid;
            while(l < r)
            {
                mid = (l+r) / 2;
                if(f[mid] > num[b[i]]) r = mid;
                else l = mid + 1;
            }
            f[l] = min(f[l],num[b[i]]);
        }
    }
    cout<<len<<endl;
    return 0;
}