Hello?

如何思考传统算法问题——以 Union-Find 算法为例

2019.07.05

我们常有这样的困扰:

“怎样系统地思考一个算法问题?”  
“为什么我只会解决见过的算法题,对陌生的算法题不知道怎样着手?”  

之所以说传统算法问题,是为了区别于机器学习算法,我们这里所说的算法,只涉及传统的数据结构与算法。
《算法(第4版)》在第一章就解决了这个新手的困扰,介绍思考任意传统算法问题的科学思路。本文做一个梳理,通过Union-find算法作为例子,总结面对一个传统算法场景的思考方向。

总体思路

 定义基本操作API -> 定义测试用例 -> 实现一个简单方案 -> 分析算法效率 -> 优化迭代 

以 Union-Find 问题为例:

什么是 Union-Find 问题?

假如有很多点,构成这样的集合。

每个点右下角是角标,代表点的位置。
简单来说问题就是判断其中任意2个点是否连接,具体如下:
1. 输入一串数字,每两个数字为一组,比如 tinyUF.txt:

» cat tinyUF.txt
15
4 3
3 8
6 5
9 4
2 1
8 9
5 0
7 2
6 1
1 0
6 7

第一个数字表示数据集合大小,只用来方便测试,我们先忽略它。
2. 从第二行开始每次读2个数字为一组,判断这两个数字代表的点是否连接。(这里定义的连接可以传递,比如a连接b,b连接c,那么a与c也连接。)
3. 如果不连接,实现算法让其连接;如果连接,返回true,然后继续读下一组数。

Step 1. 定义API

一般来讲,我们要先把问题明确为一个函数,规定入参和出参。或者,把问题明确为一个类,规定类中包括哪些成员变量和方法。
当我们想好函数的入参和出参(或这个类的每个方法参数和返回值),就相当于明确了我们要解决什么问题。
这样用“封装”的思想,把问题定义在这个函数或类里,我们只需要思考怎么实现函数就行了。
(据我观察,工作中遇到的很多问题,都是因为没把问题本身想清楚,当你明确了问题之后,问题自然解决了。)

定义 API 就是抽象问题的过程
  • 用 int[] 表示所有点的集合。
    这相当于把每个点抽象成数组里的一个元素,元素的角标用来确定这个点的位置,元素的值代表这个点的某种信息,比如所在分量的标识。
  • 把互相连接的点的集合称为分量
  • 接下来定义解决这个问题的类:
class UF{
    //    所有点的集合
    private int[] id;
    //    分量数量
    private int count;
    //    初始化
    public UF(int N) { }
    //    查找分量标识符
    public int find(int p) { }
    //    两个触点是否连接,也就是是否在同一个分量
    public boolean connected(int p, int q) { }
    //    连接两个触点
    public void union(int p, int q){ }
    public int count() { }
}

我们接下来的任务就是实现这个类里面每个方法。 其实,我们只需要优化这两个方法:

//根据角标找到这个点的值
public int find(int p) {}
//连接两个点,参数为角标
public void union(int p, int q) {}

因为其他方法可以基于这两个方法实现。
为了之后方便定义多种实现,把上面的类改为抽象类,同时先实现其他方法的逻辑:

public abstract class UF {
    //    所有点的集合
    private int[] id;
    //    分量数量
    private int count;
    //    初始化
    public UF(int N) {
        count = N;
        id = new int[N];
        for (int i = 0; i < N; i++) {
            id[i] = i;
        }
    }
    //    查找分量标识符
    abstract int find(int p);
    //    两个点是否连接,也就是是否在同一个分量
    public boolean connected(int p, int q) {
        return find(p) == find(q);
    }
    //todo    连接两个点,需要在接下来实现
    abstract public void union(int p, int q);
    public int count() {
        return count;
    }
}

这样,之后每次写出新的实现,只需继承同一个抽象类 UF,实现 find() 和 union()。

Step 2. 定义测试用例

在真正写实现逻辑之前,首先写测试代码是比较好的习惯。这既能检测上面定义的API是否合理,也可以方便验证效果。
定义一个实现类,用来写具体逻辑:

public class UFQuickFind extends UF {
// 具体需要实现的方法
    public UFQuickFind(int N) {
        super(N);
    }
//    查找分量标识符
    public int find(int p) {
        return id[p];
    }
//具体需要实现的方法
    public void union(int p, int q) {}
//测试用例
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int count = scanner.nextInt();
        UFQuickFind qf = new UFQuickFind(count);
        while (scanner.hasNextInt()) {
            int p = scanner.nextInt();
            int q = scanner.nextInt();
//            判断是否连接
            if (qf.connected(p, q)) {
                continue;
            }
//            如果没连接,则归并分量
            qf.union(p, q);
//            打印连接
            System.out.println(p + " " + q);
        }
        System.out.println(qf.count() + "components");
    }
}

做这种小实验,个人喜欢摆脱 IDE,用命令行编译:

测试数据 tinyUF.txt 可以在《算法》官网找到 https://algs4.cs.princeton.edu/15uf/

Step 3. 简单实现:quick-find 算法

首先我们需要先定义,什么叫连接? 我们姑且用两个元素的值相等来代表连接。显然,同一个分量中每个元素值都相等。 比如:

id[0] == id[5]  

代表0和5连接,也就是说属于同一分量。

实现 find() 方法:

上面在 UFQuickFind 类中已经写了这种 find() 方法,很简单,就是拿到数组角标对应的值,代表所属分量。

public int find(int p) {
    return id[p];
}
实现 union() 方法:

union() 方法用来连接两个点,只需要遍历每个元素,把一个点所属分量合并到另一个点所属的分量。

    public void union(int p, int q) {
//        找p和q各自所属分量
        int pGroup = find(p);
        int qGroup = find(q);
        if (pGroup == qGroup) {
            return;
        }
//        遍历所有元素,把与q同一个分量的元素放到p所属的分量
        for (int t = 0; t < id.length; t++) {
            if (id[t] == qGroup) {
                id[t] = pGroup;
            }
        }
        count--;
    }

因为这个算法的 find() 查找很快,所以叫 quick-find 算法。
测试结果:

Step 4. 分析效率

显然,quick-find 算法效率不高:

我们在这里把访问数组的次数作为衡量效率的指标,如果不管数据规模如何,某个算法只需访问数组常数次,则复杂度O(1)。这样规定指标的做法叫成本模型,本文这个算法问题的成本模型是访问数组的次数。

  • find() 时间复杂度 O(1)
  • union() 时间复杂度 O(N),N就是数组长度
  • 如果想用 quick-find 算法把所有数据连通,最后只剩下一个分量,最坏情况需要调用 N-1次 union(),那么很明显,这种情况下,时间复杂度 O(N^2)。

对于这个算法,靠直觉就能发现 union() 之所以效率低,是因为遍历了所有元素,受原始数据规模影响。接下来的优化思路,自然是想办法避免这种遍历。

Step 5. 优化:quick-union 算法

如果让一个分量的元素之间构成某种关联关系,能不能避免遍历所有数据?
比如让元素的值来指向另一个元素:
规定id[] 数组每个元素的值是同分量另一个元素的角标
这样,我们根据某个元素,可以追溯同分量的另一个元素,直到最后一个元素,这时,规定最后一个元素的值等于自己的角标。
这样,相当于每个元素的值都像一个“指针”,指向另一个元素。最后一个元素的“指针”指向自己。
想象一下,每个分量构成一个树状结构。

这种设计下,怎么表示两个点“连接”?

显然,当两个点属于同一分量,那么两个点属于同一个“树”,有同一个“根节点”。所以可以用根节点的值表示所属分量,也就是 find() 的返回值。

实现 find()

根据参数角标,拿到对应元素值,指向了下一个元素角标,然后访问下一个元素的值,就是下下个元素角标……不断向前追溯,直到某个元素的值等于自己角标值,说明这个元素是根节点。
我写了两种 find() 方法,一种是递归:

public int find(int p) {
    int pVal = id[p];
    if (pVal == p) {
        return p;
    }
    return find(pVal);
}

一种用 while() 循环:

public int find(int p) {
    while (p != id[p]) {
        p = id[p];
    }
    return p;
}
实现 union()

用 find() 找到两个点各自所属的树的根节点,把一个根节点指向另一个根节点,这样,就合并两个树。

public void union(int p, int q) {
    int pRoot = find(p);
    int qRoot = find(q);
    if (pRoot == qRoot) {
        return;
    }
    id[pRoot] = qRoot;
    count--;
    return;
}

因为这种算法的 union() 效率很高,所以叫 quick-union 算法。

重复 Step 4 - Step 5 (分析效率 -> 优化算法 -> 分析效率 -> 优化算法……)

分析 quick-union 效率

find() 方法时间复杂度

只看用 while 循环的方法。find() 查找速度受节点与根节点距离(也就是深度)影响,如果节点的父节点就是根节点,那么显然只需要访问数组2次。但如果整个原始数据构成一张树,且节点深度等于 N-1,这是 find() 遇到的最坏情况,如图:

如果调用find(0),则需要访问 2 * (N-1) (每次 while 循环访问2次)。
什么样的输入会导致这种最坏情况?
比如在输入:

0-1 0-2 0-3 0-4 0-5 0-6 0-7 0-8 0-9 

这样的序列,find()复杂度为 O(N),union中会调用2次 find(),复杂度也是 O(N)。
如果想把所有元素连通,最坏情况总的复杂度依然是 O(N^2)。
之所以会有这样的情况,是因为每次连接两个树,总是朝固定的一侧合并根节点,可能让本就比较高的树更高,不能限制树的规模。所以接下来思考方向是,怎么控制合并的树的规模,尽量均匀分配,不形成“一棵独大”的树。

继续优化 quick-union 算法:weighted-quick-union 算法

既然我们想控制树的高度(也就是节点中的最大深度),就需要先思考这个问题:树的高度是什么时候增加的?
答案是 union() 中把树 A 根节点指向树 B 根节点的时候,树 A 的高度增加1,树 B 的高度不变。
于是优化的方法呼之欲出:union() 每次合并树的时候,都把高度小的树往高度大的树上合并。使得合并后只增高小树。
我们可以定义一个数组记录根节点对应的树的高度。合并树的时候,读取两个根节点对应的树的高度,找出哪棵树更高:

private int[] height;
//构造方法中初始化 height
public UFWeightedQuickUnion(int N) {
    count = N;
    id = new int[N];
    for (int i = 0; i < N; i++) {
        id[i] = i;
    }        
    height = new int[N];
    for (int i = 0; i < N; i++) {
        height[i] = i;
    }
}

union() 方法:

void union(int p, int q) {
    int pRoot = find(p);
    int qRoot = find(q);
    if (pRoot == qRoot) {
        return;
    }
    int pH = height[pRoot];
    int qH = height[qRoot];
    if (pH > qH) {
        id[qRoot] = pRoot;
        height[pRoot]++;
    } else {
        id[pRoot] = qRoot;
        height[qRoot]++;
    }
    count--;
    return;
}

因为这种算法通过记录树的高度作为一种加权(weight),控制合并方向,所以可以叫加权quick-union 算法。

如果不用树的高度作为加权,而是用树的大小,也就是节点数做加权呢?这也是一种加权方法,这里就不介绍了。

分析 weighted-quick-union 算法

  • find() 方法复杂度 O(lgN)
  • union() 方法复杂度 O(lgN)
  • 如果要把所有点合到一个分量,最坏情况的复杂度 O((N -1) * lgN) ~ O(N * lgN)

最优算法:路径压缩 weighted-quick-union 算法

虽然加权 quick-union 算法已经把树的高度控制在 lgN,但毕竟 find() 中追溯根节点的时候还是需要经过其他节点。数据规模越大,节点深度越大,路径越长,如果能缩短这条路径,显然就能大大提升 find() 效率。
干脆让每个节点指向根节点好了。
在 find() 中,找到根节点后,再把路径上每个节点都指向根节点,不就减少树高度了吗?

int find(int p) {
        int start = p;
//        让 p 取到根节点的值
        while (p != id[p]) {
            p = id[p];
        }
//        重新遍历,把路径上的节点都指向根节点
        while (start != p) {
            int temp = start;
            start = id[start];
            id[temp] = p;
        }
        return p;
}

分析路径压缩 weighted-quick-union 算法

表面上看压缩路径算法的 find() 中加了一个 while 循环,好像反而增加了复杂度。
可换个角度思考,当树慢慢扩大的时候,高度却没有增加。这样避免了速度受数据规模影响太大。虽然表面上看增加了一个循环,但如果大多数情况节点深度只有2,while 循环也不过只有2次而已。
前面分析时间复杂度都是考虑最坏情况,实际上均摊成本也很重要。路径压缩加权 quick-union 算法均摊成本已经非常接近常数了。具体可以参考《算法(第四版)》,不过书里也只是粗略的介绍这个概念,在我看来不够严谨,以后如果有时间可以进一步分析这个问题。