1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
use crate::{
write_tensor_recursive, AnyTensor, DataType, Result, Shape, Tensor, TensorInner, TensorType,
};
use core::fmt;
use fmt::{Debug, Formatter};
use libc::c_int;
use std::{fmt::Display, ops::Deref};
use tensorflow_sys as tf;
/// A read-only tensor.
///
/// ReadonlyTensor is a [`Tensor`](Tensor) that does not support mutation.
#[derive(Clone, Eq)]
pub struct ReadonlyTensor<T: TensorType> {
pub(super) inner: T::InnerType,
pub(super) dims: Vec<u64>,
}
impl<T: TensorType> AnyTensor for ReadonlyTensor<T> {
fn inner(&self) -> Result<*mut tf::TF_Tensor> {
self.inner.as_mut_ptr(&self.dims)
}
fn data_type(&self) -> DataType {
T::data_type()
}
}
impl<T: TensorType> Deref for ReadonlyTensor<T> {
type Target = [T];
#[inline]
fn deref(&self) -> &[T] {
self.inner.deref()
}
}
impl<T: TensorType> Display for ReadonlyTensor<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> ::std::fmt::Result {
let mut counter: i64 = match std::env::var("TF_RUST_DISPLAY_MAX") {
Ok(e) => e.parse().unwrap_or(-1),
Err(_) => -1,
};
write_tensor_recursive(f, self, self.dims(), &mut counter)
}
}
impl<T: TensorType> Debug for ReadonlyTensor<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
crate::format_tensor(self, "ReadonlyTensor", self.dims(), f)
}
}
impl<T: TensorType + PartialEq> PartialEq for ReadonlyTensor<T> {
fn eq(&self, other: &Self) -> bool {
self.dims == other.dims && self.deref() == other.deref()
}
}
impl<T: TensorType + PartialEq> PartialEq<Tensor<T>> for ReadonlyTensor<T> {
fn eq(&self, other: &Tensor<T>) -> bool {
self.dims == other.dims && self.deref() == other.deref()
}
}
impl<T: TensorType> ReadonlyTensor<T> {
/// Get one single value from the Tensor.
///
/// ```
/// # use tensorflow::Tensor;
/// # use tensorflow::eager::ReadonlyTensor;
/// let mut a = Tensor::<i32>::new(&[2, 3, 5]);
///
/// a[1*15 + 1*5 + 1] = 5;
/// let a: ReadonlyTensor<_> = a.freeze();
/// assert_eq!(a.get(&[1, 1, 1]), 5);
/// ```
pub fn get(&self, indices: &[u64]) -> T {
let index = self.get_index(indices);
self[index].clone()
}
/// Get the array index from rows / columns indices.
///
/// ```
/// # use tensorflow::Tensor;
/// # use tensorflow::eager::ReadonlyTensor;
/// let a: ReadonlyTensor<_> = Tensor::<f32>::new(&[3, 3, 3]).freeze();
///
/// assert_eq!(a.get_index(&[2, 2, 2]), 26);
/// assert_eq!(a.get_index(&[1, 2, 2]), 17);
/// assert_eq!(a.get_index(&[1, 2, 0]), 15);
/// assert_eq!(a.get_index(&[1, 0, 1]), 10);
/// ```
pub fn get_index(&self, indices: &[u64]) -> usize {
assert!(self.dims.len() == indices.len());
let mut index = 0;
let mut d = 1;
for i in (0..indices.len()).rev() {
assert!(self.dims[i] > indices[i]);
index += indices[i] * d;
d *= self.dims[i];
}
index as usize
}
/// Returns the tensor's dimensions.
pub fn dims(&self) -> &[u64] {
&self.dims
}
/// Returns the tensor's dimensions as a Shape.
pub fn shape(&self) -> Shape {
Shape::from(&self.dims[..])
}
// Wraps a TF_Tensor. Returns None if types don't match.
pub(super) unsafe fn from_tf_tensor(tensor: *mut tf::TF_Tensor) -> Option<Self> {
let mut dims = Vec::with_capacity(tf::TF_NumDims(tensor) as usize);
for i in 0..dims.capacity() {
dims.push(tf::TF_Dim(tensor, i as c_int) as u64);
}
Some(Self {
inner: T::InnerType::from_tf_tensor(tensor)?,
dims,
})
}
/// Convert back to a Tensor.
///
/// # Safety
///
/// This is unsafe because modifying the returned Tensor will modify the underlying memory,
/// which may affect other Tensors that share the same memory.
///
/// ```
/// # use tensorflow::{Tensor, Result};
/// # use tensorflow::eager::*;
/// # fn main() -> Result<()> {
/// let ctx = Context::new(ContextOptions::new()).unwrap();
/// let tensor = Tensor::from(0i32).freeze();
/// let h = tensor.to_handle(&ctx).unwrap();
///
/// let t0 = h.resolve::<i32>().unwrap();
/// assert_eq!(t0[0], 0i32);
///
/// // Manipulating the Tensor will affect the Tensor that shares underlying buffer.
/// {
/// // Getting multiple times should return the same Tensor.
/// let t1 = h.resolve::<i32>().unwrap();
///
/// // Convert back from a TensorHandle to a Tensor.
/// let mut t1 = unsafe { t1.into_tensor() };
/// t1[0] = 5;
/// }
///
/// // Check that t0 shares the same underlying buffer with t1.
/// // This is why we need to use unsafe.
/// assert_eq!(t0[0], 5);
/// # Ok(())
/// # }
/// ```
pub unsafe fn into_tensor(self) -> Tensor<T> {
Tensor {
inner: self.inner,
dims: self.dims,
}
}
}