高级数据结构--树状数组

发布时间 2023-10-03 17:45:11作者: chfychin

一维树状数组

单点修改-区间查询

点击查看代码
#include <bits/stdc++.h>
#define IOS ios::sync_with_stdio(false);cin.tie(nullptr),cout.tie(nullptr);
#define int long long

using namespace std;

const int N = 1e6 + 10;

int a[N], tr[N];
int n, m;

int lowbit(int x)
{
	return x & -x;
}

void add(int x, int y)
{
	for(int i = x; i <= n; i += lowbit(i))
		tr[i] += y;
}

int sum(int x)
{
	int ans = 0;
	for(int i = x; i; i -= lowbit(i))
		ans += tr[i];
	return ans;
}

int query(int l, int r)
{
	return sum(r) - sum(l - 1);
}

void solve()
{
	cin >> n >> m;
	for(int i = 1; i <= n; i ++)
		cin >> a[i];
		
	for(int i = 1; i <= n; i ++)
		add(i, a[i]);
	int f, l, r;
	while(m --)
	{
		cin >> f >> l >> r;
		if(f == 1)
			add(l, r);
		else
			cout << query(l, r) << endl;
	}
}

signed main()
{
	IOS;
	int _ = 1;
	// cin >> _;
	while(_ --)
		solve();
	return 0;
}

区间修改-单点查询

点击查看代码
#include <bits/stdc++.h>
#define IOS ios::sync_with_stdio(false);cin.tie(nullptr),cout.tie(nullptr);
#define int long long

using namespace std;

const int N = 1e6 + 10;

int a[N], tr[N];
int n, m;

int lowbit(int x)
{
	return x & -x;
}

void add(int x, int y)
{
	for(int i = x; i <= n; i += lowbit(i))
		tr[i] += y;
}

int sum(int x)
{
	int ans = 0;
	for(int i = x; i; i -= lowbit(i))
		ans += tr[i];
	return ans;
}

void solve()
{
	cin >> n >> m;
	
	for(int i = 1; i <= n; i ++)
		cin >> a[i];
	
	for(int i = 1; i <= n; i ++)
		add(i, a[i] - a[i - 1]);
	
	int f, l, r, x;
	while(m --)
	{
		cin >> f;
		if(f == 1)
		{
			cin >> l >> r >> x;
			add(l, x), add(r + 1, -x);
		}
		else
		{
			cin >> l;
			cout << sum(l) << '\n';
		}
	}
}

signed main()
{
	IOS;
	int _ = 1;
	// cin >> _;
	while(_ --)
		solve();
	return 0;
}

区间修改-区间查询

点击查看代码
#include <bits/stdc++.h>
#define IOS ios::sync_with_stdio(false);cin.tie(nullptr),cout.tie(nullptr);
#define int long long

using namespace std;

const int N = 1e6 + 10;

int a[N];
int tr[N], pretr[N];
int n, m;

int lowbit(int x)
{
	return x & -x;
}

void add(int tr[], int x, int y)
{
	for(int i = x; i <= n; i += lowbit(i))
		tr[i] += y;
}

int sum(int tr[], int x)
{
	int ans = 0;
	for(int i = x; i; i -= lowbit(i))
		ans += tr[i];
	return ans;
}

int ask(int x)
{
	return sum(tr, x) * (x + 1) - sum(pretr, x);
}

void solve()
{
	cin >> n >> m;
	for(int i = 1; i <= n; i ++)
		cin >> a[i];
	
	for(int i = 1; i <= n; i ++)
	{
		int x = a[i] - a[i - 1];
		add(tr, i, x);
		add(pretr, i, x * i);
	}

	int l, r, x, f;
	while(m --)
	{
		cin >> f >> l >> r;
		if(f == 1)
		{
			cin >> x;
			add(tr, l, x);
			add(tr, r + 1, -x);
			add(pretr, l, l * x);
			add(pretr, r + 1, (r + 1) * (-x));
		}
		else
		{
			printf("%lld\n", ask(r) - ask(l - 1));
		}
	}
}

signed main()
{
	IOS;
	int _ = 1;
	// cin >> _;
	while(_ --)
		solve();
	return 0;
}

二维树状数组

单点修改-区间查询

点击查看代码
#include<bits/stdc++.h>
#define IOS ios::sync_with_stdio(false);cin.tie(nullptr),cout.tie(nullptr)
#define int long long

using namespace std;

const int N = 5010;

int tr[N][N];
int n, m;

int lowbit(int x)
{
    return x & -x;
}

void get(int x, int y, int k)
{
    for(int i = x; i <= n; i += lowbit(i))
        for(int j = y; j <= m; j += lowbit(j))
            tr[i][j] += k;
}

int sum(int x, int y)
{
    int ans = 0;
    for(int i = x; i; i -= lowbit(i))
        for(int j = y; j; j -= lowbit(y))
            ans += tr[i][j];
    return ans;
}

int query(int x1, int y1, int x2, int y2)
{
    return sum(x2, y2) - sum(x2, y1 - 1) - sum(x1 - 1, y2) + sum(x1 - 1, y1 - 1);
}

void solve()
{
    int x1, y1, x2, y2;
    int x, y, k;
    int f;
    x1 = y1 = x2 = y2 = x = y = k = f = 0;
    cin >> n >> m;
    while(cin >> f)
    {
        if(f == 1)
        {
            cin >> x >> y >> k;
            get(x, y, k);
        }
        else
        {
            cin >> x1 >>  y1 >> x2 >> y2;
            cout << query(x1, y1, x2, y2) << '\n';
        }
    }
}

signed main()
{
    IOS;
    int _ = 1;
    // cin >> _;
    while(_ --)
        solve();
    return _ ^ _;
}

区间修改-区间查询

点击查看代码
#include<bits/stdc++.h>
#define IOS ios::sync_with_stdio(false);cin.tie(nullptr),cout.tie(nullptr)
#define int long long

using namespace std;

const int N = 2100;

int tr1[N][N], tr2[N][N], tr3[N][N], tr4[N][N];
int n, m;

int lowbit(int x)
{
    return x & -x;
}

void add(int x, int y, int d)
{
    for(int i = x; i <= n; i += lowbit(i))
        for(int j = y; j <= m; j += lowbit(j))
        {
        	tr1[i][j] += d;
        	tr2[i][j] += x * d;
        	tr3[i][j] += y * d;
        	tr4[i][j] += x * y * d;
        }
}

int sum(int x, int y)
{
    int ans = 0;
    for(int i = x; i; i -= lowbit(i))
        for(int j = y; j; j -= lowbit(j))
        	ans += (x + 1) * (y + 1) * tr1[i][j] - (x + 1) * tr3[i][j] - (y + 1) * tr2[i][j] + tr4[i][j];
    return ans;
}

int query(int x1, int y1, int x2, int y2)
{
    return sum(x2, y2) - sum(x2, y1 - 1) - sum(x1 - 1, y2) + sum(x1 - 1, y1 - 1);
}

void get(int x1, int y1, int x2, int y2, int x)
{
	add(x1, y1, x);
    add(x1, y2 + 1, -x);
    add(x2 + 1, y1, -x);
    add(x2 + 1, y2 + 1, x);
}

void solve()
{
    int x1, y1, x2, y2;
    int x, y, k;
    int f;
    x1 = y1 = x2 = y2 = x = y = k = f = 0;
    cin >> n >> m;
    while(cin >> f)
    {
        if(f == 1)
        {
            cin >> x1 >> y1 >> x2 >> y2 >> k;
            get(x1, y1, x2, y2, k);
        }
        else
        {
            cin >> x1 >> y1 >> x2 >> y2;
            cout << query(x1, y1, x2, y2) << '\n';
        }
    }
}

signed main()
{
    IOS;
    int _ = 1;
    // cin >> _;
    while(_ --)
        solve();
    return _ ^ _;
}