Make cross_process use cloudpickle (#1118)
* fix syntax issues in imagenet_download.py * use cloudpickle in cross_process to make it work in Python 3.9+ * add cross_process test * prevent unpickling on every function call * add cloudpickle to setup.py * add support for args/kwargspull/1025/head^2
parent
c709dec8b5
commit
b4ce23e4b8
|
@ -23,7 +23,7 @@ def imagenet_prepare_val():
|
|||
# Create folders and move files into those
|
||||
for co,dir in enumerate(labels):
|
||||
os.makedirs(Path(__file__).parent.parent / "datasets/imagenet/val" / dir, exist_ok=True)
|
||||
os.replace(Path(__file__).parent.parent / "datasets/imagenet/val" / images[co], Path(__file__).parent.parent / "datasets/imagenet/val" / dir / images[co], exist_ok=True)
|
||||
os.replace(Path(__file__).parent.parent / "datasets/imagenet/val" / images[co], Path(__file__).parent.parent / "datasets/imagenet/val" / dir / images[co])
|
||||
os.remove(Path(__file__).parent.parent / "datasets/imagenet/imagenet_2012_validation_synset_labels.txt")
|
||||
|
||||
def imagenet_prepare_train():
|
||||
|
@ -45,7 +45,7 @@ if __name__ == "__main__":
|
|||
download_file("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar", Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_val.tar") # 7GB
|
||||
imagenet_extract(Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_val.tar", Path(__file__).parent.parent / "datasets/imagenet/val")
|
||||
imagenet_prepare_val()
|
||||
if os.getenv['IMGNET_TRAIN'] is not None:
|
||||
if os.getenv('IMGNET_TRAIN', None) is not None:
|
||||
download_file("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar", Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_train.tar") #138GB!
|
||||
imagenet_extract(Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_train.tar", Path(__file__).parent.parent / "datasets/imagenet/train")
|
||||
imagenet_prepare_train()
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from tinygrad.helpers import Timing
|
||||
from typing import Any
|
||||
import cloudpickle
|
||||
import subprocess
|
||||
import multiprocessing
|
||||
|
||||
|
@ -28,10 +30,23 @@ def proc(itermaker, q) -> None:
|
|||
q.put(None)
|
||||
q.close()
|
||||
|
||||
class _CloudpickleFunctionWrapper:
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def __getstate__(self):
|
||||
return cloudpickle.dumps(self.fn)
|
||||
|
||||
def __setstate__(self, pfn):
|
||||
self.fn = cloudpickle.loads(pfn)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
def cross_process(itermaker, maxsize=16):
|
||||
# TODO: use cloudpickle for itermaker
|
||||
q: multiprocessing.Queue = multiprocessing.Queue(maxsize)
|
||||
p = multiprocessing.Process(target=proc, args=(itermaker, q))
|
||||
# multiprocessing uses pickle which cannot dump lambdas, so use cloudpickle.
|
||||
p = multiprocessing.Process(target=proc, args=(_CloudpickleFunctionWrapper(itermaker), q))
|
||||
#p.daemon = True
|
||||
p.start()
|
||||
|
||||
|
|
1
setup.py
1
setup.py
|
@ -42,6 +42,7 @@ setup(name='tinygrad',
|
|||
"tabulate",
|
||||
"safetensors",
|
||||
"types-PyYAML",
|
||||
"cloudpickle",
|
||||
],
|
||||
},
|
||||
include_package_data=True)
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
#!/usr/bin/env python
|
||||
import multiprocessing
|
||||
import unittest
|
||||
from extra.helpers import cross_process
|
||||
|
||||
class TestCrossProcess(unittest.TestCase):
|
||||
def test_cross_process(self):
|
||||
def _iterate():
|
||||
for i in range(3): yield i
|
||||
|
||||
ret = cross_process(lambda: _iterate())
|
||||
assert len(list(ret)) == 3
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue