1
0
Fork 0

distributed world (#1481)

* feat: world

* feat: tests

* feat: no more backwards

* feat: recv into

* feat: whoops

* feat: test in ci

* feat: some debug logging

* feat: workflow naming

* feat: need to set pythonpath

* feat: just send to same device
pull/1517/head
wozeparrot 2023-08-10 13:00:51 -04:00 committed by GitHub
parent e3c6c0c6db
commit 7e7c9001e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 206 additions and 3 deletions

View File

@ -104,8 +104,8 @@ jobs:
strategy:
fail-fast: false
matrix:
task: [optimage, openpilot]
name: ${{ matrix.task=='optimage'&&'GPU OPT and IMAGE Tests'||'openpilot (OpenCL) Tests'}}
task: [optimage, openpilot, multigpu]
name: ${{ matrix.task=='optimage'&&'GPU OPT and IMAGE Tests'|| matrix.task=='openpilot'&&'openpilot (OpenCL) Tests'|| matrix.task=='multigpu'&&'MultiGPU Tests'}}
runs-on: ubuntu-20.04
timeout-minutes: 20
@ -154,6 +154,10 @@ jobs:
- if: ${{ matrix.task == 'openpilot' }}
name: Test tensor core ops
run: GPU=1 TC=2 python3 -m pytest -n=auto test/test_ops.py
- if: ${{ matrix.task == 'multigpu' }}
name: Test multigpu
run: |
PYTHONPATH="." python test/external/dist/test_world.py
testmetalwebgpu:
name: Metal and WebGPU Tests

2
.gitignore vendored
View File

@ -7,7 +7,7 @@ notebooks
*.pyc
*.so
build
dist
/dist
*.egg-info
/env
a.out

61
extra/dist/__init__.py vendored 100644
View File

@ -0,0 +1,61 @@
# this file needs to be very careful with its imports as to not accidentally initialize the runtimes
from multiprocessing.connection import Connection
from typing import Any, Callable, List, Tuple
from tinygrad.helpers import DEBUG, getenv
import multiprocessing as mp
import os
# this needs to be called before everything else if you are using distributed
def preinit():
os.environ["DELAYED_RUNTIME_INIT"] = "1"
mp.set_start_method("spawn")
# out-of-band communication/synchronization
class _OOB:
def __init__(self, pipes:List[Tuple[Connection, Connection]]):
self.pipes = pipes
# send some data to a target rank, blocks until data is received
def send(self, data:Any, target_rank:int):
self.pipes[getenv("RANK") * getenv("WORLD_SIZE") + target_rank][1].send(data)
# receive some data from a target rank, blocks until data is received
def recv(self, target_rank:int) -> Any:
return self.pipes[target_rank * getenv("WORLD_SIZE") + getenv("RANK")][0].recv()
OOB = None
def init_oob(world_size:int):
os.environ["WORLD_SIZE"] = str(world_size)
global OOB
OOB = _OOB([mp.Pipe(False) for _ in range(world_size * world_size)])
# this runs in the spawned process so we can do all the delayed runtime initialization
def _process_wrap(rank:int, device:str, oob:_OOB, fn:Callable, args=()):
# setup the rank
os.environ["RANK"] = str(rank)
# setup out of band communication
global OOB
OOB = oob
# do specific runtime initialization for distributed
from tinygrad.lazy import Device
device, device_num = Device.canonicalize(device), 0 if ":" not in device else int(device.split(":")[-1])
if "GPU" in device:
from tinygrad.runtime.ops_gpu import CL
CL.post_init(device_num)
elif "HIP" in device:
import extra.hip_wrapper as hip
hip.hipSetDevice(device_num)
if DEBUG >= 1: print(f"distributed process {rank} initialized runtime for device {device}")
# convert device to be process specific
Device.DEFAULT = device.split(":")[0]
fn(*args)
# wrapper around mp.Process that initializes the runtime
def spawn(rank:int, device:str, fn:Callable, args=()) -> mp.Process:
(p := mp.Process(target=_process_wrap, args=(rank, device, OOB, fn, args))).start()
return p

77
extra/dist/world.py vendored 100644
View File

@ -0,0 +1,77 @@
from typing import Any, Optional, Tuple
from extra import dist
from multiprocessing import shared_memory
from tinygrad.helpers import DEBUG, GlobalCounters, colored
from tinygrad.lazy import LazyBuffer
from tinygrad.runtime.lib import RawBufferCopyIn, RawBufferCopyInOut
from tinygrad.runtime.ops_shm import RawShmBuffer
from tinygrad.tensor import Tensor, Function
import numpy as np
# fake the function signature of ASTRunner so we can put it in the cache
def __send_rb(args:Tuple[RawBufferCopyInOut, RawShmBuffer, int, Any], jit=False, force_wait=False):
args[0]._copyout(np.frombuffer(args[1]._buffer(), dtype=args[0].dtype.np))
dist.OOB.send(args[3], args[2])
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} sent {args[0]} to rank {args[2]}")
def __recv_rb(args:Tuple[RawBufferCopyIn, RawShmBuffer, int], jit=False, force_wait=False):
dist.OOB.recv(args[2])
args[0]._copyin(args[1].toCPU())
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} recv {args[0]} from rank {args[2]}")
# send a rawbuffer from out rank to the target rank
def _send_rb(x:RawBufferCopyInOut, target_rank:int, cache_id:Optional[str]=None):
assert isinstance(x, RawBufferCopyInOut), "we only support RawBufferCopyInOut for now"
# cache the shared memory so we don't have to create it every time
if cache_id is not None and cache_id in _send_rb.shared_memory_cache:
shm_name = _send_rb.shared_memory_cache[cache_id]
else:
shm_name = (s := shared_memory.SharedMemory(create=True, size=x.size * x.dtype.itemsize)).name
s.close()
if cache_id is not None: _send_rb.shared_memory_cache[cache_id] = shm_name
# copy the buffer into shared memory
device = f"{shm_name},{cache_id}" if cache_id is not None else shm_name
rb = RawShmBuffer(x.size, x.dtype, device=device)
__send_rb((x, rb, target_rank, (shm_name, cache_id)))
# jit support
if GlobalCounters.cache is not None: GlobalCounters.cache.append((__send_rb, [x, rb, target_rank, None]))
setattr(_send_rb, "shared_memory_cache", {})
# receive a rawbuffer from the target rank
def _recv_rb(x:RawBufferCopyIn, target_rank:int):
assert isinstance(x, RawBufferCopyIn), "we only support RawBufferCopyIn for now"
extra = dist.OOB.recv(target_rank)
device = f"{extra[0]},{extra[1]}" if extra[1] is not None else f"{extra[0]}"
rb = RawShmBuffer(x.size, x.dtype, device=device)
x._copyin(rb.toCPU())
if DEBUG >= 2: print(f"**** got {x} from rank {target_rank}")
if extra[1] is None:
(s := shared_memory.SharedMemory(name=extra[0])).close()
s.unlink()
# jit support
if GlobalCounters.cache is not None: GlobalCounters.cache.append((__recv_rb, [x, rb, target_rank]))
# sends a lazybuffer from our rank to the target rank
def _send_lb(x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> None: _send_rb(x.contiguous().realize().realized, target_rank, cache_id=cache_id)
# receive a lazybuffer from the target rank
def _recv_lb(x:LazyBuffer, target_rank:int) -> LazyBuffer:
_recv_rb(x.realize().realized, target_rank)
return x
class Send(Function):
def forward(self, x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> LazyBuffer:
self.target_rank, self.shape, self.dtype = target_rank, x.shape, x.dtype
_send_lb(x, target_rank, cache_id=cache_id)
return x
class Recv(Function):
def forward(self, x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> LazyBuffer:
self.target_rank, self.cache_id = target_rank, cache_id
return _recv_lb(x, target_rank)
def send(x:Tensor, target_rank:int, cache_id:Optional[str]=None) -> Tensor: return Send.apply(x, target_rank=target_rank, cache_id=cache_id)
def recv(x:Tensor, target_rank:int, cache_id:Optional[str]=None) -> Tensor: return Recv.apply(x, target_rank=target_rank, cache_id=cache_id)

View File

@ -0,0 +1,61 @@
from extra import dist
from tinygrad.jit import TinyJit
if __name__ == "__main__":
dist.preinit()
from extra.dist import world
from tinygrad.helpers import CI, getenv
from tinygrad.tensor import Tensor
import numpy as np
@TinyJit
def send_jit(t, target_rank, cache_id=None) -> Tensor:
return world.send(t, target_rank, cache_id=cache_id).realize()
@TinyJit
def recv_jit(t, target_rank, cache_id=None) -> Tensor:
return world.recv(t, target_rank, cache_id=cache_id).realize()
SIZE = 2048 if not CI else 2
def run():
# set a deterministic seed so that both ranks generate the same random tensor
Tensor.manual_seed(42)
rank = getenv("RANK")
# loop 3 times to make sure it works with the jit
for _ in range(3):
# create a tensor to send
t = Tensor.randn(SIZE, SIZE)
# send to rank 1
if rank == 0:
send_jit(t, 1, cache_id="test")
elif rank == 1:
t2 = Tensor.empty(SIZE, SIZE)
recv_jit(t2, 0, cache_id="test")
# recv from rank 1
if rank == 0:
t2 = Tensor.empty(SIZE, SIZE)
recv_jit(t2, 1, cache_id="test2")
elif rank == 1:
send_jit(t2, 0, cache_id="test2")
# check that the received tensor is the same as the sent tensor
if rank == 0:
assert np.allclose(t.numpy(), t2.numpy())
print(f"rank {rank} passed")
if __name__ == "__main__":
devices = ["gpu:0", "gpu:1" if not CI else "gpu:0"]
world_size = len(devices)
dist.init_oob(world_size)
processes = []
for rank, device in enumerate(devices):
processes.append(dist.spawn(rank, device, fn=run, args=()))
for p in processes: p.join()