1
0
Fork 0

move device to ops (#1646)

* move device to ops

* mlops types

* 2 lines
pull/1616/head^2
George Hotz 2023-08-23 08:30:17 -07:00 committed by GitHub
parent a65ae1198b
commit a6d842af7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 60 additions and 59 deletions

View File

@ -190,7 +190,7 @@ jobs:
- name: Run symbolic shapetracker test
run: METAL=1 python -m pytest -n=auto test/test_symbolic_shapetracker.py test/test_symbolic_ops.py test/test_symbolic_jit.py
- name: Check Device.DEFAULT
run: WEBGPU=1 python -c "from tinygrad.lazy import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
run: WEBGPU=1 python -c "from tinygrad.ops import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
#- name: Run webgpu pytest
# run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto --ignore test/models/ --ignore test/unit/test_example.py --ignore test/extra/test_lr_scheduler.py --ignore test/test_linearizer.py test/
#- name: Build WEBGPU Efficientnet
@ -261,7 +261,7 @@ jobs:
- name: Install dependencies
run: pip install -e '.[testing${{matrix.backend=='llvm'&&',llvm'||matrix.backend=='cuda'&&',cuda'||matrix.backend=='ptx'&&',cuda'||''}}]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Check Device.DEFAULT
run: python -c "from tinygrad.lazy import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU'], Device.DEFAULT"
run: python -c "from tinygrad.ops import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU'], Device.DEFAULT"
- name: Run pytest (not cuda)
if: matrix.backend!='cuda' && matrix.backend!='ptx'
run: python -m pytest -n=auto test/ -k '${{matrix.backend=='llvm'&&'not (test_nn.py and test_conv_transpose2d)'||'test'}}' -m 'not exclude_${{matrix.backend}}'

View File

@ -22,7 +22,7 @@ from abc import ABC
# let's trace an addition down through the layers of abstraction.
# we will be using the clang backend
from tinygrad.lazy import Device
from tinygrad.ops import Device
Device.DEFAULT = "CLANG"
# first, 2+3 as a Tensor, the highest level

View File

@ -50,7 +50,6 @@ PRINT_PRG | [1] | print program code
IMAGE | [1] | enable 2d specific optimizations
FLOAT16 | [1] | use float16 for images instead of float32
ENABLE_METHOD_CACHE | [1] | enable method cache (this is the default)
EARLY_STOPPING | [# > 0] | stop after this many kernels
DISALLOW_ASSIGN | [1] | disallow assignment of tensors
CL_EXCLUDE | [name0,name1] | comma-separated list of device names to exclude when using OpenCL GPU backend (like `CL_EXCLUDE=gfx1036`)
CL_PLATFORM | [# >= 0] | index of the OpenCL [platform](https://documen.tician.de/pyopencl/runtime_platform.html#pyopencl.Platform) to run on. Defaults to 0.

View File

@ -9,7 +9,7 @@ from typing import Optional, Tuple
from tinygrad.helpers import Timing, getenv, dtypes, DEBUG
from tinygrad.ops import GlobalCounters
from tinygrad.lazy import Device
from tinygrad.ops import Device
from tinygrad.tensor import Tensor
from tinygrad.nn import Embedding, Linear
from tinygrad.jit import TinyJit

View File

@ -16,7 +16,7 @@ from extra.datasets import fetch_cifar, cifar_mean, cifar_std
from tinygrad import nn
from tinygrad.nn.state import get_state_dict
from tinygrad.nn import optim
from tinygrad.lazy import Device
from tinygrad.ops import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters
from extra.lr_scheduler import OneCycleLR

View File

@ -11,7 +11,7 @@ np.set_printoptions(linewidth=200)
from typing import Optional, Tuple
from tinygrad.helpers import Timing, getenv, DEBUG, dtypes
from tinygrad.lazy import Device
from tinygrad.ops import Device
from tinygrad.tensor import Tensor
from tinygrad.nn import Embedding, Linear
from tinygrad.ops import GlobalCounters

View File

@ -40,7 +40,7 @@ def _process_wrap(rank:int, device:str, oob:_OOB, fn:Callable, args=()):
OOB = oob
# do specific runtime initialization for distributed
from tinygrad.lazy import Device
from tinygrad.ops import Device
device, device_num = Device.canonicalize(device), 0 if ":" not in device else int(device.split(":")[-1])
if "GPU" in device:
from tinygrad.runtime.ops_gpu import CL

View File

@ -33,7 +33,7 @@ except RuntimeError:
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.lazy import Device
from tinygrad.ops import Device
b = Tensor(nb)
c = Tensor(nc)
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator

View File

@ -6,7 +6,7 @@ from collections import defaultdict
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
from tinygrad.ops import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.lazy import Device
from tinygrad.ops import Device
from tinygrad.shape.shapetracker import strides_for_shape
OSX = platform.system() == "Darwin"
WINDOWS = platform.system() == "Windows"

View File

@ -1,6 +1,6 @@
import unittest
from tinygrad.helpers import prod
from tinygrad.lazy import Device
from tinygrad.ops import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters

View File

@ -4,7 +4,7 @@ import torch, json, argparse
from examples.llama import LLaMa
from tinygrad.tensor import Tensor
from tinygrad.lazy import Device
from tinygrad.ops import Device
class LLaMaAdaptor(BaseLM):
def __init__(

View File

@ -10,7 +10,7 @@ from extra.utils import download_file
from extra.onnx import get_run_onnx
from tinygrad.helpers import OSX, DEBUG
from tinygrad.tensor import Tensor
from tinygrad.lazy import Device
from tinygrad.ops import Device
MODELS = {
"resnet50": "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-caffe2-v1-9.onnx",

View File

@ -6,7 +6,7 @@ from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.ops import GlobalCounters, LazyOp, LoadOps
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes, prod
from tinygrad.lazy import Device
from tinygrad.ops import Device
from examples.llama import Transformer

View File

@ -4,7 +4,7 @@ import numpy as np
from examples.llama import Transformer, MODEL_PARAMS
from test.test_net_speed import start_profile, stop_profile
from tinygrad.tensor import Tensor
from tinygrad.lazy import Device
from tinygrad.ops import Device
from tinygrad.nn.state import get_state_dict
from tinygrad.ops import Compiled
from tinygrad.helpers import dtypes, prod

View File

@ -2,7 +2,7 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.lazy import Device
from tinygrad.ops import Device
import torch
def get_question_samp(bsz, seq_len, vocab_size, seed):

View File

@ -4,7 +4,7 @@ from tinygrad.nn import optim
from tinygrad.nn.state import get_parameters
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
from tinygrad.ops import GlobalCounters, LazyOp, LoadOps
from tinygrad.lazy import Device
from tinygrad.ops import Device
from tinygrad.helpers import CI, dtypes
from examples.hlb_cifar10 import SpeedyResNet

View File

@ -5,7 +5,7 @@ from weakref import ref
from tinygrad.ops import GlobalCounters
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes, prod
from tinygrad.lazy import Device
from tinygrad.ops import Device
def check_gc():
if Device.DEFAULT == "GPU":

View File

@ -1,7 +1,7 @@
import unittest
import numpy as np
from tinygrad.helpers import getenv, DType, DEBUG, CI
from tinygrad.lazy import Device
from tinygrad.ops import Device
from tinygrad.tensor import Tensor, dtypes
from typing import List, Optional
from extra.utils import OSX, temp

View File

@ -2,7 +2,7 @@ import numpy as np
import unittest
from tinygrad.codegen.linearizer import Linearizer, UOps
from tinygrad.lazy import Device
from tinygrad.ops import Device
from tinygrad.ops import GlobalCounters, Compiled
from tinygrad.tensor import Tensor

View File

@ -5,7 +5,7 @@ import numpy as np
import unittest
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes
from tinygrad.lazy import Device
from tinygrad.ops import Device
if CI:
import warnings

View File

@ -1,7 +1,7 @@
import unittest
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
from tinygrad.lazy import Device
from tinygrad.ops import Device
import pytest
# similar to test/external/external_test_gpu_ast.py, but universal

View File

@ -10,7 +10,7 @@ import time
import numpy as np
np.set_printoptions(linewidth=160)
from functools import partial
from tinygrad.lazy import Device
from tinygrad.ops import Device
from tinygrad.ops import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d

View File

@ -1,6 +1,6 @@
import unittest
import numpy as np
from tinygrad.lazy import Device
from tinygrad.ops import Device
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, CI

View File

@ -1,7 +1,7 @@
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
import functools, itertools
from tinygrad.helpers import DEBUG, DType, merge_dicts
from tinygrad.lazy import Device
from tinygrad.ops import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters, RawBuffer
from tinygrad.shape.shapetracker import ShapeTracker

View File

@ -1,17 +1,17 @@
from __future__ import annotations
import operator, math
import sys, operator, math
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast
import sys, importlib, inspect, functools, pathlib
from weakref import ref, WeakSet, WeakValueDictionary
import numpy as np
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
from tinygrad.runtime.ops_disk import RawDiskBuffer
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp
from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction
from tinygrad.shape.symbolic import Node
from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp
from tinygrad.runtime.lib import RawBufferMapped, RawConst, RawBuffer, RawBufferTransfer
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
from tinygrad.runtime.ops_disk import RawDiskBuffer
# lazy can recurse a lot
sys.setrecursionlimit(10000)
@ -176,7 +176,7 @@ class LazyBuffer:
return create_lazybuffer(device, ShapeTracker(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype)
# create a constant with the shape and dtype of self
def const(self, val) -> LazyBuffer:
def const(self, val:Union[float, int]) -> LazyBuffer:
# NOTE: dtypes.from_np(self.dtype.np) to deal with image types
return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
@ -323,23 +323,6 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
new_srcs.append(x)
return tuple(new_srcs)
class _Device:
def __init__(self) -> None:
self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
self.DEFAULT: str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) or self._default_device()
def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]:
x = x.split(":")[0].upper()
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
def _default_device(self) -> str:
for device in ["METAL", "CUDA", "GPU"]:
try:
if self[device]: return device
except Exception: pass
return "CPU"
Device = _Device()
def _realize_contiguous(buffer: LazyBuffer) -> None:
realized = buffer.op.src[0].realize().realized
if buffer.op.src[0].st.contiguous and realized.__class__ is not RawConst and cast(RawBuffer, realized).size == prod(buffer.shape):

View File

@ -13,6 +13,7 @@ class Cast(Function):
def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
self.input_dtype, self.bitcast = x.dtype, bitcast
return x.e(UnaryOps.CAST, arg=(dtype, bitcast))
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.e(UnaryOps.CAST, arg=(self.input_dtype, self.bitcast))
@ -22,6 +23,7 @@ class Sin(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
return x.e(UnaryOps.SIN)
def backward(self, grad:LazyBuffer) -> LazyBuffer:
return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad)
@ -137,7 +139,7 @@ class Where(Function):
self.x = x
return x.e(TernaryOps.WHERE, y, z)
def backward(self, grad_output:LazyBuffer):
def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
return None, \
self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
@ -158,7 +160,7 @@ class Reshape(Function):
self.input_shape = x.shape
return x.reshape(shape)
def backward(self, grad_output:LazyBuffer):
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.reshape(self.input_shape)
class Permute(Function):
@ -186,7 +188,7 @@ class Shrink(Function):
return grad_output.pad(self.narg)
class Flip(Function):
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]):
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
return x.stride(self.arg)

View File

@ -4,7 +4,7 @@ from typing import Dict, Union, List
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI
from tinygrad.shape.shapetracker import strides_for_shape
from tinygrad.lazy import Device
from tinygrad.ops import Device
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}

View File

@ -1,5 +1,5 @@
from __future__ import annotations
import time
import time, importlib, inspect, functools, pathlib
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup, merge_dicts
@ -77,6 +77,25 @@ class LazyOp:
def shrink(self, _): raise NotImplementedError
def stride(self, _): raise NotImplementedError
# **************** Device ****************
class _Device:
def __init__(self) -> None:
self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
self.DEFAULT: str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) or self._default_device()
def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]:
x = x.split(":")[0].upper()
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
def _default_device(self) -> str:
for device in ["METAL", "CUDA", "GPU"]:
try:
if self[device]: return device
except Exception: pass
return "CPU"
Device = _Device()
# **************** for Interpreted Buffers ****************
class Interpreted:
@ -148,7 +167,6 @@ class ASTRunner:
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += op_estimate
GlobalCounters.global_mem += self.mem_estimate
if getenv("EARLY_STOPPING") and GlobalCounters.kernel_count == getenv("EARLY_STOPPING"): exit(0)
return et
class Compiled:
@ -181,8 +199,7 @@ class Compiled:
break
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
if not output.realized:
output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **kwargs)
if not output.realized: output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **kwargs)
# update the output var_vals from src
output.st.var_vals = dict(sorted(merge_dicts([buf.st.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))

View File

@ -7,8 +7,8 @@ import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
from tinygrad.lazy import Device, LazyBuffer
from tinygrad.ops import LoadOps
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import Device, LoadOps
# An instantiation of the Function is the Context
class Function: