最近需要做多卡通信的工作,实现一些 kernel,主要是致敬 Nvidia 的 NCCL,可惜看不懂一点,不过好在 Ring AllReduce 算法网上资料有很多。
核心思想
那 NCCL 手法高超,模板、PTX 颠来倒去。
以 4 卡为例,Ring AllReduce 分成 Reduce-Scatter 和 AllGather 两个步骤:Reduce-Scatter 大概就是 N 张卡上各 1/N 大小的数据,在 N 张卡上转一个圈,每经过一张就 Reduce 一下那张卡上的输出,最后每张卡上各有 1/N 的结果:AllGather 就是将每张卡上各自的结果分发到其他卡上。
Reduce-Scatter:
A B C D
+-----------+ +-----------+ +-----------+ +-----------+
| a1 | --> | b1 | | c1 | | d1 |
+-----------+ +-----------+ +-----------+ +-----------+
| a2 | | b2 | --> | c2 | | d2 |
+-----------+ +-----------+ +-----------+ +-----------+
| a3 | | b3 | | c3 | --> | d3 |
+-----------+ +-----------+ +-----------+ +-----------+
+-> | a4 | | b4 | | c4 | | d4 | -+
| +-----------+ +-----------+ +-----------+ +-----------+ |
+------------------------------------------------------------------------+
A B C D
+-----------+ +-----------+ +-----------+ +-----------+
| a1 | | a1+b1 | --> | c1 | | d1 |
+-----------+ +-----------+ +-----------+ +-----------+
| a2 | | b2 | | b2+c2 | --> | d2 |
+-----------+ +-----------+ +-----------+ +-----------+
+-> | a3 | | b3 | | c3 | | c3+d4 | -+
| +-----------+ +-----------+ +-----------+ +-----------+ |
| | d4+a4 | --> | b4 | | c4 | | d4 | |
| +-----------+ +-----------+ +-----------+ +-----------+ |
+------------------------------------------------------------------------+
A B C D
+-----------+ +-----------+ +-----------+ +-----------+
| a1 | | a1+b1 | | a1+b1+c1 | --> | d1 |
+-----------+ +-----------+ +-----------+ +-----------+
+-> | a2 | | b2 | | b2+c2 | | b2+c2+d2 | -+
| +-----------+ +-----------+ +-----------+ +-----------+ |
| | c3+d4+a3 | --> | b3 | | c3 | | c3+d4 | |
| +-----------+ +-----------+ +-----------+ +-----------+ |
| | d4+a4 | | d4+a4+b4 | --> | c4 | | d4 | |
| +-----------+ +-----------+ +-----------+ +-----------+ |
+------------------------------------------------------------------------+
AllGather:
A B C D
+-----------+ +-----------+ +-----------+ +-----------+
+-> | a1 | | a1+b1 | | a1+b1+c1 | |a1+b1+c1+d1| -+
| +-----------+ +-----------+ +-----------+ +-----------+ |
| |b2+c2+d2+a2| --> | b2 | | b2+c2 | | b2+c2+d2 | |
| +-----------+ +-----------+ +-----------+ +-----------+ |
| | c3+d4+a3 | |c3+d4+a3+b3| --> | c3 | | c3+d4 | |
| +-----------+ +-----------+ +-----------+ +-----------+ |
| | d4+a4 | | d4+a4+b4 | |d4+a4+b4+c4| --> | d4 | |
| +-----------+ +-----------+ +-----------+ +-----------+ |
+------------------------------------------------------------------------+
A B C D
+-----------+ +-----------+ +-----------+ +-----------+
|a1+b1+c1+d1| --> | a1+b1 | | a1+b1+c1 | |a1+b1+c1+d1|
+-----------+ +-----------+ +-----------+ +-----------+
|b2+c2+d2+a2| |b2+c2+d2+a2| --> | b2+c2 | | b2+c2+d2 |
+-----------+ +-----------+ +-----------+ +-----------+
| c3+d4+a3 | |c3+d4+a3+b3| |c3+d4+a3+b3| --> | c3+d4 |
+-----------+ +-----------+ +-----------+ +-----------+
+-> | d4+a4 | | d4+a4+b4 | |d4+a4+b4+c4| |d4+a4+b4+c4| -+
| +-----------+ +-----------+ +-----------+ +-----------+ |
+------------------------------------------------------------------------+
A B C D
+-----------+ +-----------+ +-----------+ +-----------+
|a1+b1+c1+d1| |a1+b1+c1+d1| --> | a1+b1+c1 | |a1+b1+c1+d1|
+-----------+ +-----------+ +-----------+ +-----------+
|b2+c2+d2+a2| |b2+c2+d2+a2| |b2+c2+d2+a2| --> | b2+c2+d2 |
+-----------+ +-----------+ +-----------+ +-----------+
+-> | c3+d4+a3 | |c3+d4+a3+b3| |c3+d4+a3+b3| |c3+d4+a3+b3| -+
| +-----------+ +-----------+ +-----------+ +-----------+ |
| |d4+a4+b4+c4| --> | d4+a4+b4 | |d4+a4+b4+c4| |d4+a4+b4+c4| |
| +-----------+ +-----------+ +-----------+ +-----------+ |
+------------------------------------------------------------------------+
A B C D
+-----------+ +-----------+ +-----------+ +-----------+
|a1+b1+c1+d1| |a1+b1+c1+d1| |a1+b1+c1+d1| |a1+b1+c1+d1|
+-----------+ +-----------+ +-----------+ +-----------+
|b2+c2+d2+a2| |b2+c2+d2+a2| |b2+c2+d2+a2| |b2+c2+d2+a2|
+-----------+ +-----------+ +-----------+ +-----------+
|c3+d4+a3+b3| |c3+d4+a3+b3| |c3+d4+a3+b3| |c3+d4+a3+b3|
+-----------+ +-----------+ +-----------+ +-----------+
|d4+a4+b4+c4| |d4+a4+b4+c4| |d4+a4+b4+c4| |d4+a4+b4+c4|
+-----------+ +-----------+ +-----------+ +-----------+
实现
从图中一眼就可以看出算法的主体结构应该是这样的:
// k 设备编号, n 设备总数
fn kernel(k: u64, n: u64, rx: Receiver, tx: Sender) {
tx.send(this[?]) // 第一次还没有来自上一个设备的数据
for i in 1 .. n {
this[?] = reduce(this[?], rx.recv()); // n 各个值 reduce 需要 n - 1 次计算
tx.send(this[?]);
}
for i in 0 .. n - 2 { // 结果发给其他 n - 1 个设备,不过上面最后已经发了一次了
this[?] = rx.recv();
tx.send(this[?]);
}
next[?] = rx.recv();
}
利用多年猜题经验,观察一下每次计算块的规律,块的指针都是倒着走的:
fn kernel(k: u64, n: u64, rx: Receiver, tx: Sender) {
tx.send(this[k])
for i in 1 .. n {
this[k - i] = reduce(this[k - i], rx.recv());
tx.send(this[k - i]);
}
for i in 0 .. n - 2 {
this[k - n - i] = rx.recv(); // k - n - i === k - i
tx.send(this[k - n - i]);
}
next[k - n - (n - 2)] = rx.recv(); // k - n - (n - 2) === k + 2
}
这个实际上前几次 Reduce 都不需要写回到设备的存储上,只需要最后一次 Reduce 的时候写回到存储就行了,这样还方便做原地的 AllReduce。
fn kernel(k: u64, n: u64, rx: Receiver, tx: Sender) {
tx.send(this[k])
for i in 1 .. n - 1 {
tx.send(reduce(this[k - i], rx.recv()));
}
this[k - (n - 1)] = reduce(this[k - (n - 1)], rx.recv());
tx.send(this[k - (n - 1)]);
for i in 0 .. n - 2 {
this[k - i] = rx.recv();
tx.send(this[k - i]);
}
next[k + 2] = rx.recv();
}
可以看到这里可以被成五个小部分,稍稍抽象一下。OK,完美致敬 NCCL!
fn kernel(k: u64, n: u64, rx: Receiver, tx: Sender) {
send(k);
// nccl 的 ringIx + 1 等价这里的 k
for i in 1 .. n - 1 {
recvReduceSend(k - i);
}
recvReduceCopySend(k + 1);
for i in 0 .. n - 2 {
recvCopySend(k - i);
}
recv(k + 2);
}