1
0
Fork 0
tinygrab/examples/mlperf/model_train.py

37 lines
597 B
Python

from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
def train_resnet():
# TODO: Resnet50-v1.5
pass
def train_retinanet():
# TODO: Retinanet
pass
def train_unet3d():
# TODO: Unet3d
pass
def train_rnnt():
# TODO: RNN-T
pass
def train_bert():
# TODO: BERT
pass
def train_maskrcnn():
# TODO: Mask RCNN
pass
if __name__ == "__main__":
with Tensor.train():
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
nm = f"train_{m}"
if nm in globals():
print(f"training {m}")
globals()[nm]()