move state to nn/state (#1619)
parent
1e93fd5449
commit
718ced296c
|
@ -165,7 +165,7 @@ opt = SGD([net.l1.weight, net.l2.weight], lr=3e-4)
|
|||
|
||||
We can see that we are passing in the parameters of our neural network to the optimizer.
|
||||
This is due to the fact that the optimizer needs to know which parameters to update.
|
||||
There is a simpler way to do this just by using `get_parameters(net)` from `tinygrad.state` which will return a list of all the parameters in the neural network.
|
||||
There is a simpler way to do this just by using `get_parameters(net)` from `tinygrad.nn.state` which will return a list of all the parameters in the neural network.
|
||||
The parameters are just listed out explicitly here for clarity.
|
||||
|
||||
Now that we have our network, loss function, and optimizer defined all we are missing is the data to train on!
|
||||
|
@ -291,7 +291,7 @@ The standard weight format for tinygrad is [safetensors](https://github.com/hugg
|
|||
There are functions in [state.py](/tinygrad/state.py) to save and load models to and from this format.
|
||||
|
||||
```python
|
||||
from tinygrad.state import safe_save, safe_load, get_state_dict, load_state_dict
|
||||
from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict
|
||||
|
||||
# first we need the state dict of our model
|
||||
state_dict = get_state_dict(net)
|
||||
|
|
|
@ -3,7 +3,7 @@ import gc
|
|||
import time
|
||||
from tqdm import trange
|
||||
from models.efficientnet import EfficientNet
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import GlobalCounters
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from models.efficientnet import EfficientNet
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.state import safe_save
|
||||
from tinygrad.nn.state import safe_save
|
||||
from extra.utils import fetch
|
||||
from extra.export_model import export_model
|
||||
from tinygrad.helpers import getenv
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Optional, Tuple
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import getenv
|
||||
|
|
|
@ -132,7 +132,7 @@ class GPT2:
|
|||
@staticmethod
|
||||
def build(model_size="gpt2"):
|
||||
import tiktoken
|
||||
from tinygrad.state import torch_load, load_state_dict, get_state_dict
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
from extra.utils import fetch_as_file
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ import random
|
|||
import numpy as np
|
||||
from extra.datasets import fetch_cifar, cifar_mean, cifar_std
|
||||
from tinygrad import nn
|
||||
from tinygrad.state import get_state_dict
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.lazy import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
|
|
|
@ -251,7 +251,7 @@ class LLaMa:
|
|||
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
|
||||
assert sp_model.vocab_size() == VOCAB_SIZE
|
||||
|
||||
from tinygrad.state import torch_load, load_state_dict
|
||||
from tinygrad.nn.state import torch_load, load_state_dict
|
||||
params = MODEL_PARAMS[model_gen][model_size]
|
||||
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear) if quantize else Transformer(**params["args"])
|
||||
weights = concat_weights([torch_load(filename) for filename in [f"{model_path}/{model_size}/consolidated.{i:02d}.pth" for i in range(params["files"])]])
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
from tqdm import trange
|
||||
import torch
|
||||
from torchvision.utils import make_grid, save_image
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.nn import optim
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
#inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb
|
||||
import sys
|
||||
import numpy as np
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import BatchNorm2d, optim
|
||||
from tinygrad.helpers import getenv
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d, BatchNorm2d
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.training = True
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import Tuple, Optional, Type
|
|||
from tinygrad import nn
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, getenv
|
||||
from tinygrad.state import torch_load
|
||||
from tinygrad.nn.state import torch_load
|
||||
from examples.vits import ResidualCouplingBlock, PosteriorEncoder, Encoder, ResBlock1, ResBlock2, LRELU_SLOPE, sequence_mask, split, download_if_not_present, get_hparams_from_file, load_checkpoint, weight_norm, HParams
|
||||
from examples.sovits_helpers import preprocess
|
||||
import soundfile
|
||||
|
|
|
@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.helpers import dtypes, GlobalCounters
|
||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||
from extra.utils import download_file
|
||||
from tinygrad.state import torch_load, load_state_dict, get_state_dict
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
|
||||
class AttnBlock:
|
||||
def __init__(self, in_channels):
|
||||
|
|
|
@ -3,7 +3,7 @@ import time
|
|||
from multiprocessing import Process, Queue
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.tensor import Tensor
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.training import train, evaluate
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import numpy as np
|
||||
import random
|
||||
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.nn.optim import Adam
|
||||
from extra.training import train, evaluate
|
||||
from models.transformer import Transformer
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import List
|
|||
from extra.utils import download_file
|
||||
from tinygrad import nn
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.state import torch_load
|
||||
from tinygrad.nn.state import torch_load
|
||||
from tinygrad.tensor import Tensor
|
||||
from unidecode import unidecode
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import multiprocessing
|
|||
import numpy as np
|
||||
from typing import Optional
|
||||
from extra.utils import download_file
|
||||
from tinygrad.state import torch_load, load_state_dict
|
||||
from tinygrad.nn.state import torch_load, load_state_dict
|
||||
from tinygrad.helpers import getenv
|
||||
import tinygrad.nn as nn
|
||||
from tinygrad.tensor import Tensor
|
||||
|
|
|
@ -8,7 +8,7 @@ import cv2
|
|||
from collections import defaultdict
|
||||
import os
|
||||
import time, io, sys
|
||||
from tinygrad.state import safe_load, load_state_dict
|
||||
from tinygrad.nn.state import safe_load, load_state_dict
|
||||
|
||||
|
||||
#Model architecture from https://github.com/ultralytics/ultralytics/issues/189
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Tuple, Dict, List
|
|||
from tinygrad.helpers import DType
|
||||
from tinygrad.tensor import Device, Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.state import get_state_dict
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
import json
|
||||
|
||||
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
|
||||
|
|
|
@ -143,7 +143,7 @@ class EfficientNet:
|
|||
}
|
||||
|
||||
from extra.utils import fetch_as_file
|
||||
from tinygrad.state import torch_load
|
||||
from tinygrad.nn.state import torch_load
|
||||
b0 = torch_load(fetch_as_file(model_urls[self.number]))
|
||||
for k,v in b0.items():
|
||||
if k.endswith("num_batches_tracked"): continue
|
||||
|
|
|
@ -7,7 +7,7 @@ from tinygrad import nn
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes
|
||||
from extra.utils import get_child, download_file
|
||||
from tinygrad.state import torch_load
|
||||
from tinygrad.nn.state import torch_load
|
||||
from models.resnet import ResNet
|
||||
from models.retinanet import nms as _box_nms
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import unittest, gc
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.state import get_parameters, get_state_dict
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.ops import GlobalCounters, LazyOp, LoadOps
|
||||
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
|
||||
from tinygrad.helpers import dtypes, prod
|
||||
|
|
|
@ -36,7 +36,7 @@ from models.convnext import ConvNeXt
|
|||
from models.efficientnet import EfficientNet
|
||||
from models.resnet import ResNet18
|
||||
from models.vit import ViT
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
class TestInferenceMinKernels(unittest.TestCase):
|
||||
|
|
|
@ -5,7 +5,7 @@ from examples.llama import Transformer, MODEL_PARAMS
|
|||
from test.test_net_speed import start_profile, stop_profile
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import Device
|
||||
from tinygrad.state import get_state_dict
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.helpers import dtypes, prod
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
|
|
|
@ -6,7 +6,7 @@ import unittest
|
|||
import io, cv2, os
|
||||
import onnxruntime as ort
|
||||
import ultralytics
|
||||
from tinygrad.state import safe_load, load_state_dict
|
||||
from tinygrad.nn.state import safe_load, load_state_dict
|
||||
|
||||
class TestYOLOv8(unittest.TestCase):
|
||||
def test_all_load_weights(self):
|
||||
|
@ -74,4 +74,3 @@ class TestYOLOv8(unittest.TestCase):
|
|||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
import unittest
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d, BatchNorm2d, optim
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import numpy as np
|
|||
import torch
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.nn.optim import Adam
|
||||
from extra.lr_scheduler import MultiStepLR, ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR
|
||||
from extra.training import train, evaluate
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
|||
import numpy as np
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.utils import fetch, temp, download_file
|
||||
from tinygrad.state import torch_load
|
||||
from tinygrad.nn.state import torch_load
|
||||
from PIL import Image
|
||||
|
||||
@unittest.skipIf(getenv("CI", "") != "", "no internet tests in CI")
|
||||
|
|
|
@ -13,7 +13,7 @@ def get_question_samp(bsz, seq_len, vocab_size, seed):
|
|||
return in_ids, mask, seg_ids
|
||||
|
||||
def set_equal_weights(mdl, torch_mdl):
|
||||
from tinygrad.state import get_state_dict
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
state, torch_state = get_state_dict(mdl), torch_mdl.state_dict()
|
||||
assert len(state) == len(torch_state)
|
||||
for k, v in state.items():
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch import nn
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.state import get_parameters, get_state_dict
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d
|
||||
from tinygrad.tensor import Tensor
|
||||
from extra.datasets import fetch_mnist
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.nn import optim, BatchNorm2d
|
||||
from extra.training import train, evaluate
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import unittest, time
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
||||
from tinygrad.ops import GlobalCounters, LazyOp, LoadOps
|
||||
from tinygrad.lazy import Device
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
import time
|
||||
import numpy as np
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.tensor import Device
|
||||
from tinygrad.helpers import getenv
|
||||
|
|
|
@ -2,7 +2,7 @@ import pathlib
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
from tinygrad.helpers import Timing
|
||||
|
|
Loading…
Reference in New Issue