Handle broadcast flag on gemm (#1103)
parent
cbb5c655e5
commit
2071e53da8
|
@ -17,9 +17,9 @@ def Unsqueeze(data, axes):
|
|||
ptr += 1
|
||||
return data.reshape(new_shape)
|
||||
|
||||
def Gemm(A, B, C=None, alpha=1.0, beta=1.0, transA=0, transB=0):
|
||||
def Gemm(A, B, C=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0):
|
||||
ret = alpha * ((A.transpose() if transA == 1 else A) @ (B.transpose() if transB == 1 else B))
|
||||
if C is not None: ret += beta * C
|
||||
if C is not None: ret += beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(len(ret.shape))][::-1]))
|
||||
return ret
|
||||
|
||||
# TODO: this is copied from tinygrad/nn/__init__.py
|
||||
|
|
Loading…
Reference in New Issue