CF1699C

发布时间 2023-04-24 19:45:06作者: untitled0

题面

前排提醒:这是本题最烂的做法,代码长达 91 行。

由于题里提到了 \(\operatorname{MEX}\),所以考虑该怎么求 \(\operatorname{MEX}\)

我们使用主席树。由于主席树难以直接处理下标为 \(0\) 的情况,所以给数组中的每个值都加上一个偏移量 \(1\),同时下文提到的 \(\operatorname{MEX}\) 也都是题目里定义的 \(\operatorname{MEX}+1\)

主席树记下区间内元素出现的次数后,在树上对值域二分。如果左儿子中各元素出现的次数小于左儿子代表的值域区间的长度,则说明左边至少有一个元素没有出现,因此 \(\operatorname{MEX}\) 一定在左儿子的区间,向左递归查找,否则向右递归查找。

值得注意的是,一个长为 \(len\) 的区间 \(\operatorname{MEX}\) 可能为 \(len+1\) ,需要进行特判。

这是主席树部分的代码:

struct node {
    int l, r, cnt, ls, rs;
    int mid() { return (l + r) / 2; }
    int len() { return r - l + 1;}
}nd[N * 24];//N是数据范围
int rt[N], tot;
int build(int l, int r) {
    int p = ++tot;
    nd[p].l = l, nd[p].r = r;
    if(l == r) return p;
    int mid = (l + r) / 2;
    nd[p].ls = build(l, mid);
    nd[p].rs = build(mid + 1, r);
    return p;
}
void insert(int& p, int ori, int loc) {
    p = ++tot;
    nd[p] = nd[ori];
    nd[p].cnt++;
    if(nd[p].l == nd[p].r) return;
    if(loc <= nd[p].mid()) insert(nd[p].ls, nd[ori].ls, loc);
    else insert(nd[p].rs, nd[ori].rs, loc);
}
int mex(int lp, int rp) {
    if(nd[lp].l == nd[lp].r) {
        if(nd[rp].cnt - nd[lp].cnt) return nd[lp].l + 1;
        else return nd[lp].l;
    }
    if(nd[nd[rp].ls].cnt - nd[nd[lp].ls].cnt < nd[nd[lp].ls].len()) 
        return mex(nd[lp].ls, nd[rp].ls);
    else 
        return mex(nd[lp].rs, nd[rp].rs);
}

这样,通过调用 mex(rt[l-1],rt[r]) 就可以求出 \(\operatorname{MEX}([a_l,a_{l+1}\ldots a_{r-1},a_r])\) 辣!

接下来进入核心做法部分:

首先 \(1\) 所在位置一定不变,否则 \(b\) 中长为 \(1\) 的区间的 \(\operatorname{MEX}\) 会和原排列不完全相同。

对于 \(a\) 中任意一个区间 \([l,r]\),小于其 \(\operatorname{MEX}\) 值的所有数字都只能在该区间内随意排列。否则得到的新区间的 \(\operatorname{MEX}\) 会比原区间小,不符合题意。

我们锁定一个区间,设当前区间为 \([l,r]\),并且只能在当前区间内排列的数字所贡献的方案数都已经计算完。

设当前区间的 \(\operatorname{MEX}\) 值为 \(m\),寻找 \(m\)\(a\) 中的位置 \(p\),若 \(p<l\) 则将区间 \([l,r]\) 扩张到 \([p,r]\),否则扩张到 \([l,p]\)。设新区间是 \([l',r']\)

对于扩张的新区间,设它的 \(\operatorname{MEX}\) 值为 \(m'\),那么小于 \(m'\) 的所有数字都只能在这个区间内排列。区间长为 \(r'-l'+1\),其中原区间 \([l,r]\) 已经确定了 \(m-1\) 个数字的位置,又因为数字 \(m\) 的位置不能改变,则对答案有贡献的数字共有 \(m'-m-1\) 个,这些数字可以随意排列的空位共有 \(r'-l'+1-m\) 个,因此这些数字对答案的贡献为 \(P_{r'-l'+1-m}^{m'-m-1}\),乘到答案上即可。

区间扩张为整个排列后就不能再继续了,因此这是迭代出口。

正确性可以这样理解:如果选择其他区间,很难得出关于数字位置的信息,也就难以计算。因此选择上述扩张区间的方式,可以找到所有能够确定一些数字的位置的区间,进而进行计算。

时间复杂度 \(O(n\log n)\)

核心代码:

//P(m,n)的作用是求出m!/(m-n)!,即排列数
//loc[i]代表元素i在原数组中的位置
//mod就是题里给的模数1e9+7
//oldmex即上文的m,newmex即上文的m'
ll ans = 1;
int l = loc[1], r = loc[1], oldmex = 2;
while(r - l + 1 < n) {
    if(loc[oldmex] < l) l = loc[oldmex];
    else r = loc[oldmex];
    int newmex = mex(rt[l - 1], rt[r]);
    ans = ans * P(r - l + 1 - oldmex, newmex - oldmex - 1) % mod;
    oldmex = newmex;
}
cout << ans << endl;

谢谢大家能耐心看完我的烂做法,以下是完整代码:(\(91\) 行)

#include<bits/stdc++.h>
#define endl '\n'
using namespace std;
typedef long long ll;
constexpr int N = 1e5 + 5;
constexpr ll mod = 1e9 + 7;
int n, loc[N], rt[N], tot;
ll fact[N];
ll qp(ll a, ll b) { // 快速幂
    ll ans = 1;
    for(; b; b >>= 1, a = a * a % mod)
        if(b & 1) ans = ans * a % mod;
    return ans;
}
ll P(ll m, ll n) { // 排列数,fact[i]即i!
    return fact[m] * qp(fact[m - n], mod - 2) % mod;
}
struct node {
    int l, r, cnt, ls, rs;
    int mid() { return (l + r) / 2; }
    int len() { return r - l + 1;}
}nd[N * 24];
int build(int l, int r) {
    int p = ++tot;
    nd[p].l = l, nd[p].r = r;
    if(l == r) return p;
    int mid = (l + r) / 2;
    nd[p].ls = build(l, mid);
    nd[p].rs = build(mid + 1, r);
    return p;
}
void insert(int& p, int ori, int loc) {
    p = ++tot;
    nd[p] = nd[ori];
    nd[p].cnt++;
    if(nd[p].l == nd[p].r) return;
    if(loc <= nd[p].mid()) insert(nd[p].ls, nd[ori].ls, loc);
    else insert(nd[p].rs, nd[ori].rs, loc);
}
int mex(int lp, int rp) {
    if(nd[lp].l == nd[lp].r) {
        if(nd[rp].cnt - nd[lp].cnt) return nd[lp].l + 1;
        else return nd[lp].l;
    }
    if(nd[nd[rp].ls].cnt - nd[nd[lp].ls].cnt < nd[nd[lp].ls].len()) 
        return mex(nd[lp].ls, nd[rp].ls);
    else 
        return mex(nd[lp].rs, nd[rp].rs);
}
void solve() {
    cin >> n;
    rt[0] = build(1, n);
    for(int i = 1; i <= n; ++i) {
        int x; cin >> x;
        loc[x + 1] = i;
        insert(rt[i], rt[i - 1], x + 1);
    }
    if(n == 1) {
        cout << 1 << endl;
        return;
    }
    ll ans = 1;
    int l = loc[1], r = loc[1], oldmex = 2;
    while(r - l + 1 < n) {
        if(loc[oldmex] < l) l = loc[oldmex];
        else r = loc[oldmex];
        int newmex = mex(rt[l - 1], rt[r]);
        ans = ans * P(r - l + 1 - oldmex, newmex - oldmex - 1) % mod;
        oldmex = newmex;
    }
    cout << ans << endl;
}
void clear() { // 多测不清空,爆零两行泪
    memset(loc + 1, 0, n * sizeof(int));
    memset(rt, 0, (n + 1) * sizeof(int));
    memset(nd + 1, 0, tot * sizeof(node));
    tot = 0;
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    fact[0] = 1;
    for(int i = 1; i < N; ++i)
        fact[i] = fact[i - 1] * i % mod;
    int t; cin >> t;
    while(t--) {
        solve();
        clear();
    }
    return 0;
}