75 lines
2.4 KiB
Python
75 lines
2.4 KiB
Python
# this file needs to be very careful with its imports as to not accidentally initialize the runtimes
|
|
from multiprocessing.connection import Connection
|
|
from typing import Any, Callable, List, Optional, Tuple
|
|
from tinygrad.helpers import DEBUG, getenv
|
|
import multiprocessing as mp
|
|
import os
|
|
|
|
|
|
# this needs to be called before everything else if you are using distributed
|
|
def preinit():
|
|
os.environ["DELAYED_RUNTIME_INIT"] = "1"
|
|
mp.set_start_method("spawn")
|
|
|
|
|
|
# out-of-band communication/synchronization
|
|
class _OOB:
|
|
def __init__(self, pipes: List[Tuple[Connection, Connection]]):
|
|
self.pipes = pipes
|
|
|
|
# send some data to a target rank, blocks until data is received
|
|
def send(self, data: Any, target_rank: int):
|
|
self.pipes[getenv("RANK") * getenv("WORLD_SIZE") + target_rank][1].send(data)
|
|
|
|
# receive some data from a target rank, blocks until data is received
|
|
def recv(self, target_rank: int) -> Any:
|
|
return self.pipes[target_rank * getenv("WORLD_SIZE") + getenv("RANK")][0].recv()
|
|
|
|
|
|
OOB: Optional[_OOB] = None
|
|
|
|
|
|
def init_oob(world_size: int):
|
|
os.environ["WORLD_SIZE"] = str(world_size)
|
|
|
|
global OOB
|
|
OOB = _OOB([mp.Pipe(False) for _ in range(world_size * world_size)])
|
|
|
|
|
|
# this runs in the spawned process so we can do all the delayed runtime initialization
|
|
def _process_wrap(rank: int, device: str, oob: _OOB, fn: Callable, args=()):
|
|
# setup the rank
|
|
os.environ["RANK"] = str(rank)
|
|
|
|
# setup out of band communication
|
|
global OOB
|
|
OOB = oob
|
|
|
|
# do specific runtime initialization for distributed
|
|
from tinygrad import Device
|
|
|
|
device, device_num = Device.canonicalize(device), 0 if ":" not in device else int(
|
|
device.split(":")[-1]
|
|
)
|
|
if "GPU" in device:
|
|
from tinygrad.runtime.ops_gpu import CL
|
|
|
|
CL.post_init(device_num)
|
|
elif "HIP" in device:
|
|
os.environ["HIP_DEFAULT_DEVICE"] = os.environ["HIP_VISIBLE_DEVICES"] = str(
|
|
device_num
|
|
)
|
|
if DEBUG >= 1:
|
|
print(f"distributed process {rank} initialized runtime for device {device}")
|
|
|
|
# convert device to be process specific
|
|
Device.DEFAULT = device.split(":")[0] if "GPU" in device else device
|
|
|
|
fn(*args)
|
|
|
|
|
|
# wrapper around mp.Process that initializes the runtime
|
|
def spawn(rank: int, device: str, fn: Callable, args=()) -> mp.Process:
|
|
(p := mp.Process(target=_process_wrap, args=(rank, device, OOB, fn, args))).start()
|
|
return p
|