Ring AllReduce

2024年3月26日
2024年3月26日

最近需要做多卡通信的工作,实现一些 kernel,主要是致敬 Nvidia 的 NCCL,可惜看不懂一点,不过好在 Ring AllReduce 算法网上资料有很多。

核心思想

那 NCCL 手法高超,模板、PTX 颠来倒去。

以 4 卡为例,Ring AllReduce 分成 Reduce-Scatter 和 AllGather 两个步骤:Reduce-Scatter 大概就是 N 张卡上各 1/N 大小的数据,在 N 张卡上转一个圈,每经过一张就 Reduce 一下那张卡上的输出,最后每张卡上各有 1/N 的结果:AllGather 就是将每张卡上各自的结果分发到其他卡上。

Reduce-Scatter:

          A                 B                 C                 D      
    +-----------+     +-----------+     +-----------+     +-----------+
    |    a1     | --> |    b1     |     |    c1     |     |    d1     |
    +-----------+     +-----------+     +-----------+     +-----------+
    |    a2     |     |    b2     | --> |    c2     |     |    d2     |
    +-----------+     +-----------+     +-----------+     +-----------+
    |    a3     |     |    b3     |     |    c3     | --> |    d3     |
    +-----------+     +-----------+     +-----------+     +-----------+
+-> |    a4     |     |    b4     |     |    c4     |     |    d4     | -+
|   +-----------+     +-----------+     +-----------+     +-----------+  |
+------------------------------------------------------------------------+

          A                 B                 C                 D      
    +-----------+     +-----------+     +-----------+     +-----------+
    |    a1     |     |   a1+b1   | --> |    c1     |     |    d1     |   
    +-----------+     +-----------+     +-----------+     +-----------+   
    |    a2     |     |    b2     |     |   b2+c2   | --> |    d2     |   
    +-----------+     +-----------+     +-----------+     +-----------+   
+-> |    a3     |     |    b3     |     |    c3     |     |   c3+d4   | -+
|   +-----------+     +-----------+     +-----------+     +-----------+  |
|   |   d4+a4   | --> |    b4     |     |    c4     |     |    d4     |  |
|   +-----------+     +-----------+     +-----------+     +-----------+  |
+------------------------------------------------------------------------+

          A                 B                 C                 D      
    +-----------+     +-----------+     +-----------+     +-----------+
    |    a1     |     |   a1+b1   |     | a1+b1+c1  | --> |    d1     |   
    +-----------+     +-----------+     +-----------+     +-----------+   
+-> |    a2     |     |    b2     |     |   b2+c2   |     | b2+c2+d2  | -+
|   +-----------+     +-----------+     +-----------+     +-----------+  |
|   | c3+d4+a3  | --> |    b3     |     |    c3     |     |   c3+d4   |  |
|   +-----------+     +-----------+     +-----------+     +-----------+  |
|   |   d4+a4   |     | d4+a4+b4  | --> |    c4     |     |    d4     |  |
|   +-----------+     +-----------+     +-----------+     +-----------+  |
+------------------------------------------------------------------------+
  
AllGather:

          A                 B                 C                 D      
    +-----------+     +-----------+     +-----------+     +-----------+
+-> |    a1     |     |   a1+b1   |     | a1+b1+c1  |     |a1+b1+c1+d1| -+   
|   +-----------+     +-----------+     +-----------+     +-----------+  |
|   |b2+c2+d2+a2| --> |    b2     |     |   b2+c2   |     | b2+c2+d2  |  |
|   +-----------+     +-----------+     +-----------+     +-----------+  |
|   | c3+d4+a3  |     |c3+d4+a3+b3| --> |    c3     |     |   c3+d4   |  |
|   +-----------+     +-----------+     +-----------+     +-----------+  |
|   |   d4+a4   |     | d4+a4+b4  |     |d4+a4+b4+c4| --> |    d4     |  |
|   +-----------+     +-----------+     +-----------+     +-----------+  |
+------------------------------------------------------------------------+

          A                 B                 C                 D      
    +-----------+     +-----------+     +-----------+     +-----------+
    |a1+b1+c1+d1| --> |   a1+b1   |     | a1+b1+c1  |     |a1+b1+c1+d1|      
    +-----------+     +-----------+     +-----------+     +-----------+   
    |b2+c2+d2+a2|     |b2+c2+d2+a2| --> |   b2+c2   |     | b2+c2+d2  |   
    +-----------+     +-----------+     +-----------+     +-----------+   
    | c3+d4+a3  |     |c3+d4+a3+b3|     |c3+d4+a3+b3| --> |   c3+d4   |   
    +-----------+     +-----------+     +-----------+     +-----------+   
+-> |   d4+a4   |     | d4+a4+b4  |     |d4+a4+b4+c4|     |d4+a4+b4+c4| -+
|   +-----------+     +-----------+     +-----------+     +-----------+  |
+------------------------------------------------------------------------+

          A                 B                 C                 D      
    +-----------+     +-----------+     +-----------+     +-----------+
    |a1+b1+c1+d1|     |a1+b1+c1+d1| --> | a1+b1+c1  |     |a1+b1+c1+d1|      
    +-----------+     +-----------+     +-----------+     +-----------+   
    |b2+c2+d2+a2|     |b2+c2+d2+a2|     |b2+c2+d2+a2| --> | b2+c2+d2  |   
    +-----------+     +-----------+     +-----------+     +-----------+   
+-> | c3+d4+a3  |     |c3+d4+a3+b3|     |c3+d4+a3+b3|     |c3+d4+a3+b3| -+  
|   +-----------+     +-----------+     +-----------+     +-----------+  |  
|   |d4+a4+b4+c4| --> | d4+a4+b4  |     |d4+a4+b4+c4|     |d4+a4+b4+c4|  |
|   +-----------+     +-----------+     +-----------+     +-----------+  |
+------------------------------------------------------------------------+

          A                 B                 C                 D      
    +-----------+     +-----------+     +-----------+     +-----------+
    |a1+b1+c1+d1|     |a1+b1+c1+d1|     |a1+b1+c1+d1|     |a1+b1+c1+d1|      
    +-----------+     +-----------+     +-----------+     +-----------+   
    |b2+c2+d2+a2|     |b2+c2+d2+a2|     |b2+c2+d2+a2|     |b2+c2+d2+a2|   
    +-----------+     +-----------+     +-----------+     +-----------+   
    |c3+d4+a3+b3|     |c3+d4+a3+b3|     |c3+d4+a3+b3|     |c3+d4+a3+b3|  
    +-----------+     +-----------+     +-----------+     +-----------+  
    |d4+a4+b4+c4|     |d4+a4+b4+c4|     |d4+a4+b4+c4|     |d4+a4+b4+c4|
    +-----------+     +-----------+     +-----------+     +-----------+

实现

从图中一眼就可以看出算法的主体结构应该是这样的:

// k 设备编号, n 设备总数
fn kernel(k: u64, n: u64, rx: Receiver, tx: Sender) {
    tx.send(this[?]) // 第一次还没有来自上一个设备的数据
    for i in 1 .. n {
        this[?] = reduce(this[?], rx.recv()); // n 各个值 reduce 需要 n - 1 次计算
        tx.send(this[?]);
    }
    for i in 0 .. n - 2 { // 结果发给其他 n - 1 个设备,不过上面最后已经发了一次了
        this[?] = rx.recv();
        tx.send(this[?]);
    }
    next[?] = rx.recv();
}

利用多年猜题经验,观察一下每次计算块的规律,块的指针都是倒着走的:

fn kernel(k: u64, n: u64, rx: Receiver, tx: Sender) {
    tx.send(this[k])
    for i in 1 .. n {
        this[k - i] = reduce(this[k - i], rx.recv());
        tx.send(this[k - i]);
    }
    for i in 0 .. n - 2 {
        this[k - n - i] = rx.recv(); // k - n - i === k - i
        tx.send(this[k - n - i]);
    }
    next[k - n - (n - 2)] = rx.recv(); // k - n - (n - 2) === k + 2
}

这个实际上前几次 Reduce 都不需要写回到设备的存储上,只需要最后一次 Reduce 的时候写回到存储就行了,这样还方便做原地的 AllReduce。

fn kernel(k: u64, n: u64, rx: Receiver, tx: Sender) {
    tx.send(this[k])
    for i in 1 .. n - 1 {
        tx.send(reduce(this[k - i], rx.recv()));
    }
    this[k - (n - 1)] = reduce(this[k - (n - 1)], rx.recv());
    tx.send(this[k - (n - 1)]);
    for i in 0 .. n - 2 {
        this[k - i] = rx.recv();
        tx.send(this[k - i]);
    }
    next[k + 2] = rx.recv();
}

可以看到这里可以被成五个小部分,稍稍抽象一下。OK,完美致敬 NCCL!

fn kernel(k: u64, n: u64, rx: Receiver, tx: Sender) {
    send(k);
    // nccl 的 ringIx + 1 等价这里的 k
    for i in 1 .. n - 1 {
        recvReduceSend(k - i);
    }                         
    recvReduceCopySend(k + 1);
    for i in 0 .. n - 2 {
        recvCopySend(k - i);
    }
    recv(k + 2);
}