用秦九韶算法求解最大子数组和问题

問有兩尖田一段其尖長不等兩大斜三十九步兩小斜二十五步中廣三十步欲知其積幾何

術曰以少廣求之翻法入之。置半廣自乘為半冪與小斜冪相減相乘為小率 以半冪與大斜冪相減相乘為大率。以二率相減餘自乘為實併二率倍之為從上廉以一為益隅開翻法三乘方得積。

一直以来有些朋友对函数式编程不屑一顾。他们会觉得函数式编程只不过是增加了一些无谓的抽象而代价则是性能的下降。持这种观点的朋友恐怕是肯定没有读过 Richard Bird 所著 PEARLS OF FUNCTIONAL ALGORITHM DESIGN. 这本书讨论的内容确实太过生涩我也没有完整读过一遍。

幸运的是Richard Bird 1989 年写过一篇仅仅只有 5 页的论文 Algebraic identities for program calculation. 这篇论文用一个著名的问题最大子数组和1Maximum subarray problem淋漓尽致地展现了函数式编程思维对算法问题的独特理解最大子数组和问题可以用秦九韶算法求解。

最大子数组和问题

最大子数组和问题的目标是在数列的一维方向找到一个连续的子数列使该子数列的和最大。例如对一个数列 [−2, 1, −3, 4, −1, 2, 1, −5, 4]其连续子数列中和最大的是 [4, −1, 2, 1], 其和为 6 2

显然我们有平凡的算法即枚举所有的子数组然后求最大值

1def mss(arr):
2    r = 0
3    for end in range(len(arr) + 1):
4        for start in range(0, end):
5            s = 0
6            for k in range(start, end):
7                s += arr[k]
8            r = max(s, r)

这个程序用 Haskell 写会简洁一些

1import Data.List (tails, inits)
2
3mss = maximum . map sum . segs
4segs = concat . map tails . inits

其中 tails 后缀inits 前缀

1tails [1, 2, 3] = [[1, 2, 3], [2, 3], [3], []]
2inits [1, 2, 3] = [[], [1], [1, 2], [1, 2, 3]]

易见所有前缀的所有后缀或者所有后缀的所有前缀就是一个列表所有的子列表。

现在我们回到算法本身。这个平凡算法的时间复杂度是 O(n3)O(n^3). 把所有的子数组枚举出来需要 O(n2)O(n^2)对每一个子数组求和又需要 O(n)O(n)所以就构成了 O(n3)O(n^3).

事实上最优秀的算法在 O(n)O(n) 时间复杂度内就可以解决这个问题。Bird 的论文正是在讨论如何从平凡的算法出发推导 出最优秀的算法。首先对算法稍作变形

1def mss(arr):
2    m = 0
3    for end in range(len(arr) + 1):
4        m = max(max_tails(arr, end), m)
5
6def max_tails(arr, end):
7    m = 0
8    for start in range(0, end):
9        s = 0
10        for k in range(start, end):
11            s += arr[k]
12        m = max(m, s)
13    return m

这似乎是一个平凡的变形我只是把三层循环的里面两层拆出来了。max_tails 函数会返回数组 arr [0, end) 区间的最大后缀和。我们用 \uparrow 来表示最大值函数max_tails 的函数式版本或数学式版本可以写成

arr[0:end][x1,x2,,xn]m=(i=1nxi)(i=2nxi)(xn)0 \begin{aligned} \textsf{arr[0:end]} &\equiv [ x_1, x_2, \cdots, x_n ] \\ m &= (\sum_{i=1}^n x_i) \uparrow (\sum_{i=2}^n x_i) \uparrow \cdots \uparrow (x_n) \uparrow 0 \end{aligned}

这第一眼看上去十分令人费解下面用例子来说明。首先考虑 max_tails([1, 2, 3], 3), 也就是 [1, 2, 3] 的所有后缀和的最大值。

max_tails 的计算可以写成

1(1 + 2 + 3) ↑ (2 + 3) ↑ (3) ↑ 0

也就是说max_tails 会计算每个后缀和并把它们的最大值求出来。当把最大值写成一个二元函数 \uparrow 的时候自然就可以用它连接每个后缀和以构成最后的表达式。

在朴素算法中求后缀和的最大值需要 O(n2)O(n^2) 的复杂度。下面我们用秦九韶算法将这个复杂度降低到 O(n)O(n).

秦九韶算法

秦九韶在其著作《数术九章》中提出了一种求高次方程近似解的方法。其中的一个部分被近现代研究者称作秦九韶算法

考虑 nn 次多项式

f(x)=i=0naixi f(x) = \sum_{i=0}^{n} a_i x^{i}

对于确定的 tt, 有几种方法求 f(t)f(t) 的值。最平凡的方法就是分别求每项的值再加起来。这种算法需要 O(n2)O(n^2) 次乘法。秦九韶算法注意到

anxn+an1xn1++a0=((anxn1)+(an1xn2)++a1)x+a0==(((anx+an1)x+)x+a1)x+a0 \begin{aligned} & a_n x^n + a_{n-1} x^{n-1} + \cdots + a_0 \\ = & ((a_n x^{n - 1}) + (a_{n-1} x^{n - 2}) + \cdots + a_{1}) x + a_0 \\ = & \cdots \\ = & (((a_n x + a_{n - 1}) x + \cdots) x + a_{1})x + a_0 \end{aligned}

这可以写成一个迭代形式

s0=ansi+1=sit+ani1 \begin{aligned} s_0 &= a_n \\ s_{i + 1} &= s_i t + a_{n - i - 1} \end{aligned}

这样一来 sn=f(t)s_n = f(t). 在这个计算中乘法计算了 nn 加法也计算了 nn 次。

看上去秦九韶算法只是一个 数值计算 算法。它和 max_tails 函数有什么关系呢我们要考虑广义的秦九韶算法形式。

首先这个算法的输入不一定要是多项式如下的式子仍然可以用秦九韶算法改写

x1x2x3+x2x3+x3+1=(((x1+1)x2)+1)x3+1 \begin{aligned} & x_1x_2x_3 + x_2x_3 + x_3 + 1 \\ =& (((x_1 + 1) x_2) + 1) x_3 + 1 \end{aligned}

其次秦九韶算法考虑的运算不一定是 (+,)(+, \cdot), 这个算法提取公因式的时候其实是在用乘法分配率改写算式。考虑任意的两个函数 (,)(\oplus, \odot)只要它们满足如下条件3那就可以运用秦九韶算法

  • \odot \oplus 分配
    ac+bc=(a+b)c(ac)(bc)=(ab)c \begin{aligned} a \cdot c + b \cdot c &= (a + b) \cdot c \\ (a \odot c) \oplus (b \odot c) &= (a \oplus b) \odot c \\ \end{aligned}
  • 存在 \odot 单位元即存在 aa使得 x,ax=x\forall x, a \odot x = x

以刚才的表达式为例

((x1x2)x3)(x2x3)(x3)(1)=((((x11)x2)1)x3)1 \begin{aligned} & ((x_1 \odot x_2) \odot x_3) \oplus (x_2 \odot x_3) \oplus (x_3) \oplus (1) \\ = & ((((x_1 \oplus 1) \odot x_2) \oplus 1) \odot x_3) \oplus 1 \end{aligned}

其中 11 \odot 的左单位元。

Kadane 算法

刚才已经说过了max_tails 可以写成

arr[0:end][x1,x2,,xn]m=(i=1nxi)(i=2nxi)(xn)0 \begin{aligned} \textsf{arr[0:end]} &\equiv [x_1, x_2, \cdots, x_n ] \\ m &= (\sum_{i=1}^n x_i) \uparrow (\sum_{i=2}^n x_i) \uparrow \cdots \uparrow (x_n) \uparrow 0 \end{aligned}

这里的 \uparrow 可以被看作 \oplus ++ 可以被看作 \odot. 加法存在单位元 00而分配性也是显然的

(ab)+c=(a+c)(b+c) (a \uparrow b) + c = (a + c) \uparrow (b + c)

继续以 [1, 2, 3] 为例

1(1 + 2 + 3) ↑ (2 + 3) ↑ (3) ↑ 0
2= ((1 + 2) ↑ 2 ↑ 0) + 3 ↑ 0
3= (((1 ↑ 0) + 2) ↑ 0) + 3 ↑ 0
4= ((((0 + 1) ↑ 0) + 2 ↑ 0) + 3) ↑ 0

根据这个性质我们用秦九韶算法改写 max_tails:

1def step(s, a):
2    # (s + a) ↑ 0
3    return max(s + a, 0)
4
5def max_tails(arr, end):
6    s = 0
7    for i in range(end):
8        s = step(s, arr[i])

显然改写后的 max_tails 的时间复杂度是 O(n). 不仅如此我们注意到

max_tails(arr,end+1)= step(step(,arr[end1]),arr[end])= step(max_tails(arr, end),arr[end]) \begin{aligned} & \textsf{max\_tails}(\textsf{arr}, \textsf{end} + 1) \\ =\ & \textsf{step}(\textsf{step}(\cdots,\textsf{arr}[\textsf{end} - 1]), \textsf{arr}[\textsf{end}]) \\ =\ & \textsf{step}(\textsf{max\_tails(arr, end)}, \textsf{arr}[\textsf{end}]) \end{aligned}

由于 mss 需要求 [1, end + 1) max_tails所以改写后时间复杂度是 O(n2)O(n^2). 但上面的发现告诉我们max_tails(arr, i) 的值可以由 max_tails(arr, i - 1) 的值求得。这立刻就可以用来得到一个 O(n)O(n) mss 函数

1def step1(state, a):
2    (s, m) = state
3    s_next = step(s, a)
4    return (s_next, max(m, s_next))
5
6def mss(arr):
7    state = (0, 0)
8    for a in arr:
9        state = step1(state, a)
10    return state[1]

毫无疑问这就是大名鼎鼎的 Kadane 算法。

原始推导

Bird 教授将他的研究总结为一段话

I was interested in the specific task of taking a clear but inefficient functional program, a program that acted as a specification of the problem in hand, and using equational reasoning to calculate a more efficient one. 4

也就是说Bird 教授喜欢从一个简单、清晰却不高效的函数式程序出发通过一般的规则对程序进行优化从而得到一个更高效的程序。Bird 教授从最直觉的程序出发一步步推导之后甚至可以得到 KMP 算法这种著名算法。Bird 教授论文中对最大子数组和问题的原始推导是5

1mss = { by definition }                                     
2      max . map sum . segs                                  O(n³)
3    = { by definition of segs }
4      max . map sum . concat . map tails . inits            O(n³)
5    = { map promotion }
6      max . concat . map (map sum) . map tails . inits      O(n³)
7    = { definition of max and fold promotion }               
8      max . map max . map (map sum) . map tails  . inits    O(n³)
9    = { map distributivity }
10      max . map (max . (map sum) . tails) . inits           O(n³)
11    = { Horner's Rule }
12      max . map (foldl (⊗) 0) . inits                       O(n²)
13    = { scan theorem }
14      max . scanl (⊗) 0                                     O(n)
15    = { fold-scan fusion }
16      fst . foldl (⊙) (0, 0)                                O(n)
17    where a ⊗ b = (a + b) ↑ 0
18          (u, v) ⊙ t = (w ↑ u, w), w = v ⊗ t

仅仅用了 8 条等式就得到了最后的算法可以说是神乎其技。更可贵的是无论是秦九韶算法还是最后的优化都在推导过程中被展示为直接的定理。而这些定理就是关于函数式编程中常用函数 map, foldl, concat的恒等式。例如秦九韶算法可以直接表示为

1fold (⊕) b . map (fold (⊗) a) . tails = fold (⊙) a
2u ⊙ v = (u ⊗ v) ⊕ a

那么为什么函数式编程可以方便地进行这种程序推导呢因为函数式编程使得我们更加地注意到程序的 代数 性质。Bird 教授还写过一本书名字就叫做 The algebra of programming.

不可爱的另一面

如果读者尝试着把上面的代码交到 LeetCode 也许会发现它会在一些测试上失败。这当然不是算法错了而是最大子数组和问题的表述不同。

  • (LeetCode): 给定一个数组给出它的最大子数组和。
  • (Programming Pearls): 给定一个数组给出它的最大子数组和。如果数组的最大子数组和为负那么返回 0.

我们的问题是 Programming Pearls 版本的它永远不会返回负的最大子数组和。

如果要处理 LeetCode 版本的题目一个简单的修复是

1mssNeg l
2  | r == 0    = max l
3  | otherwise = r
4    where r = mss l

这种做法略显笨重。我们能够对 Bird 教授的推导过程作修改从而推导出能处理负数情况的算法吗对于这个问题我没有给出什么理想的解决方案。因为 Bird 教授的第一条式子就错了。

1--sum [] = 0
2mss = maximum . map sum . segs

segs 的结果中一定存在空表而空表在用 sum 求值之后就会得到 0. 所以显然这个式子就无法处理负数最大和的情况。又考虑到在这里的秦九韶算法中0 + 的单位元我认为要修改到能处理负数的算法需要的可能是完全推倒重来。

这也许就是函数式程序推导不可爱的一面。函数式程序推导依赖的是程序的代数性质而这种代数性质有时候是脆弱的可能问题稍加修改后原有的代数性质就消失了。对命令式程序来说也许它本身就不依赖这些代数性质所以一点平凡的改变就可以 work但对函数式、特别是用函数式程序推导推导出的程序来说这种修改并不平凡。

可以发邮件问问 Bird 教授的看法吗

也许可以吧。Bird 教授在 2022 4 4 日永远地离开了我们有没有一个能向天堂寄信的邮局呢……