1
0
Fork 0

Allow Tensor(tuple) (#911)

pull/843/head
Alexey Zaytsev 2023-06-04 13:48:19 +07:00 committed by GitHub
parent afd0be8a9c
commit d429553730
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -34,10 +34,10 @@ class Tensor:
no_grad: ClassVar[bool] = False
default_type: ClassVar[DType] = dtypes.float32
def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
def __init__(self, data:Union[int, float, list, tuple, LazyBuffer, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
device = Device.canonicalize(device)
if isinstance(data, list):
if isinstance(data, (list, tuple)):
data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np)
if isinstance(data, LazyBuffer):