Python `itertools.permutations` 使用的排列算法

- 算法

TL;DR:这本质上是一个基于回溯,利用元素交换的递归排列生成算法,但被重写成了循环形式(可能出于效率考量)。

引子

最近在算法复健,刷到了排列相关的题目。恰巧 Python 内置了一个非常实用的工具库 itertools,其中有一个 permutations(iterable, r) 方法,可以对一个给定的 iterable 生成所有大小为 r 的排列,且输出按照字典序排列。

>>> list(permutations('ABCD', 2))
[('A', 'B'), ('A', 'C'), ('A', 'D'), 
 ('B', 'A'), ('B', 'C'), ('B', 'D'), 
 ('C', 'A'), ('C', 'B'), ('C', 'D'), 
 ('D', 'A'), ('D', 'B'), ('D', 'C')]

在我之前所接触的算法中,排列生成要么是基于回溯,要么是基于字典序,但无论哪种都只能生成全排列,而无法生成这样的部分排列(指生成的排列长度 r 和原输入长度 n 不同)。另一条思路是先生成所有长度为 r 的组合,然后再在每个组合内生成全排列,但这样无法保证输出按字典序(除非先手动收集再排序)。

于是我打开了 Python 的 itertools 的官方文档,其中提供了与 CPython 实现等价的 Python 代码,permutations 方法的代码如下(🔗):

def permutations(iterable, r=None):
    # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) --> 012 021 102 120 201 210
    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return
    indices = list(range(n))
    cycles = list(range(n, n-r, -1))
    yield tuple(pool[i] for i in indices[:r])
    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
            else:
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                yield tuple(pool[i] for i in indices[:r])
                break
        else:
            return

尝试初步理解

虽然顶部有两行注释,不过这也只是描述了这个方法的效果,对其原理并没有任何解释。往后继续看,可以发现算法首先构造了两个 list:indicescycles,且之后每次输出结果(yield)实际上都是将 indices 中的前 r 个输出。再往后算法进入了一个神秘的 for 循环,对 cycles 中的元素做了一些修改,用 cycles 的值对 indices 中的一些元素做了交换。仅从代码层面出发,对算法的理解似乎也就止步于此了。然而这并没有回答一个重要问题:为什么这个算法能工作?

寻找相关信息

考虑到 itertools 库是在 Python 2.3 (2003 年 7 月)被引入标准库的,这个算法已经算得上历史悠久了。再加之 itertools 库的广泛使用,这个算法的原理应该是早已被详尽记录的。带着这样的期望,我开始用各种关键词组合搜索相关信息。可惜的是,除了 Stack Overflow 上一个 2010 年的问题(🔗),和一个知乎提问(🔗),就没有任何其他相关的网页了,甚至连当年的提交记录都找不到。

既然如此,那就只能从这两个链接入手了。

自己来

已有的信息似乎不是很充分。看来我只能自己来了。在加了一堆 print 并在纸上手动模拟了多次这个算法之后,我认为我可能大概理解它的工作原理,并且可以证明其正确性了。下文将详述我的理解。

算法

起步

开始前,需要统一一下后文使用的记号:

并回顾我们先前阅读算法得到的理解:

我们将按照如下步骤理解这个算法:

  1. 理解 cycles 的变化
  2. 理解 indices 的变化,并尝试说明这个算法的正确性
  3. 尝试重新实现这个算法的「原始」递归版本

cycles

我们首先从 cycles 变量入手,理解它在这个算法中是如何变化的。这个阶段我们暂时先不考虑 indices

可以先指定一些具体的输入,然后尝试加一些 print 语句。以 iterable="ABCD", r=2 作为输入,在 ifelse 两个分支执行前后中都插入 print,可以得到如下结果:(其中中括号说明算法有输出 yield,大括号部分算法无输出)

[4,3] -> [4,2] -> [4,1] -> {4,0} -> {4,3} -> 
[3,3] -> [3,2] -> [3,1] -> {3,0} -> {3,3} -> 
[2,3] -> [2,2] -> [2,1] -> {2,0} -> {2,3} -> 
[1,3] -> [1,2] -> [1,1] -> {1,0} -> {1,3} -> {0,3} -> {4,3}

我们可以直观感受到,似乎 cycles 变量就像一个「倒计时」,或者说「带借位的减法」。

从这个具体的示例出发,我们可以这样理解 cycles 的变化:

有了这一直观感受,就可以为 cycles 找出一个可能的解释(「物理含义」)了。我认为,cycles 代表的是 「每个位置上剩余的可用选择数」 。如果将 cycles 视作一个变进制数,则 cycles 也代表 「总体剩余还没有输出的排列数」 。理由如下:

其实 cycles 的变化,无论是 Stack Overflow 上的回答,还是知乎上的回答,都有相对详尽的描述。在此我只是尝试以自己的语言重述了一次而已。但接下来对 indices 的理解就大部分是我自己的了。

indices

现在我们来看看 indices 是如何变化的。和之前对 cycles 的探索一样,我们也先从一个具体的例子开始:iterable="ABCDE",r=3,并关注一个子问题:前 3 个输出(ABC, ABD, ABE)是如何产生的。为便于展示,这里我直接使用具体元素(字母)代替 index。加了一些 print 后,我们可以得到如下的变化过程。

py_permutation

这个图稍微有些复杂。以下是进一步解释。

可以发现,这部分执行过程,恰好满足了回溯算法的正确性要求:

虽然图中仅描述了一个子问题(i=2,或者说i=r-1),但不难发现对于其他的 i[0, r-1] 这一讨论都是成立的。这也(不严格地)说明了这一算法的确可以遍历所有的可能排列。输出顺序为字典序,则是因为每个 tick 中交换元素时都维护了 backlog 中的相对顺序。

这部分讨论有些复杂,如果不太理解(或者不完全信服)的话,可以自己多加点 print ,或者手动在纸上执行感受一下。

重新实现

现在我们已经了解了这个算法的原理,重新实现其原始递归版本也就不难了。

以下是一个可能的 Python 重新实现。

## a reimplementation of `itertools.permutation`

# helpers
def swap(list, i, j):
    list[i], list[j] = list[j], list[i]

def move_to_last(list, i):
    list[i:] = list[i+1:] + [list[i]]

def print_first_n_element(list, n):
    print("".join(list[:n]))

# backtracking dfs
def permutations(list, r, changing_index):
    if changing_index == r:
        # we've reached the deepest level
        print_first_n_element(list, r)
        return
    
    # a pseudo `tick`
    # process initial permutation
    # which is just doing nothing (using the initial value)
    permutations(list, r, changing_index + 1)

    # note: initial permutaion has been outputed, thus the minus 1
    remaining_choices = len(list) - 1 - changing_index
    # for (i=1;i<=remaining_choices;i++)
    for i in range(1, remaining_choices+1):
        # `tick` phases
        
        # make one swap
        swap_idx = changing_index + i
        swap(list, changing_index, swap_idx)
        # finished one move at current level, now go deeper
        permutations(list, r, changing_index + 1)
    
    # `reset` phase
    move_to_last(list, changing_index)

# wrapper
def permutations_wrapper(list, r):
    permutations(list, r, 0)

# main
if __name__ == "__main__":
    my_list = ["A", "B", "C", "D"]
    permutations_wrapper(my_list, 2)

递归转循环优化

出于性能和安全(防止爆栈)的考量,我们会想将这个算法的递归版本转换成循环版本。这需要我们用栈手动维护每一层递归的相关状态,包括递归中的变量和下一次执行的开始位置。幸运的是,对这个算法而言,我们需要维护的状态并不多。

基于上文分析,可以发现我们需要维护的栈有两个特点:

回头看看,这实际上就是 cycles。在「剩余可能数」的身份之外,cycles 也承担起了维护递归状态的职责。而作者巧妙利用了 Python 列表索引可以为负数从后往前的特性,统一了 cycles 的两面。

至此,我们完成了对这一算法的分析。🎉

相关链接