Make cross_process use cloudpickle (#1118)
* fix syntax issues in imagenet_download.py * use cloudpickle in cross_process to make it work in Python 3.9+ * add cross_process test * prevent unpickling on every function call * add cloudpickle to setup.py * add support for args/kwargspull/1025/head^2
parent
c709dec8b5
commit
b4ce23e4b8
|
@ -23,7 +23,7 @@ def imagenet_prepare_val():
|
||||||
# Create folders and move files into those
|
# Create folders and move files into those
|
||||||
for co,dir in enumerate(labels):
|
for co,dir in enumerate(labels):
|
||||||
os.makedirs(Path(__file__).parent.parent / "datasets/imagenet/val" / dir, exist_ok=True)
|
os.makedirs(Path(__file__).parent.parent / "datasets/imagenet/val" / dir, exist_ok=True)
|
||||||
os.replace(Path(__file__).parent.parent / "datasets/imagenet/val" / images[co], Path(__file__).parent.parent / "datasets/imagenet/val" / dir / images[co], exist_ok=True)
|
os.replace(Path(__file__).parent.parent / "datasets/imagenet/val" / images[co], Path(__file__).parent.parent / "datasets/imagenet/val" / dir / images[co])
|
||||||
os.remove(Path(__file__).parent.parent / "datasets/imagenet/imagenet_2012_validation_synset_labels.txt")
|
os.remove(Path(__file__).parent.parent / "datasets/imagenet/imagenet_2012_validation_synset_labels.txt")
|
||||||
|
|
||||||
def imagenet_prepare_train():
|
def imagenet_prepare_train():
|
||||||
|
@ -45,7 +45,7 @@ if __name__ == "__main__":
|
||||||
download_file("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar", Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_val.tar") # 7GB
|
download_file("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar", Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_val.tar") # 7GB
|
||||||
imagenet_extract(Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_val.tar", Path(__file__).parent.parent / "datasets/imagenet/val")
|
imagenet_extract(Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_val.tar", Path(__file__).parent.parent / "datasets/imagenet/val")
|
||||||
imagenet_prepare_val()
|
imagenet_prepare_val()
|
||||||
if os.getenv['IMGNET_TRAIN'] is not None:
|
if os.getenv('IMGNET_TRAIN', None) is not None:
|
||||||
download_file("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar", Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_train.tar") #138GB!
|
download_file("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar", Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_train.tar") #138GB!
|
||||||
imagenet_extract(Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_train.tar", Path(__file__).parent.parent / "datasets/imagenet/train")
|
imagenet_extract(Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_train.tar", Path(__file__).parent.parent / "datasets/imagenet/train")
|
||||||
imagenet_prepare_train()
|
imagenet_prepare_train()
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
from tinygrad.helpers import Timing
|
from tinygrad.helpers import Timing
|
||||||
|
from typing import Any
|
||||||
|
import cloudpickle
|
||||||
import subprocess
|
import subprocess
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
|
@ -28,10 +30,23 @@ def proc(itermaker, q) -> None:
|
||||||
q.put(None)
|
q.put(None)
|
||||||
q.close()
|
q.close()
|
||||||
|
|
||||||
|
class _CloudpickleFunctionWrapper:
|
||||||
|
def __init__(self, fn):
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
return cloudpickle.dumps(self.fn)
|
||||||
|
|
||||||
|
def __setstate__(self, pfn):
|
||||||
|
self.fn = cloudpickle.loads(pfn)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs) -> Any:
|
||||||
|
return self.fn(*args, **kwargs)
|
||||||
|
|
||||||
def cross_process(itermaker, maxsize=16):
|
def cross_process(itermaker, maxsize=16):
|
||||||
# TODO: use cloudpickle for itermaker
|
|
||||||
q: multiprocessing.Queue = multiprocessing.Queue(maxsize)
|
q: multiprocessing.Queue = multiprocessing.Queue(maxsize)
|
||||||
p = multiprocessing.Process(target=proc, args=(itermaker, q))
|
# multiprocessing uses pickle which cannot dump lambdas, so use cloudpickle.
|
||||||
|
p = multiprocessing.Process(target=proc, args=(_CloudpickleFunctionWrapper(itermaker), q))
|
||||||
#p.daemon = True
|
#p.daemon = True
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -42,6 +42,7 @@ setup(name='tinygrad',
|
||||||
"tabulate",
|
"tabulate",
|
||||||
"safetensors",
|
"safetensors",
|
||||||
"types-PyYAML",
|
"types-PyYAML",
|
||||||
|
"cloudpickle",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
include_package_data=True)
|
include_package_data=True)
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
import multiprocessing
|
||||||
|
import unittest
|
||||||
|
from extra.helpers import cross_process
|
||||||
|
|
||||||
|
class TestCrossProcess(unittest.TestCase):
|
||||||
|
def test_cross_process(self):
|
||||||
|
def _iterate():
|
||||||
|
for i in range(3): yield i
|
||||||
|
|
||||||
|
ret = cross_process(lambda: _iterate())
|
||||||
|
assert len(list(ret)) == 3
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue