# Owner(s): ["NNC"] import torch import numpy as np import torch._C._te as te from torch.testing._internal.common_utils import run_tests from torch.testing._internal.jit_utils import JitTestCase import unittest LLVM_ENABLED = torch._C._llvm_enabled() def construct_adder(n: int, dtype=torch.float32): A = te.BufHandle("A", [n], dtype) B = te.BufHandle("B", [n], dtype) def compute(i): return A.load([i]) + B.load([i]) C = te.Compute("C", [n], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) return te.construct_codegen("ir_eval", stmt, [A, B, C]) class TestTensorExprPyBind(JitTestCase): def test_simple_sum(self): n = 32 cg = construct_adder(n) tA = torch.randn(n) tB = torch.randn(n) tC = torch.empty(n) cg.call([tA, tB, tC]) torch.testing.assert_close(tA + tB, tC) def test_call_raw(self): n = 16 cg = construct_adder(n, dtype=torch.float64) tA = torch.randn(n, dtype=torch.float64) tB = torch.randn(n, dtype=torch.float64) tC = torch.empty(n, dtype=torch.float64) cg.call_raw([tA.data_ptr(), tB.data_ptr(), tC.data_ptr()]) torch.testing.assert_close(tA + tB, tC) def test_external_calls(self): dtype = torch.float32 A = te.BufHandle("A", [1, 4], dtype) B = te.BufHandle("B", [4, 1], dtype) C = te.BufHandle("C", [1, 1], dtype) s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], []) loopnest = te.LoopNest(s, [C]) loopnest.prepare_for_codegen() codegen = te.construct_codegen("ir_eval", s, [A, B, C]) tA = torch.ones(1, 4) tB = torch.ones(4, 1) tC = torch.empty(1, 1) codegen.call([tA, tB, tC]) torch.testing.assert_close(torch.matmul(tA, tB), tC) def test_dynamic_shape(self): dN = te.VarHandle(torch.int32) A = te.BufHandle([dN], torch.float64) B = te.BufHandle([dN], torch.float64) def compute(i): return A.load(i) - B.load(i) C = te.Compute("C", [dN], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN]) def test_with_shape(n): tA = torch.randn(n, dtype=torch.double) tB = torch.randn(n, dtype=torch.double) tC = torch.empty(n, dtype=torch.double) cg.call([tA, tB, tC, n]) torch.testing.assert_close(tA - tB, tC) test_with_shape(8) test_with_shape(31) def test_dynamic_shape_2d(self): dN = te.VarHandle(torch.int32) dM = te.VarHandle(torch.int32) A = te.BufHandle([dN, dM], torch.float64) B = te.BufHandle([dN, dM], torch.float64) def compute(i, j): return A.load([i, j]) - B.load([i, j]) C = te.Compute("C", [dN, dM], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN, dM]) def test_with_shape(n, m): tA = torch.randn(n, m, dtype=torch.double) tB = torch.randn(n, m, dtype=torch.double) tC = torch.empty(n, m, dtype=torch.double) cg.call([tA, tB, tC, n, m]) torch.testing.assert_close(tA - tB, tC) test_with_shape(2, 4) test_with_shape(5, 3) def test_dtype_error(self): te.BufHandle("a", [1], torch.float32) # ok self.assertRaises(TypeError, lambda: te.BufHandle("a", [1], "float55")) @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") def test_kernel_with_tensor_inputs(self): def f(a, b, c): return a + b + c device, size = "cpu", (4, 4) x = torch.rand(size, device=device) y = torch.rand(size, device=device) z = torch.rand(size, device=device) graph_str = """ graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %b.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %c.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)): %6 : int = prim::Constant[value=1]() %7 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %6) %3 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%7, %c.1, %6) return (%3) """ graph = torch._C.parse_ir(graph_str) kernel = te.TensorExprKernel(graph) res1 = kernel.run((x, y, z)) res2 = kernel.fallback((x, y, z)) correct = f(x, y, z) np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") def test_kernel_with_scalar_inputs(self): def f(a, b, c): return a + b + c x = torch.tensor(0.1, dtype=torch.float, device="cpu") y = torch.tensor(0.6, dtype=torch.float, device="cpu") z = torch.tensor(0.7, dtype=torch.float, device="cpu") graph_str = """ graph(%a.1 : Float(requires_grad=0, device=cpu), %b.1 : Float(requires_grad=0, device=cpu), %c.1 : Float(requires_grad=0, device=cpu)): %3 : int = prim::Constant[value=1]() %6 : Float(requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %3) %9 : Float(requires_grad=0, device=cpu) = aten::add(%6, %c.1, %3) return (%9) """ graph = torch._C.parse_ir(graph_str) kernel = te.TensorExprKernel(graph) res1 = kernel.run((x, y, z)) res2 = kernel.fallback((x, y, z)) correct = f(x, y, z) np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") def test_kernel_shape_prop(self): device, size = "cpu", (4, 4) x = torch.rand(size, device=device) y = torch.rand(size, device=device) graph_str = """ graph(%a : Tensor, %b : Tensor): %c : Tensor = aten::mul(%a, %b) return (%c) """ graph = torch._C.parse_ir(graph_str) exception_thrown = False try: kernel = te.TensorExprKernel(graph) except RuntimeError: # Graph doesn't have shape info for inputs => compilation should # fail exception_thrown = True pass assert exception_thrown # Inject shape info and try compiling again example_inputs = [torch.rand(4, 4), torch.rand(4, 4)] torch._C._te.annotate_input_shapes(graph, example_inputs) torch._C._jit_pass_propagate_shapes_on_graph(graph) # Now compilation should pass kernel = te.TensorExprKernel(graph) res = kernel.run((x, y)) correct = torch.mul(x, y) np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5) @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") def test_kernel_shape_prop_module(self): class TestModule(torch.nn.Module): def forward(self, x, y): return x * x + y graph = torch.jit.script(TestModule()).graph # Try compiling the graph as-is. It should fail because it doesn't have # shape info. exception_thrown = False try: kernel = te.TensorExprKernel(graph) except RuntimeError: exception_thrown = True pass assert exception_thrown # Try injecting shape info for graph inputs example_inputs = [torch.rand(4, 4), torch.rand(4, 4)] exception_thrown = False try: torch._C._te.annotate_input_shapes(graph, example_inputs) except RuntimeError: # Graph has a 'self' argument for which we can't set shapes exception_thrown = True pass assert exception_thrown # Remove 'self' argument and try annotating shapes one more time torch._C._te.remove_unused_self_argument(graph) # Inject shape info and try compiling again torch._C._te.annotate_input_shapes(graph, example_inputs) torch._C._jit_pass_propagate_shapes_on_graph(graph) # Now compilation should pass kernel = te.TensorExprKernel(graph) device, size = "cpu", (4, 4) x = torch.rand(size, device=device) y = torch.rand(size, device=device) res = kernel.run((x, y)) correct = TestModule().forward(x, y) np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5) @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") def test_kernel_with_t(self): def f(a): return a.t() device, size = "cpu", (3, 4) x = torch.rand(size, device=device) graph_str = """ graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)): %3 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::t(%a.1) return (%3) """ graph = torch._C.parse_ir(graph_str) kernel = te.TensorExprKernel(graph) res1 = kernel.run((x,)) res2 = kernel.fallback((x,)) correct = f(x) np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") def test_kernel_with_transpose(self): def f(a): return a.transpose(-1, -2) device, size = "cpu", (3, 4) x = torch.rand(size, device=device) graph_str = """ graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)): %2 : int = prim::Constant[value=-1]() %3 : int = prim::Constant[value=-2]() %4 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::transpose(%a.1, %2, %3) return (%4) """ graph = torch._C.parse_ir(graph_str) kernel = te.TensorExprKernel(graph) res1 = kernel.run((x,)) res2 = kernel.fallback((x,)) correct = f(x) np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") def test_kernel_with_permute(self): def f(a): return a.permute([2, 1, 0]) device, size = "cpu", (3, 4, 5) x = torch.rand(size, device=device) graph_str = """ graph(%a.1 : Float(3, 4, 5, strides=[20, 5, 1], requires_grad=0, device=cpu)): %1 : int = prim::Constant[value=2]() %2 : int = prim::Constant[value=1]() %3 : int = prim::Constant[value=0]() %4 : int[] = prim::ListConstruct(%1, %2, %3) %5 : Float(5, 4, 3, strides=[12, 3, 1], requires_grad=0, device=cpu) = aten::permute(%a.1, %4) return (%5) """ graph = torch._C.parse_ir(graph_str) kernel = te.TensorExprKernel(graph) res1 = kernel.run((x,)) res2 = kernel.fallback((x,)) correct = f(x) np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") def test_kernel_with_custom_lowering(self): def f(a): return a.nan_to_num() device = "cpu" x = torch.ones((2, 2), device=device) x[0, 0] = x[1, 1] = torch.nan graph_str = """ graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)): %none : NoneType = prim::Constant() %y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none) return (%y) """ graph = torch._C.parse_ir(graph_str) def my_custom_lowering(inputs, out_shape, out_stride, out_type, device): def compute(idxs): load = inputs[0].as_buf().load(idxs) return te.ifThenElse( te.ExprHandle.isnan(load), te.ExprHandle.float(0.0), load ) return te.Compute2("custom_nan_to_num", out_shape, compute) kernel = te.TensorExprKernel(graph, {"aten::nan_to_num": my_custom_lowering}) res1 = kernel.run((x,)) res2 = kernel.fallback((x,)) correct = f(x) np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") def test_kernel_with_expand(self): def f(a): return a.expand((2, 3, 4)) device = "cpu" x = torch.rand((1, 3, 1), device=device) graph_str = """ graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)): %1 : int = prim::Constant[value=2]() %2 : int = prim::Constant[value=3]() %3 : int = prim::Constant[value=4]() %4 : int[] = prim::ListConstruct(%1, %2, %3) %5 : bool = prim::Constant[value=0]() %6 : Float(2, 3, 4, strides=[12, 4, 0], requires_grad=0, device=cpu) = aten::expand(%a, %4, %5) return (%6) """ graph = torch._C.parse_ir(graph_str) kernel = te.TensorExprKernel(graph) res1 = kernel.run((x,)) res2 = kernel.fallback((x,)) correct = f(x) np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") def test_alloc_in_loop(self): a, tmp, b = ( te.BufHandle(name, [1], torch.float32) for name in ["a", "tmp", "b"] ) body = te.Block([tmp.store([0], a.load([0])), b.store([0], tmp.load([0]))]) for _ in range(4): i = te.VarHandle("i", torch.int32) body = te.For.make(i, 0, 100, body) nest = te.LoopNest(body, [b]) nest.prepare_for_codegen() f = te.construct_codegen("llvm", nest.simplify(), [a, b]) ta, tb = (torch.ones(1) for _ in range(2)) f.call([ta.data_ptr(), tb.data_ptr()]) class TestExprHandlePyBind(JitTestCase): def test_unary_ops(self): unary_operators = { torch.sin: torch._C._te.sin, torch.cos: torch._C._te.cos, torch.tan: torch._C._te.tan, torch.asin: torch._C._te.asin, torch.acos: torch._C._te.acos, torch.atan: torch._C._te.atan, torch.sinh: torch._C._te.sinh, torch.cosh: torch._C._te.cosh, torch.tanh: torch._C._te.tanh, torch.sigmoid: torch._C._te.sigmoid, torch.exp: torch._C._te.exp, torch.expm1: torch._C._te.expm1, torch.abs: torch._C._te.abs, torch.log: torch._C._te.log, torch.log2: torch._C._te.log2, torch.log10: torch._C._te.log10, torch.log1p: torch._C._te.log1p, torch.erf: torch._C._te.erf, torch.erfc: torch._C._te.erfc, torch.sqrt: torch._C._te.sqrt, torch.rsqrt: torch._C._te.rsqrt, torch.ceil: torch._C._te.ceil, torch.floor: torch._C._te.floor, torch.round: torch._C._te.round, torch.trunc: torch._C._te.trunc, torch.lgamma: torch._C._te.lgamma, torch.frac: torch._C._te.frac, } def construct_te_fn(op, n: int, dtype=torch.float32): A = torch._C._te.BufHandle("A", [n], dtype) def compute(i): return op(A.load([i])) C = te.Compute("C", [n], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) return te.construct_codegen("ir_eval", stmt, [A, C]) n = 10 a = torch.rand(n) for torch_op, te_op in unary_operators.items(): ref = torch_op(a) te_fn = construct_te_fn(te_op, n, torch.float32) res = torch.empty(n) te_fn.call([a, res]) assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3) if __name__ == "__main__": run_tests()