1
0
Fork 0

Implement onnx MeanVarianceNormalization (#943)

pull/948/head
M4tthewDE 2023-06-06 19:28:19 +02:00 committed by GitHub
parent 3bb38c3518
commit 664d6cc7e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 1 deletions

View File

@ -222,4 +222,9 @@ def CastLike(input, target_type):
assert isinstance(target_type, Tensor), "can only CastLike Tensor"
return input
def Binarizer(input, threshold=0.0): return input > threshold
def Binarizer(input, threshold=0.0): return input > threshold
def MeanVarianceNormalization(input, axis=(0, 2, 3)):
data_mean = input.mean(axis=axis, keepdim=True)
std = ((input**2).mean(axis=axis, keepdim=True) - data_mean**2).sqrt()
return (input - data_mean) / (std + 1e-9)