use super::AnyTensor;
use super::Buffer;
use super::Code;
use super::DataType;
use super::Graph;
use super::MetaGraphDef;
use super::Operation;
use super::Result;
use super::SessionOptions;
use super::Status;
use super::Tensor;
use super::TensorType;
use crate::tf;
use libc::{c_char, c_int};
use std::ffi::CStr;
use std::ffi::CString;
use std::marker;
use std::path::Path;
use std::ptr;
#[derive(Debug)]
pub struct SavedModelBundle {
pub session: Session,
#[deprecated(
note = "Please use SavedModelBundle::meta_graph_def() instead",
since = "0.16.0"
)]
pub meta_graph_def: Vec<u8>,
meta_graph: MetaGraphDef,
}
impl SavedModelBundle {
pub fn load<P: AsRef<Path>, Tag: AsRef<str>, Tags: IntoIterator<Item = Tag>>(
options: &SessionOptions,
tags: Tags,
graph: &mut Graph,
export_dir: P,
) -> Result<SavedModelBundle> {
let mut status = Status::new();
let export_dir_cstr = export_dir
.as_ref()
.to_str()
.and_then(|s| CString::new(s.as_bytes()).ok())
.ok_or_else(|| invalid_arg!("Invalid export directory path"))?;
let tags_cstr: Vec<_> = tags
.into_iter()
.map(|t| CString::new(t.as_ref()))
.collect::<::std::result::Result<_, _>>()
.map_err(|_| invalid_arg!("Invalid tag name"))?;
let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect();
let mut meta = unsafe { Buffer::<u8>::from_ptr(ptr::null_mut(), 0) };
let inner = unsafe {
tf::TF_LoadSessionFromSavedModel(
options.inner,
ptr::null(),
export_dir_cstr.as_ptr(),
tags_ptr.as_ptr(),
tags_ptr.len() as c_int,
graph.inner(),
meta.inner_mut(),
status.inner(),
)
};
if inner.is_null() {
Err(status)
} else {
let session = Session { inner };
#[allow(deprecated)]
Ok(SavedModelBundle {
session,
meta_graph_def: Vec::from(meta.as_ref()),
meta_graph: MetaGraphDef::from_serialized_proto(meta.as_ref())?,
})
}
}
pub fn meta_graph_def(&self) -> &MetaGraphDef {
&self.meta_graph
}
}
#[derive(Debug)]
pub struct Session {
inner: *mut tf::TF_Session,
}
impl Session {
pub fn new(options: &SessionOptions, graph: &Graph) -> Result<Self> {
let mut status = Status::new();
let inner = unsafe { tf::TF_NewSession(graph.inner(), options.inner, status.inner()) };
if inner.is_null() {
Err(status)
} else {
Ok(Session { inner })
}
}
#[deprecated(note = "Please use SavedModelBundle::load() instead", since = "0.17.0")]
pub fn from_saved_model<P: AsRef<Path>, Tag: AsRef<str>, Tags: IntoIterator<Item = Tag>>(
options: &SessionOptions,
tags: Tags,
graph: &mut Graph,
export_dir: P,
) -> Result<Self> {
Ok(SavedModelBundle::load(options, tags, graph, export_dir)?.session)
}
pub fn close(&mut self) -> Result<()> {
let mut status = Status::new();
unsafe {
tf::TF_CloseSession(self.inner, status.inner());
}
status.into_result()
}
pub fn run(&self, step: &mut SessionRunArgs<'_>) -> Result<()> {
step.drop_output_tensors();
step.maybe_reset_run_metadata();
let mut status = Status::new();
let maybe_tensors: Result<_> = step.input_tensors.iter().map(|t| t.inner()).collect();
let input_tensors: Vec<_> = maybe_tensors?;
let run_options_ptr = match step.run_options.as_ref() {
Some(buf) => buf.inner(),
None => ptr::null(),
};
let mut run_metadata_buf = if step.request_metadata {
Some(unsafe { Buffer::new_unallocated() })
} else {
None
};
let run_metadata_ptr = match run_metadata_buf.as_mut() {
Some(meta) => meta.inner_mut(),
None => ptr::null_mut(),
};
unsafe {
tf::TF_SessionRun(
self.inner,
run_options_ptr,
step.input_ports.as_ptr(),
input_tensors.as_ptr() as *const *mut tf::TF_Tensor,
input_tensors.len() as c_int,
step.output_ports.as_ptr(),
step.output_tensors.as_mut_ptr(),
step.output_tensors.len() as c_int,
step.target_operations.as_mut_ptr(),
step.target_operations.len() as c_int,
run_metadata_ptr,
status.inner(),
);
step.run_metadata = run_metadata_buf.map(Into::into);
}
status.into_result()
}
pub fn device_list(&self) -> Result<Vec<Device>> {
let status = Status::new();
unsafe {
let list = tf::TF_SessionListDevices(self.inner, status.inner);
if !status.is_ok() {
return Err(status);
}
let result = (|| {
let n = tf::TF_DeviceListCount(list);
let mut devices = Vec::with_capacity(n as usize);
for i in 0..n {
let c_name = tf::TF_DeviceListName(list, i, status.inner);
if !status.is_ok() {
return Err(status);
}
let c_type = tf::TF_DeviceListType(list, i, status.inner);
if !status.is_ok() {
return Err(status);
}
let bytes = tf::TF_DeviceListMemoryBytes(list, i, status.inner);
if !status.is_ok() {
return Err(status);
}
let incarnation = tf::TF_DeviceListIncarnation(list, i, status.inner);
if !status.is_ok() {
return Err(status);
}
devices.push(Device {
name: CStr::from_ptr(c_name).to_str()?.to_string(),
device_type: CStr::from_ptr(c_type).to_str()?.to_string(),
memory_bytes: bytes,
incarnation,
});
}
Ok(devices)
})();
tf::TF_DeleteDeviceList(list);
result
}
}
}
impl Drop for Session {
fn drop(&mut self) {
let mut status = Status::new();
unsafe {
tf::TF_DeleteSession(self.inner, status.inner());
}
}
}
unsafe impl Send for Session {}
unsafe impl Sync for Session {}
#[derive(Copy, Clone, Debug)]
pub struct FetchToken {
index: usize,
}
#[deprecated(note = "Use FetchToken instead.", since = "0.10.0")]
pub type OutputToken = FetchToken;
#[derive(Debug)]
pub struct SessionRunArgs<'l> {
input_ports: Vec<tf::TF_Output>,
input_tensors: Vec<&'l dyn AnyTensor>,
output_ports: Vec<tf::TF_Output>,
output_tensors: Vec<*mut tf::TF_Tensor>,
target_operations: Vec<*const tf::TF_Operation>,
run_options: Option<Buffer<u8>>,
run_metadata: Option<Vec<u8>>,
request_metadata: bool,
phantom: marker::PhantomData<&'l ()>,
}
unsafe impl<'l> Send for SessionRunArgs<'l> {}
unsafe impl<'l> Sync for SessionRunArgs<'l> {}
impl<'l> Default for SessionRunArgs<'l> {
fn default() -> Self {
Self::new()
}
}
impl<'l> SessionRunArgs<'l> {
pub fn new() -> Self {
SessionRunArgs {
input_ports: vec![],
input_tensors: vec![],
output_ports: vec![],
output_tensors: vec![],
run_options: None,
run_metadata: None,
request_metadata: false,
target_operations: vec![],
phantom: marker::PhantomData,
}
}
pub fn add_feed<T: TensorType>(
&mut self,
operation: &Operation,
index: c_int,
tensor: &'l Tensor<T>,
) {
self.input_ports.push(tf::TF_Output {
oper: operation.inner(),
index,
});
self.input_tensors.push(tensor);
}
#[deprecated(note = "Use add_feed instead.", since = "0.10.0")]
pub fn add_input<T: TensorType>(
&mut self,
operation: &Operation,
index: c_int,
tensor: &'l Tensor<T>,
) {
self.add_feed(operation, index, tensor)
}
pub fn request_fetch(&mut self, operation: &Operation, index: c_int) -> FetchToken {
self.output_ports.push(tf::TF_Output {
oper: operation.inner(),
index,
});
self.output_tensors.push(ptr::null_mut());
FetchToken {
index: self.output_tensors.len() - 1,
}
}
#[deprecated(note = "Use request_fetch instead.", since = "0.10.0")]
#[allow(deprecated)]
pub fn request_output(&mut self, operation: &Operation, index: c_int) -> OutputToken {
self.request_fetch(operation, index)
}
pub fn fetch<T: TensorType>(&mut self, token: FetchToken) -> Result<Tensor<T>> {
let output_idx = token.index;
if output_idx >= self.output_tensors.len() {
return Err(Status::new_set(
Code::OutOfRange,
&format!(
"Requested output index is out of range: {} vs \
{}",
output_idx,
self.output_tensors.len()
),
)
.unwrap());
}
if self.output_tensors[output_idx].is_null() {
return Err(Status::new_set(
Code::Unavailable,
"Output not available. Either it was already taken, or \
this step has not been sucessfully run yet.",
)
.unwrap());
}
let actual_data_type = self.output_data_type(output_idx).unwrap();
if actual_data_type != T::data_type() {
return Err(invalid_arg!(
"Requested tensor type does not match actual tensor type: \
{} vs {}",
actual_data_type,
T::data_type()
));
}
let tensor = unsafe { Tensor::from_tf_tensor(self.output_tensors[output_idx]).unwrap() };
self.output_tensors[output_idx] = ptr::null_mut();
Ok(tensor)
}
#[deprecated(note = "Use fetch instead.", since = "0.10.0")]
#[allow(deprecated)]
pub fn take_output<T: TensorType>(&mut self, token: OutputToken) -> Result<Tensor<T>> {
self.fetch(token)
}
pub fn add_target(&mut self, operation: &Operation) {
self.target_operations.push(operation.inner());
}
pub fn output_data_type(&self, output_idx: usize) -> Option<DataType> {
if output_idx >= self.output_tensors.len() {
return None;
}
if self.output_tensors[output_idx].is_null() {
return None;
}
unsafe {
Some(DataType::from_c(tf::TF_TensorType(
self.output_tensors[output_idx],
)))
}
}
pub fn set_run_options(&mut self, run_options: &[u8]) {
self.run_options = Some(Buffer::from(run_options))
}
pub fn get_run_options(&self) -> Option<&[u8]> {
self.run_options.as_ref().map(std::convert::AsRef::as_ref)
}
pub fn get_metadata(&mut self) -> Option<&[u8]> {
self.run_metadata.as_ref().map(std::convert::AsRef::as_ref)
}
pub fn set_request_metadata(&mut self, request: bool) {
self.request_metadata = request;
}
pub fn is_request_metadata(&self) -> bool {
self.request_metadata
}
fn drop_output_tensors(&mut self) {
for tensor in &mut self.output_tensors {
if !tensor.is_null() {
unsafe {
tf::TF_DeleteTensor(*tensor);
}
}
*tensor = ptr::null_mut();
}
}
fn maybe_reset_run_metadata(&mut self) {
self.run_metadata = None;
}
}
impl<'l> Drop for SessionRunArgs<'l> {
fn drop(&mut self) {
self.drop_output_tensors();
}
}
#[deprecated(note = "Use SessionRunArgs instead.", since = "0.10.0")]
pub type StepWithGraph<'l> = SessionRunArgs<'l>;
#[derive(Debug, Eq, PartialEq, Clone, Hash)]
pub struct Device {
pub name: String,
pub device_type: String,
pub memory_bytes: i64,
pub incarnation: u64,
}
#[cfg(test)]
mod tests {
use super::super::DataType;
use super::super::Graph;
use super::super::Operation;
use super::super::SessionOptions;
use super::super::Shape;
use super::super::Tensor;
use super::*;
use serial_test::serial;
fn create_session() -> (Session, Operation, Operation) {
let mut g = Graph::new();
let two = {
let mut nd = g.new_operation("Const", "two").unwrap();
nd.set_attr_type("dtype", DataType::Float).unwrap();
let mut value = Tensor::new(&[1]);
value[0] = 2.0f32;
nd.set_attr_tensor("value", value).unwrap();
nd.finish().unwrap()
};
let x = {
let mut nd = g.new_operation("Placeholder", "x").unwrap();
nd.set_attr_type("dtype", DataType::Float).unwrap();
nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
nd.finish().unwrap()
};
let y = {
let mut nd = g.new_operation("Mul", "y").unwrap();
nd.add_input(two);
nd.add_input(x.clone());
nd.finish().unwrap()
};
let options = SessionOptions::new();
match Session::new(&options, &g) {
Ok(session) => (session, x, y),
Err(status) => panic!("Creating session failed with status: {}", status),
}
}
#[test]
fn smoke() {
create_session();
}
#[test]
fn test_close() {
let (mut session, _, _) = create_session();
let status = session.close();
assert!(status.is_ok());
}
#[test]
fn test_run() {
let (session, x_operation, y_operation) = create_session();
let mut x = <Tensor<f32>>::new(&[2]);
x[0] = 2.0;
x[1] = 3.0;
let mut step = SessionRunArgs::new();
step.add_feed(&x_operation, 0, &x);
let output_token = step.request_fetch(&y_operation, 0);
session.run(&mut step).unwrap();
let output_tensor = step.fetch::<f32>(output_token).unwrap();
assert_eq!(output_tensor.len(), 2);
assert_eq!(output_tensor[0], 4.0);
assert_eq!(output_tensor[1], 6.0);
}
#[test]
#[serial] fn test_run_metadata() {
let (session, x_operation, y_operation) = create_session();
let x = Tensor::<f32>::from(&[2.0, 3.0][..]);
let mut step = SessionRunArgs::new();
step.add_feed(&x_operation, 0, &x);
step.set_run_options(&[8u8, 3u8]);
step.set_request_metadata(true);
step.set_request_metadata(true);
let output_token = step.request_fetch(&y_operation, 0);
session.run(&mut step).unwrap();
step.get_metadata().unwrap();
let output_tensor = step.fetch::<f32>(output_token).unwrap();
assert_eq!(output_tensor.len(), 2);
assert_eq!(output_tensor[0], 4.0);
assert_eq!(output_tensor[1], 6.0);
session.run(&mut step).unwrap();
step.get_metadata().unwrap();
let output_tensor = step.fetch::<f32>(output_token).unwrap();
assert_eq!(output_tensor.len(), 2);
assert_eq!(output_tensor[0], 4.0);
assert_eq!(output_tensor[1], 6.0);
}
#[test]
#[serial] fn test_run_options() {
let (session, x_operation, y_operation) = create_session();
let x = Tensor::<f32>::from(&[2.0, 3.0][..]);
let mut step = SessionRunArgs::new();
step.add_feed(&x_operation, 0, &x);
step.set_run_options(&[8u8, 3u8]);
let output_token = step.request_fetch(&y_operation, 0);
session.run(&mut step).unwrap();
let output_tensor = step.fetch::<f32>(output_token).unwrap();
assert_eq!(output_tensor.len(), 2);
assert_eq!(output_tensor[0], 4.0);
assert_eq!(output_tensor[1], 6.0);
}
#[test]
fn test_run_metadata_no_run_options() {
let (session, x_operation, y_operation) = create_session();
let x = Tensor::<f32>::from(&[2.0, 3.0][..]);
let mut step = SessionRunArgs::new();
step.add_feed(&x_operation, 0, &x);
step.set_request_metadata(true);
let output_token = step.request_fetch(&y_operation, 0);
session.run(&mut step).unwrap();
step.get_metadata().unwrap();
let output_tensor = step.fetch::<f32>(output_token).unwrap();
assert_eq!(output_tensor.len(), 2);
assert_eq!(output_tensor[0], 4.0);
assert_eq!(output_tensor[1], 6.0);
}
#[test]
fn test_savedmodelbundle() {
let mut graph = Graph::new();
let bundle = SavedModelBundle::load(
&SessionOptions::new(),
&["train", "serve"],
&mut graph,
"test_resources/regression-model",
)
.unwrap();
let x_op = graph.operation_by_name_required("x").unwrap();
let y_op = graph.operation_by_name_required("y").unwrap();
let y_hat_op = graph.operation_by_name_required("y_hat").unwrap();
let _train_op = graph.operation_by_name_required("train").unwrap();
#[allow(deprecated)]
let SavedModelBundle {
session,
meta_graph_def,
meta_graph: _,
} = bundle;
assert!(!meta_graph_def.is_empty());
let mut x = <Tensor<f32>>::new(&[1]);
x[0] = 2.0;
let mut y = <Tensor<f32>>::new(&[1]);
y[0] = 4.0;
let mut step = SessionRunArgs::new();
step.add_feed(&x_op, 0, &x);
step.add_feed(&y_op, 0, &y);
let output_token = step.request_fetch(&y_hat_op, 0);
session.run(&mut step).unwrap();
let output_tensor = step.fetch::<f32>(output_token).unwrap();
assert_eq!(output_tensor.len(), 1);
}
#[test]
fn test_device_list() {
let (session, _, _) = create_session();
let devices = session.device_list().unwrap();
assert!(
devices.iter().any(|d| d.device_type == "CPU"),
"devices: {:?}",
devices
);
}
}