parent
5db0cdfbd3
commit
7d26452305
|
@ -39,7 +39,7 @@ jobs:
|
|||
- name: Lint with ruff
|
||||
run: |
|
||||
pip3 install --upgrade --force-reinstall ruff
|
||||
python3 -m ruff .
|
||||
python3 -m ruff . --preview
|
||||
- name: Lint tinygrad with pylint
|
||||
run: python -m pylint tinygrad/
|
||||
- name: Run mypy
|
||||
|
|
|
@ -15,7 +15,7 @@ repos:
|
|||
pass_filenames: false
|
||||
- id: ruff
|
||||
name: ruff
|
||||
entry: ruff .
|
||||
entry: ruff . --preview
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
|
|
|
@ -55,7 +55,7 @@ def compare_tiny_torch(model, model_torch, X, Y):
|
|||
np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}')
|
||||
|
||||
def get_mnist_data():
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
_X_train, _Y_train, X_test, Y_test = fetch_mnist()
|
||||
BS = 32
|
||||
num_classes = 10
|
||||
X = Tensor(X_test[0:BS].astype(np.float32))
|
||||
|
|
|
@ -1229,7 +1229,7 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,:,None,d,e], lambda x: x[i,:,None,o,p])
|
||||
|
||||
def test_slice_fancy_indexing_dim_inject_and_collapse(self):
|
||||
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms()
|
||||
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() # noqa
|
||||
# dim injection and collapse
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[1,b,None,d,1], lambda x: x[1,j,None,o,1])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,b,2,d,None], lambda x: x[None,j,2,o,None])
|
||||
|
|
|
@ -165,7 +165,7 @@ class TestSymbolicShapeExpr(unittest.TestCase):
|
|||
shape = (i+1, 8, 4)
|
||||
strides = (1, (i*4)+4, i+1)
|
||||
st = ShapeTracker((View.create(shape, strides), ))
|
||||
idx, valid = st.expr_idxs(idx)
|
||||
idx, _valid = st.expr_idxs(idx)
|
||||
assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)"
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -169,7 +169,7 @@ class dtypes:
|
|||
def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp)
|
||||
|
||||
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and v.__class__ is not staticmethod}
|
||||
INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}
|
||||
|
||||
class GlobalCounters:
|
||||
|
|
Loading…
Reference in New Issue