1
0
Fork 0

move to shapetracker.py

pull/682/head
George Hotz 2023-03-11 07:49:20 -08:00
parent f3ac52aee8
commit 01f39b19dc
10 changed files with 11 additions and 11 deletions

View File

@ -11,7 +11,7 @@ import triton.language as tl # type: ignore # noqa: F401
from typing import Union, Tuple, Optional, Dict
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ExplicitExecAST, GlobalCounters
from tinygrad.shape import ShapeTracker
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.helpers import prod, DEBUG
from tinygrad.runtime.cuda import CLBuffer
from tinygrad.compiler.ast import ASTKernel
@ -57,7 +57,7 @@ class TritonASTKernel(ASTKernel):
self.kernel = ["@triton.jit"]
self.kernel.append("def fxn("+','.join(f"data{i}" for i in range(len(self.bufs)))+"):")
self.output_shape = list(self.sts[0].shape[:self.first_reduce])
self.output_shape = list(self.sts[0].shape[:self.first_reduce])
# copied from ops_gpu
# TODO CUDA only supports a grid of (2^31-1, 65535, 65535), that results in invalid kernel launches for some shapes, so flattern the grid for now.
@ -70,7 +70,7 @@ class TritonASTKernel(ASTKernel):
self.output_shape = [prod(self.output_shape[0:final_dimension+1])] + list(self.output_shape[final_dimension+1:])
if DEBUG >= 3: print(f"replaced output shape with {self.output_shape}")
elif len(self.output_shape) == 0: self.output_shape = [1]
if self.reduceop:
full_shape = [st.shape for st in self.sts if st.shape != self.sts[0].shape]
full_shape = self.sts[0].shape if len(full_shape) == 0 else full_shape[0]

View File

@ -5,7 +5,7 @@ import itertools
from enum import Enum
import numpy as np
from tinygrad.ops import LazyOp, ReduceOps, BinaryOps, UnaryOps, MovementOps
from tinygrad.shape import ShapeTracker, View, ZeroView
from tinygrad.shape.shapetracker import ShapeTracker, View, ZeroView
from tinygrad.runtime.ops_gpu import GPUBuffer, CLASTKernel
from tinygrad.runtime.opencl import OSX_TIMING_RATIO
from tinygrad.helpers import getenv, DEBUG

View File

@ -7,7 +7,7 @@ from tinygrad.helpers import prod, getenv, DEBUG
from tinygrad.ops import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyNumpyArray, Device
from tinygrad.shape import strides_for_shape
from tinygrad.shape.shapetracker import strides_for_shape
def fetch(url):
if url.startswith("/"):

View File

@ -2,7 +2,7 @@
import unittest
import numpy as np
from tinygrad.ops import LazyOp, ReduceOps, BinaryOps, UnaryOps, MovementOps
from tinygrad.shape import ShapeTracker, View, ZeroView
from tinygrad.shape.shapetracker import ShapeTracker, View, ZeroView
from tinygrad.runtime.ops_gpu import GPUBuffer, CLProgram, CLCodegen
#from tinygrad.runtime.ops_metal import MetalBuffer as GPUBuffer, MetalProgram as CLProgram, MetalCodegen as CLCodegen
from tinygrad.helpers import getenv

View File

@ -2,7 +2,7 @@
import unittest
import numpy as np
from tinygrad.helpers import prod, all_same
from tinygrad.shape import ShapeTracker, View, ZeroView, merge_views, get_contraction
from tinygrad.shape.shapetracker import ShapeTracker, View, ZeroView, merge_views, get_contraction
from tinygrad.codegen.gpu import to_image_idx
def shapetracker_getitem(st, val):

View File

@ -3,7 +3,7 @@ from enum import Enum, auto
from typing import List, Tuple
from tinygrad.helpers import prod, dedup, all_same, colored, dtypes
from tinygrad.ops import LazyOp, MovementOps, get_lazyop_info, get_buffers, ReduceOps, get_lazyops, map_buffers
from tinygrad.shape import ShapeTracker, View, strides_for_shape
from tinygrad.shape.shapetracker import ShapeTracker, View, strides_for_shape
def get_first_reduce(shapes):
for i in range(len(shapes[0])):

View File

@ -4,7 +4,7 @@ from typing import Optional, List, Tuple, Dict, Set, Final, NamedTuple, ClassVar
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ASTRunner
from tinygrad.codegen.ast import ASTKernel, Token, Types
from tinygrad.shape.symbolic import Node, MulNode, DivNode, SumNode, AndNode, Variable, render_python
from tinygrad.shape import ShapeTracker, View
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.helpers import getenv, DEBUG, prod, partition, mnum, all_same, dedup, dtypes
# div is different in cl than python

View File

@ -3,7 +3,7 @@ from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, Type
import os, sys, weakref, importlib, inspect, functools
from weakref import WeakValueDictionary
from tinygrad.helpers import prod, getenv, DType, dtypes, LazyNumpyArray
from tinygrad.shape import ShapeTracker, get_contraction
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.ops import InterpretedBuffer, DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, map_buffers
from tinygrad.graph import log_op

View File

@ -4,7 +4,7 @@ import numpy as np
from enum import Enum, auto
from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional, Callable, Dict, TypeVar, Set, Final
from tinygrad.helpers import prod, DEBUG, getenv, DType, dtypes
from tinygrad.shape import ShapeTracker, MovementOps
from tinygrad.shape.shapetracker import ShapeTracker, MovementOps
# these are the llops your accelerator must implement, along with toCpu
# the Enum class doesn't work with mypy, this is static. sorry it's ugly