2023-02-10 11:08:20 -07:00
|
|
|
#!/usr/bin/env python
|
2023-06-19 10:33:24 -06:00
|
|
|
import io, unittest
|
2023-07-07 14:40:29 -06:00
|
|
|
import os
|
2023-07-23 13:00:16 -06:00
|
|
|
import tempfile
|
2023-07-07 14:40:29 -06:00
|
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
|
2023-06-19 10:33:24 -06:00
|
|
|
import torch
|
|
|
|
import numpy as np
|
2023-07-31 11:35:50 -06:00
|
|
|
from tinygrad.helpers import getenv
|
2023-07-07 14:40:29 -06:00
|
|
|
from extra.utils import fetch, temp, download_file
|
2023-06-03 10:40:43 -06:00
|
|
|
from tinygrad.state import torch_load
|
2023-02-26 18:08:24 -07:00
|
|
|
from PIL import Image
|
2023-02-10 11:08:20 -07:00
|
|
|
|
2023-03-20 00:43:49 -06:00
|
|
|
@unittest.skipIf(getenv("CI", "") != "", "no internet tests in CI")
|
|
|
|
class TestFetch(unittest.TestCase):
|
2023-02-10 11:08:20 -07:00
|
|
|
def test_fetch_bad_http(self):
|
|
|
|
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/500')
|
|
|
|
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/404')
|
|
|
|
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/400')
|
2023-03-18 15:40:23 -06:00
|
|
|
|
2023-02-10 11:08:20 -07:00
|
|
|
def test_fetch_small(self):
|
|
|
|
assert(len(fetch('https://google.com'))>0)
|
|
|
|
|
2023-02-26 18:08:24 -07:00
|
|
|
def test_fetch_img(self):
|
|
|
|
img = fetch("https://media.istockphoto.com/photos/hen-picture-id831791190")
|
|
|
|
pimg = Image.open(io.BytesIO(img))
|
|
|
|
assert pimg.size == (705, 1024)
|
|
|
|
|
2023-07-23 13:00:16 -06:00
|
|
|
class TestFetchRelative(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
|
|
self.working_dir = os.getcwd()
|
|
|
|
self.tempdir = tempfile.TemporaryDirectory()
|
|
|
|
os.chdir(self.tempdir.name)
|
|
|
|
with open('test_file.txt', 'x') as f:
|
|
|
|
f.write("12345")
|
2023-07-31 11:35:50 -06:00
|
|
|
|
2023-07-23 13:00:16 -06:00
|
|
|
def tearDown(self):
|
|
|
|
os.chdir(self.working_dir)
|
|
|
|
self.tempdir.cleanup()
|
|
|
|
|
|
|
|
#test ./
|
|
|
|
def test_fetch_relative_dotslash(self):
|
|
|
|
self.assertEqual(b'12345', fetch("./test_file.txt"))
|
2023-07-31 11:35:50 -06:00
|
|
|
|
2023-07-23 13:00:16 -06:00
|
|
|
#test ../
|
|
|
|
def test_fetch_relative_dotdotslash(self):
|
|
|
|
os.mkdir('test_file_path')
|
|
|
|
os.chdir('test_file_path')
|
|
|
|
self.assertEqual(b'12345', fetch("../test_file.txt"))
|
2023-07-07 14:40:29 -06:00
|
|
|
|
|
|
|
class TestDownloadFile(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
|
|
from pathlib import Path
|
|
|
|
self.test_file = Path(temp("test_download_file/test_file.txt"))
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
os.remove(self.test_file)
|
|
|
|
os.removedirs(self.test_file.parent)
|
|
|
|
|
|
|
|
@patch('requests.get')
|
|
|
|
def test_download_file_with_mkdir(self, mock_requests):
|
|
|
|
mock_response = MagicMock()
|
|
|
|
mock_response.iter_content.return_value = [b'1234', b'5678']
|
|
|
|
mock_response.status_code = 200
|
|
|
|
mock_response.headers = {'content-length': '8'}
|
|
|
|
mock_requests.return_value = mock_response
|
|
|
|
self.assertFalse(os.path.exists(self.test_file.parent))
|
|
|
|
download_file("https://www.mock.com/fake.txt", self.test_file, skip_if_exists=False)
|
|
|
|
self.assertTrue(os.path.exists(self.test_file.parent))
|
|
|
|
self.assertTrue(os.path.isfile(self.test_file))
|
|
|
|
self.assertEqual('12345678', self.test_file.read_text())
|
|
|
|
|
2023-03-20 00:43:49 -06:00
|
|
|
class TestUtils(unittest.TestCase):
|
2023-06-19 10:33:24 -06:00
|
|
|
def test_fake_torch_load_zipped(self): self._test_fake_torch_load_zipped()
|
|
|
|
def test_fake_torch_load_zipped_float16(self): self._test_fake_torch_load_zipped(isfloat16=True)
|
|
|
|
def _test_fake_torch_load_zipped(self, isfloat16=False):
|
2023-03-15 09:59:52 -06:00
|
|
|
class LayerWithOffset(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super(LayerWithOffset, self).__init__()
|
|
|
|
d = torch.randn(16)
|
|
|
|
self.param1 = torch.nn.Parameter(
|
2023-06-03 10:40:43 -06:00
|
|
|
d.as_strided([2, 2], [1, 2], storage_offset=5)
|
2023-03-15 09:59:52 -06:00
|
|
|
)
|
|
|
|
self.param2 = torch.nn.Parameter(
|
2023-06-03 10:40:43 -06:00
|
|
|
d.as_strided([2, 2], [1, 2], storage_offset=4)
|
2023-03-15 09:59:52 -06:00
|
|
|
)
|
|
|
|
|
2023-06-19 10:33:24 -06:00
|
|
|
model = torch.nn.Sequential(
|
|
|
|
torch.nn.Linear(4, 8),
|
|
|
|
torch.nn.Linear(8, 3),
|
|
|
|
LayerWithOffset()
|
|
|
|
)
|
|
|
|
if isfloat16: model = model.half()
|
|
|
|
|
2023-07-31 11:35:50 -06:00
|
|
|
path = temp(f"test_load_{isfloat16}.pt")
|
2023-06-19 10:33:24 -06:00
|
|
|
torch.save(model.state_dict(), path)
|
|
|
|
model2 = torch_load(path)
|
|
|
|
|
|
|
|
for name, a in model.state_dict().items():
|
|
|
|
b = model2[name]
|
|
|
|
a, b = a.numpy(), b.numpy()
|
|
|
|
assert a.shape == b.shape
|
|
|
|
assert a.dtype == b.dtype
|
|
|
|
assert np.array_equal(a, b)
|
2023-02-10 11:08:20 -07:00
|
|
|
if __name__ == '__main__':
|
2023-07-31 11:35:50 -06:00
|
|
|
unittest.main()
|