move to shapetracker.py
parent
f3ac52aee8
commit
01f39b19dc
|
@ -11,7 +11,7 @@ import triton.language as tl # type: ignore # noqa: F401
|
|||
|
||||
from typing import Union, Tuple, Optional, Dict
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ExplicitExecAST, GlobalCounters
|
||||
from tinygrad.shape import ShapeTracker
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.runtime.cuda import CLBuffer
|
||||
from tinygrad.compiler.ast import ASTKernel
|
||||
|
@ -57,7 +57,7 @@ class TritonASTKernel(ASTKernel):
|
|||
self.kernel = ["@triton.jit"]
|
||||
self.kernel.append("def fxn("+','.join(f"data{i}" for i in range(len(self.bufs)))+"):")
|
||||
|
||||
self.output_shape = list(self.sts[0].shape[:self.first_reduce])
|
||||
self.output_shape = list(self.sts[0].shape[:self.first_reduce])
|
||||
|
||||
# copied from ops_gpu
|
||||
# TODO CUDA only supports a grid of (2^31-1, 65535, 65535), that results in invalid kernel launches for some shapes, so flattern the grid for now.
|
||||
|
@ -70,7 +70,7 @@ class TritonASTKernel(ASTKernel):
|
|||
self.output_shape = [prod(self.output_shape[0:final_dimension+1])] + list(self.output_shape[final_dimension+1:])
|
||||
if DEBUG >= 3: print(f"replaced output shape with {self.output_shape}")
|
||||
elif len(self.output_shape) == 0: self.output_shape = [1]
|
||||
|
||||
|
||||
if self.reduceop:
|
||||
full_shape = [st.shape for st in self.sts if st.shape != self.sts[0].shape]
|
||||
full_shape = self.sts[0].shape if len(full_shape) == 0 else full_shape[0]
|
||||
|
|
|
@ -5,7 +5,7 @@ import itertools
|
|||
from enum import Enum
|
||||
import numpy as np
|
||||
from tinygrad.ops import LazyOp, ReduceOps, BinaryOps, UnaryOps, MovementOps
|
||||
from tinygrad.shape import ShapeTracker, View, ZeroView
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View, ZeroView
|
||||
from tinygrad.runtime.ops_gpu import GPUBuffer, CLASTKernel
|
||||
from tinygrad.runtime.opencl import OSX_TIMING_RATIO
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
|
|
|
@ -7,7 +7,7 @@ from tinygrad.helpers import prod, getenv, DEBUG
|
|||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LazyNumpyArray, Device
|
||||
from tinygrad.shape import strides_for_shape
|
||||
from tinygrad.shape.shapetracker import strides_for_shape
|
||||
|
||||
def fetch(url):
|
||||
if url.startswith("/"):
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.ops import LazyOp, ReduceOps, BinaryOps, UnaryOps, MovementOps
|
||||
from tinygrad.shape import ShapeTracker, View, ZeroView
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View, ZeroView
|
||||
from tinygrad.runtime.ops_gpu import GPUBuffer, CLProgram, CLCodegen
|
||||
#from tinygrad.runtime.ops_metal import MetalBuffer as GPUBuffer, MetalProgram as CLProgram, MetalCodegen as CLCodegen
|
||||
from tinygrad.helpers import getenv
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod, all_same
|
||||
from tinygrad.shape import ShapeTracker, View, ZeroView, merge_views, get_contraction
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View, ZeroView, merge_views, get_contraction
|
||||
from tinygrad.codegen.gpu import to_image_idx
|
||||
|
||||
def shapetracker_getitem(st, val):
|
||||
|
|
|
@ -3,7 +3,7 @@ from enum import Enum, auto
|
|||
from typing import List, Tuple
|
||||
from tinygrad.helpers import prod, dedup, all_same, colored, dtypes
|
||||
from tinygrad.ops import LazyOp, MovementOps, get_lazyop_info, get_buffers, ReduceOps, get_lazyops, map_buffers
|
||||
from tinygrad.shape import ShapeTracker, View, strides_for_shape
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View, strides_for_shape
|
||||
|
||||
def get_first_reduce(shapes):
|
||||
for i in range(len(shapes[0])):
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Optional, List, Tuple, Dict, Set, Final, NamedTuple, ClassVar
|
|||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ASTRunner
|
||||
from tinygrad.codegen.ast import ASTKernel, Token, Types
|
||||
from tinygrad.shape.symbolic import Node, MulNode, DivNode, SumNode, AndNode, Variable, render_python
|
||||
from tinygrad.shape import ShapeTracker, View
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.helpers import getenv, DEBUG, prod, partition, mnum, all_same, dedup, dtypes
|
||||
|
||||
# div is different in cl than python
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, Type
|
|||
import os, sys, weakref, importlib, inspect, functools
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.helpers import prod, getenv, DType, dtypes, LazyNumpyArray
|
||||
from tinygrad.shape import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.ops import InterpretedBuffer, DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, map_buffers
|
||||
from tinygrad.graph import log_op
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
from enum import Enum, auto
|
||||
from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional, Callable, Dict, TypeVar, Set, Final
|
||||
from tinygrad.helpers import prod, DEBUG, getenv, DType, dtypes
|
||||
from tinygrad.shape import ShapeTracker, MovementOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, MovementOps
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
||||
|
|
Loading…
Reference in New Issue