快速排序(英文:Quick Sort),也称为分区交换排序(Partition Exchange Sort),是一种排序算法。快速排序最早由 Tony Hoare 于 1961 年提出。其平均算法复杂度为 $O(n \log{n})$,然而最坏情况下算法复杂度为 $O(n^2)$,当然这种情况是非常极端的,大部分时候是不会发生这种事情的。

快速排序的主要排序思想在于:选出一个基准值,然后将所有比这个值小的值放置在左边,比这个值大的放在右边。接下来对左右两边的数值继续执行这个过程。

比如:[5, 4, 2, 7, 3] 这组数字,假设我们选择基准为数字3,那么第一趟下来大致是这样的:[2, 3, 5, 4, 7]。那么在以 3 作为基准数字的一趟排序后,数组被分为了三个部分,[2], 3, 和 [5,4,7]. 3已经被确定下来位置了,那么接下来就继续使用以上操作分别对 [2] 和 [5,4,7] 进行排序即可。

在快速排序算法中有两个核心函数:sort()partition()

  • sort() 从上述能看出,快速排序是一个具有递归结构的算法,我们在设计的时候可以设计成 void sort(int a[], size_t n); 表示对数组 a[] 进行排序,那么为什么还要 size_t n 这个参数呢?我们都知道,在 C 语言中,传递的一个数组变量那只是传递了这个数组的首地址,所以还需要传递数组的大小。这里有同学可能会说,那分区之后如何对分区后的左右两个部分进行排序呢?并不是所有的分区都是从 a[] 的首地址开始的呀。技巧就是:我们只需要进行地址偏移就可以了。
  • partition() 函数,对数组进行分区交换操作。设计的时候可以设计成 size_t partition(int a[], size_t n); 注意:这个函数是有返回值的。因为我们要告诉调用者分区最后基准值落在哪个位置,有了这个位置,就可以定位到分区之后被分成的左右两个部分,进而继续调用 sort() 函数进行排序了。

1. sort() 函数

实现上,首先判断一下 $n$ 的大小,没必要排序的情况下直接结束。然后接下来三个步骤:

  1. 调用 partition() 函数进行分区,并获取基准值最后停下来的位置。
  2. 对分区后左部分进行排序。
  3. 对分区后右部分进行排序。

这里需要注意的就是分区后左部分和右部分各自的起始地址和大小。

void sort(int a[], size_t n) {
    if (n <= 1) {
        return;
    }
    size_t mid = partition(a, n);
    sort(a, mid);
    sort(a + mid + 1, n - mid - 1);
}

2. partition() 函数

分区函数,其实这才是快速排序最最核心的部分了。分区操作基准值选择需要分区的数组段最后一个值。首先将除去最后一个值的剩余部分分成左右两部分(左部分所有值小于基准值,右部分所有值大于基准值),然后将分区后右半部分的第一个与分区的数组段最后一个值(基准值)进行互换。这里其实有一个小操作,就是我们并不要着急去规定基准值最后的位置,毕竟咋也不知道具体多少个值是小于基准值的,多少个值是大于基准值的。而是我们直接将除了基准值剩下的分成两个部分,然后将右部分第一个与基准值换个地址就好了,反正这个值往右放并没有改变右边所有值都大于基准值的事实。

size_t partition(int a[], size_t n) {
    if (n <= 1) {
        return 0;
    }
    size_t cursor_bottom = -1;
    for (size_t i = 0; i < n - 1; i++) {
        if (a[i] < a[n - 1]) {
            cursor_bottom++;
            swap(&a[cursor_bottom], &a[i]);
        }
    }
    cursor_bottom++;
    swap(&a[cursor_bottom], &a[n-1]);
    return cursor_bottom;
}

3. 完整代码

#include <stdio.h>
#include <stdlib.h>

void print_arr(int a[], size_t n) {
    for (size_t i = 0; i < n; i++) {
        printf("%2d ", a[i]);
    }
    printf("\n");
}

void swap(int *a, int *b) {
    int tmp;
    tmp = *a;
    *a = *b;
    *b = tmp;
}

size_t partition(int a[], size_t n) {
    if (n <= 1) {
        return 0;
    }
    size_t cursor_bottom = -1;
    for (size_t i = 0; i < n - 1; i++) {
        if (a[i] < a[n - 1]) {
            cursor_bottom++;
            swap(&a[cursor_bottom], &a[i]);
        }
    }
    cursor_bottom++;
    swap(&a[cursor_bottom], &a[n - 1]);
    return cursor_bottom;
}

void sort(int a[], size_t n) {
    if (n <= 1) {
        return;
    }
    size_t mid = partition(a, n);
    sort(a, mid);
    sort(a + mid + 1, n - mid - 1);
}

int main() {
    int a[] = {4, 6, 3, 2, 7, 4, 2, 5, 6, 2};
    print_arr(a, 10);
    sort(a, 10);
    print_arr(a, 10);
    return EXIT_SUCCESS;
}