跳至主要內容

倍增法与ST算法

Izudia大约 5 分钟算法算法竞赛进阶指南

倍增,顾名思义就是翻倍。它能够使线性的处理转化为对数级的处理,大大地优化时间复杂度。

这个方法在很多算法中均有应用,其中最常用的是 RMQ 问题的ST算法和求 LCA(最近公共祖先)树上倍增。

倍增法

倍增法是一种快速计算一个数的幂的算法。该算法基于指数的二进制分解,并利用指数的二进制表示的特性,通过不断平方和乘法来快速计算幂。该算法的时间复杂度为$O(\log n)$。
通过递推状态空间在2的整数次幂位置上的值作为代表,当需要其他位置上的值时,可以将任意整数表示成若干个2的次幂项之和,即二进制划分,使用已经记录的代表值拼接,因此使用倍增法需要我们的递推问题状态空间关于2 的次幂具有可划分性。

例题 AcWing109 天才ACM

Description

给定一个整数 M,对于任意一个整数集合 S,定义“校验值”如下:
从集合 S 中取出 M 对数(即 2×M 个数,不能重复使用集合中的数,如果 S 中的整数不够 M 对,则取到不能取为止),使得“每对数的差的平方”之和最大,这个最大值就称为集合 S 的“校验值”。

现在给定一个长度为 N 的数列 A 以及一个整数 T。

我们要把 A 分成若干段,使得每一段的“校验值”都不超过 T。

求最少需要分成几段。

Input

第一行输入整数 K,代表有 K 组测试数据。

对于每组测试数据,第一行包含三个整数 N,M,T 。

第二行包含 N 个整数,表示数列A1,A2…AN。

Output

对于每组测试数据,输出其答案,每个答案占一行。

Range

$1≤K≤12,1≤N,M≤500000,0≤T≤1018,0≤Ai≤220$

Sample Input

2
5 1 49
8 2 1 7 9
5 1 64
8 2 1 7 9

Sample Output

2
1

"""
二分区间可以求解,但扩展速度太慢,需要一种与右端点R扩展长度相适应的算法
倍增过程如下:
1. 初始化  p=1, R=L
2. 求出  [L, R+p]  这一段区间的校验值:
    - 若校验值  \leq T  ,则  R+=p, p^{*}=2
    - 若校验值  >T  ,则  \mathrm{p} /=2
3. 重复上一步,直到  p  变为 0 为止,此时  R  即为所求
"""


def check(l, r):
    sub = sorted(A[l:r])
    tot = 0
    i, j = 0, r - l - 1
    while i < j and i < M:
        tot += (sub[i] - sub[j]) ** 2
        i += 1
        j -= 1
    return tot <= T

"""
优化:二路归并,只排序r-r+p的新序列,与原排序好的序列合并
"""

def merge(sub_old, sub_new):
    i = j = 0
    merged = []
    while i < len(sub_old) and j < len(sub_new):
        if sub_old[i] > sub_new[j]:
            merged.append(sub_new[j])
            j += 1
        else:
            merged.append(sub_old[i])
            i += 1
    merged += sub_old[i:]
    merged += sub_new[j:]
    return merged


K = int(input())
while K:
    N, M, T = map(int, input().split())
    A = list(map(int, input().split()))
    old = []
    l, cnt = 0, 0
    while l < N:
        p, r = 1, l
        while p:
            if r + p <= N and check(l, r + p):
                r += p
                p *= 2
            else:
                p //= 2
        cnt += 1
        l = r
    print(cnt)
    K -= 1

ST算法

在RMQ问题(区间最值问题)中,ST(Sparse Table)算法是倍增的衍生。

一般来说,ST算法与线段树算法等一起提及,线段树算法是对一维序列递归二分的一棵树形结构,用于维护区间信息。ST算法也有一些限制,它要求查询的区间大小不能超过预处理区间的大小,而线段树算法没有这个限制,可以处理更为灵活的查询。

给定一个长度为N的数列A,ST算法能在$O(N \log N)$时间预处理后,以$O(1)$的时间复杂度在线回答”数列A中下标在l~r之间的数的最大值是多少“这样的问题。

设f[i, j]表示数列A中下标在子区间[i, i + $2^j$ - 1]里的最大值,即从i开始的$2^j$个数的最大值。递推边界为f[i, 0] = A[i],即数列A在子区间[i, i]里的最大值。

在递推过程中,将子区间的长度成倍增长,公式f[i, j] = max(f[i, j - 1], f[i + $2^{j - 1}$, j - 1]),即长度为$2j$的子区间的最大值是左右两半长度为$2$的子区间的最大值中较大的一个。

当询问任意区间[l, r]最值时,首先计算出k,满足$2^k$ <= r - l + 1 < $2^k$,即在2的k次幂小于区间长度的前提下最大的k。
那”从l开始的$2k$个数“和“以r结尾的$2k$个数”这两段一定能够覆盖整个区间[l, r],这两段的最大值分别是f[l][k]和f[r - 2 ** k + 1][k],而这种较大的值便是整个区间[l, r]的最值,虽然区间存在重叠,但覆盖了整个区间,所以不存在影响。

def build():
    for i in range(1, N + 1):
        f[i][0] = A[i - 1]
    for j in range(1, LOG_N):
        for i in range(1, (N - 2 ** j + 1) + 1):
            f[i][j] = max(f[i][j - 1], f[i + 2 ** (j - 1)][j - 1])
    return


def query(l, r):
    k = (r - l + 1).bit_length() - 1
    return max(f[l][k], f[r - 2 ** k + 1][k])

N = int(input())
l, r = map(int, input().split())
log_N = N.bit_length()
f = [[0] * (log_N + 1) for _ in range(N + 1)]
A = [0] + list(range(1, N + 1))
build()
query(l, r)