# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for symbolic computation."""
import threading
import sympy
[docs]class Symbol(sympy.Dummy):
pass
[docs]def IsSymbol(x):
return isinstance(x, Symbol)
[docs]def IsExpr(x):
return isinstance(x, sympy.Expr)
STATIC_VALUES = 'static'
TENSOR_VALUES = 'tensor'
VALUE_TYPES = (STATIC_VALUES, TENSOR_VALUES)
[docs]class _LocalSymbolToValueStack(threading.local):
"""A thread-local stack of symbol-to-value dicts."""
def __init__(self):
super().__init__()
self.stack = {}
for value_type in VALUE_TYPES:
self.stack[value_type] = [{}]
[docs]class SymbolToValueMap:
"""A symbol-to-value mapping.
Usage:
with SymbolToValueMap('static', {symbol1: value1, symbol2: value2, ...}):
with SymbolToValueMap('tensor', {symbol1: value1, symbol2: value2, ...}):
... = EvalExpr(value_type, symbolic_expr)
Multiple SymbolToValueMap context can be nested inside one another. The inner
contexts take precedence over outer ones when multiple contexts provide
values for the same symbol.
"""
_local_stack = _LocalSymbolToValueStack()
def __init__(self, value_type, symbol_to_value_map):
"""Creates a new symbol to value map.
Args:
value_type: the type of values in 'symbol_to_value_map'.
symbol_to_value_map: a dict from Symbol to values.
"""
assert value_type in VALUE_TYPES
self.value_type = value_type
self.merged = dict(self.Stack(value_type)[-1])
self.merged.update(symbol_to_value_map)
[docs] @staticmethod
def Stack(value_type):
return SymbolToValueMap._local_stack.stack[value_type]
def __enter__(self):
self.Stack(self.value_type).append(self.merged)
def __exit__(self, type_arg, value_arg, traceback_arg):
stack = self.Stack(self.value_type)
assert stack
assert stack[-1] is self.merged
stack.pop()
[docs] @staticmethod
def Get(value_type):
"""Returns a symbol-to-value mapping merged from Stack()."""
return SymbolToValueMap.Stack(value_type)[-1]
[docs]def EvalExpr(value_type, x):
"""Evaluates x with symbol_to_value_map within the current context.
Args:
value_type: the target value type (see VALUE_TYPE).
x: a sympy.Expr, an object, or a list/tuple of Exprs and objects.
Returns:
Evaluation result of 'x'.
"""
if isinstance(x, (list, tuple)):
return type(x)(EvalExpr(value_type, y) for y in x)
elif isinstance(x, sympy.Expr):
symbol_to_value_map = SymbolToValueMap.Get(value_type)
if not symbol_to_value_map:
return x
# In theory the below should be equivalent to:
# y = x.subs(symbol_to_value_map).
# In practice subs() doesn't work for when values are Tensors.
k, v = list(zip(*(list(symbol_to_value_map.items()))))
y = sympy.lambdify(k, x)(*v)
return y
else:
return x
[docs]def ToStatic(expr):
return EvalExpr(STATIC_VALUES, expr)
[docs]def ToTensor(expr):
return EvalExpr(TENSOR_VALUES, expr)