use crate::{ops, Operation, Scope, Session, SessionRunArgs, Status, Tensor, Variable};
#[derive(Debug)]
struct SaveRestoreOps {
prefix_save: Operation,
prefix_restore: Operation,
save_op: Operation,
restore_op: Operation,
}
#[derive(Debug)]
pub struct CheckpointMaker {
scope: Scope,
variables: Box<[Variable]>,
save_restore_ops: Option<SaveRestoreOps>,
}
impl CheckpointMaker {
pub fn new(scope: Scope, variables: Box<[Variable]>) -> CheckpointMaker {
CheckpointMaker {
scope,
variables,
save_restore_ops: None,
}
}
fn build_save_ops(&mut self) -> Result<SaveRestoreOps, Status> {
let mut all_variable_ops_opt: Option<Vec<Operation>> = None;
let existing_save_op = self.scope.graph().operation_by_name("save")?;
let (prefix_save, save_op) = if let Some(op) = existing_save_op {
let prefix_save_op = self
.scope
.graph()
.operation_by_name_required("prefix_save")?;
(prefix_save_op, op)
} else {
let all_variable_ops = all_variable_ops_opt.get_or_insert_with(|| {
self.variables
.iter()
.map(|v| v.output.operation.clone())
.collect::<Vec<_>>()
});
let prefix_save = ops::Placeholder::new()
.dtype(crate::DataType::String)
.build(&mut self.scope.with_op_name("prefix_save"))?;
let tensor_names = ops::constant(
self.variables
.iter()
.map(|v| String::from(v.name()))
.collect::<Vec<_>>()
.as_slice(),
&mut self.scope,
)?;
let shape_and_slices = ops::constant(
&self
.variables
.iter()
.map(|_| "".to_string())
.collect::<Vec<_>>()[..],
&mut self.scope,
)?;
let tensors = all_variable_ops
.iter()
.map(|v| v.output(0))
.collect::<Vec<_>>();
let mut g = self.scope.graph_mut();
let mut nd = g.new_operation("SaveV2", "save")?;
nd.add_input(prefix_save.clone());
nd.add_input(tensor_names);
nd.add_input(shape_and_slices);
nd.add_input_list(&tensors[..]);
let dtypes = all_variable_ops
.iter()
.map(|v| v.get_attr_type("dtype"))
.collect::<Result<Vec<_>, Status>>()?;
nd.set_attr_type_list("dtypes", &dtypes[..])?;
let save_op = nd.finish()?;
(prefix_save, save_op)
};
let opt_restore_op = self.scope.graph().operation_by_name("restore")?;
let (prefix_restore, restore_op) = if let Some(op) = opt_restore_op {
let the_prefix_restore = self
.scope
.graph()
.operation_by_name_required("prefix_restore")?;
(the_prefix_restore, op)
} else {
let all_variable_ops = all_variable_ops_opt.get_or_insert_with(|| {
self.variables
.iter()
.map(|v| v.output.operation.clone())
.collect::<Vec<_>>()
});
let prefix_restore = ops::Placeholder::new()
.dtype(crate::DataType::String)
.build(&mut self.scope.with_op_name("prefix_restore"))?;
let all_var_names = self
.variables
.iter()
.map(|v| v.name.clone())
.collect::<Vec<_>>();
let tensor_names = ops::constant(&all_var_names[..], &mut self.scope)?;
let shape_and_slices = ops::constant(
&self
.variables
.iter()
.map(|_| "".to_string())
.collect::<Vec<_>>()[..],
&mut self.scope,
)?;
let mut g = self.scope.graph_mut();
let mut nd = g.new_operation("RestoreV2", "restore")?;
nd.add_input(prefix_restore.clone());
nd.add_input(tensor_names);
nd.add_input(shape_and_slices);
let dtypes = all_variable_ops
.iter()
.map(|v| v.get_attr_type("dtype"))
.collect::<Result<Vec<_>, Status>>()?;
nd.set_attr_type_list("dtypes", &dtypes[..])?;
let restore_op = nd.finish()?;
drop(g);
let mut restore_var_ops = Vec::<Operation>::new();
for (i, var) in self.variables.iter().enumerate() {
let var_op = var.output.operation.clone();
restore_var_ops.push(ops::assign(
var_op,
crate::Output {
operation: restore_op.clone(),
index: i as i32,
},
&mut self.scope.new_sub_scope(format!("restore{}", i).as_str()),
)?);
}
let mut no_op = ops::NoOp::new();
for op in restore_var_ops {
no_op = no_op.add_control_input(op);
}
(prefix_restore, no_op.build(&mut self.scope)?)
};
Ok(SaveRestoreOps {
prefix_save,
prefix_restore,
save_op,
restore_op,
})
}
fn get_save_operation(&mut self) -> Result<&SaveRestoreOps, Status> {
if self.save_restore_ops.is_none() {
self.save_restore_ops = Some(self.build_save_ops()?);
}
let save_r_op_ref = self.save_restore_ops.as_ref();
let save_r_op = unsafe { save_r_op_ref.unwrap_unchecked() };
Ok(save_r_op)
}
pub fn save(&mut self, session: &Session, backup_filename_base: &str) -> Result<(), Status> {
let save_restore_ops = self.get_save_operation()?;
let prefix_arg = Tensor::from(backup_filename_base.to_string());
let mut run_args = SessionRunArgs::new();
run_args.add_feed(&save_restore_ops.prefix_save, 0, &prefix_arg);
run_args.add_target(&save_restore_ops.save_op);
session.run(&mut run_args)?;
Ok(())
}
pub fn restore(&mut self, session: &Session, path_base: &str) -> Result<(), Status> {
let save_restore_ops = self.get_save_operation()?;
let prefix_arg = Tensor::from(path_base.to_string());
let mut run_args = SessionRunArgs::new();
run_args.add_feed(&save_restore_ops.prefix_restore, 0, &prefix_arg);
run_args.add_target(&save_restore_ops.restore_op);
session.run(&mut run_args)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::ops::Placeholder;
use crate::{
ops, CheckpointMaker, Code, DataType, FetchToken, Operation, Scope, Session,
SessionOptions, SessionRunArgs, Status, Tensor, Variable,
};
fn make_variable(
scope: &mut Scope,
name: &str,
dims: &[u64],
values: &[f32],
) -> Result<Variable, Status> {
Ok(Variable::builder()
.const_initial_value(Tensor::new(dims).with_values(values)?)
.data_type(DataType::Float)
.build(&mut scope.with_op_name(name))?)
}
fn create_assignment(
var: &Variable,
scope: &mut Scope,
) -> Result<(Operation, Operation), Status> {
let placeholder = Placeholder::new()
.dtype(DataType::Float)
.shape(var.shape.clone())
.build(&mut scope.with_op_name(var.name.as_str()))?;
Ok((
placeholder.clone(),
ops::assign(var.output.clone(), placeholder, scope)?,
))
}
struct MyScopeData {
scope: Scope,
variables: [Variable; 3],
}
fn create_scope() -> Result<MyScopeData, Status> {
let mut scope = Scope::new_root_scope();
let var_w = make_variable(&mut scope, "w", &[], &[2.2])?;
let var_b = make_variable(&mut scope, "b", &[3], &[1.0, 2.0, 4.5])?;
let var_a = make_variable(&mut scope, "a", &[3, 2], &[1.0, 2.0, 3.3, 7.0, 8.0, 8.5])?;
Ok(MyScopeData {
scope,
variables: [var_w, var_b, var_a],
})
}
struct AssignData {
pub placeholder_ops: Box<[Operation]>,
pub assign_op: Operation,
}
fn add_assign_op(scope_data: &mut MyScopeData) -> Result<AssignData, Status> {
let mut placeholder_scope = scope_data.scope.new_sub_scope("placeholder");
let mut placeholders: Vec<Operation> = Vec::new();
let mut no_op_bld = ops::NoOp::new();
for var in scope_data.variables.as_ref() {
let (placeholder, assign_op) = create_assignment(&var, &mut placeholder_scope)?;
placeholders.push(placeholder);
no_op_bld = no_op_bld.add_control_input(assign_op);
}
let assign_op = no_op_bld.build(&mut scope_data.scope)?;
Ok(AssignData {
placeholder_ops: placeholders.into_boxed_slice(),
assign_op,
})
}
fn assign_variables(
session: &Session,
scope_data: &MyScopeData,
assign_data: &AssignData,
values: &[&[f32]],
) -> Result<(), Status> {
let mut values_fed: Vec<Tensor<f32>> =
Vec::with_capacity(assign_data.placeholder_ops.len());
let mut session_run = SessionRunArgs::new();
for i_var in 0..assign_data.placeholder_ops.len() {
let value_fed_as_tensor = Tensor::new(
&scope_data.variables[i_var]
.shape()
.0
.as_ref()
.ok_or(Status::new_set(Code::Internal, "Shape not present")?)?
.iter()
.map(|o| {
o.map(|i| i as u64)
.ok_or(Status::new_set(Code::Internal, "Shape item not present")?)
})
.collect::<Result<Vec<u64>, Status>>()?
.as_ref(),
)
.with_values(&values[i_var])?;
values_fed.push(value_fed_as_tensor);
}
for i_var in 0..assign_data.placeholder_ops.len() {
session_run.add_feed(&assign_data.placeholder_ops[i_var], 0, &values_fed[i_var]);
}
session_run.add_target(&assign_data.assign_op);
session.run(&mut session_run)?;
Ok(())
}
fn check_variables(
session: &Session,
variables: &[Variable],
values: &[&[f32]],
) -> Result<(), Status> {
let mut session_run = SessionRunArgs::new();
let mut tokens: Vec<FetchToken> = Vec::with_capacity(variables.len());
for i in 0..variables.len() {
tokens.push(session_run.request_fetch(
&variables[i].output().operation,
variables[i].output().index,
));
}
session.run(&mut session_run)?;
for i in 0..variables.len() {
let got_tensor: Tensor<f32> = session_run.fetch(tokens[i])?;
assert_eq!(values[i], got_tensor.as_ref());
}
Ok(())
}
#[test]
fn simple_save() -> Result<(), Box<dyn std::error::Error>> {
let mut first_scope_data = create_scope()?;
let assign_data = add_assign_op(&mut first_scope_data)?;
let first_session = Session::new(&SessionOptions::new(), &first_scope_data.scope.graph())?;
let new_values: [&[f32]; 3] = [
&[5.1],
&[4.0, 2.2, 6.0],
&[11.0, 12.0, 13.6, 17.1, 18.4, 19.5],
];
assign_variables(&first_session, &first_scope_data, &assign_data, &new_values)?;
let mut checkpoint = CheckpointMaker::new(
first_scope_data.scope.new_sub_scope("checkpoint"),
Box::from(first_scope_data.variables.clone()),
);
let temp_dir = tempfile::tempdir()?;
let checkpoint_path = temp_dir.path().join("checkpoint-vars");
let checkpoint_path_str = checkpoint_path
.into_os_string()
.into_string()
.map_err(|_| "Cannot convert checkpoint path")?;
checkpoint.save(&first_session, checkpoint_path_str.as_str())?;
let MyScopeData {
scope: second_scope,
variables: second_variables,
} = create_scope()?;
let second_session = Session::new(&SessionOptions::new(), &second_scope.graph())?;
let mut second_checkpoint =
CheckpointMaker::new(second_scope, Box::new(second_variables.clone()));
second_checkpoint.restore(&second_session, checkpoint_path_str.as_str())?;
check_variables(&second_session, &second_variables, &new_values)?;
Ok(())
}
}