1
0
Fork 0

Handle broadcast flag on gemm (#1103)

pull/1117/head
Frank Pinnola 2023-07-03 01:15:07 -04:00 committed by GitHub
parent cbb5c655e5
commit 2071e53da8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -17,9 +17,9 @@ def Unsqueeze(data, axes):
ptr += 1
return data.reshape(new_shape)
def Gemm(A, B, C=None, alpha=1.0, beta=1.0, transA=0, transB=0):
def Gemm(A, B, C=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0):
ret = alpha * ((A.transpose() if transA == 1 else A) @ (B.transpose() if transB == 1 else B))
if C is not None: ret += beta * C
if C is not None: ret += beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(len(ret.shape))][::-1]))
return ret
# TODO: this is copied from tinygrad/nn/__init__.py