0%

主席树

主席树,全名可持久化线段树,又名hjt tree,是一种保留了多个历史版本的权值线段树。插入节点的时候去和上一个版本的权值线段树做连接,计算前缀。询问的时候带着两个版本一起跳儿子,统计前缀的差来得到区间信息。

下面是几个运用

区间第k大元素

题目跳转

值域到了int,先做离散化。

这里先算出左儿子区间上的点的个数 x,如果排名 k 小于 x ,那么就往做左儿子跳,否则往右儿子跳找排名为 k-x 的节点,和平衡树的查找很像。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <bits/stdc++.h>

using namespace std;

int read(){
int res=0,sign=1;char ch=getchar();
for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')sign=-sign;
for(;ch>='0'&&ch<='9';ch=getchar())res=(res<<3)+(res<<1)+(ch^'0');
return res*sign;
}

const int N=2e5+10;

int n,m;
int tot,sum[N<<5],rt[N],ls[N<<5],rs[N<<5];

int a[N],idx[N],len;

#define MID int m=s+((t-s)>>1)

int build(int s,int t){
int p=++tot;
if(s==t)return p;
MID;
ls[p]=build(s,m);
rs[p]=build(m+1,t);
return p;
}

int ins(int s,int t,int x,int _p){
int p=++tot;
ls[p]=ls[_p],rs[p]=rs[_p],sum[p]=sum[_p]+1;
if(s==t)return p;
MID;
if(x<=m)ls[p]=ins(s,m,x,ls[p]);
else rs[p]=ins(m+1,t,x,rs[p]);
return p;
}

int query(int s,int t,int u,int v,int k){
int x=sum[ls[v]]-sum[ls[u]];
if(s==t)return s;
MID;
if(k<=x)return query(s,m,ls[u],ls[v],k);
else return query(m+1,t,rs[u],rs[v],k-x);
}

int getid(int x){
return lower_bound(idx+1,idx+1+len,x)-idx;
}

void init(){
for(int i=1;i<=n;++i)idx[i]=a[i];
sort(idx+1,idx+1+n);
len=unique(idx+1,idx+1+n)-idx-1;
rt[0]=build(1,len);
for(int i=1;i<=n;++i){
rt[i]=ins(1,len,getid(a[i]),rt[i-1]);
}
}

void solve(){
for(int i=1;i<=m;++i){
int l=read(),r=read(),k=read();
int ans=idx[query(1,len,rt[l-1],rt[r],k)];
printf("%d\n",ans);
}
}

int main(){
n=read(),m=read();
for(int i=1;i<=n;++i){
a[i]=read();
}
init();
solve();
return 0;
}

区间[l, r]中属于某一个值域[L, R]中的点的个数

题目跳转

一个序列 \(a_i\,(i=1,2,...,n)\),在 \(a_l,a_{l+1},...,a_r\) 中有多少个 \(a_i\) 满足 \(L <= a_i <= R\)

带着两个版本的树跳,然后统计区间中满足的点的个数就好了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <bits/stdc++.h>

using namespace std;

const int N=1e5+10;

int read(){
int res=0,sign=1;
char ch=getchar();
for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-'){sign=-sign;}
for(;ch>='0'&&ch<='9';ch=getchar()){res=(res<<3)+(res<<1)+(ch^'0');}
return res*sign;
}

int n,q;

vector<int> G[N];
int ver[N];
int tin[N],tout[N],timer;

void dfs(int u,int fno){
tin[u]=++timer;
for(auto&& v:G[u])if(v!=fno){
dfs(v,u);
}
tout[u]=timer;
}

int ls[N<<5],rs[N<<5],rt[N<<5],sum[N<<5],tot;

#define MID int m=s+((t-s)>>1)

int ins(int s,int t,int x,int _p){
int p=++tot;
ls[p]=ls[_p],rs[p]=rs[_p],sum[p]=sum[_p]+1;
if(s==t)return p;
MID;
if(x<=m)ls[p]=ins(s,m,x,ls[p]);
else rs[p]=ins(m+1,t,x,rs[p]);
return p;
}

int query(int s,int t,int u,int v,int l,int r){
if(l<=s&&t<=r)return sum[v]-sum[u];
MID;
int res=0;
if(l<=m)res+=query(s,m,ls[u],ls[v],l,r);
if(r>m)res+=query(m+1,t,rs[u],rs[v],l,r);
return res;
}

void solve(){
n=read(),q=read();
for(int i=1;i<=n;++i)G[i].clear();
tot=0,timer=0;
for(int i=1,u,v;i<=n-1;++i){
u=read(),v=read();
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1,0);
for(int i=1;i<=n;++i)ver[i]=read(),rt[i]=ins(1,n,tin[ver[i]],rt[i-1]);
for(int _=1;_<=q;++_){
int l=read(),r=read(),x=read();
int L=tin[x],R=tout[x];
int res=query(1,n,rt[l-1],rt[r],L,R);
if(res>=1)puts("YES");
else puts("NO");
}
puts("");
}

int main(){
int t=read();
while(t--)solve();
return 0;
}