style: else-after-return (#1216)
Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>pull/1077/head^2
parent
ab663c46e8
commit
8f2e2f5ee2
|
@ -64,7 +64,7 @@ disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1
|
|||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
enable=c-extension-no-member,use-a-generator
|
||||
enable=c-extension-no-member,use-a-generator, no-else-return
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
|
|
@ -188,7 +188,7 @@ class BoxList:
|
|||
if self.mode == "xyxy":
|
||||
xmin, ymin, xmax, ymax = self.bbox.chunk(4, dim=-1)
|
||||
return xmin, ymin, xmax, ymax
|
||||
elif self.mode == "xywh":
|
||||
if self.mode == "xywh":
|
||||
TO_REMOVE = 1
|
||||
xmin, ymin, w, h = self.bbox.chunk(4, dim=-1)
|
||||
return (
|
||||
|
|
|
@ -96,7 +96,6 @@ class AssemblyCodegen(Linearizer):
|
|||
ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
|
||||
reg = new_reg
|
||||
return tor[f"buf{args.i}"], reg, off
|
||||
else:
|
||||
reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64)
|
||||
return reg, None, off
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ class CStyleLanguage(NamedTuple):
|
|||
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
|
||||
assert self.float4 is not None, "cast is not supported on this platform"
|
||||
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
|
||||
elif var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
|
||||
if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
|
||||
raise NotImplementedError(f"no cast for {var_dtype}")
|
||||
|
||||
# returns a str expression of the const with the given type
|
||||
|
@ -45,11 +45,10 @@ class CStyleLanguage(NamedTuple):
|
|||
if isinstance(buf_dtype, ImageDType):
|
||||
assert output_dtype == dtypes._float4, "images must be float4"
|
||||
return f"read_imagef({buf_name}, smp, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}))"
|
||||
elif self.uses_vload and buf_dtype == dtypes.float16:
|
||||
if self.uses_vload and buf_dtype == dtypes.float16:
|
||||
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx.render(render_cl, strip_parens=True)})"
|
||||
elif output_dtype.sz > 1:
|
||||
if output_dtype.sz > 1:
|
||||
return f"({output_dtype.name})(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})))"
|
||||
else:
|
||||
return f"{buf_name}[{idx.render(render_cl)}]"
|
||||
|
||||
# returns a str statement that does the store
|
||||
|
@ -57,11 +56,10 @@ class CStyleLanguage(NamedTuple):
|
|||
if isinstance(buf_dtype, ImageDType):
|
||||
assert var_dtype == dtypes._float4, "images must be float4"
|
||||
return f"write_imagef({buf_name}, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}), {var_name});"
|
||||
elif self.uses_vload and buf_dtype == dtypes.float16:
|
||||
if self.uses_vload and buf_dtype == dtypes.float16:
|
||||
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx.render(render_cl, strip_parens=True)});"
|
||||
elif var_dtype.sz > 1:
|
||||
if var_dtype.sz > 1:
|
||||
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
|
||||
else:
|
||||
return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
|
||||
|
||||
code_for_op: Final[Dict[Op, Callable]] = {
|
||||
|
@ -85,7 +83,6 @@ def add_gl_dimension(args, i, var, local_size, xid):
|
|||
lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1)
|
||||
assert lidx.max == var.max and lidx.min == var.min
|
||||
return f"{{ int {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
|
||||
else:
|
||||
local_size.append(var.max+1)
|
||||
return f"{{ int {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
|
||||
|
||||
|
|
|
@ -101,10 +101,10 @@ def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True):
|
|||
# TODO: generic visitor pattern?
|
||||
def expand_node(idx:Node) -> List[Node]:
|
||||
if isinstance(idx, Variable): return [idx] if idx.expr is not None else [Variable.num(j) for j in range(idx.min, idx.max+1)]
|
||||
elif isinstance(idx, NumNode): return [idx]
|
||||
elif isinstance(idx, MulNode): return [x*idx.b for x in expand_node(idx.a)]
|
||||
elif isinstance(idx, SumNode): return [Variable.sum(list(it)) for it in itertools.product(*[expand_node(x) for x in idx.nodes])]
|
||||
else: raise NotImplementedError(idx)
|
||||
if isinstance(idx, NumNode): return [idx]
|
||||
if isinstance(idx, MulNode): return [x*idx.b for x in expand_node(idx.a)]
|
||||
if isinstance(idx, SumNode): return [Variable.sum(list(it)) for it in itertools.product(*[expand_node(x) for x in idx.nodes])]
|
||||
raise NotImplementedError(idx)
|
||||
|
||||
def expand_idxs(idxs:Sequence[Node]) -> Iterator[Tuple[Node, ...]]:
|
||||
for x in itertools.product(*[expand_node(idx) for idx in idxs[::-1]]):
|
||||
|
|
|
@ -104,7 +104,6 @@ class Interpreted:
|
|||
assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype
|
||||
output.output_buffer._buf = ret._buf
|
||||
return output.output_buffer
|
||||
else:
|
||||
return ret
|
||||
|
||||
class FlopCounter:
|
||||
|
|
|
@ -70,7 +70,6 @@ class MetalProgram:
|
|||
if wait:
|
||||
command_buffer.waitUntilCompleted()
|
||||
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
else:
|
||||
METAL.mtl_buffers_in_flight.append(command_buffer)
|
||||
|
||||
class MetalCodegen(CStyleCodegen):
|
||||
|
|
|
@ -119,9 +119,7 @@ def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]:
|
|||
|
||||
new_view = View(new_shape, strides_for_shape(new_shape))
|
||||
if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset
|
||||
else:
|
||||
if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False
|
||||
else:
|
||||
if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}")
|
||||
return new_view, True
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ class Node:
|
|||
return create_node(LtNode(lhs, b))
|
||||
def __mul__(self, b:int):
|
||||
if b == 0: return NumNode(0)
|
||||
elif b == 1: return self
|
||||
if b == 1: return self
|
||||
return create_node(MulNode(self, b))
|
||||
|
||||
# *** complex ops ***
|
||||
|
|
|
@ -422,7 +422,6 @@ class Tensor:
|
|||
xup = xup.slice(slc_prefix + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_)))
|
||||
xup = xup.reshape(*prefix, *flatten((k,o) for k,o in zip(k_, o_)))
|
||||
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2+1 for i in range(len(k_))], *[len(prefix)+i*2 for i in range(len(k_))])
|
||||
else:
|
||||
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
|
||||
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
|
||||
xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
|
||||
|
|
Loading…
Reference in New Issue