1
0
Fork 0

distributed collectives (#1519)

* feat: world

* feat: tests

* feat: no more backwards

* feat: recv into

* feat: whoops

* feat: test in ci

* feat: some debug logging

* feat: workflow naming

* feat: need to set pythonpath

* feat: just send to same device

* feat: allreduce

* feat: test

* feat: need contiguous

* feat: test in ci

* feat: exit with correct code

* feat: don't need that

* feat: opencl wait_for just doesn't work

* feat: synchronize on out

* feat: try?

* feat: try again?

* feat: add extra realizes

* feat: print

* feat: seed

* feat: tol

* feat: test ones and zeros

* feat: remove print

* feat: are you just flaky

* feat: seperate scatter and gather?

* feat: just try synchronizing

* feat: remove print again

* feat: bring back difference

* feat: no sync

* feat: revert that

* feat: back to wait_for

* fix: typo
pull/1530/head
wozeparrot 2023-08-11 13:22:07 -04:00 committed by GitHub
parent 2e85fce068
commit 29d5801387
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 106 additions and 2 deletions

View File

@ -158,6 +158,7 @@ jobs:
name: Test multigpu
run: |
PYTHONPATH="." python test/external/dist/test_world.py
PYTHONPATH="." python test/external/dist/test_collectives.py
testmetalwebgpu:
name: Metal and WebGPU Tests

43
extra/dist/collectives.py vendored 100644
View File

@ -0,0 +1,43 @@
from tinygrad.tensor import Tensor, Device
from tinygrad.helpers import getenv
from extra.dist import world
def allreduce(t:Tensor, cache_id=None) -> Tensor:
RANK, WORLD_SIZE = getenv("RANK"), getenv("WORLD_SIZE")
cache_id = f"{RANK}-{cache_id}" if cache_id is not None else None
# flatten
flattened = t.flatten()
# pad to evenly divide
if flattened.shape[0] % WORLD_SIZE != 0:
flattened = Tensor.cat(flattened, Tensor.zeros(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE)))
# chunk
chunks = flattened.chunk(WORLD_SIZE, dim=0)
reduced = chunks[RANK]
next_rank = (RANK + 1) % WORLD_SIZE
prev_rank = ((RANK - 1) + WORLD_SIZE) % WORLD_SIZE
# scatter reduce
current_chunk_index = RANK
for i in range(WORLD_SIZE - 1):
world.send(reduced, next_rank, cache_id=f"{cache_id}-{i}-s" if cache_id is not None else None)
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
recv_buf = Tensor.empty(*reduced.shape)
world.recv(recv_buf, prev_rank)
reduced = chunks[current_chunk_index] + recv_buf
# gather
chunks[current_chunk_index] = reduced
current_chunk_index = (RANK + 1) % WORLD_SIZE
for i in range(WORLD_SIZE - 1):
world.send(reduced, next_rank, cache_id=f"{cache_id}-{i}-g" if cache_id is not None else None)
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
recv_buf = Tensor.empty(*reduced.shape)
world.recv(recv_buf, prev_rank)
reduced = chunks[current_chunk_index] = recv_buf
return Tensor.cat(*chunks, dim=0).shrink(((0, t.numel()),)).reshape(*t.shape)

2
extra/dist/world.py vendored
View File

@ -59,7 +59,7 @@ def _send_lb(x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> None
# receive a lazybuffer from the target rank
def _recv_lb(x:LazyBuffer, target_rank:int) -> LazyBuffer:
_recv_rb(x.realize().realized, target_rank)
_recv_rb(x.contiguous().realize().realized, target_rank)
return x
class Send(Function):

View File

@ -0,0 +1,56 @@
from extra import dist
from tinygrad.jit import TinyJit
if __name__ == "__main__":
dist.preinit()
from extra.dist import collectives
from tinygrad.helpers import CI, getenv
from tinygrad.tensor import Tensor
import numpy as np
@TinyJit
def allreduce_jit(t:Tensor, cache_id=None) -> Tensor:
return collectives.allreduce(t, cache_id=cache_id).realize()
SIZE = 2048 if not CI else 2
SIZE_2 = 255 if not CI else 3
def run():
# set a deterministic seed so that both ranks generate the same random tensor
Tensor.manual_seed(42)
rank = getenv("RANK")
# loop 3 times to make sure it works with the jit
for _ in range(3):
# create a tensor to send
t = Tensor.zeros(SIZE, SIZE) if rank == 0 else Tensor.ones(SIZE, SIZE)
t2 = allreduce_jit(t.contiguous().realize(), cache_id="test")
assert np.allclose(np.ones((SIZE, SIZE)), t2.numpy())
# reset jit
allreduce_jit.cnt = 0
# test uneven chunk sizes
for _ in range(3):
# create a tensor to send
t = Tensor.ones(SIZE_2, SIZE_2, SIZE_2) if rank == 0 else Tensor.zeros(SIZE_2, SIZE_2, SIZE_2)
t2 = allreduce_jit(t.contiguous().realize(), cache_id="test2")
assert np.allclose(np.ones((SIZE_2, SIZE_2, SIZE_2)), t2.numpy())
print(f"rank {rank} passed")
if __name__ == "__main__":
devices = ["gpu:0", "gpu:1" if not CI else "gpu:0"]
world_size = len(devices)
dist.init_oob(world_size)
processes = []
for rank, device in enumerate(devices):
processes.append(dist.spawn(rank, device, fn=run, args=()))
for p in processes: p.join()
# exit with error code if any of the processes failed
for p in processes:
if p.exitcode != 0: exit(p.exitcode)

View File

@ -59,3 +59,7 @@ if __name__ == "__main__":
for rank, device in enumerate(devices):
processes.append(dist.spawn(rank, device, fn=run, args=()))
for p in processes: p.join()
# exit with error code if any of the processes failed
for p in processes:
if p.exitcode != 0: exit(p.exitcode)

View File

@ -48,7 +48,7 @@ class CLBuffer(RawBufferCopyInOut):
assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}"
buf = cl.Buffer(CL.cl_ctxs[self._buf.device], cl.mem_flags.WRITE_ONLY | cl.mem_flags.USE_HOST_PTR, 0, hostbuf=x.data)
mapped, event = cl.enqueue_map_buffer(CL.cl_queue[self._buf.device], buf, cl.map_flags.WRITE, 0, self.size, dtype=self.dtype.np, is_blocking=False)
with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event])
with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event] + ([self.event] if hasattr(self, "event") else []))
class CLProgram:
def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None):