style: else-after-return (#1216)
Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>pull/1077/head^2
parent
ab663c46e8
commit
8f2e2f5ee2
|
@ -64,7 +64,7 @@ disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1
|
|||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
enable=c-extension-no-member,use-a-generator
|
||||
enable=c-extension-no-member,use-a-generator, no-else-return
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
|
|
@ -188,7 +188,7 @@ class BoxList:
|
|||
if self.mode == "xyxy":
|
||||
xmin, ymin, xmax, ymax = self.bbox.chunk(4, dim=-1)
|
||||
return xmin, ymin, xmax, ymax
|
||||
elif self.mode == "xywh":
|
||||
if self.mode == "xywh":
|
||||
TO_REMOVE = 1
|
||||
xmin, ymin, w, h = self.bbox.chunk(4, dim=-1)
|
||||
return (
|
||||
|
|
|
@ -96,9 +96,8 @@ class AssemblyCodegen(Linearizer):
|
|||
ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
|
||||
reg = new_reg
|
||||
return tor[f"buf{args.i}"], reg, off
|
||||
else:
|
||||
reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64)
|
||||
return reg, None, off
|
||||
reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64)
|
||||
return reg, None, off
|
||||
|
||||
ins = []
|
||||
ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64, scalar=True), [], f"buf{i}") for i in range(len(self.bufs))]
|
||||
|
|
|
@ -30,7 +30,7 @@ class CStyleLanguage(NamedTuple):
|
|||
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
|
||||
assert self.float4 is not None, "cast is not supported on this platform"
|
||||
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
|
||||
elif var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
|
||||
if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
|
||||
raise NotImplementedError(f"no cast for {var_dtype}")
|
||||
|
||||
# returns a str expression of the const with the given type
|
||||
|
@ -45,24 +45,22 @@ class CStyleLanguage(NamedTuple):
|
|||
if isinstance(buf_dtype, ImageDType):
|
||||
assert output_dtype == dtypes._float4, "images must be float4"
|
||||
return f"read_imagef({buf_name}, smp, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}))"
|
||||
elif self.uses_vload and buf_dtype == dtypes.float16:
|
||||
if self.uses_vload and buf_dtype == dtypes.float16:
|
||||
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx.render(render_cl, strip_parens=True)})"
|
||||
elif output_dtype.sz > 1:
|
||||
if output_dtype.sz > 1:
|
||||
return f"({output_dtype.name})(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})))"
|
||||
else:
|
||||
return f"{buf_name}[{idx.render(render_cl)}]"
|
||||
return f"{buf_name}[{idx.render(render_cl)}]"
|
||||
|
||||
# returns a str statement that does the store
|
||||
def render_store(self, buf_name, buf_dtype, var_name, var_dtype, idx, local=False) -> str:
|
||||
if isinstance(buf_dtype, ImageDType):
|
||||
assert var_dtype == dtypes._float4, "images must be float4"
|
||||
return f"write_imagef({buf_name}, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}), {var_name});"
|
||||
elif self.uses_vload and buf_dtype == dtypes.float16:
|
||||
if self.uses_vload and buf_dtype == dtypes.float16:
|
||||
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx.render(render_cl, strip_parens=True)});"
|
||||
elif var_dtype.sz > 1:
|
||||
if var_dtype.sz > 1:
|
||||
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
|
||||
else:
|
||||
return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
|
||||
return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
|
||||
|
||||
code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.EXP2: lambda x: f"exp2({x})",
|
||||
|
@ -85,9 +83,8 @@ def add_gl_dimension(args, i, var, local_size, xid):
|
|||
lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1)
|
||||
assert lidx.max == var.max and lidx.min == var.min
|
||||
return f"{{ int {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
|
||||
else:
|
||||
local_size.append(var.max+1)
|
||||
return f"{{ int {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
|
||||
local_size.append(var.max+1)
|
||||
return f"{{ int {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
|
||||
|
||||
def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
|
||||
prekernel: Set[str] = set()
|
||||
|
|
|
@ -101,10 +101,10 @@ def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True):
|
|||
# TODO: generic visitor pattern?
|
||||
def expand_node(idx:Node) -> List[Node]:
|
||||
if isinstance(idx, Variable): return [idx] if idx.expr is not None else [Variable.num(j) for j in range(idx.min, idx.max+1)]
|
||||
elif isinstance(idx, NumNode): return [idx]
|
||||
elif isinstance(idx, MulNode): return [x*idx.b for x in expand_node(idx.a)]
|
||||
elif isinstance(idx, SumNode): return [Variable.sum(list(it)) for it in itertools.product(*[expand_node(x) for x in idx.nodes])]
|
||||
else: raise NotImplementedError(idx)
|
||||
if isinstance(idx, NumNode): return [idx]
|
||||
if isinstance(idx, MulNode): return [x*idx.b for x in expand_node(idx.a)]
|
||||
if isinstance(idx, SumNode): return [Variable.sum(list(it)) for it in itertools.product(*[expand_node(x) for x in idx.nodes])]
|
||||
raise NotImplementedError(idx)
|
||||
|
||||
def expand_idxs(idxs:Sequence[Node]) -> Iterator[Tuple[Node, ...]]:
|
||||
for x in itertools.product(*[expand_node(idx) for idx in idxs[::-1]]):
|
||||
|
|
|
@ -104,8 +104,7 @@ class Interpreted:
|
|||
assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype
|
||||
output.output_buffer._buf = ret._buf
|
||||
return output.output_buffer
|
||||
else:
|
||||
return ret
|
||||
return ret
|
||||
|
||||
class FlopCounter:
|
||||
def __init__(self, tup:Tuple[Tuple[int, ...], DType, int]): self.shape, self.dtype, self.flops, self._buf = *tup, self
|
||||
|
|
|
@ -70,8 +70,7 @@ class MetalProgram:
|
|||
if wait:
|
||||
command_buffer.waitUntilCompleted()
|
||||
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
else:
|
||||
METAL.mtl_buffers_in_flight.append(command_buffer)
|
||||
METAL.mtl_buffers_in_flight.append(command_buffer)
|
||||
|
||||
class MetalCodegen(CStyleCodegen):
|
||||
lang = CStyleLanguage(
|
||||
|
|
|
@ -119,11 +119,9 @@ def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]:
|
|||
|
||||
new_view = View(new_shape, strides_for_shape(new_shape))
|
||||
if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset
|
||||
else:
|
||||
if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False
|
||||
else:
|
||||
if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}")
|
||||
return new_view, True
|
||||
if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False
|
||||
if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}")
|
||||
return new_view, True
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_pad_args(shape:Tuple[int,...], arg:Tuple[Tuple[int, int], ...]):
|
||||
|
|
|
@ -49,7 +49,7 @@ class Node:
|
|||
return create_node(LtNode(lhs, b))
|
||||
def __mul__(self, b:int):
|
||||
if b == 0: return NumNode(0)
|
||||
elif b == 1: return self
|
||||
if b == 1: return self
|
||||
return create_node(MulNode(self, b))
|
||||
|
||||
# *** complex ops ***
|
||||
|
|
|
@ -422,17 +422,16 @@ class Tensor:
|
|||
xup = xup.slice(slc_prefix + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_)))
|
||||
xup = xup.reshape(*prefix, *flatten((k,o) for k,o in zip(k_, o_)))
|
||||
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2+1 for i in range(len(k_))], *[len(prefix)+i*2 for i in range(len(k_))])
|
||||
else:
|
||||
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
|
||||
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
|
||||
xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
|
||||
xup = xup.reshape(*prefix, *([1]*len(_insert_dims)), *flatten(((o, s) for o,s in zip(o_, s_))))
|
||||
if len(_insert_dims):
|
||||
xup = xup.expand(*prefix, *_insert_dims, *flatten(((o, s) for o,s in zip(o_, s_))))
|
||||
prefix += _insert_dims
|
||||
slc_prefix += [(0,x) for x in _insert_dims]
|
||||
xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
|
||||
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
|
||||
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
|
||||
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
|
||||
xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
|
||||
xup = xup.reshape(*prefix, *([1]*len(_insert_dims)), *flatten(((o, s) for o,s in zip(o_, s_))))
|
||||
if len(_insert_dims):
|
||||
xup = xup.expand(*prefix, *_insert_dims, *flatten(((o, s) for o,s in zip(o_, s_))))
|
||||
prefix += _insert_dims
|
||||
slc_prefix += [(0,x) for x in _insert_dims]
|
||||
xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
|
||||
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
|
||||
|
||||
# NOTE: these work for more than 2D
|
||||
def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
|
||||
|
|
Loading…
Reference in New Issue