1
0
Fork 0

share duplicate renders with cstyle (#2538)

pull/2541/head
qazal 2023-12-01 11:10:36 -05:00 committed by GitHub
parent 7fec966b5e
commit 0fb4ff30c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 11 deletions

View File

@ -1,7 +1,7 @@
from tinygrad.helpers import dtypes, DType
from tinygrad.renderer.cstyle import CStyleLanguage
from typing import List, Union
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.ops import BinaryOps, TernaryOps
import math
from typing import Tuple
@ -13,15 +13,7 @@ class WGSLLanguage(CStyleLanguage):
barrier="workgroupBarrier();"
generic_var_prefix = "var "
external_local_bufs = True
code_for_op = {
UnaryOps.NEG: lambda x: f"(-{x})",
UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})",
UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})",
BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", BinaryOps.MUL: lambda x,y: f"({x}*{y})",
BinaryOps.DIV: lambda x,y: f"({x}/{y})", BinaryOps.MOD: lambda x,y: f"({x}%{y})",
BinaryOps.MAX: lambda x,y: f"max({x},{y})", BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})",
TernaryOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c: f"select({c},{b},{a}!=0.)"
}
code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})", TernaryOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c: f"select({c},{b},{a}!=0.)" }
def render_local(self, name: str, size: int):
return f"var<workgroup> {name}: array<f32,{size}>;"
@ -59,4 +51,4 @@ class WGSLLanguage(CStyleLanguage):
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
if buf_dtype != var_dtype:
var_name = f"{type_map[buf_dtype]}({var_name})"
return f"{buf_name}[{idx}] = {var_name};"
return f"{buf_name}[{idx}] = {var_name};"