树状数组套主席树

之前学了主席树,解决了静态区间第k大,想着顺便学习一下如何求解动态区间第k大,奈何太菜,看了一晚上未能很好的理解,主要就是询问的时候卡住了我。第二天终于理解了。
这里给出两种方法

  • 方法1

前缀和建树(和静态主席树一样),树状数组更新,树状数组只保存更新带来的影响,每次查询左孩子区间sum的时候需要同时查询 (原来的值+树状数组的影响)

  • 方法2

类似树状数组的方式建树,T[i] 只保存区间[i-lowbit(i)+1,i]内的个数, 更新和查询都类似树状数组,这里树状数组直接保存了更新后的结果,因此询问时只需查询树状数组

方法1的代码在洛谷和ZOJ都能过,方法2就只能过洛谷,ZOJ一直SF。

方法1

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<string>
#include<cmath>
#include<string>
#include<string.h>
#include<vector>
#include<map>
#include<queue>
#include<cstdlib>
using namespace std;
#define lowbit(x) (x&(-x))
typedef long long ll;
const int N = 6e4 + 10;
const int M = 2500000;
int a[N], b[N * 2], tot, n, m, q, numx, numy;
struct query{
	int op, l, r, v;
}p[10005];
struct tree{
	int lc, rc, cnt;
}T[N * 32];
int root[N];
int ux[50], uy[50];
int S[N];
void build(int &rt, int l, int r)
{
	rt = ++tot;
	T[rt].cnt = 0;
	T[rt].lc = T[rt].rc = 0;
	if (l == r) return;
	int mid = (l + r) >> 1;
	build(T[rt].lc, l, mid);
	build(T[rt].rc, mid + 1, r);
}
void update(int &cur, int pre, int l, int r, int pos, int val)
{
	cur = ++tot;
	T[cur] = T[pre];
	T[cur].cnt += val;
	if (l == r) return;
	int mid = (l + r) >> 1;
	if (pos <= mid) update(T[cur].lc, T[pre].lc, l, mid, pos, val);
	else update(T[cur].rc, T[pre].rc, mid + 1, r, pos, val);
}
void add(int pos, int val)
{
	int npos = lower_bound(b + 1, b + m + 1, a[pos]) - b;
	for (; pos <= n; pos += lowbit(pos))
		update(S[pos], S[pos], 1, m, npos, val);
}
int query(int l, int r, int x, int y, int k)
{
	if (l == r) return l;
	int mid = (l + r) >> 1;
	int sum = 0;
	for (int j = 0; j < numx; j++) sum -= T[T[ux[j]].lc].cnt;
	for (int j = 0; j < numy; j++) sum += T[T[uy[j]].lc].cnt;
	sum += T[T[y].lc].cnt - T[T[x].lc].cnt;
	if (k>sum)
	{
		for (int j = 0; j < numx; j++)
			ux[j] = T[ux[j]].rc;
		for (int j = 0; j < numy; j++)
			uy[j] = T[uy[j]].rc;
		return query(mid + 1, r, T[x].rc, T[y].rc, k - sum);
	}
	else
	{
		for (int j = 0; j < numx; j++)
			ux[j] = T[ux[j]].lc;
		for (int j = 0; j < numy; j++)
			uy[j] = T[uy[j]].lc;
		return query(l, mid, T[x].lc, T[y].lc, k );
	}
	
	
}
int main()
{
	int t;
	scanf("%d", &t);
	while (t--)
	{
		tot = 0;
		m = 0;
		scanf("%d%d", &n, &q);
		for (int i = 1; i <= n; i++)
		{
			scanf("%d", &a[i]);
			b[++m] = a[i];
		}
		for (int i = 1; i <= q; i++)
		{
			char op[5];
			scanf("%s", op);
			if (op[0] == 'Q')
			{
				scanf("%d%d%d", &p[i].l, &p[i].r, &p[i].v);
				p[i].op = 1;
			}
			else
			{
				scanf("%d%d", &p[i].l, &p[i].r);
				b[++m] = p[i].r;
				p[i].op = 2;
			}
		}
		sort(b + 1, b + m + 1);
		m = unique(b + 1, b + m + 1) - b - 1;
		build(root[0], 1, m);
		for (int i = 1; i <= n; i++)
		{
			int pos = lower_bound(b + 1, b + m + 1, a[i]) - b;
			update(root[i], root[i - 1], 1, m, pos, 1);
		}
		for (int i = 0; i <= n; i++)
			S[i] = root[0];
		for (int i = 1; i <= q; i++)
		{
			if (p[i].op == 1)
			{
				numx = 0, numy = 0;
				for (int j = p[i].l - 1; j > 0; j -= lowbit(j)) ux[numx++] = S[j];
				for (int j = p[i].r; j > 0; j -= lowbit(j)) uy[numy++] = S[j];
				int ans = query(1, m, root[p[i].l - 1], root[p[i].r], p[i].v);
				printf("%d\n", b[ans]);
			}
			else
			{
				add(p[i].l, -1);
				a[p[i].l] = p[i].r;
				add(p[i].l, 1);
			}
		}

	}

	return 0;
}

方法2(顺便来个非递归查询和更新)

// luogu-judger-enable-o2
// luogu-judger-enable-o2
// luogu-judger-enable-o2
// luogu-judger-enable-o2
// luogu-judger-enable-o2
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<string>
#include<cmath>
#include<string>
#include<string.h>
#include<vector>
#include<map>
#include<queue>
#include<cstdlib>
using namespace std;
#define lowbit(x) (x&(-x))
typedef long long ll;
const int N = 1e5+ 10;
const int M = 2500000;
int a[N], b[N*2], tot, n, m, q,numx,numy;
struct query{
    int op, l, r, v;
}p[100005];
struct tree{
    int lc, rc, cnt;
}T[N<<9];
int root[N];
int ux[50], uy[50];
void build(int &rt,int l,int r)
{
    rt = ++tot;
    T[rt].cnt = 0;
    T[rt].lc = T[rt].rc = 0;
    if (l == r) return;
    int mid = (l + r) >> 1;
    build(T[rt].lc, l, mid);
    build(T[rt].rc, mid + 1, r);
}
void update(int &cur, int pre,int l, int r,int pos,int val)
{
    cur = ++tot;
    T[cur] = T[pre];
    T[cur].cnt += val;
    int tmp = cur;
    while (l < r)
    {
        int mid = (l + r) >> 1;
        if (pos <= mid)
        {
            T[tmp].lc = ++tot;
            T[tmp].rc = T[pre].rc;
            pre = T[pre].lc;
            tmp = T[tmp].lc;
            r = mid;
        }
        else
        {
            T[tmp].rc = ++tot;
            T[tmp].lc = T[pre].lc;
            pre = T[pre].rc;
            tmp = T[tmp].rc;
            l = mid + 1;
        }
        T[tmp].cnt = T[pre].cnt + val;
    }
    return;
}
void add(int pos, int val)
{
    int npos = lower_bound(b + 1, b + m + 1, a[pos]) - b;
    for (; pos <= n; pos += lowbit(pos))
        update(root[pos], root[pos], 1, m, npos, val);
}
int query(int l,int r,int k)
{
    while (l < r)
    {
        int mid = (l + r) >> 1;
        int sum = 0;
        for (int j = 0; j < numx; j++) sum -= T[T[ux[j]].lc].cnt;
        for (int j = 0; j < numy; j++) sum += T[T[uy[j]].lc].cnt;
        if (k>sum)
        {
            k -= sum;
            for (int j = 0; j < numx; j++)
                ux[j] = T[ux[j]].rc;
            for (int j = 0; j < numy; j++)
                uy[j] = T[uy[j]].rc;
            l = mid + 1;
        }
        else
        {
            for (int j = 0; j < numx; j++)
                ux[j] = T[ux[j]].lc;
            for (int j = 0; j < numy; j++)
                uy[j] = T[uy[j]].lc;
            r = mid;
        }
    }

    return l;
}
int main()
{
    tot = 0;
    m = 0;
    scanf("%d%d", &n, &q);
    for (int i = 1; i <= n; i++)
    {
        scanf("%d", &a[i]);
        b[++m] = a[i];
    }
    for (int i = 1; i <= q; i++)
    {
        char op[5];
        scanf("%s", op);
        if (op[0] == 'Q')
        {
            scanf("%d%d%d", &p[i].l, &p[i].r, &p[i].v);
            p[i].op = 1;
        }
        else
        {
            scanf("%d%d", &p[i].l, &p[i].r);
            b[++m] = p[i].r;
            p[i].op = 2;
        }
    }
    sort(b + 1, b + m + 1);
    m = unique(b + 1, b + m + 1) - b - 1;
    for (int i = 1; i <= n; i++)
        add(i, 1);
    for (int i = 1; i <= q; i++)
    {
        if (p[i].op == 1)
        {
            numx = 0, numy = 0;
            for (int j = p[i].l-1; j > 0; j -= lowbit(j)) ux[numx++] = root[j];
            for (int j = p[i].r; j > 0; j -= lowbit(j)) uy[numy++] = root[j];
            int ans = query(1, m, p[i].v);
            printf("%d\n", b[ans]);
        }
        else
        {
            add(p[i].l, -1);
            a[p[i].l] = p[i].r;
            add(p[i].l, 1);
        }
    }

    return 0;
}