1
0
Fork 0

Add commit hooks (#478)

* Add pre-commit hook

* We need ret

* Fix some type definitions
pull/479/head
Jacky Lee 2023-01-26 22:24:31 -08:00 committed by GitHub
parent c07bc39941
commit 026ba78526
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 22 additions and 5 deletions

View File

@ -0,0 +1,15 @@
repos:
- repo: local
hooks:
- id: pylint
name: pylint
entry: pylint tinygrad/
language: system
always_run: true
pass_filenames: false
- id: mypy
name: mypy
entry: mypy tinygrad/ --ignore-missing-imports
language: system
always_run: true
pass_filenames: false

View File

@ -32,6 +32,8 @@ setup(name='tinygrad',
"onnx",
"onnx2torch",
"mypy",
"pylint",
"pre-commit",
],
},
include_package_data=True)

View File

@ -116,12 +116,12 @@ class ASTKernel:
for i in range(1, len(shapes[0])):
can_merge = []
for j in range(len(shapes)):
# TODO: added the always mergability of 1s, is this right? if so, add to shapetracker in the 1 case
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
can_merge.append((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*strides[j][i]) or (strides[j][i] == 0 and rets[j][-1][1] == 0))
# more can merge than this
can_merge = all(can_merge) and i != self.first_reduce
mergeable = all(can_merge) and i != self.first_reduce
for j in range(len(shapes)):
if can_merge:
if mergeable:
rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
else:
rets[j].append((shapes[j][i], strides[j][i]))

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import os
import functools
from typing import Tuple, Union, List, Optional
from typing import Tuple, Union, List, Optional, Any
from tinygrad.helpers import prod
from tinygrad.shape.symbolic import Variable
@ -21,7 +21,7 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
return ret
class View:
def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0):
def __init__(self, shape:Union[Tuple[int, ...],List[Any]], strides:Union[Tuple[int, ...],List[Any]], offset:int=0):
self.shape, self.strides, self.offset = tuple(shape), tuple(strides), offset
self.shape_strides = to_shape_strides(self.shape, self.strides)