树状数组和线段树

发布时间 2023-09-27 23:59:28作者: wxzcch

今天太幸运了!硬啃把模板啃下来!

树状数组

解决的本质问题

树状数组解决的本质问题只有一个:

  • 单点改动、区间求值

其他的问题,都是可以转化到该问题上的。

代码板子重点操作

  1. lowbit操作
int lowbit(int x){ return x & -x;}
  1. add添加一个值操作
void add(int a,int v){for(int i = a; i <= n; i += lowbit(i) tr[i] += v;}
  1. query求区间值操作
int query(int r)
{
   int sum = 0;
   for(int i = a; i ; i -= lowbit(i)) sum += tr[i];
   return sum;
}

例题——动态求区间和

#include <iostream>
#include <algorithm>
#include <unordered_map>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <vector>
#include <queue>
#include <map>
#include <set>
#define x first
#define y second
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;

const int Mod = 1e9 + 7;
const int INF = 0x3f3f3f3f;

ll qmi(ll a, ll b){
    ll res = 1;
    while(b)
    {
        if(b & 1) res = res * a % Mod;
        a = a * a % Mod;
        b >>= 1;
    }
    return res;
}

ll inv(ll p){ return qmi(p, Mod - 2);}
ll mo(ll p){ return (p % Mod + Mod) % Mod;}

const int N = 1e6 + 10;

int n, m;
int w[N];
struct Node
{
    int l, r;
    int val;
}tr[N];


void pushup(int u){tr[u].val = tr[u << 1].val + tr[u << 1 | 1].val;}



void build(int u, int l, int r)
{
    if(l == r) tr[u] = {l, r, w[r]};    // 如果已经是叶子节点了
    else{
        tr[u] = {l, r};
        int mid = (l + r) >> 1;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}

void modify(int u, int a, int v)
{
    if(tr[u].l == tr[u].r) tr[u].val += v;
    else{
        int mid = tr[u].l + tr[u].r >> 1;
        if(a <= mid) modify(u << 1, a, v);
        else modify(u << 1 | 1, a, v);
        pushup(u);
    }
}

int query(int u, int l, int r)
{
    if(tr[u].l >= l && tr[u].r <= r) return tr[u].val;

    int mid = tr[u].l + tr[u].r >> 1;
    int sum = 0;

    // 如果区间是在左边就继续找,在右边就加
    if(l <= mid) sum = query(u << 1, l, r);
    if(r > mid) sum += query(u << 1 | 1, l, r);
    return sum;
}

int main() {

    cin >> n >> m;

    for(int i = 1; i <= n; i++) cin >> w[i];

    build(1, 1, n);
    int k, a, b;
    while(m--)
    {
        cin >> k >> a >> b;
        if(k == 0) cout << query(1, a, b) << endl;
        else modify(1, a, b);
    }
    return 0;
}

线段树

解决的问题

线段树可以解决特别多的,区间问题,区间和、区间最大值、区间......题面的板子性强,一定得多敲,才能熟练

重点操作 —— 这里以动态区间和为例

  1. build构造线段树
void build(int u, int l, int r)
{
    if(l == r) tr[u] = {l, r, w[r]};    // 如果已经是叶子节点了
    else{
        tr[u] = {l, r};
        int mid = (l + r) >> 1;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}
  1. query查询区间操作
int query(int u, int l, int r)
{
    if(tr[u].l >= l && tr[u].r <= r) return tr[u].val;

    int mid = tr[u].l + tr[u].r >> 1;
    int sum = 0;

    // 如果区间是在左边就继续找,在右边就加
    if(l <= mid) sum = query(u << 1, l, r);
    if(r > mid) sum += query(u << 1 | 1, l, r);
    return sum;
}
  1. pushup操作,更新值
void pushup(int u){tr[u].val = tr[u << 1].val + tr[u << 1 | 1].val;}
  1. modify修改线段树操作
void modify(int u, int a, int v)
{
    if(tr[u].l == tr[u].r) tr[u].val += v;
    else{
        int mid = tr[u].l + tr[u].r >> 1;
        if(a <= mid) modify(u << 1, a, v);
        else modify(u << 1 | 1, a, v);
        pushup(u);
    }
}

例题 —— 动态求连续区间和线段树版本

#include <iostream>
#include <algorithm>
#include <unordered_map>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <vector>
#include <queue>
#include <map>
#include <set>
#define x first
#define y second
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;

const int Mod = 1e9 + 7;
const int INF = 0x3f3f3f3f;

ll qmi(ll a, ll b){
    ll res = 1;
    while(b)
    {
        if(b & 1) res = res * a % Mod;
        a = a * a % Mod;
        b >>= 1;
    }
    return res;
}

ll inv(ll p){ return qmi(p, Mod - 2);}
ll mo(ll p){ return (p % Mod + Mod) % Mod;}

const int N = 1e6 + 10;

int n, m;
int w[N];
struct Node
{
    int l, r;
    int val;
}tr[N];


void pushup(int u){tr[u].val = tr[u << 1].val + tr[u << 1 | 1].val;}



void build(int u, int l, int r)
{
    if(l == r) tr[u] = {l, r, w[r]};    // 如果已经是叶子节点了
    else{
        tr[u] = {l, r};
        int mid = (l + r) >> 1;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}

void modify(int u, int a, int v)
{
    if(tr[u].l == tr[u].r) tr[u].val += v;
    else{
        int mid = tr[u].l + tr[u].r >> 1;
        if(a <= mid) modify(u << 1, a, v);
        else modify(u << 1 | 1, a, v);
        pushup(u);
    }
}

int query(int u, int l, int r)
{
    if(tr[u].l >= l && tr[u].r <= r) return tr[u].val;

    int mid = tr[u].l + tr[u].r >> 1;
    int sum = 0;

    // 如果区间是在左边就继续找,在右边就加
    if(l <= mid) sum = query(u << 1, l, r);
    if(r > mid) sum += query(u << 1 | 1, l, r);
    return sum;
}

int main() {

    cin >> n >> m;

    for(int i = 1; i <= n; i++) cin >> w[i];

    build(1, 1, n);
    int k, a, b;
    while(m--)
    {
        cin >> k >> a >> b;
        if(k == 0) cout << query(1, a, b) << endl;
        else modify(1, a, b);
    }
    return 0;
}