Handle broadcast flag on gemm (#1103)
parent
cbb5c655e5
commit
2071e53da8
|
@ -17,9 +17,9 @@ def Unsqueeze(data, axes):
|
||||||
ptr += 1
|
ptr += 1
|
||||||
return data.reshape(new_shape)
|
return data.reshape(new_shape)
|
||||||
|
|
||||||
def Gemm(A, B, C=None, alpha=1.0, beta=1.0, transA=0, transB=0):
|
def Gemm(A, B, C=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0):
|
||||||
ret = alpha * ((A.transpose() if transA == 1 else A) @ (B.transpose() if transB == 1 else B))
|
ret = alpha * ((A.transpose() if transA == 1 else A) @ (B.transpose() if transB == 1 else B))
|
||||||
if C is not None: ret += beta * C
|
if C is not None: ret += beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(len(ret.shape))][::-1]))
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
# TODO: this is copied from tinygrad/nn/__init__.py
|
# TODO: this is copied from tinygrad/nn/__init__.py
|
||||||
|
|
Loading…
Reference in New Issue