线段树的基本操作

  1. pushup() 通过子节点更新父节点
  2. pushdown()
  3. build()将区间初始化为线段树
  4. modify() 单点修改 & 区间修改
  5. query()

线段树的存储结构

  1. 使用类似完全二叉树逻辑结构,故我们可以采取一维数组的存储结构存储线段树

  2. 编号(类比堆结构)

    父节点:x/2(整除) 或者 x>>1

    左子节点:2*x 或者 x<<1

    右子节点:2*x+1 或者 x<<1|1

  3. 存储线段树总节点个数(需要开的空间大小):4*x


我们以一个线段树的具体应用来讲解一下集体操作的实现:

最大数

题面描述:给定一个正整数数列 a[1]~a[n],每一个数都在 0∼p−1 之间。

可以对这列数进行两种操作:

添加操作:向序列后添加一个数,序列长度变成 n+1;
询问操作:询问这个序列中最后 L 个数中最大的数是多少。
程序运行的最开始,整数序列为空。

一共要对整数序列进行 m 次操作。

写一个程序,读入操作的序列,并输出询问操作的答案。

原题链接:AcWing 1275. 最大数

线段树的节点结构

1
2
3
4
5
//一般情况将线段树的节点个数声明为需要维护的总区间的大小的4倍
struct Node{
int l, r; //维护区间[l, r]
int v;//区间[l, r]中的最大值
}tr[4*N];

初始化建树

1
2
3
4
5
6
7
8
9
10
11
12
//u是待构造的树节点编号,而其需要维护的就是区间[l, r]
//显式地指明初始化的区间值的大小,是因为本题的初始化值就是0
void build(int u, int l, int r){
//先直接将树节点初始化 等价于 tr[u] = {l, r, 0};
tr[u] = {l, r};
//l==r 表明这里的节点是叶节点
//不需要继续递归处理,直接返回
if(l == r) return;
//依次初始化左右节点
int mid = l+r>>1;
build(u<<1, l, mid), build(u<<1|1, mid+1, r);
}

查询操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
//从树的节点u开始查找[l, r]区间内的最大值
int query(int u, int l, int r){
//树中节点已经完全包含在[l, r]中了
//也就意味着这里的树节点u维护的区间再递归下去已经没意义了
if(tr[u].l >= l && tr[u].r <= r) return tr[u].v;

//继续递归左右子节点维护的区间
int mid = tr[u].l + tr[u].r >> 1;
int v = 0;
if(l <= mid) v = query(u<<1, l, r);
if(r > mid) v = max(v, query(u<<1|1, l, r));

return v;
}

push_up操作

1
2
3
4
5
//由子节点的信息来计算父节点的信息
//核心操作
void pushup(int u){
tr[u].v = max(tr[u<<1].v, tr[u<<1|1].v);
}

修改操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
//从树的节点u开始递归地查找到区间[x, x]代表的节点
//并将其节点所代表的值修改为 v 而后在一步步地回溯中将父节点的值更新
void modify(int u, int x, int v){
//如果已经查找到了这个区间那么就将这个节点所代表的值直接置为v
//否则继续递归地查找其左右子节点
if(tr[u].l == x && tr[u].r == x) tr[u].v = v;
else{
//左右子节点的区间分界为mid
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid) modify(u<<1, x , v);
else modify(u<<1|1, x, v);
//已经更改完子树的值后还要更新自己的值
pushup(u);
}
}

由于这里的区间修改的值还和最近一次查找的结果有关所以是个非常典型的动态过程

完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;

const int N = 200010;
int m, p;
struct Node{
int l, r;
int v;//区间[l, r]中的最大值
}tr[4*N];

//由子节点的信息来计算父节点的信息
void pushup(int u){
tr[u].v = max(tr[u<<1].v, tr[u<<1|1].v);
}

void build(int u, int l, int r){
tr[u] = {l, r};
if(l == r) return;
int mid = l+r>>1;
build(u<<1, l, mid), build(u<<1|1, mid+1, r);
}

int query(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r) return tr[u].v; //树中节点已经完全包含在[l, r]中了

int mid = tr[u].l + tr[u].r >> 1;
int v = 0;
if(l <= mid) v = query(u<<1, l, r);
if(r > mid) v = max(v, query(u<<1|1, l, r));

return v;
}

void modify(int u, int x, int v){
if(tr[u].l == x && tr[u].r == x) tr[u].v = v;
else{
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid) modify(u<<1, x , v);
else modify(u<<1|1, x, v);
pushup(u);
}
}

int main(){
int n = 0, last = 0;
scanf("%d%d",&m, &p);
build(1, 1, m);

int x;
char op[2];

while(m--){
scanf("%s%d", op, &x);
if(op[0] == 'Q'){
last = query(1, n-x+1, n);
printf("%d\n", last);
}
else{
modify(1, n+1, ((long long)last+x)%p);
n++;
}
}

return 0;
}

你能回答这些问题吗

题面描述:给定长度为 N 的数列 A,以及 M 条指令,每条指令可能是以下两种之一:

  1. 1 x y,查询区间 [x,y] 中的最大连续子段和,即 maxxlryi=lrA[i]max_{x≤l≤r≤y}{\sum_{i=l}^rA[i]}
  2. 2 x y,把 A[x] 改成 y。

对于每个查询指令,输出一个整数表示答案。

原题链接:AcWing 245. 你能回答这些问题吗

解题思路:

  1. 按照常规思路先构造节点的结构,由题目所给的条件先得出最朴素的节点结构

    1
    2
    3
    4
    struct Node{
    int l, r; //表示当前节点维护的区间左右端点
    int max; //表示问题指向的最大的连续子段和
    }
  2. 然后再思考我们构造出来的节点是否完备,即父节点的值是否能通过子节点的值完全推导出来,我们这里构造的最原始的结构是无法从子节点得到父节点的所有值的,两个子节点分别的连续最大子段和并不能推导得到福节点所代表的最大连续子段和(父节点的连续子段可能同时从两个子节点代表的区间中去取值)

  3. 我们现在可以思考父节点的最大连续子段和能如何充分利用当前的信息得到,而又有哪些信息是欠缺的:父节点的连续子段和包括了两个左右子节点的最大连续子段和还有一种情况就是在左节点取一半值也在右节点取一半的值也就是左节点的最大后缀和加上右节点的最大前缀和,这三种情况取最大值就是父节点的最大连续子段和的大小了

  4. 我们为求出父节点的最大连续子段和增加了两个节点信息(实际上是一种信息)即当前节点的最大前缀和、最大后缀和,也就是我们又需要维护这样的两种信息,同样的我们思考需要如何完备地通过子节点求出父节点维护的信息

    1
    2
    3
    4
    5
    6
    struct Node{
    int l, r; //表示当前节点维护的区间左右端点
    int tmax; //表示问题指向的最大的连续子段和
    int lmax; //表示当前节点代表的区间的最大前缀和
    int rmax; //表示当前节点代表的区间的最大后缀和
    }
  5. 父节点的最大前缀和就是左子节点这一边的区间的代表的最大前缀和加上右子节点这半边的区间代表的最大前缀和取最大值,而我们需要注意的是左半边的区间的最大前缀和就是左子节点的lmax,而右半边的区间的最大前缀和是左半边的区间和加上右半边的最大前缀和,也就是这里我们还需要维护一个区间和的信息(感觉上像是进入了循环😂)后缀和的信息和前缀和的维护几乎一致就不必细讲

  6. 为了维护前缀和、后缀和信息我们又需要引入区间和这一信息取维护

    1
    2
    3
    4
    5
    6
    7
    struct Node{
    int l, r; //表示当前节点维护的区间左右端点
    int tmax; //表示问题指向的最大的连续子段和
    int lmax; //表示当前节点代表的区间的最大前缀和
    int rmax; //表示当前节点代表的区间的最大后缀和
    int sum; //表示当前节点所代表的区间的区间和
    }
  7. 而父节点的区间和可以直接通过两个子节点的区间和相加得到,所以到这里我们就可以通过已知的子节点的信息完备地推出父节点的信息,即子节点的最后形态已经确定了下来,并且经过我们上面的分析也清楚了最关键的操作pushup(),后面的操作套用模板即可

完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;

const int N = 500010;

int n, m;
int w[N];
struct Node{
int l, r;
int sum, lmax, rmax, tmax;
}tr[N*4];

void pushup(Node &u, Node &l, Node &r){
u.sum = l.sum+r.sum;
u.lmax = max(l.lmax, l.sum + r.lmax);
u.rmax = max(r.rmax, r.sum + l.rmax);
u.tmax = max(max(l.tmax, r.tmax), l.rmax+r.lmax);
}

void pushup(int u){
pushup(tr[u], tr[u<<1], tr[u<<1 | 1]);
}

void build(int u, int l, int r){
if(l == r){
tr[u] = {l, r, w[r], w[r], w[r], w[r]};
}
else{
tr[u] = {l, r};
int mid = l+r >> 1;
build(u<<1, l, mid), build(u<<1|1, mid+1, r);
pushup(u);
}
}

void modlfy(int u, int x, int v){
if(tr[u].l == x && tr[u].r == x) tr[u] = {x, x, v, v, v, v};
else{
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid) modlfy(u<<1, x, v);
else modlfy(u<<1|1, x, v);
pushup(u);
}
}

Node query(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r) return tr[u];
else{
int mid = tr[u].l + tr[u].r >> 1;
if(r <= mid) return query(u<<1, l, r);
else if(l > mid) return query(u<<1|1, l, r);
else{
auto left = query(u<<1, l, r);
auto right = query(u<<1|1, l, r);
Node ans;
pushup(ans, left, right);
return ans;
}
}
}

int main(){
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++) scanf("%d", &w[i]);
build(1, 1, n);

int k, x, y;
while(m--){
scanf("%d%d%d", &k, &x, &y);
if(k == 1){
if(x > y) swap(x, y);
printf("%d\n", query(1, x, y).tmax);
}
else modlfy(1, x, y);
}
return 0;
}

区间最大公约数

题面描述:给定一个长度为 N 的数列 A,以及 M 条指令,每条指令可能是以下两种之一:

  1. C l r d,表示把 A[l],A[l+1],,A[r]A[l],A[l+1],…,A[r] 都加上 dd
  2. Q l r,表示询问 A[l],A[l+1],,A[r]A[l],A[l+1],…,A[r] 的最大公约数(GCDGCD)。

对于每个询问,输出一个整数表示答案。

原题链接:AcWing 246. 区间最大公约数

解题思路:

  1. 照例先构造需要维护的节点结构,有题目给出最朴素的节点结构

    1
    2
    3
    4
    struct Node{
    int l, r; //表示当前节点维护的区间左右端点
    int d; //区间[l, r]的最大公约数
    }
  2. 如果仅仅只查询而仅考虑单点修改,实际上维护这些数据就已经完备了并且pushup()操作中父节点维护的区间的最大公约数就是两个子区间的维护的最大公约数的最大公约数

    但是当前的需要解决的更新操作是区间操作并且是区间内加上同一个数,首先分析 gcd(a,b,c)gcd(a, b, c)gcd(a+t,b+t,c+t)gcd(a+t, b+t, c+t) 之间是否有任何联系,很遗憾的是并不存在明显的联系

  3. 下面考虑欧几里得定理: gcd(a,b)=gcd(a,bna)gcd(a, b) = gcd(a, b-n*a) ,由其原理可以推论的得到 gcd(a1, a2, ... , an)=gcd(a1, a2a1, ... ,anan1)gcd(a_1,\ a_2,\ ...\ ,\ a_n) = gcd(a_1,\ a_2-a_1,\ ...\ , a_n-a_{n-1}) 而看等式右边我们可以发现,这其实是 a1, a2, ... ,ana_1,\ a_2,\ ...\ , a_n​ 序列的差分数组,对应我们当前的问题中的更新操作 “将区间内的每个数都加上d” 就是差分数组的典型操作 “区间加减”

    这样通过维护差分数组来得到最大公约数的做法,在维护两个单点修改( al+da_l+d​ 和 ar+1da_{r+1}-d​ )的情况下处理区间修改的问题了

  4. 小结一下前面所述:通过欧几里得定理我们得到推论“原序列的最大公因数等价于其差分序列的最大公因数” 我们通过维护差分序列将原序列的区间修改转变为差分序列的单点修改实现从而降低复杂度

  5. 通过上面的结论我们可以发现(这里将原序列记做 aia_i 将差分序列记做 wiw_i ):

    1. 对于指令C l r d而言,就是对差分数组进行 wl+dw_l+dwr+1dw_{r+1}-d 两步单点修改操作即可
    2. 对于指令Q l r而言,就是求出 gcd(al, wl+1, ... , wr)gcd(a_l,\ w_{l+1},\ ...\ ,\ w_r) 需要注意的是这里第一个数是 ala_l ,这样组成的序列任然是一个以 ala_l 为起始原始的差分序列,但是将 ala_l 转换为 wlw_lalal1a_l-a_{l-1} 就会发现:gcd(al, al+1, ... , ar)gcd(wl, wl+1, ... ,wr)gcd(a_l,\ a_{l+1},\ ...\ ,\ a_r) \neq gcd(w_l,\ w_{l+1},\ ...\ , w_r) 然而一般的差分数组要求单点值的复杂度是 OnO_{n} 的也就是 ai=w1+w2+ ... +wi=wia_i = w_1+w_2+\ ...\ +w_i = \sum w_i ,为了使得求 aia_i 的复杂度降低我们很容易想在线段树的节点上维护 sum(l, r)sum(l,\ r) 的信息这样就能实现快速求 aia_i
  6. 需要维护的完整节点结构

    1
    2
    3
    4
    5
    struct Node{
    int l, r; //表示当前节点维护的区间左右端点
    LL sum; //区间[l, r]的差分序列的和
    LL d; //区间[l, r]的最大公约数
    } tr[N*4];
  7. 初始建树

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    void build(int u, int l, int r){
    tr[u].l = l;
    tr[u].r = r;

    //叶节点时,最大公约数和区间总和就是当前w[l]
    if(l == r){
    tr[u].d = tr[u].sum = w[l];
    }

    //否则需要先求得子节点再pushup()
    else{
    int mid = l+r >> 1;
    build(u<<1, l, mid);
    build(u<<1|1, mid+1, r);
    pushup(u);
    }
    }
  8. push_up操作

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    //由于维护的值实际上很简单,故这里的pushup实际上很简单
    LL gcd(LL a, LL b){
    return b ? gcd(b, a%b) : a;
    }
    void pushup(Node &u, Node &l, Node &r){
    u.sum = l.sum+r.sum;
    u.d = gcd(l.d, r.d);
    }

    void pushup(int u){
    pushup(tr[u], tr[u<<1], tr[u<<1|1]);
    }
  9. 查询操作

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    Node query(int u, int l, int r){
    //避免段错误
    if(l > r) return {0};
    //查询区间包含了节点维护的区间
    if(tr[u].l >= l && tr[u].r <= r) return tr[u];
    else{
    int mid = tr[u].l + tr[u].r >> 1;
    //全在左子节点表示的区间中
    if(r <= mid) return query(u<<1, l, r);
    //全在右子节点表示的区间中
    else if(l > mid) return query(u<<1|1, l, r);
    //在左右节点区间各占一半
    else{
    auto left = query(u<<1, l, r);
    auto right = query(u<<1|1, l, r);
    Node ans;
    //结果等于左右节点的整合
    pushup(ans, left, right);
    return ans;
    }
    }
    }
  10. 修改操作

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    void modify(int u, int x, LL v){
    //找到了代表当前位点x的叶节点u
    if(tr[u].l == x && tr[u].r == x){
    LL b = tr[u].sum + v;
    tr[u] = {x, x, b, b};
    }
    //否则需要修改后更新父节点
    else{
    int mid = tr[u].l + tr[u].r >> 1;
    if(x <= mid) modify(u<<1, x, v);
    else modify(u<<1|1, x, v);
    pushup(u);
    }
    }

完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;

typedef long long LL;

const int N = 500010;

int n, m;
LL w[N];

struct Node{
int l, r;
LL sum, d;
} tr[N*4];

LL gcd(LL a, LL b){
return b ? gcd(b, a%b) : a;
}

void pushup(Node &u, Node &l, Node &r){
u.sum = l.sum+r.sum;
u.d = gcd(l.d, r.d);
}

void pushup(int u){
pushup(tr[u], tr[u<<1], tr[u<<1|1]);
}

void build(int u, int l, int r){
tr[u].l = l;
tr[u].r = r;
if(l == r){
tr[u].d = tr[u].sum = w[l];
}
else{
int mid = l+r >> 1;
build(u<<1, l, mid);
build(u<<1|1, mid+1, r);
pushup(u);
}
}

void modify(int u, int x, LL v){
if(tr[u].l == x && tr[u].r == x){
LL b = tr[u].sum + v;
tr[u] = {x, x, b, b};
}
else{
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid) modify(u<<1, x, v);
else modify(u<<1|1, x, v);
pushup(u);
}
}

Node query(int u, int l, int r){
if(l > r) return {0};
if(tr[u].l >= l && tr[u].r <= r) return tr[u];
else{
int mid = tr[u].l + tr[u].r >> 1;
if(r <= mid) return query(u<<1, l, r);
else if(l > mid) return query(u<<1|1, l, r);
else{
auto left = query(u<<1, l, r);
auto right = query(u<<1|1, l, r);
Node ans;
pushup(ans, left, right);
return ans;
}
}
}

int main(){
scanf("%d%d",&n, &m);
for(int i = 1; i <= n; i++){
LL a;
scanf("%lld", &a);
w[i] += a;
w[i+1] -= a;
}
build(1, 1, n);

int l, r;
LL d;
char op[2];

while(m--){
scanf("%s%d%d", op, &l, &r);
if(op[0] == 'Q'){
auto left = query(1, 1, l);
auto right = query(1, l+1, r);
printf("%lld\n", abs(gcd(left.sum, right.d)));
}
else{
scanf("%lld", &d);
modify(1, l, d);
if(r+1 <= n) modify(1, r+1, -d);
}
}
}