use super::Graph;
use super::Output;
use super::Result;
use super::Status;
use std::ffi::CString;
use std::ffi::NulError;
use std::mem;
use std::os::raw::c_int;
use std::ptr;
use std::result;
use std::slice;
#[cfg(feature = "default")]
use tensorflow_sys as tf;
#[cfg(feature = "tensorflow_runtime_linking")]
use tensorflow_sys_runtime as tf;
#[derive(Debug)]
struct CWhileParams {
inner: tf::TF_WhileParams,
finished: bool,
}
impl Drop for CWhileParams {
fn drop(&mut self) {
if !self.finished {
unsafe {
tf::TF_AbortWhile(&self.inner);
}
}
}
}
#[derive(Debug)]
pub struct WhileBuilder<'a> {
graph: &'a mut Graph,
inner: CWhileParams,
name: Option<CString>,
#[allow(dead_code)]
c_inputs: Vec<tf::TF_Output>, }
impl<'a> WhileBuilder<'a> {
pub fn new<
CF: Fn(&mut Graph, &[Output]) -> Result<Output>,
BF: Fn(&mut Graph, &[Output]) -> Result<Vec<Output>>,
>(
graph: &'a mut Graph,
cond: CF,
body: BF,
inputs: &[Output],
) -> Result<Self> {
let mut status = Status::new();
let c_inputs: Vec<_> = inputs.iter().map(Output::to_c).collect();
let mut inner = CWhileParams {
inner: unsafe {
tf::TF_NewWhile(
graph.inner(),
c_inputs.as_ptr() as *mut _,
c_inputs.len() as c_int,
status.inner(),
)
},
finished: false,
};
if let Err(e) = status.into_result() {
inner.finished = true; return Err(e);
}
let mut cond_graph = unsafe { Graph::from_c(inner.inner.cond_graph as *mut _) };
let cond_inputs: Vec<_> = unsafe {
slice::from_raw_parts(inner.inner.cond_inputs, inputs.len())
.iter()
.map(|out| Output::from_c(graph, out))
.collect()
};
let cond_out = cond(&mut cond_graph, &cond_inputs)?;
inner.inner.cond_output = cond_out.to_c();
let mut body_graph = unsafe { Graph::from_c(inner.inner.body_graph as *mut _) };
let body_inputs: Vec<_> = unsafe {
slice::from_raw_parts(inner.inner.body_inputs, inputs.len())
.iter()
.map(|out| Output::from_c(graph, out))
.collect()
};
let body_out = body(&mut body_graph, &body_inputs)?;
if body_out.len() != inputs.len() {
return Err(invalid_arg!(
"Expected {} outputs, but got {}",
inputs.len(),
body_out.len()
));
}
let c_body_out =
unsafe { slice::from_raw_parts_mut(inner.inner.body_outputs as *mut _, inputs.len()) };
for i in 0..inputs.len() {
c_body_out[i] = body_out[i].to_c();
}
Ok(WhileBuilder {
graph,
inner,
name: None,
c_inputs,
})
}
pub fn name(mut self, name: &str) -> result::Result<Self, NulError> {
self.name = Some(CString::new(name)?);
Ok(self)
}
pub fn finish(mut self) -> Result<Vec<Output>> {
let status = Status::new();
let mut c_outputs: Vec<tf::TF_Output> =
Vec::with_capacity(self.inner.inner.ninputs as usize);
let mut name = None;
mem::swap(&mut self.name, &mut name);
let name = match name {
None => {
let while_loop_index = self.graph.generate_operation_name("while_loop_{}/Merge")?;
CString::new(format!("while_loop_{}", while_loop_index))?
}
Some(name) => name,
};
self.inner.inner.name = name.as_ptr();
unsafe {
c_outputs.set_len(self.inner.inner.ninputs as usize);
for c_output in &mut c_outputs {
c_output.oper = ptr::null_mut();
c_output.index = -1;
}
self.inner.finished = true; tf::TF_FinishWhile(&self.inner.inner, status.inner, c_outputs.as_mut_ptr());
}
status.into_result()?;
Ok(c_outputs
.iter()
.map(|out| Output::from_c(self.graph, out))
.collect())
}
}
#[cfg(test)]
mod tests {
use super::super::DataType;
use super::super::Operation;
use super::super::Session;
use super::super::SessionOptions;
use super::super::SessionRunArgs;
use super::super::Tensor;
use super::*;
fn constant(graph: &mut Graph, name: &str, value: i32) -> Operation {
let value = Tensor::<i32>::new(&[]).with_values(&[value]).unwrap();
let mut nd = graph.new_operation("Const", name).unwrap();
nd.set_attr_type("dtype", DataType::Int32).unwrap();
nd.set_attr_tensor("value", value).unwrap();
nd.finish().unwrap()
}
fn while_cond(graph: &mut Graph, inputs: &[Output]) -> Result<Output> {
let ten = constant(graph, "ten", 10);
let counter = inputs[0].clone();
let less = {
let mut nd = graph.new_operation("Less", "less").unwrap();
nd.add_input(counter.operation);
nd.add_input(ten);
nd.finish().unwrap()
};
Ok(less.into())
}
fn while_body(graph: &mut Graph, inputs: &[Output]) -> Result<Vec<Output>> {
let two = constant(graph, "two", 2);
let counter = inputs[0].clone();
let mul = {
let mut nd = graph.new_operation("Mul", "mul").unwrap();
nd.add_input(counter);
nd.add_input(two);
nd.finish().unwrap()
};
Ok(vec![mul.into()])
}
#[test]
fn simple_while() {
let mut main_graph = Graph::new();
let one = constant(&mut main_graph, "one", 1);
let output = WhileBuilder::new(&mut main_graph, while_cond, while_body, &[one.into()])
.unwrap()
.name("foo")
.unwrap()
.finish()
.unwrap();
assert_eq!(1, output.len());
let options = SessionOptions::new();
let session = Session::new(&options, &main_graph).unwrap();
let mut step = SessionRunArgs::new();
let output_token = step.request_fetch(&output[0].operation, 0);
session.run(&mut step).unwrap();
let output_tensor = step.fetch::<i32>(output_token).unwrap();
assert_eq!(&output_tensor[..], &[16i32]);
}
#[test]
fn generated_name_while() {
let mut main_graph = Graph::new();
let one = constant(&mut main_graph, "one", 1);
WhileBuilder::new(
&mut main_graph,
while_cond,
while_body,
&[one.clone().into()],
)
.unwrap()
.finish()
.unwrap();
WhileBuilder::new(&mut main_graph, while_cond, while_body, &[one.into()])
.unwrap()
.finish()
.unwrap();
}
}