1
0
Fork 0

Make cross_process use cloudpickle (#1118)

* fix syntax issues in imagenet_download.py

* use cloudpickle in cross_process to make it work in Python 3.9+

* add cross_process test

* prevent unpickling on every function call

* add cloudpickle to setup.py

* add support for args/kwargs
pull/1025/head^2
Daniel Hipke 2023-07-04 00:47:34 -07:00 committed by GitHub
parent c709dec8b5
commit b4ce23e4b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 35 additions and 4 deletions

View File

@ -23,7 +23,7 @@ def imagenet_prepare_val():
# Create folders and move files into those
for co,dir in enumerate(labels):
os.makedirs(Path(__file__).parent.parent / "datasets/imagenet/val" / dir, exist_ok=True)
os.replace(Path(__file__).parent.parent / "datasets/imagenet/val" / images[co], Path(__file__).parent.parent / "datasets/imagenet/val" / dir / images[co], exist_ok=True)
os.replace(Path(__file__).parent.parent / "datasets/imagenet/val" / images[co], Path(__file__).parent.parent / "datasets/imagenet/val" / dir / images[co])
os.remove(Path(__file__).parent.parent / "datasets/imagenet/imagenet_2012_validation_synset_labels.txt")
def imagenet_prepare_train():
@ -45,7 +45,7 @@ if __name__ == "__main__":
download_file("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar", Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_val.tar") # 7GB
imagenet_extract(Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_val.tar", Path(__file__).parent.parent / "datasets/imagenet/val")
imagenet_prepare_val()
if os.getenv['IMGNET_TRAIN'] is not None:
if os.getenv('IMGNET_TRAIN', None) is not None:
download_file("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar", Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_train.tar") #138GB!
imagenet_extract(Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_train.tar", Path(__file__).parent.parent / "datasets/imagenet/train")
imagenet_prepare_train()

View File

@ -1,4 +1,6 @@
from tinygrad.helpers import Timing
from typing import Any
import cloudpickle
import subprocess
import multiprocessing
@ -28,10 +30,23 @@ def proc(itermaker, q) -> None:
q.put(None)
q.close()
class _CloudpickleFunctionWrapper:
def __init__(self, fn):
self.fn = fn
def __getstate__(self):
return cloudpickle.dumps(self.fn)
def __setstate__(self, pfn):
self.fn = cloudpickle.loads(pfn)
def __call__(self, *args, **kwargs) -> Any:
return self.fn(*args, **kwargs)
def cross_process(itermaker, maxsize=16):
# TODO: use cloudpickle for itermaker
q: multiprocessing.Queue = multiprocessing.Queue(maxsize)
p = multiprocessing.Process(target=proc, args=(itermaker, q))
# multiprocessing uses pickle which cannot dump lambdas, so use cloudpickle.
p = multiprocessing.Process(target=proc, args=(_CloudpickleFunctionWrapper(itermaker), q))
#p.daemon = True
p.start()

View File

@ -42,6 +42,7 @@ setup(name='tinygrad',
"tabulate",
"safetensors",
"types-PyYAML",
"cloudpickle",
],
},
include_package_data=True)

View File

@ -0,0 +1,15 @@
#!/usr/bin/env python
import multiprocessing
import unittest
from extra.helpers import cross_process
class TestCrossProcess(unittest.TestCase):
def test_cross_process(self):
def _iterate():
for i in range(3): yield i
ret = cross_process(lambda: _iterate())
assert len(list(ret)) == 3
if __name__ == '__main__':
unittest.main()