#![cfg(feature = "tensorflow_unstable")]
use super::Graph;
use super::Operation;
use super::Shape;
use super::Status;
use super::Tensor;
use super::TensorType;
use std::cmp::Eq;
use std::collections::HashMap;
use std::convert::From;
use std::fmt::Debug;
use std::fmt::Display;
use std::fmt::Error;
use std::fmt::Formatter;
use std::hash::Hash;
use std::hash::Hasher;
use std::marker::PhantomData;
use std::ops;
use std::rc::Rc;
#[derive(Ord, PartialOrd, Eq, PartialEq, Debug, Copy, Clone)]
pub enum OpLevel {
Assign,
Add,
Mul,
Unary,
Atom,
}
#[derive(Debug, Clone)]
pub struct Expr<T: TensorType> {
expr: Rc<dyn ExprImpl<T>>,
}
impl<T: TensorType> Expr<T> {
pub fn new<I>(expr: I) -> Expr<T>
where
I: ExprImpl<T> + 'static,
{
Expr {
expr: Rc::new(expr),
}
}
}
impl<T: TensorType> ops::Deref for Expr<T> {
type Target = dyn ExprImpl<T>;
fn deref(&self) -> &Self::Target {
self.expr.deref()
}
}
impl<T: TensorType> Display for Expr<T> {
fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
Display::fmt(&self.expr, f)
}
}
impl<T: TensorType> From<T> for Expr<T> {
fn from(value: T) -> Self {
Expr::new(value)
}
}
#[derive(Debug)]
pub enum ShapeHint<'a> {
Unknown,
Exactly(&'a [u64]),
}
pub trait ExprImpl<T: TensorType>: Display + Debug {
fn op_level(&self) -> OpLevel;
fn children(&self) -> Vec<Box<dyn AnyExpr>>; fn create_operation(
&self,
graph: &mut Graph,
children: &[Operation],
id_gen: &mut dyn FnMut() -> String,
) -> Result<Operation, Status>;
fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status>;
fn shape_hint(&self) -> ShapeHint {
ShapeHint::Unknown
}
}
impl<T: TensorType> ExprImpl<T> for T {
fn op_level(&self) -> OpLevel {
OpLevel::Atom
}
fn children(&self) -> Vec<Box<dyn AnyExpr>> {
vec![]
}
fn create_operation(
&self,
graph: &mut Graph,
_children: &[Operation],
id_gen: &mut dyn FnMut() -> String,
) -> Result<Operation, Status> {
let mut nd = graph.new_operation("Const", &id_gen())?;
nd.set_attr_type("dtype", T::data_type())?;
let mut value = Tensor::new(&[1]);
value[0] = self.clone();
nd.set_attr_tensor("value", value)?;
nd.finish()
}
fn derivative_by_variable(&self, _var: &str) -> Result<Expr<T>, Status> {
Ok(Expr::from(T::zero()))
}
}
macro_rules! impl_bin_op {
($name:ident, $fn_name:ident, $op:expr, $op_level:ident, $assoc:expr,
$tf_op:expr, $doc:expr, $($ximpl:tt)*) => {
#[doc = $doc]
#[derive(Debug)]
pub struct $name<T: TensorType> {
left: Expr<T>,
right: Expr<T>,
}
impl<T: TensorType> ops::$name for Expr<T> {
type Output = Expr<T>;
fn $fn_name(self, rhs: Expr<T>) -> Expr<T> {
Expr::new($name {
left: self,
right: rhs,
})
}
}
impl<T: TensorType> ops::$name<T> for Expr<T> {
type Output = Expr<T>;
fn $fn_name(self, rhs: T) -> Expr<T> {
Expr::new($name {
left: self,
right: Expr::from(rhs),
})
}
}
impl<T: TensorType> Display for $name<T> {
fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
if self.left.op_level() < OpLevel::$op_level {
write!(f, "({})", self.left)?;
} else {
write!(f, "{}", self.left)?;
}
write!(f, concat!(" ", $op, " "))?;
let paren = if $assoc {
self.right.op_level() < OpLevel::$op_level
} else {
self.right.op_level() <= OpLevel::$op_level
};
if paren {
write!(f, "({})", self.right)
} else {
write!(f, "{}", self.right)
}
}
}
impl<T: TensorType> ExprImpl<T> for $name<T> {
fn op_level(&self) -> OpLevel {
OpLevel::$op_level
}
fn children(&self) -> Vec<Box<dyn AnyExpr>> {
vec![Box::new(self.left.clone()), Box::new(self.right.clone())]
}
fn create_operation(&self, graph: &mut Graph, children: &[Operation],
id_gen: &mut dyn FnMut() -> String) -> Result<Operation, Status> {
let mut nd = graph.new_operation($tf_op, &id_gen())?;
nd.add_input(children[0].clone());
nd.add_input(children[1].clone());
nd.finish()
}
$($ximpl)*
}
}
}
impl_bin_op!(
Add,
add,
"+",
Add,
true,
"Add",
"Expression resulting from adding two subexpressions.",
fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
Ok(self.left.derivative_by_variable(var)? + self.right.derivative_by_variable(var)?)
}
);
impl_bin_op!(
Sub,
sub,
"-",
Add,
false,
"Sub",
"Expression resulting from subtracting two subexpressions.",
fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
Ok(self.left.derivative_by_variable(var)? - self.right.derivative_by_variable(var)?)
}
);
impl_bin_op!(
Mul,
mul,
"*",
Mul,
true,
"Mul",
"Expression resulting from multiplying two subexpressions.",
fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
Ok(self.left.derivative_by_variable(var)? * self.right.clone()
+ self.left.clone() * self.right.derivative_by_variable(var)?)
}
);
impl_bin_op!(
Div,
div,
"/",
Mul,
false,
"Div",
"Expression resulting from dividing two subexpressions.",
fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
let num = self.left.derivative_by_variable(var)? * self.right.clone()
- self.left.clone() * self.right.derivative_by_variable(var)?;
let denom = self.right.clone() * self.right.clone();
Ok(num / denom)
}
);
impl_bin_op!(
Rem,
rem,
"%",
Mul,
false,
"Mod",
"Expression resulting from taking a modulus.",
fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
Ok(self.left.derivative_by_variable(var)?
- TruncateDiv::new_expr(self.left.clone(), self.right.clone())
* self.right.derivative_by_variable(var)?)
}
);
#[derive(Debug)]
pub struct TruncateDiv<T: TensorType> {
left: Expr<T>,
right: Expr<T>,
}
impl<T: TensorType> TruncateDiv<T> {
fn new(left: Expr<T>, right: Expr<T>) -> Self {
TruncateDiv { left, right }
}
pub fn new_expr(left: Expr<T>, right: Expr<T>) -> Expr<T> {
Expr::new(TruncateDiv::new(left, right))
}
}
impl<T: TensorType> Display for TruncateDiv<T> {
fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
write!(f, "{} // {}", self.left, self.right)
}
}
impl<T: TensorType> ExprImpl<T> for TruncateDiv<T> {
fn op_level(&self) -> OpLevel {
OpLevel::Mul
}
fn children(&self) -> Vec<Box<dyn AnyExpr>> {
vec![Box::new(self.left.clone()), Box::new(self.right.clone())]
}
fn create_operation(
&self,
graph: &mut Graph,
children: &[Operation],
id_gen: &mut dyn FnMut() -> String,
) -> Result<Operation, Status> {
let mut nd = graph.new_operation("TruncateDiv", &id_gen())?;
nd.add_input(children[0].clone());
nd.add_input(children[1].clone());
nd.finish()
}
fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
let diff = self.left.clone() - self.left.clone() % self.right.clone();
let term1 = self.right.clone() * diff.derivative_by_variable(var)?;
let term2 = diff * self.right.derivative_by_variable(var)?;
Ok((term1 - term2) / (self.right.clone() * self.right.clone()))
}
}
#[derive(Debug)]
pub struct Neg<T: TensorType> {
expr: Expr<T>,
}
impl<T: TensorType> ops::Neg for Expr<T> {
type Output = Expr<T>;
fn neg(self) -> Expr<T> {
Expr::new(Neg { expr: self })
}
}
impl<T: TensorType> Display for Neg<T> {
fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
write!(f, "-")?;
if self.expr.op_level() <= OpLevel::Unary {
write!(f, "({})", self.expr)
} else {
write!(f, "{}", self.expr)
}
}
}
impl<T: TensorType> ExprImpl<T> for Neg<T> {
fn op_level(&self) -> OpLevel {
OpLevel::Unary
}
fn children(&self) -> Vec<Box<dyn AnyExpr>> {
vec![Box::new(self.expr.clone())]
}
fn create_operation(
&self,
graph: &mut Graph,
children: &[Operation],
id_gen: &mut dyn FnMut() -> String,
) -> Result<Operation, Status> {
let mut nd = graph.new_operation("Neg", &id_gen())?;
nd.add_input(children[0].clone());
nd.finish()
}
fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
Ok(-self.expr.derivative_by_variable(var)?)
}
}
#[derive(Debug)]
pub struct Variable<T: TensorType> {
shape: Vec<u64>,
name: String,
phantom: PhantomData<T>,
}
impl<T: TensorType> Variable<T> {
fn new(shape: &[u64], name: &str) -> Self {
Variable {
shape: Vec::from(shape),
name: name.to_string(),
phantom: PhantomData,
}
}
pub fn new_expr(shape: &[u64], name: &str) -> Expr<T> {
Expr::new(Variable::new(shape, name))
}
}
impl<T: TensorType> Display for Variable<T> {
fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
write!(f, "{}", self.name)
}
}
impl<T: TensorType> ExprImpl<T> for Variable<T> {
fn op_level(&self) -> OpLevel {
OpLevel::Atom
}
fn children(&self) -> Vec<Box<dyn AnyExpr>> {
vec![]
}
fn create_operation(
&self,
graph: &mut Graph,
_children: &[Operation],
_id_gen: &mut dyn FnMut() -> String,
) -> Result<Operation, Status> {
let mut nd = graph.new_operation("Variable", &self.name)?;
let shape = self
.shape
.iter()
.map(|dim_size| Some(*dim_size as i64))
.collect();
nd.set_attr_type("dtype", T::data_type()).unwrap();
nd.set_attr_shape("shape", &Shape(Some(shape))).unwrap();
nd.finish()
}
fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
Ok(if var == self.name {
Expr::from(T::one())
} else {
Expr::from(T::zero())
})
}
fn shape_hint(&self) -> ShapeHint {
ShapeHint::Exactly(&self.shape)
}
}
#[derive(Debug)]
pub struct Placeholder<T: TensorType> {
shape: Vec<u64>,
name: String,
phantom: PhantomData<T>,
}
impl<T: TensorType> Placeholder<T> {
fn new(shape: &[u64], name: &str) -> Self {
Placeholder {
shape: Vec::from(shape),
name: name.to_string(),
phantom: PhantomData,
}
}
pub fn new_expr(shape: &[u64], name: &str) -> Expr<T> {
Expr::new(Placeholder::new(shape, name))
}
}
impl<T: TensorType> Display for Placeholder<T> {
fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
write!(f, "{}", self.name)
}
}
impl<T: TensorType> ExprImpl<T> for Placeholder<T> {
fn op_level(&self) -> OpLevel {
OpLevel::Atom
}
fn children(&self) -> Vec<Box<dyn AnyExpr>> {
vec![]
}
fn create_operation(
&self,
graph: &mut Graph,
_children: &[Operation],
_id_gen: &mut dyn FnMut() -> String,
) -> Result<Operation, Status> {
let mut nd = graph.new_operation("Placeholder", &self.name)?;
let shape = self
.shape
.iter()
.map(|dim_size| Some(*dim_size as i64))
.collect();
nd.set_attr_type("dtype", T::data_type()).unwrap();
nd.set_attr_shape("shape", &Shape(Some(shape))).unwrap();
nd.finish()
}
fn derivative_by_variable(&self, _var: &str) -> Result<Expr<T>, Status> {
Ok(Expr::from(T::zero()))
}
fn shape_hint(&self) -> ShapeHint {
ShapeHint::Exactly(&self.shape)
}
}
#[derive(Debug)]
pub struct Constant<T: TensorType> {
tensor: Tensor<T>,
}
impl<T: TensorType> Constant<T> {
pub fn new(tensor: Tensor<T>) -> Self {
Constant { tensor }
}
pub fn new_expr(tensor: Tensor<T>) -> Expr<T> {
Expr::new(Constant { tensor })
}
}
impl<T: TensorType> Display for Constant<T> {
fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
write!(f, "{}", self.tensor)
}
}
impl<T: TensorType> ExprImpl<T> for Constant<T> {
fn op_level(&self) -> OpLevel {
OpLevel::Atom
}
fn children(&self) -> Vec<Box<dyn AnyExpr>> {
vec![]
}
fn create_operation(
&self,
graph: &mut Graph,
_children: &[Operation],
id_gen: &mut dyn FnMut() -> String,
) -> Result<Operation, Status> {
let mut nd = graph.new_operation("Const", &id_gen())?;
nd.set_attr_type("dtype", T::data_type())?;
nd.set_attr_tensor("value", self.tensor.clone())?;
nd.finish()
}
fn derivative_by_variable(&self, _var: &str) -> Result<Expr<T>, Status> {
Ok(Expr::from(T::zero()))
}
}
#[derive(Debug)]
pub struct Assign<T: TensorType> {
variable: Expr<T>,
value: Expr<T>,
}
impl<T: TensorType> Assign<T> {
fn new(variable: Expr<T>, value: Expr<T>) -> Self {
Assign { variable, value }
}
pub fn new_expr(variable: Expr<T>, value: Expr<T>) -> Expr<T> {
Expr::new(Assign::new(variable, value))
}
pub fn to(variable: Expr<T>, iterable: impl Iterator<Item = T>) -> crate::Result<Expr<T>> {
let constant = if let ShapeHint::Exactly(shape) = variable.expr.shape_hint() {
let values: Vec<_> = iterable
.take(shape.iter().product::<u64>() as usize)
.collect();
Constant::new_expr(Tensor::new(shape).with_values(&values)?)
} else {
return Err(invalid_arg!(
"Cannot assign to expression {} with unknown size!",
variable
));
};
Ok(Assign::new_expr(variable, constant))
}
}
impl<T: TensorType> Display for Assign<T> {
fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
write!(f, "{} = {}", self.variable, self.value)
}
}
impl<T: TensorType> ExprImpl<T> for Assign<T> {
fn op_level(&self) -> OpLevel {
OpLevel::Assign
}
fn children(&self) -> Vec<Box<dyn AnyExpr>> {
vec![
Box::new(self.variable.clone()),
Box::new(self.value.clone()),
]
}
fn create_operation(
&self,
graph: &mut Graph,
children: &[Operation],
id_gen: &mut dyn FnMut() -> String,
) -> Result<Operation, Status> {
let mut nd = graph.new_operation("Assign", &id_gen())?;
nd.add_input(children[0].clone());
nd.add_input(children[1].clone());
nd.finish()
}
fn derivative_by_variable(&self, _var: &str) -> Result<Expr<T>, Status> {
Err(invalid_arg!("Cannot take the derivative of an assignment"))
}
}
pub trait AnyExpr: Debug {
fn key(&self) -> *const ();
fn children(&self) -> Vec<Box<dyn AnyExpr>>; fn create_operation(
&self,
graph: &mut Graph,
children: &[Operation],
id_gen: &mut dyn FnMut() -> String,
) -> Result<Operation, Status>;
fn clone_box(&self) -> Box<dyn AnyExpr>;
}
impl<T: TensorType> AnyExpr for Expr<T> {
#[allow(trivial_casts)]
fn key(&self) -> *const () {
self.expr.as_ref() as *const dyn ExprImpl<T> as *const ()
}
fn children(&self) -> Vec<Box<dyn AnyExpr>> {
self.expr.children()
}
fn create_operation(
&self,
graph: &mut Graph,
children: &[Operation],
id_gen: &mut dyn FnMut() -> String,
) -> Result<Operation, Status> {
self.expr.create_operation(graph, children, id_gen)
}
fn clone_box(&self) -> Box<dyn AnyExpr> {
Box::new(self.clone())
}
}
#[derive(Debug)]
struct Key(Box<dyn AnyExpr>);
impl PartialEq for Key {
fn eq(&self, other: &Key) -> bool {
self.0.key() == other.0.key()
}
}
impl Eq for Key {}
impl Hash for Key {
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
state.write_isize(self.0.key() as isize)
}
}
#[derive(Debug)]
pub struct Compiler<'l> {
graph: &'l mut Graph,
operations: HashMap<Key, Operation>,
next_id: i32,
}
impl<'l> Compiler<'l> {
pub fn new(graph: &'l mut Graph) -> Self {
Compiler {
graph,
operations: HashMap::new(),
next_id: 0,
}
}
pub fn compile<T: TensorType>(&mut self, expr: Expr<T>) -> Result<Operation, Status> {
self.compile_any(Box::new(expr))
}
pub fn compile_any(&mut self, expr: Box<dyn AnyExpr>) -> Result<Operation, Status> {
let mut child_operations = vec![];
for child in expr.children() {
let key = Key(child.clone_box());
let value = self.operations.get(&key).cloned();
child_operations.push(match value {
Some(v) => v,
None => self.compile_any(child)?,
});
}
let mut next_id = self.next_id;
let result = expr.create_operation(self.graph, &child_operations, &mut || {
let id = format!("operation_{}", next_id);
next_id += 1;
id
});
self.next_id = next_id;
let operation = result?;
self.operations.insert(Key(expr), operation.clone());
Ok(operation)
}
}
#[cfg(test)]
mod tests {
use super::super::Graph;
use super::*;
#[test]
fn test_display() {
assert_eq!("1 + 2 + 3", format!("{}", (Expr::from(1) + 2) + 3));
assert_eq!(
"1 + 2 + 3",
format!("{}", Expr::from(1) + (Expr::from(2) + 3))
);
assert_eq!("1 + 2 - 3", format!("{}", (Expr::from(1) + 2) - 3));
assert_eq!(
"1 - (2 + 3)",
format!("{}", Expr::from(1) - (Expr::from(2) + 3))
);
assert_eq!("(1 + 2) * 3", format!("{}", (Expr::from(1) + 2) * 3));
assert_eq!(
"1 * (2 + 3)",
format!("{}", Expr::from(1) * (Expr::from(2) + 3))
);
assert_eq!("1 * 2 * 3", format!("{}", (Expr::from(1) * 2) * 3));
assert_eq!(
"1 * 2 * 3",
format!("{}", Expr::from(1) * (Expr::from(2) * 3))
);
assert_eq!("(1 + 2) / 3", format!("{}", (Expr::from(1) + 2) / 3));
assert_eq!(
"1 / (2 + 3)",
format!("{}", Expr::from(1) / (Expr::from(2) + 3))
);
assert_eq!("1 * 2 / 3", format!("{}", (Expr::from(1) * 2) / 3));
assert_eq!(
"1 / (2 * 3)",
format!("{}", Expr::from(1) / (Expr::from(2) * 3))
);
assert_eq!("(1 + 2) % 3", format!("{}", (Expr::from(1) + 2) % 3));
assert_eq!(
"1 % (2 + 3)",
format!("{}", Expr::from(1) % (Expr::from(2) + 3))
);
assert_eq!("1 * 2 % 3", format!("{}", (Expr::from(1) * 2) % 3));
assert_eq!(
"1 % (2 * 3)",
format!("{}", Expr::from(1) % (Expr::from(2) * 3))
);
assert_eq!("-1", format!("{}", -Expr::from(1)));
assert_eq!("-(-1)", format!("{}", -(-Expr::from(1))));
assert_eq!("-(1 + 2)", format!("{}", -(Expr::from(1) + 2)));
assert_eq!("x", format!("{}", <Variable<f32>>::new(&vec![2, 3], "x")));
assert_eq!(
"x",
format!("{}", <Placeholder<f32>>::new(&vec![2, 3], "x"))
);
assert_eq!(
"x = 1 + 2",
format!(
"{}",
Assign::new(
<Placeholder<f32>>::new_expr(&vec![2, 3], "x"),
Expr::from(1.0f32) + 2.0f32
)
)
);
}
#[test]
fn test_compile() {
let mut g = Graph::new();
let x = <Placeholder<f32>>::new_expr(&vec![2, 3], "x");
let w = <Variable<f32>>::new_expr(&vec![2, 3], "w");
let mut compiler = Compiler::new(&mut g);
compiler
.compile(x * w.clone() / w.clone() % w.clone() + w.clone() - w.clone())
.unwrap();
compiler
.compile(Assign::to(w, ::std::iter::repeat(1.)).unwrap())
.unwrap();
}
#[test]
fn test_derivative_by_variable() {
let x = <Variable<f32>>::new_expr(&[], "x");
let y = <Variable<f32>>::new_expr(&[], "y");
for &(ref expected, ref expression) in [
("0", Expr::from(1.0f32)),
("1", x.clone()),
("0", y.clone()),
("1 + 0", x.clone() + y.clone()),
("1 - 0", x.clone() - y.clone()),
("1 * x + x * 1", x.clone() * x.clone()),
("1 * y + x * 0", x.clone() * y.clone()),
(
"(1 * x + x * 1) * x + x * x * 1",
x.clone() * x.clone() * x.clone(),
),
("(1 * y - x * 0) / (y * y)", x.clone() / y.clone()),
("1 - x // y * 0", x.clone() % y.clone()),
("0 - y // x * 1", y.clone() % x.clone()),
(
"(y * (1 - (1 - x // y * 0)) - (x - x % y) * 0) / (y * y)",
TruncateDiv::new_expr(x.clone(), y.clone()),
),
]
.iter()
{
assert_eq!(
*expected,
format!("{}", expression.derivative_by_variable("x").unwrap())
);
}
}
}