531 lines
18 KiB
Python
531 lines
18 KiB
Python
# Owner(s): ["oncall: pt2"]
|
|
|
|
import itertools
|
|
import sys
|
|
|
|
import sympy
|
|
from typing import Callable, List, Tuple, Type
|
|
from torch.testing._internal.common_device_type import skipIf
|
|
from torch.testing._internal.common_utils import (
|
|
TEST_Z3,
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
TestCase,
|
|
)
|
|
from torch.utils._sympy.functions import FloorDiv
|
|
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
|
|
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
|
|
from torch.utils._sympy.reference import ReferenceAnalysis
|
|
from torch.utils._sympy.interp import sympy_interp
|
|
|
|
|
|
UNARY_OPS = [
|
|
"reciprocal",
|
|
"square",
|
|
"abs",
|
|
"neg",
|
|
"exp",
|
|
"log",
|
|
"sqrt",
|
|
"floor",
|
|
"ceil",
|
|
]
|
|
BINARY_OPS = ["truediv", "div", "floordiv", "truncdiv", "add", "mul", "sub", "pow", "minimum", "maximum", "mod"]
|
|
|
|
UNARY_BOOL_OPS = ["not_"]
|
|
BINARY_BOOL_OPS = ["or_", "and_"]
|
|
COMPARE_OPS = ["eq", "ne", "lt", "gt", "le", "ge"]
|
|
|
|
# a mix of constants, powers of two, primes
|
|
CONSTANTS = [
|
|
-1,
|
|
0,
|
|
1,
|
|
2,
|
|
3,
|
|
4,
|
|
5,
|
|
8,
|
|
16,
|
|
32,
|
|
64,
|
|
100,
|
|
101,
|
|
2**24,
|
|
2**32,
|
|
2**37 - 1,
|
|
sys.maxsize - 1,
|
|
sys.maxsize,
|
|
]
|
|
# less constants for N^2 situations
|
|
LESS_CONSTANTS = [-1, 0, 1, 2, 100]
|
|
# SymPy relational types.
|
|
RELATIONAL_TYPES = [sympy.Eq, sympy.Ne, sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le]
|
|
|
|
|
|
def valid_unary(fn, v):
|
|
if fn == "log" and v <= 0:
|
|
return False
|
|
elif fn == "reciprocal" and v == 0:
|
|
return False
|
|
elif fn == "sqrt" and v < 0:
|
|
return False
|
|
return True
|
|
|
|
|
|
def valid_binary(fn, a, b):
|
|
if fn == "pow" and (
|
|
b > 4
|
|
or ( # sympy will expand to x*x*... for integral b; don't do it if it's big
|
|
a <= 0 and b == -1
|
|
)
|
|
or (a == b == 0) # no imaginary numbers # 0**0 is undefined
|
|
):
|
|
return False
|
|
elif fn == "mod" and b == 0:
|
|
return False
|
|
elif (fn == "div" or fn == "truediv") and b == 0:
|
|
return False
|
|
return True
|
|
|
|
|
|
def generate_range(vals):
|
|
for a1, a2 in itertools.product(vals, repeat=2):
|
|
if a1 in [sympy.true, sympy.false]:
|
|
if a1 == sympy.true and a2 == sympy.false:
|
|
continue
|
|
else:
|
|
if a1 > a2:
|
|
continue
|
|
# ranges that only admit infinite values are not interesting
|
|
if a1 == sympy.oo or a2 == -sympy.oo:
|
|
continue
|
|
yield ValueRanges(a1, a2)
|
|
|
|
|
|
class TestValueRanges(TestCase):
|
|
@parametrize("fn", UNARY_OPS)
|
|
@parametrize("dtype", ("int", "float"))
|
|
def test_unary_ref(self, fn, dtype):
|
|
dtype = {"int": sympy.Integer, "float": sympy.Float}[dtype]
|
|
for v in CONSTANTS:
|
|
if not valid_unary(fn, v):
|
|
continue
|
|
with self.subTest(v=v):
|
|
v = dtype(v)
|
|
ref_r = getattr(ReferenceAnalysis, fn)(v)
|
|
r = getattr(ValueRangeAnalysis, fn)(v)
|
|
self.assertEqual(r.lower.is_integer, r.upper.is_integer)
|
|
self.assertEqual(r.lower, r.upper)
|
|
self.assertEqual(ref_r.is_integer, r.upper.is_integer)
|
|
self.assertEqual(ref_r, r.lower)
|
|
|
|
def test_pow_half(self):
|
|
ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5))
|
|
|
|
@parametrize("fn", BINARY_OPS)
|
|
@parametrize("dtype_a", ("int", "float"))
|
|
@parametrize("dtype_b", ("int", "float"))
|
|
def test_binary_ref(self, fn, dtype_a, dtype_b):
|
|
to_dtype = {"int": sympy.Integer, "float": sympy.Float}
|
|
dtype_a = to_dtype[dtype_a]
|
|
dtype_b = to_dtype[dtype_b]
|
|
for a, b in itertools.product(CONSTANTS, repeat=2):
|
|
if not valid_binary(fn, a, b):
|
|
continue
|
|
a = dtype_a(a)
|
|
b = dtype_b(b)
|
|
with self.subTest(a=a, b=b):
|
|
r = getattr(ValueRangeAnalysis, fn)(a, b)
|
|
if r == ValueRanges.unknown():
|
|
continue
|
|
ref_r = getattr(ReferenceAnalysis, fn)(a, b)
|
|
|
|
# sympy.floordiv does 1.0 // 1.0 == 1 rather than 1.0. wtf
|
|
if fn != "floordiv":
|
|
self.assertEqual(r.lower.is_integer, r.upper.is_integer)
|
|
self.assertEqual(ref_r.is_integer, r.upper.is_integer)
|
|
self.assertEqual(r.lower, r.upper)
|
|
self.assertEqual(ref_r, r.lower)
|
|
|
|
def test_mul_zero_unknown(self):
|
|
self.assertEqual(
|
|
ValueRangeAnalysis.mul(ValueRanges.wrap(0), ValueRanges.unknown()),
|
|
ValueRanges.wrap(0),
|
|
)
|
|
|
|
@parametrize("fn", UNARY_BOOL_OPS)
|
|
def test_unary_bool_ref_range(self, fn):
|
|
vals = [sympy.false, sympy.true]
|
|
for a in generate_range(vals):
|
|
with self.subTest(a=a):
|
|
ref_r = getattr(ValueRangeAnalysis, fn)(a)
|
|
unique = set()
|
|
for a0 in vals:
|
|
if a0 not in a:
|
|
continue
|
|
with self.subTest(a0=a0):
|
|
r = getattr(ReferenceAnalysis, fn)(a0)
|
|
self.assertIn(r, ref_r)
|
|
unique.add(r)
|
|
if ref_r.lower == ref_r.upper:
|
|
self.assertEqual(len(unique), 1)
|
|
else:
|
|
self.assertEqual(len(unique), 2)
|
|
|
|
@parametrize("fn", BINARY_BOOL_OPS)
|
|
def test_binary_bool_ref_range(self, fn):
|
|
vals = [sympy.false, sympy.true]
|
|
for a, b in itertools.product(generate_range(vals), repeat=2):
|
|
with self.subTest(a=a, b=b):
|
|
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
|
|
unique = set()
|
|
for a0, b0 in itertools.product(vals, repeat=2):
|
|
if a0 not in a or b0 not in b:
|
|
continue
|
|
with self.subTest(a0=a0, b0=b0):
|
|
r = getattr(ReferenceAnalysis, fn)(a0, b0)
|
|
self.assertIn(r, ref_r)
|
|
unique.add(r)
|
|
if ref_r.lower == ref_r.upper:
|
|
self.assertEqual(len(unique), 1)
|
|
else:
|
|
self.assertEqual(len(unique), 2)
|
|
|
|
@parametrize("fn", UNARY_OPS)
|
|
def test_unary_ref_range(self, fn):
|
|
vals = [-sympy.oo, *CONSTANTS, sympy.oo]
|
|
for a in generate_range(vals):
|
|
with self.subTest(a=a):
|
|
ref_r = getattr(ValueRangeAnalysis, fn)(a)
|
|
for a0 in CONSTANTS:
|
|
if a0 not in a:
|
|
continue
|
|
if not valid_unary(fn, a0):
|
|
continue
|
|
with self.subTest(a0=a0):
|
|
r = getattr(ReferenceAnalysis, fn)(sympy.Integer(a0))
|
|
self.assertIn(r, ref_r)
|
|
|
|
# This takes about 4s for all the variants
|
|
@parametrize("fn", BINARY_OPS + COMPARE_OPS)
|
|
def test_binary_ref_range(self, fn):
|
|
vals = [-sympy.oo, *LESS_CONSTANTS, sympy.oo]
|
|
for a, b in itertools.product(generate_range(vals), repeat=2):
|
|
# don't attempt pow on exponents that are too large (but oo is OK)
|
|
if fn == "pow" and b.upper > 4 and b.upper != sympy.oo:
|
|
continue
|
|
with self.subTest(a=a, b=b):
|
|
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
|
|
for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2):
|
|
if a0 not in a or b0 not in b:
|
|
continue
|
|
if not valid_binary(fn, a0, b0):
|
|
continue
|
|
with self.subTest(a0=a0, b0=b0):
|
|
r = getattr(ReferenceAnalysis, fn)(
|
|
sympy.Integer(a0), sympy.Integer(b0)
|
|
)
|
|
if r.is_finite:
|
|
self.assertIn(r, ref_r)
|
|
|
|
def test_rational_bounds(self):
|
|
# Repro from https://github.com/pytorch/pytorch/issues/105097
|
|
from sympy import floor, Eq
|
|
shape_0 = sympy.Symbol('shape_0', positive=True, integer=True)
|
|
new_expr = (
|
|
Eq(30 * floor(4 * ((shape_0 + 1) // 96) *
|
|
((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647 +
|
|
2584 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647),
|
|
2880 * floor(((shape_0 + 1) // 96) *
|
|
((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 15528 +
|
|
323 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 7764)))
|
|
new_range_env = {shape_0: ValueRanges(lower=1, upper=190)}
|
|
self.assertTrue(new_expr.subs({shape_0: 95}))
|
|
self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr))
|
|
|
|
|
|
class TestSympyInterp(TestCase):
|
|
@parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)
|
|
def test_interp(self, fn):
|
|
# SymPy does not implement truncation for Expressions
|
|
if fn in ("div", "truncdiv", "minimum", "maximum"):
|
|
return
|
|
|
|
from sympy.abc import x, y
|
|
vals = CONSTANTS
|
|
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
|
|
vals = [True, False]
|
|
arity = 1
|
|
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
|
|
arity = 2
|
|
symbols = [x]
|
|
if arity == 2:
|
|
symbols = [x, y]
|
|
for args in itertools.product(vals, repeat=arity):
|
|
if arity == 1 and not valid_unary(fn, *args):
|
|
continue
|
|
elif arity == 2 and not valid_binary(fn, *args):
|
|
continue
|
|
with self.subTest(args=args):
|
|
sargs = [sympy.sympify(a) for a in args]
|
|
sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
|
|
ref_r = getattr(ReferenceAnalysis, fn)(*sargs)
|
|
# Yes, I know this is a longwinded way of saying xreplace; the
|
|
# point is to test sympy_interp
|
|
r = sympy_interp(ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr)
|
|
self.assertEqual(ref_r, r)
|
|
|
|
|
|
def type_name_fn(type: Type) -> str:
|
|
return type.__name__
|
|
|
|
def parametrize_relational_types(*types):
|
|
def wrapper(f: Callable):
|
|
return parametrize("op", types or RELATIONAL_TYPES, name_fn=type_name_fn)(f)
|
|
return wrapper
|
|
|
|
|
|
class TestSympySolve(TestCase):
|
|
def _create_integer_symbols(self) -> List[sympy.Symbol]:
|
|
return sympy.symbols("a b c", integer=True)
|
|
|
|
def test_give_up(self):
|
|
from sympy import Eq, Ne
|
|
|
|
a, b, c = self._create_integer_symbols()
|
|
|
|
cases = [
|
|
# Not a relational operation.
|
|
a + b,
|
|
# 'a' appears on both sides.
|
|
Eq(a, a + 1),
|
|
# 'a' doesn't appear on neither side.
|
|
Eq(b, c + 1),
|
|
# Result is a 'sympy.And'.
|
|
Eq(FloorDiv(a, b), c),
|
|
# Result is a 'sympy.Or'.
|
|
Ne(FloorDiv(a, b), c),
|
|
]
|
|
|
|
for case in cases:
|
|
e = try_solve(case, a)
|
|
self.assertEqual(e, None)
|
|
|
|
@parametrize_relational_types()
|
|
def test_noop(self, op):
|
|
a, b, _ = self._create_integer_symbols()
|
|
|
|
lhs, rhs = a, 42 * b
|
|
expr = op(lhs, rhs)
|
|
|
|
r = try_solve(expr, a)
|
|
self.assertNotEqual(r, None)
|
|
|
|
r_expr, r_rhs = r
|
|
self.assertEqual(r_expr, expr)
|
|
self.assertEqual(r_rhs, rhs)
|
|
|
|
@parametrize_relational_types()
|
|
def test_noop_rhs(self, op):
|
|
a, b, _ = self._create_integer_symbols()
|
|
|
|
lhs, rhs = 42 * b, a
|
|
|
|
mirror = mirror_rel_op(op)
|
|
self.assertNotEqual(mirror, None)
|
|
|
|
expr = op(lhs, rhs)
|
|
|
|
r = try_solve(expr, a)
|
|
self.assertNotEqual(r, None)
|
|
|
|
r_expr, r_rhs = r
|
|
self.assertEqual(r_expr, mirror(rhs, lhs))
|
|
self.assertEqual(r_rhs, lhs)
|
|
|
|
def _test_cases(self, cases: List[Tuple[sympy.Basic, sympy.Basic]], thing: sympy.Basic, op: Type[sympy.Rel], **kwargs):
|
|
for source, expected in cases:
|
|
r = try_solve(source, thing, **kwargs)
|
|
|
|
self.assertTrue(
|
|
(r is None and expected is None)
|
|
or (r is not None and expected is not None)
|
|
)
|
|
|
|
if r is not None:
|
|
r_expr, r_rhs = r
|
|
self.assertEqual(r_rhs, expected)
|
|
self.assertEqual(r_expr, op(thing, expected))
|
|
|
|
def test_addition(self):
|
|
from sympy import Eq
|
|
|
|
a, b, c = self._create_integer_symbols()
|
|
|
|
cases = [
|
|
(Eq(a + b, 0), -b),
|
|
(Eq(a + 5, b - 5), b - 10),
|
|
(Eq(a + c * b, 1), 1 - c * b),
|
|
]
|
|
|
|
self._test_cases(cases, a, Eq)
|
|
|
|
@parametrize_relational_types(sympy.Eq, sympy.Ne)
|
|
def test_multiplication_division(self, op):
|
|
a, b, c = self._create_integer_symbols()
|
|
|
|
cases = [
|
|
(op(a * b, 1), 1 / b),
|
|
(op(a * 5, b - 5), (b - 5) / 5),
|
|
(op(a * b, c), c / b),
|
|
]
|
|
|
|
self._test_cases(cases, a, op)
|
|
|
|
@parametrize_relational_types(*INEQUALITY_TYPES)
|
|
def test_multiplication_division_inequality(self, op):
|
|
a, b, _ = self._create_integer_symbols()
|
|
intneg = sympy.Symbol("neg", integer=True, negative=True)
|
|
intpos = sympy.Symbol("pos", integer=True, positive=True)
|
|
|
|
cases = [
|
|
# Divide/multiply both sides by positive number.
|
|
(op(a * intpos, 1), 1 / intpos),
|
|
(op(a / (5 * intpos), 1), 5 * intpos),
|
|
(op(a * 5, b - 5), (b - 5) / 5),
|
|
# 'b' is not strictly positive nor negative, so we can't
|
|
# divide/multiply both sides by 'b'.
|
|
(op(a * b, 1), None),
|
|
(op(a / b, 1), None),
|
|
(op(a * b * intpos, 1), None),
|
|
]
|
|
|
|
mirror_cases = [
|
|
# Divide/multiply both sides by negative number.
|
|
(op(a * intneg, 1), 1 / intneg),
|
|
(op(a / (5 * intneg), 1), 5 * intneg),
|
|
(op(a * -5, b - 5), -(b - 5) / 5),
|
|
]
|
|
mirror_op = mirror_rel_op(op)
|
|
assert mirror_op is not None
|
|
|
|
self._test_cases(cases, a, op)
|
|
self._test_cases(mirror_cases, a, mirror_op)
|
|
|
|
@parametrize_relational_types()
|
|
def test_floordiv(self, op):
|
|
from sympy import Eq, Ne, Gt, Ge, Lt, Le
|
|
|
|
a, b, c = sympy.symbols("a b c")
|
|
pos = sympy.Symbol("pos", positive=True)
|
|
integer = sympy.Symbol("integer", integer=True)
|
|
|
|
# (Eq(FloorDiv(a, pos), integer), And(Ge(a, integer * pos), Lt(a, (integer + 1) * pos))),
|
|
# (Eq(FloorDiv(a + 5, pos), integer), And(Ge(a, integer * pos), Lt(a, (integer + 1) * pos))),
|
|
# (Ne(FloorDiv(a, pos), integer), Or(Lt(a, integer * pos), Ge(a, (integer + 1) * pos))),
|
|
|
|
special_case = {
|
|
# 'FloorDiv' turns into 'And', which can't be simplified any further.
|
|
Eq: (Eq(FloorDiv(a, pos), integer), None),
|
|
# 'FloorDiv' turns into 'Or', which can't be simplified any further.
|
|
Ne: (Ne(FloorDiv(a, pos), integer), None),
|
|
Gt: (Gt(FloorDiv(a, pos), integer), (integer + 1) * pos),
|
|
Ge: (Ge(FloorDiv(a, pos), integer), integer * pos),
|
|
Lt: (Lt(FloorDiv(a, pos), integer), integer * pos),
|
|
Le: (Le(FloorDiv(a, pos), integer), (integer + 1) * pos),
|
|
}[op]
|
|
|
|
cases: List[Tuple[sympy.Basic, sympy.Basic]] = [
|
|
# 'b' is not strictly positive
|
|
(op(FloorDiv(a, b), integer), None),
|
|
# 'c' is not strictly positive
|
|
(op(FloorDiv(a, pos), c), None),
|
|
]
|
|
|
|
# The result might change after 'FloorDiv' transformation.
|
|
# Specifically:
|
|
# - [Ge, Gt] => Ge
|
|
# - [Le, Lt] => Lt
|
|
if op in (sympy.Gt, sympy.Ge):
|
|
r_op = sympy.Ge
|
|
elif op in (sympy.Lt, sympy.Le):
|
|
r_op = sympy.Lt
|
|
else:
|
|
r_op = op
|
|
|
|
self._test_cases([special_case, *cases], a, r_op)
|
|
self._test_cases([(special_case[0], None), *cases], a, r_op, floordiv_inequality=False)
|
|
|
|
def test_floordiv_eq_simplify(self):
|
|
from sympy import Eq, Lt, Le
|
|
|
|
a = sympy.Symbol("a", positive=True, integer=True)
|
|
|
|
def check(expr, expected):
|
|
r = try_solve(expr, a)
|
|
self.assertNotEqual(r, None)
|
|
r_expr, _ = r
|
|
self.assertEqual(r_expr, expected)
|
|
|
|
# (a + 10) // 3 == 3
|
|
# =====================================
|
|
# 3 * 3 <= a + 10 (always true)
|
|
# a + 10 < 4 * 3 (not sure)
|
|
check(Eq(FloorDiv(a + 10, 3), 3), Lt(a, (3 + 1) * 3 - 10))
|
|
|
|
# (a + 10) // 2 == 4
|
|
# =====================================
|
|
# 4 * 2 <= 10 - a (not sure)
|
|
# 10 - a < 5 * 2 (always true)
|
|
check(Eq(FloorDiv(10 - a, 2), 4), Le(a, -(4 * 2 - 10)))
|
|
|
|
@skipIf(not TEST_Z3, "Z3 not installed")
|
|
def test_z3_proof_floordiv_eq_simplify(self):
|
|
import z3
|
|
from sympy import Eq, Lt
|
|
|
|
a = sympy.Symbol("a", positive=True, integer=True)
|
|
a_ = z3.Int("a")
|
|
|
|
# (a + 10) // 3 == 3
|
|
# =====================================
|
|
# 3 * 3 <= a + 10 (always true)
|
|
# a + 10 < 4 * 3 (not sure)
|
|
solver = z3.SolverFor("QF_NRA")
|
|
|
|
# Add assertions for 'a_'.
|
|
solver.add(a_ > 0)
|
|
|
|
expr = Eq(FloorDiv(a + 10, 3), 3)
|
|
r_expr, _ = try_solve(expr, a)
|
|
|
|
# Check 'try_solve' really returns the 'expected' below.
|
|
expected = Lt(a, (3 + 1) * 3 - 10)
|
|
self.assertEqual(r_expr, expected)
|
|
|
|
# Check whether there is an integer 'a_' such that the
|
|
# equation below is satisfied.
|
|
solver.add(
|
|
# expr
|
|
(z3.ToInt((a_ + 10) / 3.0) == 3)
|
|
!=
|
|
# expected
|
|
(a_ < (3 + 1) * 3 - 10)
|
|
)
|
|
|
|
# Assert that there's no such an integer.
|
|
# i.e. the transformation is sound.
|
|
r = solver.check()
|
|
self.assertEqual(r, z3.unsat)
|
|
|
|
|
|
instantiate_parametrized_tests(TestValueRanges)
|
|
instantiate_parametrized_tests(TestSympyInterp)
|
|
instantiate_parametrized_tests(TestSympySolve)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|