Implement onnx MeanVarianceNormalization (#943)
parent
3bb38c3518
commit
664d6cc7e5
|
@ -223,3 +223,8 @@ def CastLike(input, target_type):
|
|||
return input
|
||||
|
||||
def Binarizer(input, threshold=0.0): return input > threshold
|
||||
|
||||
def MeanVarianceNormalization(input, axis=(0, 2, 3)):
|
||||
data_mean = input.mean(axis=axis, keepdim=True)
|
||||
std = ((input**2).mean(axis=axis, keepdim=True) - data_mean**2).sqrt()
|
||||
return (input - data_mean) / (std + 1e-9)
|
||||
|
|
Loading…
Reference in New Issue