#![allow(dead_code)] use libc::c_float;
use libc::c_int;
use libc::c_uchar;
use libc::c_void;
use libc::size_t;
use std::ffi::{CStr, CString};
use std::marker::PhantomData;
use std::mem::{self, ManuallyDrop};
use std::os::raw::c_void as std_c_void;
use std::ptr;
use crate::eager::{Context, TensorHandle};
use crate::{AnyTensor, Code, DataType, Result, Shape, Status};
use tensorflow_sys as tf;
#[cfg(test)]
mod op_test_util;
#[allow(
non_snake_case,
clippy::too_many_arguments,
clippy::derivable_impls,
clippy::needless_lifetimes
)]
pub mod raw_ops;
#[derive(Debug)]
struct Op<'a> {
inner: *mut tf::TFE_Op,
ctx: PhantomData<&'a Context>,
}
impl<'a> Drop for Op<'a> {
fn drop(&mut self) {
unsafe {
tf::TFE_DeleteOp(self.inner);
}
}
}
struct OpContext<'a> {
ctx: ManuallyDrop<Context>,
lifetime: PhantomData<&'a Context>,
}
impl<'a> Op<'a> {
fn new(ctx: &'a Context, op_or_function_name: &str) -> Result<Self> {
let status = Status::new();
let c_op_or_function_name = CString::new(op_or_function_name)?;
let inner =
unsafe { tf::TFE_NewOp(ctx.inner, c_op_or_function_name.as_ptr(), status.inner) };
if inner.is_null() || !status.is_ok() {
return Err(status);
}
Ok(Self {
inner,
ctx: PhantomData,
})
}
fn get_name(&self) -> Result<&str> {
let status = Status::new();
let c_name = unsafe { tf::TFE_OpGetName(self.inner, status.inner) };
status.into_result()?;
let name = unsafe { CStr::from_ptr(c_name).to_str()? };
Ok(name)
}
fn get_context(&self) -> Result<OpContext<'a>> {
let status = Status::new();
let inner = unsafe { tf::TFE_OpGetContext(self.inner, status.inner) };
status.into_result()?;
let ctx = ManuallyDrop::new(Context { inner });
Ok(OpContext {
ctx,
lifetime: PhantomData,
})
}
fn add_input(&mut self, input: &TensorHandle) -> Result<()> {
let status = Status::new();
unsafe {
tf::TFE_OpAddInput(self.inner, input.inner, status.inner);
};
status.into_result()
}
fn set_device(&mut self, device_name: &str) -> Result<()> {
let status = Status::new();
let c_device_name = CString::new(device_name)?;
unsafe {
tf::TFE_OpSetDevice(self.inner, c_device_name.as_ptr(), status.inner);
}
status.into_result()
}
fn get_device(&self) -> Result<&str> {
let status = Status::new();
let c_device_name = unsafe { tf::TFE_OpGetDevice(self.inner, status.inner) };
status.into_result()?;
let device_name = unsafe { CStr::from_ptr(c_device_name).to_str()? };
Ok(device_name)
}
fn add_input_list(&mut self, inputs: &[TensorHandle]) -> Result<()> {
let status = Status::new();
unsafe {
let mut inputs: Vec<*mut tf::TFE_TensorHandle> =
inputs.iter().map(|v| v.inner).collect();
tf::TFE_OpAddInputList(
self.inner,
inputs.as_mut_ptr(),
inputs.len() as c_int,
status.inner,
);
};
status.into_result()
}
fn set_attr_string(&mut self, attr_name: &str, value: &str) -> Result<()> {
let attr_name = CString::new(attr_name)?;
let c_value = value.as_bytes();
unsafe {
tf::TFE_OpSetAttrString(
self.inner,
attr_name.as_ptr(),
c_value.as_ptr() as *const std_c_void,
c_value.len() as size_t,
);
}
Ok(())
}
fn set_attr_string_list<S: AsRef<str>>(&mut self, attr_name: &str, values: &[S]) -> Result<()> {
let c_attr_name = CString::new(attr_name)?;
let bytes: Vec<&[u8]> = values.iter().map(|x| x.as_ref().as_bytes()).collect();
let ptrs: Vec<*const c_void> = bytes.iter().map(|x| x.as_ptr() as *const c_void).collect();
let lens: Vec<size_t> = bytes.iter().map(|x| x.len() as size_t).collect();
unsafe {
tf::TFE_OpSetAttrStringList(
self.inner,
c_attr_name.as_ptr(),
ptrs.as_ptr() as *const *const std_c_void,
lens.as_ptr(),
ptrs.len() as c_int,
);
}
Ok(())
}
fn set_attr_int(&mut self, attr_name: &str, value: i64) -> Result<()> {
let c_attr_name = CString::new(attr_name)?;
unsafe {
tf::TFE_OpSetAttrInt(self.inner, c_attr_name.as_ptr(), value);
}
Ok(())
}
fn set_attr_int_list(&mut self, attr_name: &str, value: &[i64]) -> Result<()> {
let c_attr_name = CString::new(attr_name)?;
unsafe {
tf::TFE_OpSetAttrIntList(
self.inner,
c_attr_name.as_ptr(),
value.as_ptr(),
value.len() as i32,
);
}
Ok(())
}
fn set_attr_float(&mut self, attr_name: &str, value: f32) -> Result<()> {
let c_attr_name = CString::new(attr_name)?;
unsafe {
tf::TFE_OpSetAttrFloat(self.inner, c_attr_name.as_ptr(), value);
}
Ok(())
}
fn set_attr_float_list(&mut self, attr_name: &str, value: &[f32]) -> Result<()> {
let c_attr_name = CString::new(attr_name)?;
let c_value: Vec<c_float> = value.iter().map(|x| *x as c_float).collect();
unsafe {
tf::TFE_OpSetAttrFloatList(
self.inner,
c_attr_name.as_ptr(),
c_value.as_ptr(),
c_value.len() as i32,
);
}
Ok(())
}
fn set_attr_bool(&mut self, attr_name: &str, value: bool) -> Result<()> {
let c_attr_name = CString::new(attr_name)?;
unsafe {
tf::TFE_OpSetAttrBool(self.inner, c_attr_name.as_ptr(), if value { 1 } else { 0 });
}
Ok(())
}
fn set_attr_bool_list(&mut self, attr_name: &str, value: &[bool]) -> Result<()> {
let c_attr_name = CString::new(attr_name)?;
let c_value: Vec<c_uchar> = value.iter().map(|x| if *x { 1 } else { 0 }).collect();
unsafe {
tf::TFE_OpSetAttrBoolList(
self.inner,
c_attr_name.as_ptr(),
c_value.as_ptr(),
c_value.len() as c_int,
);
}
Ok(())
}
fn set_attr_type(&mut self, attr_name: &str, value: DataType) -> Result<()> {
let c_attr_name = CString::new(attr_name)?;
unsafe {
tf::TFE_OpSetAttrType(self.inner, c_attr_name.as_ptr(), value.to_c());
}
Ok(())
}
fn set_attr_type_list(&mut self, attr_name: &str, value: &[DataType]) -> Result<()> {
let c_attr_name = CString::new(attr_name)?;
let c_value: Vec<tf::TF_DataType> = value.iter().map(|x| x.to_c()).collect();
unsafe {
tf::TFE_OpSetAttrTypeList(
self.inner,
c_attr_name.as_ptr(),
c_value.as_ptr(),
c_value.len() as i32,
);
}
Ok(())
}
fn set_attr_shape(&mut self, attr_name: &str, value: &Shape) -> Result<()> {
let status = Status::new();
let c_attr_name = CString::new(attr_name)?;
unsafe {
match value.0 {
None => tf::TFE_OpSetAttrShape(
self.inner,
c_attr_name.as_ptr(),
ptr::null(),
-1,
status.inner,
),
Some(ref dims) => {
let c_dims: Vec<i64> = dims.iter().map(|x| (*x).unwrap_or(-1)).collect();
tf::TFE_OpSetAttrShape(
self.inner,
c_attr_name.as_ptr(),
c_dims.as_ptr(),
c_dims.len() as i32,
status.inner,
);
}
}
}
status.into_result()
}
fn set_attr_shape_list(&mut self, attr_name: &str, value: &[Shape]) -> Result<()> {
let status = Status::new();
let c_attr_name = CString::new(attr_name)?;
let c_dims: Vec<Option<Vec<i64>>> = value
.iter()
.map(|x| {
x.0.as_ref()
.map(|dims| dims.iter().map(|x| (*x).unwrap_or(-1)).collect())
})
.collect();
let mut ptrs: Vec<*const i64> = c_dims
.iter()
.map(|x| match *x {
None => ptr::null(),
Some(ref dims) => dims.as_ptr(),
})
.collect();
let lens: Vec<c_int> = value
.iter()
.map(|x| match x.0 {
None => -1,
Some(ref dims) => dims.len() as c_int,
})
.collect();
unsafe {
tf::TFE_OpSetAttrShapeList(
self.inner,
c_attr_name.as_ptr(),
ptrs.as_mut_ptr(),
lens.as_ptr(),
ptrs.len() as c_int,
status.inner,
);
}
status.into_result()
}
fn set_attr_any_tensor(&mut self, attr_name: &str, value: &dyn AnyTensor) -> Result<()> {
let c_attr_name = CString::new(attr_name)?;
let mut status = Status::new();
unsafe {
tf::TFE_OpSetAttrTensor(
self.inner,
c_attr_name.as_ptr(),
value.inner()?,
status.inner(),
);
}
status.into_result()
}
fn execute<const N: usize>(self, ctx: &'a Context) -> Result<[TensorHandle; N]> {
let status = Status::new();
let mut num_retvals = N as i32;
let mut retvals: [*mut tf::TFE_TensorHandle; N] = [ptr::null_mut(); N];
unsafe {
tf::TFE_Execute(
self.inner,
retvals.as_mut_ptr(),
&mut num_retvals,
status.inner,
);
}
status.into_result()?;
if num_retvals != N as i32 {
for i in 0..num_retvals as usize {
unsafe {
tf::TFE_DeleteTensorHandle(retvals[i]);
}
}
let status = Status::new_set_lossy(
Code::InvalidArgument,
&format!("Expected {} outputs, got {}", N, num_retvals),
);
return Err(status);
}
let mut handles_uninit: [mem::MaybeUninit<TensorHandle>; N] =
unsafe { mem::MaybeUninit::uninit().assume_init() };
for i in 0..N {
let t = unsafe { TensorHandle::from_tensor_handle(ctx, retvals[i]) };
handles_uninit[i].write(t);
}
let ptr = &mut handles_uninit as *mut _ as *mut [TensorHandle; N];
let handles: [TensorHandle; N] = unsafe { ptr.read() };
mem::forget(handles_uninit);
Ok(handles)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eager::{Context, ContextOptions, TensorHandle};
use crate::Tensor;
use op_test_util::add as add_ut;
use raw_ops::{add, concat_v2};
#[cfg(feature = "ndarray")]
use ndarray::array;
#[test]
fn test_add_op() {
let ctx = Context::new(ContextOptions::new()).unwrap();
let x = Tensor::new(&[2, 2])
.with_values(&[1i32, 2, 3, 4])
.unwrap()
.freeze();
let h_x = TensorHandle::new(&ctx, &x).unwrap();
let h_y = h_x.copy_sharing_tensor().unwrap();
let op_name = "Add";
let mut op = Op::new(&ctx, op_name).unwrap();
op.add_input(&h_x).unwrap();
op.add_input(&h_y).unwrap();
const NUMBER_OF_OUTPUTS: usize = 1;
let [h] = op.execute::<NUMBER_OF_OUTPUTS>(&ctx).unwrap();
let z = h.resolve::<i32>().unwrap();
let expected = Tensor::new(&[2, 2]).with_values(&[2i32, 4, 6, 8]).unwrap();
assert_eq!(z, expected);
}
#[test]
fn test_invalid_add() {
let ctx = Context::new(ContextOptions::new()).unwrap();
let x = Tensor::new(&[2, 2])
.with_values(&[1i32, 2, 3, 4])
.unwrap()
.freeze();
let h_x = TensorHandle::new(&ctx, &x).unwrap();
let h_y = h_x.copy_sharing_tensor().unwrap();
let op_name = "Add";
let mut op = Op::new(&ctx, op_name).unwrap();
op.add_input(&h_x).unwrap();
op.add_input(&h_y).unwrap();
const WRONG_NUMBER_OF_OUTPUTS: usize = 2;
let res = op.execute::<WRONG_NUMBER_OF_OUTPUTS>(&ctx);
assert!(res.is_err());
}
#[test]
fn test_add_ut() {
let values = [1i32, 2, 3, 4];
let ctx = Context::new(ContextOptions::new()).unwrap();
let x = Tensor::new(&[2, 2]).with_values(&values).unwrap().freeze();
let h_x = TensorHandle::new(&ctx, &x).unwrap();
let h_y = h_x.copy_sharing_tensor().unwrap();
let expected = Tensor::new(&[2, 2]).with_values(&[2i32, 4, 6, 8]).unwrap();
let h_z = add_ut(&ctx, &x, &x).unwrap();
let z = h_z.resolve::<i32>().unwrap();
assert_eq!(z, expected);
let h_z = add_ut(&ctx, &x, &h_y).unwrap();
let z = h_z.resolve::<i32>().unwrap();
assert_eq!(z, expected);
let h_z = add_ut(&ctx, &h_x, &x).unwrap();
let z = h_z.resolve::<i32>().unwrap();
assert_eq!(z, expected);
let h_z = add_ut(&ctx, &h_x, &h_y).unwrap();
let z = h_z.resolve::<i32>().unwrap();
assert_eq!(z, expected);
}
#[test]
fn test_raw_ops_add() {
let values = [1i32, 2, 3, 4];
let ctx = Context::new(ContextOptions::new()).unwrap();
let x = Tensor::new(&[2, 2]).with_values(&values).unwrap().freeze();
let h_x = TensorHandle::new(&ctx, &x).unwrap();
let h_y = h_x.copy_sharing_tensor().unwrap();
let expected = Tensor::new(&[2, 2]).with_values(&[2i32, 4, 6, 8]).unwrap();
let h_z = add(&ctx, &x, &x).unwrap();
let z = h_z.resolve::<i32>().unwrap();
assert_eq!(z, expected);
let h_z = add(&ctx, &x, &h_y).unwrap();
let z = h_z.resolve::<i32>().unwrap();
assert_eq!(z, expected);
let h_z = add(&ctx, &h_x, &x).unwrap();
let z = h_z.resolve::<i32>().unwrap();
assert_eq!(z, expected);
let h_z = add(&ctx, &h_x, &h_y).unwrap();
let z = h_z.resolve::<i32>().unwrap();
assert_eq!(z, expected);
}
#[test]
fn test_raw_ops_concat() {
let values = [1i32, 2, 3, 4];
let ctx = Context::new(ContextOptions::new()).unwrap();
let h = Tensor::new(&[2, 2])
.with_values(&values)
.unwrap()
.into_handle(&ctx)
.unwrap();
let h_z = concat_v2(&ctx, &[&h, &h], &Tensor::from(0i32).freeze()).unwrap();
let z = h_z.resolve::<i32>().unwrap();
let expected = Tensor::new(&[4, 2])
.with_values(&[1i32, 2, 3, 4, 1, 2, 3, 4])
.unwrap();
assert_eq!(z, expected);
let h_z = concat_v2(&ctx, &[&h, &h], &Tensor::from(1i32).freeze()).unwrap();
let z = h_z.resolve::<i32>().unwrap();
let expected = Tensor::new(&[2, 4])
.with_values(&[1i32, 2, 1, 2, 3, 4, 3, 4])
.unwrap();
assert_eq!(z, expected);
}
fn test_add_tensor_and_others() {
let values = [1i32, 2, 3, 4];
let ctx = Context::new(ContextOptions::new()).unwrap();
let h = Tensor::new(&[2, 2])
.with_values(&values)
.unwrap()
.into_handle(&ctx)
.unwrap();
let h_z = add(&ctx, &h, &1).unwrap();
let z = h_z.resolve::<i32>().unwrap();
let expected = Tensor::new(&[2, 2]).with_values(&[2i32, 3, 4, 5]).unwrap();
assert_eq!(z, expected);
let h_z = add(&ctx, &h, &[1]).unwrap();
let z = h_z.resolve::<i32>().unwrap();
let expected = Tensor::new(&[2, 2]).with_values(&[2i32, 3, 4, 5]).unwrap();
assert_eq!(z, expected);
let h_z = add(&ctx, &h, &[1, 2]).unwrap();
let z = h_z.resolve::<i32>().unwrap();
let expected = Tensor::new(&[2, 2]).with_values(&[2i32, 4, 4, 6]).unwrap();
assert_eq!(z, expected);
}
#[cfg(feature = "ndarray")]
#[test]
fn test_add_tensor_and_ndarray() {
let values = [1i32, 2, 3, 4];
let ctx = Context::new(ContextOptions::new()).unwrap();
let h = Tensor::new(&[2, 2])
.with_values(&values)
.unwrap()
.into_handle(&ctx)
.unwrap();
let h_z = add(&ctx, &h, &array![1]).unwrap();
let z = h_z.resolve::<i32>().unwrap();
let expected = Tensor::new(&[2, 2]).with_values(&[2i32, 3, 4, 5]).unwrap();
assert_eq!(z, expected);
let h_z = add(&ctx, &h, &array![[1]]).unwrap();
let z = h_z.resolve::<i32>().unwrap();
let expected = Tensor::new(&[2, 2]).with_values(&[2i32, 3, 4, 5]).unwrap();
assert_eq!(z, expected);
let h_z = add(&ctx, &h, &array![1, 2]).unwrap();
let z = h_z.resolve::<i32>().unwrap();
let expected = Tensor::new(&[2, 2]).with_values(&[2i32, 4, 4, 6]).unwrap();
assert_eq!(z, expected);
}
#[cfg(feature = "tensorflow_gpu")]
#[test]
#[ignore]
fn test_add_gpu() {
let opts = ContextOptions::new();
let ctx = Context::new(opts).unwrap();
let devices = ctx.device_list().unwrap();
let gpu_device = devices
.iter()
.find(|d| d.device_type == "GPU")
.expect("No GPU device was found.");
let target_device = &gpu_device.name;
let x = Tensor::new(&[2, 2])
.with_values(&[1.0f32, 2.0, 3.0, 4.0])
.unwrap()
.freeze();
let h = TensorHandle::new(&ctx, &x).unwrap();
let h_gpu = h.copy_to_device(&ctx, target_device).unwrap();
let op_name = "Add";
let mut op = Op::new(&ctx, op_name).unwrap();
op.add_input(&h).unwrap();
op.add_input(&h_gpu).unwrap();
op.set_device(target_device).unwrap();
let [h_z_gpu] = op.execute(&ctx).unwrap();
assert!(&h_z_gpu.device_name().unwrap() == target_device);
let z = h_z_gpu.resolve::<f32>().unwrap();
let expected = [2.0f32, 4.0, 6.0, 8.0];
for (v0, v1) in z.iter().zip(&expected) {
assert!((v0 - v1).abs() < f32::EPSILON);
}
}
}