P10009 [集训队互测 2022] 线段树 题解

发布时间 2024-01-02 17:48:11作者: Athanasy

题目链接:P10009 [集训队互测 2022] 线段树

神仙分块题,先给一下出题人的神仙官解:官解

前面还看得懂。后面是啥?这不是 ds 题咋和 dp、轮廓线扯上关系了。看了半天,还是这个启发了我:

其手玩下,在 Excel 里写一下,可以理解到这里其实是想表达的一个核心意思是啥:对于一组序列而言,我们对操作 \(1\) 进行 \(2^k\) 次,很容易发现一个性质此时最终的数组和原数组有以下关系:

\[last_i=a_i \ (i<2^k),last_i=a_i \oplus a_{i-2^k} \ (i>=2^k) \]

证明可以参照 pdf 里面说的,转化为网格图,然后转变为路径数问题,考虑一个点是否计入最终点的贡献。因为异或两次就为 \(0\),所以只需要计算从这个点出发到查询点的路径数的数量再模 \(2\),而路径数:

这点还是很好理解的。而 \(diff_y\) 就是操作数,\(diff_x\) 则是步长。假如操作数为 \(2^k\)。温习下 Lucas 定理:

很显然的是由卢卡斯定理我们有:

\[{diff_y \choose diff_x }\mod 2={2^{k-1} \choose {\lfloor \dfrac{diff_x}{2} \rfloor}} \times {0 \choose diff_x \bmod 2} \mod2 \]

\[={1 \choose \lfloor \dfrac{diff_x}{2^k} \rfloor} \times \prod_{i=0}^{k-1}{0 \choose \lfloor \dfrac{diff_x}{2^i}\rfloor \mod 2} \mod 2 \]

\[=\prod_{i=0}^{k-1}{0 \choose \lfloor \dfrac{diff_x}{2^i}\rfloor \mod 2} \mod 2 \]

显然需要有每个 \(\lfloor \dfrac{diff_x}{2^i}\rfloor \mod 2==0\) 才能使结果为 \(1\)。容易知道 \(diff_x=2^t(t>k-1)\)

接下来只需要证明 \(diff_x=2^t (t>k)\) 都是无贡献的。换句话来说,对于 \(a_i\) 来说,操作了 \(2^k\) 次方次以后,\(a_j(j \in [1,i-2^k)\ )\)\(a_i\) 无贡献。非常简单,考虑 \(a_1\) ,第一次操作会影响 \(a_2\),第二次影响 \(a_3\),以此类推每次影响的对象往右平移一位。其实每次操作可以看做整体往右平移一位再异或。

最初始的 \(a_1\) 会随着每轮操作而影响对象向右移动一位。所以此时此刻的最大的 \(a_j\) 在进行了 \(2^k\) 操作以后,最多只会 \(a_j 影响 a_{i-1} (j==2^k-1)\)。所以上式当且仅当 \(diff_x==2^k\) 时才为 \(1\),才有贡献。

\(2^k\) 次操作该如何解决。答案是倍增/二进制分解,把它分解为若干个 \(2^t\) 次操作累计就行了。这样一来我们就解决了“整块操作”。散块暴力即可。结束。

算法框架及其细节

首先,需要注意一点,既然都上分块了,那么显而易见的 \(new_i=a_i \oplus a_{i-step}\) 里的 \(step\) 最多为 \(\sqrt{n}\) 会跨个一整块。实际上是我们可以考虑每次两块两块地处理,这样写较为方便。而什么时候有可能达到这个数字,显然是当修改达到至少 \(\sqrt{n}\) 次时才有可能,所以我们可以考虑每出现 \(\sqrt{n}\) 次修改再处理查询。

其次,对于查询而言,我们可以遍历每个查询和每个块,记录整块操作数,遇到散块就直接处理完整块的就行了。这一部分外层遍历的复杂度为 \(O(\sqrt{n} 个查询 \times \sqrt{n} 块)=O(n)\),再套一个处理完整块的操作次数显然最坏应该为:

\[ \sqrt{n}+\frac{\sqrt{n}}{2}+\frac{\sqrt{n}}{2^2}+\frac{\sqrt{n}}{2^3}+\frac{\sqrt{n}}{2^4}+....\frac{\sqrt{n}}{2^{\sqrt{n}}}\approx 2\sqrt{n} \]

所以单次查询的最坏复杂度 \(O(\sqrt{n})\)

而又因为这 \(\sqrt{n}\) 个查询,全是最坏的散块处理,整块处理我们可以累计到下次 \(op==2\) 时进行统一处理。实际上摊还分析一下,一次处理 \(\sqrt{n}\) 个查询的复杂度单个是 \(\sqrt{n}\),总复杂度是就是 \(O(n)\)

又因为最多有 $\lceil \frac{q}{\sqrt{n}} \rceil $ 个这种查询,所以理想复杂度是近似于 \(O(n\sqrt{n})\)? \(n\)\(q\) 差不多是一个数量级,把它俩当相等。

参考代码
#include <bits/stdc++.h>

//#pragma GCC optimize("Ofast,unroll-loops")

#define isPbdsFile

#ifdef isPbdsFile

#include <bits/extc++.h>

#else

#include <ext/pb_ds/priority_queue.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/trie_policy.hpp>
#include <ext/pb_ds/tag_and_trait.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/list_update_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/exception.hpp>
#include <ext/rope>

#endif

using namespace std;
using namespace __gnu_cxx;
using namespace __gnu_pbds;
typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef tuple<int, int, int> tii;
typedef tuple<ll, ll, ll> tll;
typedef unsigned int ui;
typedef unsigned long long ull;
typedef __int128 i128;
#define hash1 unordered_map
#define hash2 gp_hash_table
#define hash3 cc_hash_table
#define stdHeap std::priority_queue
#define pbdsHeap __gnu_pbds::priority_queue
#define sortArr(a, n) sort(a+1,a+n+1)
#define all(v) v.begin(),v.end()
#define yes cout<<"YES"
#define no cout<<"NO"
#define Spider ios_base::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);
#define MyFile freopen("..\\input.txt", "r", stdin),freopen("..\\output.txt", "w", stdout);
#define forn(i, a, b) for(int i = a; i <= b; i++)
#define forv(i, a, b) for(int i=a;i>=b;i--)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
#define endl '\n'
//用于Miller-Rabin
[[maybe_unused]] static int Prime_Number[13] = {0, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};

template <typename T>
int disc(T* a, int n)
{
    return unique(a + 1, a + n + 1) - (a + 1);
}

template <typename T>
T lowBit(T x)
{
    return x & -x;
}

template <typename T>
T Rand(T l, T r)
{
    static mt19937 Rand(time(nullptr));
    uniform_int_distribution<T> dis(l, r);
    return dis(Rand);
}

template <typename T1, typename T2>
T1 modt(T1 a, T2 b)
{
    return (a % b + b) % b;
}

template <typename T1, typename T2, typename T3>
T1 qPow(T1 a, T2 b, T3 c)
{
    a %= c;
    T1 ans = 1;
    for (; b; b >>= 1, (a *= a) %= c)if (b & 1)(ans *= a) %= c;
    return modt(ans, c);
}

template <typename T>
void read(T& x)
{
    x = 0;
    T sign = 1;
    char ch = getchar();
    while (!isdigit(ch))
    {
        if (ch == '-')sign = -1;
        ch = getchar();
    }
    while (isdigit(ch))
    {
        x = (x << 3) + (x << 1) + (ch ^ 48);
        ch = getchar();
    }
    x *= sign;
}

template <typename T, typename... U>
void read(T& x, U&... y)
{
    read(x);
    read(y...);
}

template <typename T>
void write(T x)
{
    if (typeid(x) == typeid(char))return;
    if (x < 0)x = -x, putchar('-');
    if (x > 9)write(x / 10);
    putchar(x % 10 ^ 48);
}

template <typename C, typename T, typename... U>
void write(C c, T x, U... y)
{
    write(x), putchar(c);
    write(c, y...);
}


template <typename T11, typename T22, typename T33>
struct T3
{
    T11 one;
    T22 tow;
    T33 three;

    bool operator<(const T3 other) const
    {
        if (one == other.one)
        {
            if (tow == other.tow)return three < other.three;
            return tow < other.tow;
        }
        return one < other.one;
    }

    T3() { one = tow = three = 0; }

    T3(T11 one, T22 tow, T33 three) : one(one), tow(tow), three(three)
    {
    }
};

template <typename T1, typename T2>
void uMax(T1& x, T2 y)
{
    if (x < y)x = y;
}

template <typename T1, typename T2>
void uMin(T1& x, T2 y)
{
    if (x > y)x = y;
}

constexpr int N = 2.5e5 + 10;
constexpr int SIZE = sqrt(N);
constexpr int CNT = (N + SIZE - 1) / SIZE + 1;
int pos[N];
int s[CNT], e[CNT];
int pre[N]; //上一轮序列
int nxt[N]; //下一轮序列
int tmp[N]; //临时序列
int ans[N]; //答案
int n, q;
int siz, cnt; //块大小,块数量
struct Query
{
    int op, l, r, id;
} qu[N];

//整块更新,2^i次方就直接tmp[x]^=tmp[x-2^i],否则二进制拆分倍增
inline void allUpdate(const int id, int updateCnt)
{
    while (updateCnt)
    {
        int t = log2(updateCnt);
        int step = 1 << t; //步长
        //两块两块处理,因为updateCnt最多为sqrt(n)
        forv(i, e[id], s[id-1]+step)tmp[i] ^= tmp[i - step];
        updateCnt -= step;
    }
}

//处理操作
inline void update(const int queryCnt)
{
    //遍历每个块进行贡献更新
    forn(idx, 1, cnt)
    {
        int blockCnt = 0; //整块操作次数
        //两块两块进行处理,对于当前块,同时拿到它之前的块进行一并处理方便a[i]^=a[i-step],step=sqrt(n)。
        forn(i, s[idx-1], e[idx])tmp[i] = pre[i];
        forn(curr, 1, queryCnt)
        {
            if (auto [op,l,r,id] = qu[curr]; op == 1)
            {
                const int L1 = s[idx - 1]; //前一块边界
                const int R2 = e[idx]; //当前块边界
                //完全包括
                if (l <= L1 and R2 <= r)
                {
                    ++blockCnt;
                    continue;
                }
                //部分包括,有交集,暴力
                if (l <= R2 and r >= L1)
                {
                    //先把之前的整块更新改了
                    allUpdate(idx, blockCnt);
                    blockCnt = 0;
                    //起点需要+1
                    forv(i, min(r,R2), max(l+1,L1+1))tmp[i] ^= tmp[i - 1];
                }
            }
            else
            {
                //计算当前块的贡献
                if (s[idx] <= l and l <= e[idx])
                {
                    allUpdate(idx, blockCnt);
                    blockCnt = 0;
                    ans[id] = tmp[l];
                }
            }
        }
        allUpdate(idx, blockCnt); //处理还未处理的整块更新
        forn(i, s[idx], e[idx])nxt[i] = tmp[i]; //下一轮的序列
    }
    forn(i, 1, n)pre[i] = nxt[i];
    forn(i, 1, queryCnt)if (qu[i].op == 2)cout << ans[qu[i].id] << endl;
}

inline void solve()
{
    cin >> n >> q;
    siz = sqrt(n);
    cnt = (n + siz - 1) / siz;
    forn(i, 1, n)cin >> pre[i], pos[i] = (i - 1) / siz + 1;
    s[0] = 1;
    forn(i, 1, cnt)s[i] = (i - 1) * siz + 1, e[i] = i * siz;
    e[cnt] = n;
    int updateCnt = 0; //修改次数
    int queryCnt = 0; //待处理的操作次数
    forn(i, 1, q)
    {
        cin >> qu[++queryCnt].op;
        qu[queryCnt].id = i;
        if (qu[queryCnt].op == 1)
        {
            cin >> qu[queryCnt].l >> qu[queryCnt].r;
            if (++updateCnt == siz)update(queryCnt), updateCnt = 0, queryCnt = 0;
        }
        else cin >> qu[queryCnt].l;
    }
    update(queryCnt);
    forn(i, 1, n)cout << pre[i] << endl;
}

signed int main()
{
    Spider
    //------------------------------------------------------
    int test = 1;
    //    read(test);
    cin >> test;
    test = 1;
    forn(i, 1, test)solve();
    //    while (cin >> n, n)solve();
    //    while (cin >> test)solve();
}

PS:复杂度部分可能最终分析的有点乱,但实测下来确实跑得非快。官解剩余的神仙dp确实晦涩,后续如果看懂也会补充做法。文章证明如果有瑕疵或者不正确的地方,欢迎指出。