1
0
Fork 0

style: else-after-return (#1216)

Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
pull/1077/head^2
Roelof van Dijk 2023-07-12 19:26:38 +02:00 committed by GitHub
parent ab663c46e8
commit 8f2e2f5ee2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 33 additions and 42 deletions

View File

@ -64,7 +64,7 @@ disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=c-extension-no-member,use-a-generator
enable=c-extension-no-member,use-a-generator, no-else-return
[REPORTS]

View File

@ -188,7 +188,7 @@ class BoxList:
if self.mode == "xyxy":
xmin, ymin, xmax, ymax = self.bbox.chunk(4, dim=-1)
return xmin, ymin, xmax, ymax
elif self.mode == "xywh":
if self.mode == "xywh":
TO_REMOVE = 1
xmin, ymin, w, h = self.bbox.chunk(4, dim=-1)
return (

View File

@ -96,9 +96,8 @@ class AssemblyCodegen(Linearizer):
ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
reg = new_reg
return tor[f"buf{args.i}"], reg, off
else:
reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64)
return reg, None, off
reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64)
return reg, None, off
ins = []
ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64, scalar=True), [], f"buf{i}") for i in range(len(self.bufs))]

View File

@ -30,7 +30,7 @@ class CStyleLanguage(NamedTuple):
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
assert self.float4 is not None, "cast is not supported on this platform"
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
elif var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
raise NotImplementedError(f"no cast for {var_dtype}")
# returns a str expression of the const with the given type
@ -45,24 +45,22 @@ class CStyleLanguage(NamedTuple):
if isinstance(buf_dtype, ImageDType):
assert output_dtype == dtypes._float4, "images must be float4"
return f"read_imagef({buf_name}, smp, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}))"
elif self.uses_vload and buf_dtype == dtypes.float16:
if self.uses_vload and buf_dtype == dtypes.float16:
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx.render(render_cl, strip_parens=True)})"
elif output_dtype.sz > 1:
if output_dtype.sz > 1:
return f"({output_dtype.name})(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})))"
else:
return f"{buf_name}[{idx.render(render_cl)}]"
return f"{buf_name}[{idx.render(render_cl)}]"
# returns a str statement that does the store
def render_store(self, buf_name, buf_dtype, var_name, var_dtype, idx, local=False) -> str:
if isinstance(buf_dtype, ImageDType):
assert var_dtype == dtypes._float4, "images must be float4"
return f"write_imagef({buf_name}, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}), {var_name});"
elif self.uses_vload and buf_dtype == dtypes.float16:
if self.uses_vload and buf_dtype == dtypes.float16:
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx.render(render_cl, strip_parens=True)});"
elif var_dtype.sz > 1:
if var_dtype.sz > 1:
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
else:
return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.EXP2: lambda x: f"exp2({x})",
@ -85,9 +83,8 @@ def add_gl_dimension(args, i, var, local_size, xid):
lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1)
assert lidx.max == var.max and lidx.min == var.min
return f"{{ int {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
else:
local_size.append(var.max+1)
return f"{{ int {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
local_size.append(var.max+1)
return f"{{ int {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
prekernel: Set[str] = set()

View File

@ -101,10 +101,10 @@ def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True):
# TODO: generic visitor pattern?
def expand_node(idx:Node) -> List[Node]:
if isinstance(idx, Variable): return [idx] if idx.expr is not None else [Variable.num(j) for j in range(idx.min, idx.max+1)]
elif isinstance(idx, NumNode): return [idx]
elif isinstance(idx, MulNode): return [x*idx.b for x in expand_node(idx.a)]
elif isinstance(idx, SumNode): return [Variable.sum(list(it)) for it in itertools.product(*[expand_node(x) for x in idx.nodes])]
else: raise NotImplementedError(idx)
if isinstance(idx, NumNode): return [idx]
if isinstance(idx, MulNode): return [x*idx.b for x in expand_node(idx.a)]
if isinstance(idx, SumNode): return [Variable.sum(list(it)) for it in itertools.product(*[expand_node(x) for x in idx.nodes])]
raise NotImplementedError(idx)
def expand_idxs(idxs:Sequence[Node]) -> Iterator[Tuple[Node, ...]]:
for x in itertools.product(*[expand_node(idx) for idx in idxs[::-1]]):

View File

@ -104,8 +104,7 @@ class Interpreted:
assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype
output.output_buffer._buf = ret._buf
return output.output_buffer
else:
return ret
return ret
class FlopCounter:
def __init__(self, tup:Tuple[Tuple[int, ...], DType, int]): self.shape, self.dtype, self.flops, self._buf = *tup, self

View File

@ -70,8 +70,7 @@ class MetalProgram:
if wait:
command_buffer.waitUntilCompleted()
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
else:
METAL.mtl_buffers_in_flight.append(command_buffer)
METAL.mtl_buffers_in_flight.append(command_buffer)
class MetalCodegen(CStyleCodegen):
lang = CStyleLanguage(

View File

@ -119,11 +119,9 @@ def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]:
new_view = View(new_shape, strides_for_shape(new_shape))
if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset
else:
if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False
else:
if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}")
return new_view, True
if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False
if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}")
return new_view, True
@functools.lru_cache(maxsize=None)
def get_pad_args(shape:Tuple[int,...], arg:Tuple[Tuple[int, int], ...]):

View File

@ -49,7 +49,7 @@ class Node:
return create_node(LtNode(lhs, b))
def __mul__(self, b:int):
if b == 0: return NumNode(0)
elif b == 1: return self
if b == 1: return self
return create_node(MulNode(self, b))
# *** complex ops ***

View File

@ -422,17 +422,16 @@ class Tensor:
xup = xup.slice(slc_prefix + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_)))
xup = xup.reshape(*prefix, *flatten((k,o) for k,o in zip(k_, o_)))
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2+1 for i in range(len(k_))], *[len(prefix)+i*2 for i in range(len(k_))])
else:
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
xup = xup.reshape(*prefix, *([1]*len(_insert_dims)), *flatten(((o, s) for o,s in zip(o_, s_))))
if len(_insert_dims):
xup = xup.expand(*prefix, *_insert_dims, *flatten(((o, s) for o,s in zip(o_, s_))))
prefix += _insert_dims
slc_prefix += [(0,x) for x in _insert_dims]
xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
xup = xup.reshape(*prefix, *([1]*len(_insert_dims)), *flatten(((o, s) for o,s in zip(o_, s_))))
if len(_insert_dims):
xup = xup.expand(*prefix, *_insert_dims, *flatten(((o, s) for o,s in zip(o_, s_))))
prefix += _insert_dims
slc_prefix += [(0,x) for x in _insert_dims]
xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
# NOTE: these work for more than 2D
def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))