unparser: refactor delimiting with context managers in ast.unparse

Backport of 4b3b1226e8
pull/28346/head
Todd Gamblin 2021-12-21 16:43:03 -08:00 committed by Greg Becker
parent 5847eb1e65
commit afb358313a
1 changed files with 177 additions and 190 deletions

View File

@ -13,6 +13,7 @@ from six import StringIO
# TODO: if we require Python 3.7, use its `nullcontext()`
@contextmanager
def nullcontext():
yield
@ -101,6 +102,21 @@ class Unparser:
def block(self):
return self._Block(self)
@contextmanager
def delimit(self, start, end):
"""A context manager for preparing the source for expressions. It adds
*start* to the buffer and enters, after exit it adds *end*."""
self.write(start)
yield
self.write(end)
def delimit_if(self, start, end, condition):
if condition:
return self.delimit(start, end)
else:
return nullcontext()
def dispatch(self, tree):
"Dispatcher function, dispatching tree type T to method _T."
if isinstance(tree, list):
@ -135,11 +151,10 @@ class Unparser:
self.dispatch(tree.value)
def _NamedExpr(self, tree):
self.write("(")
self.dispatch(tree.target)
self.write(" := ")
self.dispatch(tree.value)
self.write(")")
with self.delimit("(", ")"):
self.dispatch(tree.target)
self.write(" := ")
self.dispatch(tree.value)
def _Import(self, t):
self.fill("import ")
@ -172,11 +187,9 @@ class Unparser:
def _AnnAssign(self, t):
self.fill()
if not t.simple and isinstance(t.target, ast.Name):
self.write('(')
self.dispatch(t.target)
if not t.simple and isinstance(t.target, ast.Name):
self.write(')')
with self.delimit_if(
"(", ")", not node.simple and isinstance(t.target, ast.Name)):
self.dispatch(t.target)
self.write(": ")
self.dispatch(t.annotation)
if t.value:
@ -250,28 +263,25 @@ class Unparser:
interleave(lambda: self.write(", "), self.write, t.names)
def _Await(self, t):
self.write("(")
self.write("await")
if t.value:
self.write(" ")
self.dispatch(t.value)
self.write(")")
with self.delimit("(", ")"):
self.write("await")
if t.value:
self.write(" ")
self.dispatch(t.value)
def _Yield(self, t):
self.write("(")
self.write("yield")
if t.value:
self.write(" ")
self.dispatch(t.value)
self.write(")")
with self.delimit("(", ")"):
self.write("yield")
if t.value:
self.write(" ")
self.dispatch(t.value)
def _YieldFrom(self, t):
self.write("(")
self.write("yield from")
if t.value:
self.write(" ")
self.dispatch(t.value)
self.write(")")
with self.delimit("(", ")"):
self.write("yield from")
if t.value:
self.write(" ")
self.dispatch(t.value)
def _Raise(self, t):
self.fill("raise")
@ -356,35 +366,33 @@ class Unparser:
self.dispatch(deco)
self.fill("class "+t.name)
if six.PY3:
self.write("(")
comma = False
for e in t.bases:
if comma: self.write(", ")
else: comma = True
self.dispatch(e)
for e in t.keywords:
if comma: self.write(", ")
else: comma = True
self.dispatch(e)
if sys.version_info[:2] < (3, 5):
if t.starargs:
with self.delimit("(", ")"):
comma = False
for e in t.bases:
if comma: self.write(", ")
else: comma = True
self.write("*")
self.dispatch(t.starargs)
if t.kwargs:
self.dispatch(e)
for e in t.keywords:
if comma: self.write(", ")
else: comma = True
self.write("**")
self.dispatch(t.kwargs)
self.write(")")
self.dispatch(e)
if sys.version_info[:2] < (3, 5):
if t.starargs:
if comma: self.write(", ")
else: comma = True
self.write("*")
self.dispatch(t.starargs)
if t.kwargs:
if comma: self.write(", ")
else: comma = True
self.write("**")
self.dispatch(t.kwargs)
elif t.bases:
self.write("(")
with self.delimit("(", ")"):
for a in t.bases[:-1]:
self.dispatch(a)
self.write(", ")
self.dispatch(t.bases[-1])
self.write(")")
with self.block():
self.dispatch(t.body)
@ -399,10 +407,10 @@ class Unparser:
for deco in t.decorator_list:
self.fill("@")
self.dispatch(deco)
def_str = fill_suffix+" "+t.name + "("
def_str = fill_suffix + " " + t.name
self.fill(def_str)
self.dispatch(t.args)
self.write(")")
with self.delimit("(", ")"):
self.dispatch(t.args)
if getattr(t, "returns", False):
self.write(" -> ")
self.dispatch(t.returns)
@ -574,13 +582,12 @@ class Unparser:
def _Constant(self, t):
value = t.value
if isinstance(value, tuple):
self.write("(")
if len(value) == 1:
self._write_constant(value[0])
self.write(",")
else:
interleave(lambda: self.write(", "), self._write_constant, value)
self.write(")")
with self.delimit("(", ")"):
if len(value) == 1:
self._write_constant(value[0])
self.write(",")
else:
interleave(lambda: self.write(", "), self._write_constant, value)
elif value is Ellipsis: # instead of `...` for Py2 compatibility
self.write("...")
else:
@ -594,49 +601,41 @@ class Unparser:
self.write(repr_n.replace("inf", INFSTR))
else:
# Parenthesize negative numbers, to avoid turning (-1)**2 into -1**2.
if repr_n.startswith("-"):
self.write("(")
if "inf" in repr_n and repr_n.endswith("*j"):
repr_n = repr_n.replace("*j", "j")
# Substitute overflowing decimal literal for AST infinities.
self.write(repr_n.replace("inf", INFSTR))
if repr_n.startswith("-"):
self.write(")")
with self.delimit_if("(", ")", repr_n.startswith("-")):
if "inf" in repr_n and repr_n.endswith("*j"):
repr_n = repr_n.replace("*j", "j")
# Substitute overflowing decimal literal for AST infinities.
self.write(repr_n.replace("inf", INFSTR))
def _List(self, t):
self.write("[")
interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write("]")
with self.delimit("[", "]"):
interleave(lambda: self.write(", "), self.dispatch, t.elts)
def _ListComp(self, t):
self.write("[")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write("]")
with self.delimit("[", "]"):
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
def _GeneratorExp(self, t):
self.write("(")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write(")")
with self.delimit("(", ")"):
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
def _SetComp(self, t):
self.write("{")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write("}")
with self.delimit("{", "}"):
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
def _DictComp(self, t):
self.write("{")
self.dispatch(t.key)
self.write(": ")
self.dispatch(t.value)
for gen in t.generators:
self.dispatch(gen)
self.write("}")
with self.delimit("{", "}"):
self.dispatch(t.key)
self.write(": ")
self.dispatch(t.value)
for gen in t.generators:
self.dispatch(gen)
def _comprehension(self, t):
if getattr(t, 'is_async', False):
@ -651,22 +650,19 @@ class Unparser:
self.dispatch(if_clause)
def _IfExp(self, t):
self.write("(")
self.dispatch(t.body)
self.write(" if ")
self.dispatch(t.test)
self.write(" else ")
self.dispatch(t.orelse)
self.write(")")
with self.delimit("(", ")"):
self.dispatch(t.body)
self.write(" if ")
self.dispatch(t.test)
self.write(" else ")
self.dispatch(t.orelse)
def _Set(self, t):
assert(t.elts) # should be at least one element
self.write("{")
interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write("}")
with self.delimit("{", "}"):
interleave(lambda: self.write(", "), self.dispatch, t.elts)
def _Dict(self, t):
self.write("{")
def write_key_value_pair(k, v):
self.dispatch(k)
self.write(": ")
@ -681,64 +677,59 @@ class Unparser:
self.dispatch(v)
else:
write_key_value_pair(k, v)
interleave(lambda: self.write(", "), write_item, zip(t.keys, t.values))
self.write("}")
with self.delimit("{", "}"):
interleave(lambda: self.write(", "), write_item, zip(t.keys, t.values))
def _Tuple(self, t):
self.write("(")
if len(t.elts) == 1:
elt = t.elts[0]
self.dispatch(elt)
self.write(",")
else:
interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write(")")
with self.delimit("(", ")"):
if len(t.elts) == 1:
elt = t.elts[0]
self.dispatch(elt)
self.write(",")
else:
interleave(lambda: self.write(", "), self.dispatch, t.elts)
unop = {"Invert":"~", "Not": "not", "UAdd":"+", "USub":"-"}
def _UnaryOp(self, t):
self.write("(")
self.write(self.unop[t.op.__class__.__name__])
if not self._py_ver_consistent:
self.write(" ")
if six.PY2 and isinstance(t.op, ast.USub) and isinstance(t.operand, ast.Num):
# If we're applying unary minus to a number, parenthesize the number.
# This is necessary: -2147483648 is different from -(2147483648) on
# a 32-bit machine (the first is an int, the second a long), and
# -7j is different from -(7j). (The first has real part 0.0, the second
# has real part -0.0.)
self.write("(")
self.dispatch(t.operand)
self.write(")")
else:
self.dispatch(t.operand)
self.write(")")
with self.delimit("(", ")"):
self.write(self.unop[t.op.__class__.__name__])
if not self._py_ver_consistent:
self.write(" ")
if six.PY2 and isinstance(t.op, ast.USub) and isinstance(t.operand, ast.Num):
# If we're applying unary minus to a number, parenthesize the number.
# This is necessary: -2147483648 is different from -(2147483648) on
# a 32-bit machine (the first is an int, the second a long), and
# -7j is different from -(7j). (The first has real part 0.0, the second
# has real part -0.0.)
with self.delimit("(", ")"):
self.dispatch(t.operand)
else:
self.dispatch(t.operand)
binop = { "Add":"+", "Sub":"-", "Mult":"*", "MatMult":"@", "Div":"/", "Mod":"%",
"LShift":"<<", "RShift":">>", "BitOr":"|", "BitXor":"^", "BitAnd":"&",
"FloorDiv":"//", "Pow": "**"}
def _BinOp(self, t):
self.write("(")
self.dispatch(t.left)
self.write(" " + self.binop[t.op.__class__.__name__] + " ")
self.dispatch(t.right)
self.write(")")
with self.delimit("(", ")"):
self.dispatch(t.left)
self.write(" " + self.binop[t.op.__class__.__name__] + " ")
self.dispatch(t.right)
cmpops = {"Eq":"==", "NotEq":"!=", "Lt":"<", "LtE":"<=", "Gt":">", "GtE":">=",
"Is":"is", "IsNot":"is not", "In":"in", "NotIn":"not in"}
def _Compare(self, t):
self.write("(")
self.dispatch(t.left)
for o, e in zip(t.ops, t.comparators):
self.write(" " + self.cmpops[o.__class__.__name__] + " ")
self.dispatch(e)
self.write(")")
with self.delimit("(", ")"):
self.dispatch(t.left)
for o, e in zip(t.ops, t.comparators):
self.write(" " + self.cmpops[o.__class__.__name__] + " ")
self.dispatch(e)
boolops = {ast.And: 'and', ast.Or: 'or'}
def _BoolOp(self, t):
self.write("(")
s = " %s " % self.boolops[t.op.__class__]
interleave(lambda: self.write(s), self.dispatch, t.values)
self.write(")")
with self.delimit("(", ")"):
s = " %s " % self.boolops[t.op.__class__]
interleave(lambda: self.write(s), self.dispatch, t.values)
def _Attribute(self,t):
self.dispatch(t.value)
@ -752,55 +743,52 @@ class Unparser:
def _Call(self, t):
self.dispatch(t.func)
self.write("(")
comma = False
with self.delimit("(", ")"):
comma = False
# move starred arguments last in Python 3.5+, for consistency w/earlier versions
star_and_kwargs = []
move_stars_last = sys.version_info[:2] >= (3, 5)
# starred arguments last in Python 3.5+, for consistency w/earlier versions
star_and_kwargs = []
move_stars_last = sys.version_info[:2] >= (3, 5)
for e in t.args:
if move_stars_last and isinstance(e, ast.Starred):
star_and_kwargs.append(e)
else:
if comma: self.write(", ")
else: comma = True
self.dispatch(e)
for e in t.args:
if move_stars_last and isinstance(e, ast.Starred):
star_and_kwargs.append(e)
else:
if comma: self.write(", ")
else: comma = True
self.dispatch(e)
for e in t.keywords:
# starting from Python 3.5 this denotes a kwargs part of the invocation
if e.arg is None and move_stars_last:
star_and_kwargs.append(e)
else:
if comma: self.write(", ")
else: comma = True
self.dispatch(e)
for e in t.keywords:
# starting from Python 3.5 this denotes a kwargs part of the invocation
if e.arg is None and move_stars_last:
star_and_kwargs.append(e)
else:
if comma: self.write(", ")
else: comma = True
self.dispatch(e)
if move_stars_last:
for e in star_and_kwargs:
if comma: self.write(", ")
else: comma = True
self.dispatch(e)
if move_stars_last:
for e in star_and_kwargs:
if comma: self.write(", ")
else: comma = True
self.dispatch(e)
if sys.version_info[:2] < (3, 5):
if t.starargs:
if comma: self.write(", ")
else: comma = True
self.write("*")
self.dispatch(t.starargs)
if t.kwargs:
if comma: self.write(", ")
else: comma = True
self.write("**")
self.dispatch(t.kwargs)
self.write(")")
if sys.version_info[:2] < (3, 5):
if t.starargs:
if comma: self.write(", ")
else: comma = True
self.write("*")
self.dispatch(t.starargs)
if t.kwargs:
if comma: self.write(", ")
else: comma = True
self.write("**")
self.dispatch(t.kwargs)
def _Subscript(self, t):
self.dispatch(t.value)
self.write("[")
self.dispatch(t.slice)
self.write("]")
with self.delimit("[", "]"):
self.dispatch(t.slice)
def _Starred(self, t):
self.write("*")
@ -902,12 +890,11 @@ class Unparser:
self.dispatch(t.value)
def _Lambda(self, t):
self.write("(")
self.write("lambda ")
self.dispatch(t.args)
self.write(": ")
self.dispatch(t.body)
self.write(")")
with self.delimit("(", ")"):
self.write("lambda ")
self.dispatch(t.args)
self.write(": ")
self.dispatch(t.body)
def _alias(self, t):
self.write(t.name)