数组分块简述

在某些情况下,我们可能需要暴力地处理数组的每一个数,数组分块提供了一种简单的优化来降低复杂度,使得原本的$O(n)$降为$O(\sqrt n)$,这可能不是最优解,但一定足够简单(事实上,在1e5范围内,$O(\log^2 n)$和$O(\sqrt n)$差距不大)。 这类问题的“整体”性质更弱,有时不存在高效的信息合并化简方法,需要在“批量”和“零散”之间找到平衡点。

  • 基础形式

我们将一个数组分为形如下图的几段,每段大小为B,当然为了规范,如果$B \nmid n$,我们要在数组后面填充$inf$(在后面的一些处理中会用到这个)。对每次操作只需要对左右两个散块单独处理和对中间一些整块处理。

理解了分块的概念以后,你就可以开始着手尝试用分块去解决一些线段树问题了(虽然此时分块对比线段树不占优)。

例题

模板题

题目大意:区间上所有元素加一个数$k$;求区间所有元素的和。 对于这个问题,简单地开三个数组$a,s,v$ 分别表示:数组值;分块中除去标记的数组值和;标记:分块中的每个元素 + k。

参考代码: 每个块的左右边界可以用结构体记录下来,也可以直接计算,这里采用的是直接计算的方式。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
const int INF = 1e9 + 7, MAXN = 1e5 + 10, mod = 998244353;
int BS;
ll a[MAXN], s[405], v[405];
void add(int l, int r, int k) {
    int bl = l / BS, br = r / BS;
    if(bl == br) {
        for(int i = l; i <= r; i ++) {
            a[i] += k, s[bl] += k;
        }
    } else {
        for(int i = l; i < (bl + 1) * BS; i ++) {
            a[i] += k, s[bl] += k;
        }
        for(int i = bl + 1; i < br; i ++) {
            v[i] += k, s[i] += 1ll * BS * k;
        }
        for(int i = br * BS; i <= r; i ++) {
            a[i] += k, s[br] += k;
        }
    }
}
ll query(int l, int r) {
    ll res = 0;
    int bl = l / BS, br = r / BS;
    if(bl == br) {
        for(int i = l; i <= r; i ++) {
            res += a[i] + v[bl];
        }
    } else {
        for(int i = l; i < (bl + 1) * BS; i ++) {
            res += a[i] + v[bl];
        }
        for(int i = bl + 1; i < br; i ++) {
            res += s[i];
        }
        for(int i = br * BS; i <= r; i ++) {
            res += a[i] + v[br];
        }
    }
    return res;
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    int n, m;
    cin >> n >> m;
    BS = sqrt(n) + 1;
    for(int i = 0; i < n; i ++) {
        cin >> a[i];
        s[i / BS]  += a[i];
    }
    while(m -- ) {
        int op, l, r, k;
        cin >> op  >> l >> r;
        l --, r --;
        if(op == 1) {
            cin >> k;
            add(l, r, k);
        } else {
            cout << query(l, r) << '\n';
        }
    }

}

分块的大小可以人为确定,需要计算时间复杂度后再确定。此题中理论时间复杂度为$O(q (B + n / B))$, 另$B = \sqrt n$,得到最优时间复杂度$O(q \sqrt n)$。

教主的魔法

题目大意:1、区间加;2、求区间大于$c$的元素个数。

首先来思考一下如何找出一个整块中大于$c$的元素个数 …… 如果一个整块中的元素有序,那么是不是可以二分。

在上题中我们用$s$来维护区间和,现在我们用$s$来维护区间的有序数组。那么一个区间内大于$c$的元素个数就等于散块大于$c$的个数 + 每个整块中大于$c$的个数(这部分用lower_bound来处理)。单次询问时间复杂度为$O((n/B) log B + B)$。

在来思考一下区间加,由于需要额外维护一个有序数组,每次加操作可以对散块进行暴力排序,整块就在标记上加值,单词加的时间复杂度为$O(B log B + n / B)$。

关于B的取值问题,首先应该考虑平衡两种操作的时间复杂度,其次应该考虑整块的时间复杂度和散块部分的时间复杂度,具体而言就是让两者相等求出B。比如此题中对散块重新排序可以优化成提取有序的两部分(有改动和无改动)再merge,此时加操作的时间复杂度就降为$O(B + n / B)$,发现这个复杂度是恒小于询问操作的复杂度的,此时就应该考虑适当增大B的大小,可以对$(n / B) log B + B$ 求导计算出极值,也可以用三分来求,事实上这个值约等于$\sqrt{n\log \sqrt n}$。

参考代码(不加优化):

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
const int INF = 1e9 + 7, MAXN = 1100000, mod = 998244353;
int n, q, BS;
int a[MAXN], v[410], s[MAXN];
void add(int l, int r, int k) {
    int bl = l / BS, br = r / BS;
    if(bl == br) {
        for(int i = l; i <= r; i ++) {
            a[i] += k;
        }
        for(int i = bl * BS; i < (bl + 1) * BS; i ++) {
            s[i] = a[i];
        }
        sort(s + bl * BS, s + (bl + 1) * BS);
    } else {
        for(int i = l; i < (bl + 1) * BS; i ++) {
            a[i] += k;
        }
        for(int i = bl * BS; i < (bl + 1) * BS; i ++) {
            s[i] = a[i];
        }
        sort(s + bl * BS, s + (bl + 1) * BS);
        for(int i = bl + 1; i < br; i ++) {
            v[i] += k;
        }
        for(int i = br * BS; i <= r; i ++) {
            a[i] += k;
        }
        for(int i = br * BS; i < (br + 1) * BS; i ++) {
            s[i] = a[i];
        }
        sort(s + br * BS, s + (br + 1) * BS);
    }
}
int query(int l, int r, int c) {
    int bl = l / BS, br = r / BS;
    int C, res = 0;
    if(bl == br) {
        C = c - v[bl];
        for(int i = l; i <= r; i ++) {
            if(a[i] >= C) res ++;
        }
    } else {
        C = c - v[bl];
        for(int i = l; i < (bl + 1) * BS; i ++) {
            if(a[i] >= C) res ++;
        }
        for(int i = bl + 1; i < br; i ++) {
            C = c - v[i];
            res += BS - (lower_bound(s + BS * i, s + BS * (i + 1), C) - (s + BS * i));
        }
        C = c - v[br];
        for(int i = br * BS; i <= r; i ++) {
            if(a[i] >= C) res ++;
        }
    }
    return res;
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    memset(a, 0x3f3f3f3f, sizeof a);
    cin >> n >> q;
    BS = sqrt(n) + 1;
    for(int i = 0; i < n; i ++) {
        cin >> a[i];
        s[i] = a[i];
    }
    for(int i = 0; i * BS <= n ; i ++) {
        sort(s + i * BS, s + (i + 1) * BS);
    }
    while(q --) {
        char op;
        int l, r, x;
        cin >> op >> l >> r >> x;
        l --, r --;
        if(op == 'M') add(l, r, x);
        else cout << query(l, r, x) << '\n';
    }
}

稍稍写一下归并的代码(给一个结构体写的板子)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
int n, m, BS;
ll a[MAXN];
struct BLOCK {
    int l, r;
    ll ex;
    vector<ll> B;
    vector<int> mp;
    BLOCK(int s) {
        B.resize(s);
        ex = 0;
    }
    BLOCK(int _l, int _r) {
        ex = 0;
        l = _l, r = _r;
        int s = r - l + 1;
        B.resize(s);
        mp.resize(s);
        for(int i = l; i <= r; i ++) {
            mp[i - l] = i;
        }
        sort(mp.begin(), mp.end(), [&] (int x, int y)  {
            return a[x] < a[y];
        });
        for(int i = 0; i < s; i ++) {
            B[i] = a[mp[i]];
        }
    }
    void rebuild(int s, int t, int k) {
        vector<int> _1, _2;
        for(int i = 0; i < mp.size(); i ++) {
            int x = mp[i];
            if(x >= s && x <= t) {
                _1.push_back(x);
            } else {
                _2.push_back(x);
            }
        }
        for(int i : _1) a[i] += k;
        int i = 0, j = 0, c = 0;
        while(i < _1.size() && j < _2.size()) {
            if(a[_1[i]] < a[_2[j]]) mp[c ++] = _1[i ++];
            else mp[c ++] = _2[j ++];
        }
        while(i < _1.size()) mp[c ++] = _1[i ++];
        while(j < _2.size()) mp[c ++] = _2[j ++];
        for(int i = 0; i < mp.size(); i ++) B[i] = a[mp[i]];
    }
};
vector<BLOCK> block;
void add(int l, int r, int k) {
    int lb = l / BS, rb = r / BS;
    if(lb == rb) {
        block[lb].rebuild(l, r, k);
    } else {
        block[lb].rebuild(l, block[lb].r, k);
        block[rb].rebuild(block[rb].l, r, k);
        for(int i = lb + 1; i < rb; i ++) {
            block[i].ex += k;
        }
    }
}
由乃打扑克

给你一个长为 $n$ 的序列 $a$,需要支持 $m$ 次操作,操作有两种: 1、查询区间 $[l,r]$ 的第 $k$ 小值。 2、区间 $[l,r]$ 加上 $k$。

分块 + 整体二分

第$k$小的值$x$就相当于区间内存在$k-1$个数小于等于$x$,并且求这个值符合二分的规则,故可以整体二分确定$x$的大小,具体的check方法与上一例题一模一样。

几个优化: 1. 二分前将左右两块散块合并,这样可以用$O(log\ B)$来代替对整个散块遍历的$O(B)$ 2. 二分前计算出二分的上下界,以减少二分次数。 复杂度分析: add 时间复杂度$O(B + n / B)$ query 时间复杂度$O(B + (n / B)log\ B\ log\ w)$ ($log\ w$ 是二分是次数,大概为$20-30$)。

确定块长:理论上复杂度的计算应该是查询时间复杂度 对$B$求偏导得到的,但是按照理论来算的话,这题这么做是可以被卡掉的。于是需要调亿点点块长,下面这份代码块长取$912$可以拿到$82pts$(已经不想扣常数了)。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
const int INF = 1e9 + 7, MAXN = 1e5 + 9, mod = 998244353;
void read() {}
template<typename T,typename... Ts>
inline void read(T &arg,Ts&... args) {
    T x = 0, f = 1;
    char c = getchar();
    while(!isdigit(c)){if(c == '-') f = -1; c = getchar();}
    while(isdigit(c)){x = (x << 3) +(x << 1) + (c - '0');c = getchar();}
    arg = x * f;
    read(args...);
}
int n, m, BS;
ll a[MAXN];
int _1[MAXN], _2[MAXN], t1, t2;
struct BLOCK {
    int l, r;
    ll ex;
    vector<ll> B;
    vector<int> mp;
    BLOCK(int s) {
        B.resize(s);
        ex = 0;
    }
    BLOCK(int _l, int _r) {
        ex = 0;
        l = _l, r = _r;
        int s = r - l + 1;
        B.resize(s);
        mp.resize(s);
        for(int i = l; i <= r; ++ i) {
            mp[i - l] = i;
        }
        sort(mp.begin(), mp.end(), [&] (int x, int y)  {
            return a[x] < a[y];
        });
        for(int i = 0; i < s; i ++) {
            B[i] = a[mp[i]];
        }
    }
    void rebuild(int s, int t, int k) {
        t1 = t2 = 0;
        for(int i = 0; i < mp.size(); ++ i) {
            int x = mp[i];
            if(x >= s && x <= t) _1[t1 ++] = x;
            else _2[t2 ++] = x;
        }
        for(int i = 0; i < t1; ++ i) a[_1[i]] += k;
        int  i = 0, j = 0, c = 0;
        while(i < t1 && j < t2) {
            if(a[_1[i]] < a[_2[j]]) mp[c ++] = _1[i ++];
            else mp[c ++] = _2[j ++];
        }
        while(i < t1) mp[c ++] = _1[i ++];
        while(j < t2) mp[c ++] = _2[j ++];
        for(int i = 0; i < mp.size(); i ++) B[i] = a[mp[i]];
    }
};
vector<BLOCK> block;
inline void add(int l, int r, int k) {
    int lb = l / BS, rb = r / BS;
    if(lb == rb) {
        block[lb].rebuild(l, r, k);
    } else {
        block[lb].rebuild(l, block[lb].r, k);
        block[rb].rebuild(block[rb].l, r, k);
        for(int i = lb + 1; i < rb; i ++) {
            block[i].ex += k;
        }
    }
}
inline ll query(int l, int r, int k) {
    int lb = l / BS, rb = r / BS;
    if(lb == rb) {
        int c = 0;
        for(int i : block[lb].mp) {
            if(i >= l && i <= r) c ++;
            if(c == k) return a[i] + block[lb].ex;
        }
    } else {
        BLOCK _ex(block[lb].r - l + 1 + r - block[rb].l + 1);
        t1 = t2 = 0;
        for(int i : block[lb].mp) {
            if(i >= l) _1[t1 ++] = a[i] + block[lb].ex;
        }
        for(int i : block[rb].mp) {
            if(i <= r) _2[t2 ++] = a[i] + block[rb].ex;
        }
        int c = 0, i = 0, j = 0;
        while(i < t1 && j < t2) {
            if(_1[i] < _2[j]) _ex.B[c ++] = _1[i ++];
            else _ex.B[c ++] = _2[j ++];
        }
        while(i < t1) _ex.B[c ++] = _1[i ++];
        while(j < t2) _ex.B[c ++] = _2[j ++];
        ll L = 1e18, R = -1e18;
        L = min(L, _ex.B[0]);
        R = max(R, _ex.B.back());
        for(int i = lb + 1; i < rb; ++ i) {
            L = min(L, block[i].B[0] + block[i].ex);
            R = max(R, block[i].B.back() + block[i].ex);
        }
        L --, R ++;
        while(R > L + 1) {
            ll M = L + R  >> 1;
            int cnt = 0;
            cnt += lower_bound(_ex.B.begin(), _ex.B.end(), M) - _ex.B.begin();
            for(int i = lb + 1; i < rb; i ++) {
                cnt += lower_bound(block[i].B.begin(), block[i].B.end(), M - block[i].ex) - block[i].B.begin();
            }
            if(cnt >= k) R = M;
            else L = M;
        }
        return R - 1;
    }
}
int main() {
    read(n, m);
    for(int i = 0; i < n; ++ i) {
        read(a[i]);
    }
    BS = 912;
    for(int i = 0; 1ll * i * BS < n; ++ i) {
        int l = i * BS, r = min(n - 1, (i + 1) * BS - 1);
        block.push_back(BLOCK(l, r));
    }
    while(m --) {
        int op, l, r, k;
        read(op, l, r, k);
        l --, r --;
        if(op == 1) {
            if(r - l + 1 < k) {
                cout << "-1\n";
            } else {
                printf("%lld\n", query(l, r, k));
            }
        } else {
            add(l, r, k);
        }
    }
}
弹飞绵羊

有$n$个点,每个点都有一个值$k$,第$i$个点会通过一次弹射到$i + k_i$上,直到点不存在。两种操作,一种询问第$i$个点一共会经历几次弹射,另一种修改$k_i$。

如果不存在修改操作,那么只需要预处理出所有点的答案,每次查询只需要$O(1)$的复杂度就可以求出解。而修改的操作是$O(n)$的,因此需要一种方法来平衡两种操作。发现可以将所有点任意地分割成$x$段,每一小段中的每个点都存在一个$k’_i$从所在段中弹飞,并且落到另一段中的某个点,因此可以记录每个段中每个点从所在段弹飞的所需的次数以及落点,此时每次修改只需要修改所在块即可。故得出解:将$n$个点分成$\sqrt n$段,同时记录次数和位置,修改就重构块,查询就对该点在所有块中的次数求和。两种操作的时间复杂度均为$\sqrt n$。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
const int INF = 1e9 + 7, MAXN = 2e5 + 400, mod = 998244353;
int n, BS, m, B;
int a[MAXN], c[MAXN], to[MAXN];
void upd(int l, int r) {
    for(int i = r; i >= l; i --) {
        if(i + a[i] > r) {
            c[i] = 1;
            to[i] = i + a[i];
        } else {
            c[i] = c[i + a[i]] + 1;
            to[i] = to[i + a[i]];
        }
    }
}
int query(int x) {
    int res = 0, pos = x;
    do {
        res += c[pos];
        pos = to[pos];
    } while(pos < n);
    return res;
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin >> n;
    BS = sqrt(n) + 1;
    for(int i = 0; i < n; i ++) {
        cin >> a[i];
    }
    for(int i = 0; i < n; i += BS) {
        upd(i, min(n - 1, i + BS - 1));
    }
    cin >> m;
    while (m --) {
        int op, x, y;
        cin >> op >> x;
        if(op == 1) {
            cout << query(x) << '\n';
        } else {
            cin >> y;
            a[x] = y;
            int b = x / BS;
            upd(b * BS, min(n - 1, (b + 1) * BS - 1));
        }
    }
}