「数据结构」平衡树 - 替罪羊树

#主体思想

与 Treap 采用随机化、Splay 采用统计学玄学原理不同,替罪羊的主题思想就是直接将不平衡的子树树暴力重建为一棵尽可能平衡的子树。这样,我们就需要一个平衡因子 \(\alpha\in[0.5,1]\) 来判定这颗子树是否平衡,设 \(s_x\) 表示以 \(x\) 为根的子树大小,定义以 \(x\) 为根的子树平衡当且仅当该子树的两个儿子子树的大小都不超过 \(\alpha \cdot s_x\),当一个子树不平衡时就直接将它重建,可以证明(见下文 #复杂度证明),各种插入删除等影响平衡性的操作的复杂度都是 \(O(\log n)\).

#基础实现

#结构基础

先把结构基础及基础操作摆在这里,不多讲。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
struct Node {int w, cnt, ls, rs, siz_single, siz_all, siz_without_del;};
/*w 为该点权值, cnt 为重复权值的个数, ls, rs 分别是左右儿子编号*/
/*siz_single 为该子树内所有节点个数(每个节点只计算一次)*/
/*siz_all 为子树内的元素的个数,包含重复元素的次数*/
/*siz_without_del 为子树内不包含已经删除的点的个数*/
struct ScapeGoat {
Node p[N]; int rt, cnt, rub[N], rcnt, q[N], qcnt;

inline int new_ind() {return rcnt ? rub[rcnt --] : ++ cnt;}
inline void del_node(int k) {rub[++ rcnt] = k;}

inline int new_node(int w) {
int k = new_ind(); p[k].w = w, p[k].ls = p[k].rs = 0;
p[k].cnt = p[k].siz_single = 1;
p[k].siz_all = p[k].siz_without_del = 1; return k;
}

inline void pushup(int k) {
int ls = p[k].ls, rs = p[k].rs;
p[k].siz_single = p[ls].siz_single + p[rs].siz_single + 1;
p[k].siz_all = p[ls].siz_all + p[rs].siz_all + p[k].cnt;
p[k].siz_without_del = p[ls].siz_without_del + p[rs].siz_without_del + (p[k].cnt ? 1 : 0);
}
/*...something others...*/
};

注意到我们用 siz_without_del 记录子树内已不包含经被删除的节点的个数,这是因为替罪羊树采用惰性删除,也就是在删除时只减对应节点的个数。不难发现,如果一个子树内的已删除节点占比过高,会严重影响操作效率,所以如果我们发现一个子树内未被删除的节点占比达不到 \(\alpha\),那么我们就需要考虑重构。

#重建

我们首先来考虑一个子树什么时候需要重建:

  • 刚经历过可能影响平衡性的操作;

  • 满足上文中的两个条件(不平衡或空节点过多)任意一个;

考虑到各种影响平衡性的操作一定是递归到某个点,于是我们可以直接在回溯时调用以下函数判断是否需要重构:

1
2
3
inline bool check(int k) {return p[k].w && (alpha * p[k].siz_single 
<= 1.0 * Max(p[p[k].ls].siz_single, p[p[k].rs].siz_single)
|| 1.0 * p[k].siz_without_del <= p[k].siz_single * alpha);}

然后我们来思考这样一个问题:如何 \(O(x)\) 地重建一棵树(\(x\) 为子树大小)?怎样建树最优?

首先,显然二分地建树得到的树是最平衡的,即将原本的树转化为中序遍历,这样不会破坏平衡树的顺序性,然后每次选择中点作为当前区间的根,然后两侧递归,正确性显然;至于时间复杂度,注意到递归树上最多有 \(\log n\) 层,整体形态与线段树接近,显然相同深度时一棵满二叉树的节点最多,此时节点个数为

\[ \sum_{i=0}^{\log n}2^i=2\cdot n-1, \]

于是总的建树复杂度为 \(O(子树大小)\),具体实现分两部分:展开和重建。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
void unfold(int k) {
if (!k) return;
unfold(p[k].ls);
if (p[k].cnt) q[++ qcnt] = k;
else del_node(k);
unfold(p[k].rs);
}

int build(int l, int r) {
if (l > r) return 0; int mid = l + r >> 1;
p[q[mid]].ls = build(l, mid - 1);
p[q[mid]].rs = build(mid + 1, r);
pushup(q[mid]); return q[mid];
}

void rebuild(int &k) {qcnt = 0; unfold(k); k = build(1, qcnt);}

#常用操作

其实也没什么好说的,还是插入、删除、前趋、后继那些东西,只简单地提两点:

  • 插入删除回溯时记得检查重建;

  • 新增几个操作:

    • 严格大于 \(x\) 的最小的数的最小排名;

    • 严格小于 \(x\) 的最大的数的最大排名;以上两个操作

    以上两个操作都需要注意节点为空时的特殊贡献;

  • 查找排名为 \(k\) 的值时需要注意一个节点上相同权值的重复次数;

  • 前趋、后继可用上面的操作组合得到;

上面的所有代码实现见 #Code

#复杂度证明

这里我们取 \(\alpha=0.75\),使用时可以根据组题情况进行调节。

#插入

假设当前已经有了一个大小为 \(x\) 的已平衡的子树,那么为了让其尽快达到不平衡状态,我们一定是一直向一边插入节点,不妨设插入 \(k\) 次后,当前这颗树不再平衡,那么应当有

\[ \left\lfloor\dfrac x 2\right\rfloor+k\geq\alpha\cdot(x + k), \]

可以解得此时有

\[ k\geq\dfrac{\alpha-\frac 1 2}{1-\alpha}\cdot x, \]

于是应当是每插入 \(\frac{\alpha-\frac 1 2}{1-\alpha}\cdot x\) 次后,进行一次重构,重构一次的时间复杂度为 \(\Theta(x)\),我们将这一次重构的时间复杂度均摊到导致这次重构的所有插入操作中,时间复杂度为

\[ \dfrac{\Theta(x)}{\frac{\alpha-\frac 1 2}{1-\alpha}\cdot x}=\Omega(1),(\alpha=0.75) \]

于是我们可以将重构操作均摊为 \(\Omega(1)\) 的时间复杂度,由于树平衡,每次插入时查询的时间复杂度为 \(O(\log n)\),于是插入操作的总体时间复杂度为 \(O(\log n)\).

#删除

与证明插入时间复杂度同样的思路,设删除 \(k\) 个数据,每次恰使一个节点变为空节点后,当前这颗大小为 \(x\) 的平衡树中的空节点数量过大,也就是有

\[ x-k\leq\alpha\cdot x, \]

于是可以解得

\[ k\geq(1-\alpha)\cdot x, \]

于是每 \((1-\alpha)\cdot x\) 次删除(每次删除恰好使一个节点为空),那么均摊的复杂度为

\[ \dfrac{\Theta(x)}{(1-\alpha)\cdot x}=\Omega(1),(\alpha=0.75) \]

于是总的删除复杂度为 \(O(\log n)\),常数略大。

看起来 \(\alpha\) 越小,删除时的重建效率就越高,那为什么不能单独给删除定一个 \(\beta\),让这个 \(\beta\) 尽可能小呢?注意到,单独定一个 \(\beta\) 是可以的,但是我们删除重建的目的是去掉冗余的节点,让每次访问时的复杂度降低,如果直接把 \(\beta\) 定为 \(0\),那么冗余节点会一直存在,影响查找的复杂度。当然可以适当的将 \(\beta\) 调小到 \(0.5\) 左右。

#更多操作

由于替罪羊树维护平衡并没有用到太多特殊的性质,就是暴力重建,所以很多其他操作都可以用熟悉的套路实现,时间复杂度也多为 \(O(\log n)\).

#Code

题目为 Luogu3369 【模板】普通平衡树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
const int N = 500010;
const double alpha = 0.75;
const double del_alpha = 0.5;
const int INF = 0x3fffffff;

template <typename T> inline void read(T &x) {
x = 0; int f = 1; char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
x *= f;
}

template <typename T> inline T Max(T x, T y) {return x > y ? x : y;}

struct Node {int w, cnt, ls, rs, siz_single, siz_all, siz_without_del;};
struct ScapeGoat {
Node p[N]; int rt, cnt, rub[N], rcnt, q[N], qcnt;

inline int new_ind() {return rcnt ? rub[rcnt --] : ++ cnt;}
inline void del_node(int k) {rub[++ rcnt] = k;}

inline int new_node(int w) {
int k = new_ind(); p[k].w = w, p[k].ls = p[k].rs = 0;
p[k].cnt = p[k].siz_single = 1;
p[k].siz_all = p[k].siz_without_del = 1; return k;
}

inline void pushup(int k) {
int ls = p[k].ls, rs = p[k].rs;
p[k].siz_single = p[ls].siz_single + p[rs].siz_single + 1;
p[k].siz_all = p[ls].siz_all + p[rs].siz_all + p[k].cnt;
p[k].siz_without_del = p[ls].siz_without_del + p[rs].siz_without_del + (p[k].cnt ? 1 : 0);
}

inline bool check(int k) {return p[k].w && (alpha * p[k].siz_single
<= 1.0 * Max(p[p[k].ls].siz_single, p[p[k].rs].siz_single)
|| 1.0 * p[k].siz_without_del <= p[k].siz_single * alpha);}

void unfold(int k) {
if (!k) return;
unfold(p[k].ls);
if (p[k].cnt) q[++ qcnt] = k;
else del_node(k);
unfold(p[k].rs);
}

int build(int l, int r) {
if (l > r) return 0; int mid = l + r >> 1;
p[q[mid]].ls = build(l, mid - 1);
p[q[mid]].rs = build(mid + 1, r);
pushup(q[mid]); return q[mid];
}

void rebuild(int &k) {qcnt = 0; unfold(k); k = build(1, qcnt);}

void insert(int &k, int x) {
if (!k) {k = new_node(x); return;}
if (p[k].w == x) ++ p[k].cnt;
else if (p[k].w < x) insert(p[k].rs, x);
else insert(p[k].ls, x);
pushup(k); if (check(k)) rebuild(k);
}

void del(int &k, int x) {
if (!k) return;
if (p[k].w == x && p[k].cnt) -- p[k].cnt;
else if (p[k].w < x) del(p[k].rs, x);
else if (p[k].w > x) del(p[k].ls, x);
pushup(k); if (check(k)) rebuild(k);
}

int upper_grade(int k, int x) {
if (!k) return 1;
if (p[k].w == x && p[k].cnt) return p[p[k].ls].siz_all + p[k].cnt + 1;
if (p[k].w > x) return upper_grade(p[k].ls, x);
else return p[p[k].ls].siz_all + p[k].cnt + upper_grade(p[k].rs, x);
}

int lower_grade(int k, int x) {
if (!k) return 0;
if (p[k].w == x && p[k].cnt) return p[p[k].ls].siz_all;
if (p[k].w > x) return lower_grade(p[k].ls, x);
else return p[p[k].ls].siz_all + p[k].cnt + lower_grade(p[k].rs, x);
}

int kth_value(int k, int x) {
if (!k) return 0;
if (p[p[k].ls].siz_all < x && x <= p[p[k].ls].siz_all + p[k].cnt) return p[k].w;
if (p[p[k].ls].siz_all >= x) return kth_value(p[k].ls, x);
else return kth_value(p[k].rs, x - p[k].cnt - p[p[k].ls].siz_all);
}

inline int pre_value(int x) {return kth_value(rt, lower_grade(rt, x));}
inline int nxt_value(int x) {return kth_value(rt, upper_grade(rt, x));}
inline int get_grade(int x) {return lower_grade(rt, x) + 1;}
} t;

int n;

int main() {
read(n);
while (n --) {
int opt = 0, x = 0; read(opt), read(x);
if (opt == 1) t.insert(t.rt, x);
else if (opt == 2) t.del(t.rt, x);
else if (opt == 3) printf("%d\n", t.get_grade(x));
else if (opt == 4) printf("%d\n", t.kth_value(t.rt, x));
else if (opt == 5) printf("%d\n", t.pre_value(x));
else printf("%d\n", t.nxt_value(x));
}
return 0;
}

参考文章