use std::ffi::CStr;
use tensorflow_sys as tf;
use crate::{Device, Result, Status};
#[derive(Debug)]
pub struct ContextOptions {
inner: *mut tf::TFE_ContextOptions,
}
impl_new!(
ContextOptions,
TFE_NewContextOptions,
"Creates a blank set of context options."
);
impl_drop!(ContextOptions, TFE_DeleteContextOptions);
impl ContextOptions {
pub fn set_config(&mut self, config: &[u8]) -> Result<()> {
let mut status = Status::new();
unsafe {
tf::TFE_ContextOptionsSetConfig(
self.inner,
config.as_ptr() as *const _,
config.len(),
status.inner(),
);
}
status.into_result()
}
pub fn set_async(&mut self, enable: bool) {
unsafe {
tf::TFE_ContextOptionsSetAsync(self.inner, enable as u8);
}
}
}
#[derive(Debug)]
pub struct Context {
pub(crate) inner: *mut tf::TFE_Context,
}
impl_drop!(Context, TFE_DeleteContext);
impl Context {
pub fn new(opts: ContextOptions) -> Result<Self> {
let status = Status::new();
let inner = unsafe { tf::TFE_NewContext(opts.inner, status.inner) };
if inner.is_null() {
Err(status)
} else {
Ok(Context { inner })
}
}
pub fn device_list(&self) -> Result<Vec<Device>> {
let status = Status::new();
unsafe {
let list = tf::TFE_ContextListDevices(self.inner, status.inner);
if !status.is_ok() {
return Err(status);
}
let result = (|| {
let n = tf::TF_DeviceListCount(list);
let mut devices = Vec::with_capacity(n as usize);
for i in 0..n {
let c_name = tf::TF_DeviceListName(list, i, status.inner);
if !status.is_ok() {
return Err(status);
}
let c_type = tf::TF_DeviceListType(list, i, status.inner);
if !status.is_ok() {
return Err(status);
}
let bytes = tf::TF_DeviceListMemoryBytes(list, i, status.inner);
if !status.is_ok() {
return Err(status);
}
let incarnation = tf::TF_DeviceListIncarnation(list, i, status.inner);
if !status.is_ok() {
return Err(status);
}
devices.push(Device {
name: CStr::from_ptr(c_name).to_str()?.to_string(),
device_type: CStr::from_ptr(c_type).to_str()?.to_string(),
memory_bytes: bytes,
incarnation,
});
}
Ok(devices)
})();
tf::TF_DeleteDeviceList(list);
result
}
}
pub fn clear_caches(&mut self) {
unsafe {
tf::TFE_ContextClearCaches(self.inner);
}
}
}
unsafe impl std::marker::Send for Context {}
unsafe impl std::marker::Sync for Context {}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_create_context() {
let opts = ContextOptions::new();
Context::new(opts).unwrap();
}
#[test]
fn test_create_async_context() {
let mut opts = ContextOptions::new();
opts.set_async(true);
Context::new(opts).unwrap();
}
#[test]
fn test_context_set_config() {
use crate::protos::config::{ConfigProto, GPUOptions};
use protobuf::Message;
let gpu_options = GPUOptions {
per_process_gpu_memory_fraction: 0.5,
allow_growth: true,
..Default::default()
};
let mut config = ConfigProto::new();
config.set_gpu_options(gpu_options);
let mut buf = vec![];
config.write_to_writer(&mut buf).unwrap();
let mut opts = ContextOptions::new();
opts.set_config(&buf).unwrap();
Context::new(opts).unwrap();
}
#[test]
fn test_device_list() {
let opts = ContextOptions::new();
let ctx = Context::new(opts).unwrap();
let devices = ctx.device_list().unwrap();
for d in &devices {
assert_ne!(String::from(""), d.name);
}
}
}