From 0b5930d406e55ffd3cfd10f1386241f2770b3fe6 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 15 Aug 2023 09:07:26 -0700 Subject: [PATCH] more uops testing, who isn't passing right now... (#1522) * more uops * llvm refactor * update test uops * rest of the nodes * ors and ands --- test/test_uops.py | 53 ++++++++++++++++++++++------------- tinygrad/ops.py | 2 +- tinygrad/renderer/cstyle.py | 2 +- tinygrad/renderer/llvmir.py | 54 +++++++++++++++++++++--------------- tinygrad/runtime/ops_llvm.py | 3 +- 5 files changed, 69 insertions(+), 45 deletions(-) diff --git a/test/test_uops.py b/test/test_uops.py index 766ad2e31..4f8bcf8d5 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -37,36 +37,20 @@ def _test_single_value_const(tc, tt, vals, op): prg([buf]) return buf.toCPU()[0] -@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends") class TestUOps(unittest.TestCase): def _equal(self, v1, v2): if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5) def _test_uop_fxn(self, bop, fxn, dt=dtypes.float32): for f in [_test_single_value, _test_single_value_const]: - for a in [-2.0, 2.0]: + for a in [-2.0, 0.0, 1.0, 2.0]: self._equal(f(Token('c', dt), [Token('a', dt)], [a], bop), fxn(a)) - def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a)) - def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('nan')) - def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a)) - def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan')) - #def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1.0/a) - def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32): + def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32, no_b_zero=False): for f in [_test_single_value, _test_single_value_const]: - for a in [-2.0, 2.0]: - for b in [-3.0, 3.0]: + for a in [-2.0, 0.0, 1.0, 2.0]: + for b in [-3.0, 1.0, 3.0] + ([] if no_b_zero else [0.0]): self._equal(f(Token('c', dt), [Token('a', dt), Token('b', dt)], [a,b], bop), fxn(a,b)) - def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b) - def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b) - def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b) - def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b) - def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b)) - def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a 0 else float('-inf' if a==0 else 'nan')) + def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a)) + def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan')) + # this is not on most backends + #def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1.0/a if a != 0 else float('inf')) + + def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b) + def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b) + def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b) + def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b if b != 0 else a*float('inf')) + def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b)) + def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a FlopCounter: return InterpretedFlopCounter.ex class ASTRunner: def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None): - if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args): print(prg) + if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args or not runtime_args['binary']): print(prg) self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {} def build(self, runtime): diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 756160d75..1eaa0b433 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -35,7 +35,7 @@ class CStyleLanguage(NamedTuple): UnaryOps.SQRT: lambda x: f"sqrt({x})", BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})", BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})", - BinaryOps.MAX: lambda a,b: f"max({a},{b})", + BinaryOps.MAX: lambda a,b: f"max({a},{b})", BinaryOps.MOD: lambda a,b: f"({a}%{b})", BinaryOps.CMPLT: lambda a,b: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})", TernaryOps.WHERE: lambda a,b,c: f"({a}!=0?{b}:{c})" } diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index cf017b694..5c0073464 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -32,6 +32,35 @@ code_for_op: Final[Dict[Op, Callable]] = { TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)), } +dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)} + +def cast(bb, val, input_type, output_type): + if input_type == output_type: return val + + if output_type == dtypes.float32: + if dtypes.is_int(input_type) or input_type == dtypes.bool: + val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(input_type) or input_type == dtypes.bool else bb[-1].sitofp(val, ir.FloatType()) + elif input_type == dtypes.bfloat16: + val = bb[-1].sext(val, ir.IntType(32)) + val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16)) + val = bb[-1].bitcast(val, ir.FloatType()) + else: + val = bb[-1].fpext(val, ir.FloatType()) + return val + + if input_type == dtypes.float32: + if dtypes.is_int(output_type) or output_type == dtypes.bool: + val = bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type]) if dtypes.is_unsigned(output_type) or output_type == dtypes.bool else bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type]) + elif output_type == dtypes.bfloat16: + val = bb[-1].bitcast(val, ir.IntType(32)) + val = bb[-1].lshr(val, ir.Constant(ir.IntType(32), 16)) + val = bb[-1].trunc(val, ir.IntType(16)) + else: + val = bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type]) + return val + + raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented") + def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[List[int]], Optional[List[int]]]: # all llvm stuff goes into a module module = ir.Module(name=__file__) @@ -41,7 +70,6 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())} # create llvm function - dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)} func_dtypes = [dtype_to_llvm_dtype[dtype] for dtype in buf_to_dtype.values()] func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name=function_name) @@ -84,9 +112,9 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li bb[-2].cbranch(bb[-2].icmp_unsigned("==", idx_p1, int_const(var.max+1)), bb[-1]._block, block._block) if uop == UOps.LOAD: assert newvar is not None and isinstance(args, (MemOp, ConstOp)) - assert newvar.dtype == dtypes.float, "newvar must be float" valid = args.valid.render(render_llvm, bb[-1]) if isinstance(args, ConstOp): + assert newvar.dtype == dtypes.float, "newvar must be float" if args.valid.min == 0 and args.valid.max == 1: val = bb[-1].select(valid, ir.Constant(ir.FloatType(), args.value), ir.Constant(ir.FloatType(), args.invalid_value)) else: @@ -100,30 +128,12 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [aug_idx], inbounds=True)), ir.Constant(dtype_to_llvm_dtype[args.memory_dtype], args.invalid_value)) else: val = bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True)) - - if args.memory_dtype != newvar.dtype: - if dtypes.is_int(args.memory_dtype) or args.memory_dtype == dtypes.bool: - val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(args.memory_dtype) or args.memory_dtype == dtypes.bool else bb[-1].sitofp(val, ir.FloatType()) - elif args.memory_dtype == dtypes.bfloat16: - val = bb[-1].sext(val, ir.IntType(32)) - val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16)) - val = bb[-1].bitcast(val, ir.FloatType()) - else: - val = bb[-1].fpext(val, ir.FloatType()) + val = cast(bb, val, args.memory_dtype, newvar.dtype) lvars[newvar] = val if uop == UOps.STORE: assert args.valid.min == 1 and isinstance(args, MemOp), "store must be valid and to memory" idx = args.idx.render(render_llvm, bb[-1]) - element = lvars[vin[0]] - if args.memory_dtype != vin[0].dtype: - if dtypes.is_int(args.memory_dtype) or args.memory_dtype == dtypes.bool: - element = bb[-1].fptoui(element, dtype_to_llvm_dtype[args.memory_dtype]) if dtypes.is_unsigned(args.memory_dtype) or args.memory_dtype == dtypes.bool else bb[-1].fptosi(element, dtype_to_llvm_dtype[args.memory_dtype]) - elif args.memory_dtype == dtypes.bfloat16: - element = bb[-1].bitcast(element, ir.IntType(32)) - element = bb[-1].lshr(element, ir.Constant(ir.IntType(32), 16)) - element = bb[-1].trunc(element, ir.IntType(16)) - else: - element = bb[-1].fptrunc(element, dtype_to_llvm_dtype[args.memory_dtype]) + element = cast(bb, lvars[vin[0]], vin[0].dtype, args.memory_dtype) bb[-1].store(element, bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True)) if uop == UOps.ALU: lvars[newvar] = code_for_op[args](bb[-1], *[lvars[x] for x in vin]) diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index ab0f8b9a2..00d404a79 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -55,7 +55,8 @@ class LLVMProgram: LLVM.engine.finalize_object() self.fxn = LLVM.engine.get_function_address(name) - def __del__(self): LLVM.engine.remove_module(self.mod) + def __del__(self): + if hasattr(self, 'mod'): LLVM.engine.remove_module(self.mod) def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False): cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn)