1use super::ops;
2use super::protos;
3use super::Code;
4use super::DataType;
5use super::Graph;
6use super::Operation;
7use super::Output;
8use super::OutputName;
9use super::Result;
10use super::Scope;
11use super::Session;
12use super::SessionRunArgs;
13use super::Shape;
14use super::Status;
15use super::Tensor;
16use super::Variable;
17use protobuf::Message;
18use protobuf::ProtobufError;
19use std::borrow::Borrow;
20use std::collections::HashMap;
21use std::error::Error;
22use std::fmt;
23use std::fmt::Display;
24use std::fmt::Formatter;
25use std::fs;
26use std::fs::File;
27use std::io;
28use std::io::Write;
29use std::path::Path;
30
31pub const DEFAULT_SERVING_SIGNATURE_DEF_KEY: &str = "serving_default";
35
36pub const CLASSIFY_INPUTS: &str = "inputs";
38
39pub const CLASSIFY_METHOD_NAME: &str = "tensorflow/serving/classify";
41
42pub const CLASSIFY_OUTPUT_CLASSES: &str = "classes";
44
45pub const CLASSIFY_OUTPUT_SCORES: &str = "scores";
47
48pub const PREDICT_INPUTS: &str = "inputs";
50
51pub const PREDICT_METHOD_NAME: &str = "tensorflow/serving/predict";
53
54pub const PREDICT_OUTPUTS: &str = "outputs";
56
57pub const REGRESS_INPUTS: &str = "inputs";
59
60pub const REGRESS_METHOD_NAME: &str = "tensorflow/serving/regress";
62
63pub const REGRESS_OUTPUTS: &str = "outputs";
65
66#[derive(Debug)]
68pub struct SaveModelError {
69 source: Box<dyn Error>,
70}
71
72impl SaveModelError {
73 fn from_protobuf_error(e: ProtobufError) -> Self {
75 Self {
76 source: Box::new(e),
77 }
78 }
79}
80
81impl Display for SaveModelError {
82 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
83 write!(f, "SaveModelError: {}", &self.source)
84 }
85}
86
87impl Error for SaveModelError {
88 fn source(&self) -> Option<&(dyn Error + 'static)> {
89 Some(self.source.borrow())
90 }
91}
92
93impl From<Status> for SaveModelError {
94 fn from(e: Status) -> Self {
95 Self {
96 source: Box::new(e),
97 }
98 }
99}
100
101impl From<io::Error> for SaveModelError {
102 fn from(e: io::Error) -> Self {
103 Self {
104 source: Box::new(e),
105 }
106 }
107}
108
109#[derive(Debug, Clone, PartialEq, Eq, Default)]
110pub struct TensorInfo {
112 dtype: DataType,
113 shape: Shape,
114 name: OutputName,
115}
116
117impl TensorInfo {
118 pub fn new(dtype: DataType, shape: Shape, name: OutputName) -> TensorInfo {
120 TensorInfo { dtype, shape, name }
121 }
122
123 pub fn name(&self) -> &OutputName {
125 &self.name
126 }
127
128 pub fn dtype(&self) -> DataType {
130 self.dtype
131 }
132
133 pub fn shape(&self) -> &Shape {
135 &self.shape
136 }
137
138 fn into_proto(self) -> protos::meta_graph::TensorInfo {
140 let mut info = protos::meta_graph::TensorInfo::new();
141 info.set_dtype(self.dtype.into_proto());
142 info.set_tensor_shape(self.shape.into_proto());
143 info.set_name(self.name.to_string());
144 info
145 }
146
147 fn from_proto(proto: &protos::meta_graph::TensorInfo) -> Result<Self> {
149 Ok(Self {
150 dtype: DataType::from_proto(proto.get_dtype()),
151 shape: Shape::from_proto(proto.get_tensor_shape()),
152 name: proto.get_name().parse()?,
153 })
154 }
155}
156
157#[derive(Debug, Clone)]
158pub struct SignatureDef {
161 method_name: String,
162 inputs: HashMap<String, TensorInfo>,
163 outputs: HashMap<String, TensorInfo>,
164}
165
166impl SignatureDef {
167 pub fn new(method_name: String) -> SignatureDef {
169 SignatureDef {
170 method_name,
171 inputs: HashMap::new(),
172 outputs: HashMap::new(),
173 }
174 }
175
176 pub fn add_input_info(&mut self, name: String, info: TensorInfo) {
178 self.inputs.insert(name, info);
179 }
180
181 pub fn add_output_info(&mut self, name: String, info: TensorInfo) {
183 self.outputs.insert(name, info);
184 }
185
186 pub fn method_name(&self) -> &str {
188 &self.method_name
189 }
190
191 pub fn inputs(&self) -> &HashMap<String, TensorInfo> {
193 &self.inputs
194 }
195
196 pub fn outputs(&self) -> &HashMap<String, TensorInfo> {
198 &self.outputs
199 }
200
201 pub fn get_input(&self, name: &str) -> Result<&TensorInfo> {
203 self.inputs.get(name).ok_or_else(|| {
204 Status::new_set_lossy(
205 Code::InvalidArgument,
206 &format!("Input '{}' not found", name),
207 )
208 })
209 }
210
211 pub fn get_output(&self, name: &str) -> Result<&TensorInfo> {
213 self.outputs.get(name).ok_or_else(|| {
214 Status::new_set_lossy(
215 Code::InvalidArgument,
216 &format!("Output '{}' not found", name),
217 )
218 })
219 }
220
221 fn into_proto(self) -> protos::meta_graph::SignatureDef {
223 let mut signature_def = protos::meta_graph::SignatureDef::new();
224 signature_def.set_method_name(self.method_name);
225 for (name, info) in self.inputs {
226 signature_def.mut_inputs().insert(name, info.into_proto());
227 }
228 for (name, info) in self.outputs {
229 signature_def.mut_outputs().insert(name, info.into_proto());
230 }
231 signature_def
232 }
233
234 fn from_proto(proto: &protos::meta_graph::SignatureDef) -> Result<Self> {
236 let mut inputs = HashMap::new();
237 let mut outputs = HashMap::new();
238 for (key, proto) in proto.get_inputs() {
239 inputs.insert(key.clone(), TensorInfo::from_proto(proto)?);
240 }
241 for (key, proto) in proto.get_outputs() {
242 outputs.insert(key.clone(), TensorInfo::from_proto(proto)?);
243 }
244 Ok(Self {
245 method_name: proto.get_method_name().to_string(),
246 inputs,
247 outputs,
248 })
249 }
250}
251
252#[derive(Debug, Clone)]
253pub struct MetaGraphDef {
258 signatures: HashMap<String, SignatureDef>,
260}
261
262impl MetaGraphDef {
263 pub(crate) fn from_serialized_proto(data: &[u8]) -> Result<Self> {
265 let proto: protos::meta_graph::MetaGraphDef = protobuf::Message::parse_from_bytes(data)
266 .map_err(|e| {
267 Status::new_set_lossy(
268 Code::InvalidArgument,
269 &format!("Invalid serialized MetaGraphDef: {}", e),
270 )
271 })?;
272 let mut signatures = HashMap::new();
273 for (key, signature_proto) in proto.get_signature_def() {
274 signatures.insert(key.clone(), SignatureDef::from_proto(signature_proto)?);
275 }
276 Ok(Self { signatures })
277 }
278
279 pub fn signatures(&self) -> &HashMap<String, SignatureDef> {
281 &self.signatures
282 }
283
284 pub fn get_signature(&self, name: &str) -> Result<&SignatureDef> {
286 self.signatures.get(name).ok_or_else(|| {
287 Status::new_set_lossy(Code::Internal, &format!("Signature '{}' not found", name))
288 })
289 }
290}
291
292#[derive(Debug)]
294pub struct SavedModelBuilder {
295 collections: HashMap<String, Vec<Variable>>,
296 tags: Vec<String>,
297 signatures: HashMap<String, SignatureDef>,
298}
299
300impl Default for SavedModelBuilder {
301 fn default() -> Self {
302 Self::new()
303 }
304}
305
306impl SavedModelBuilder {
307 pub fn new() -> SavedModelBuilder {
309 SavedModelBuilder {
310 collections: HashMap::new(),
311 tags: Vec::new(),
312 signatures: HashMap::new(),
313 }
314 }
315
316 pub fn add_collection(&mut self, key: &str, variables: &[Variable]) -> &mut Self {
318 self.collections.insert(key.to_string(), variables.to_vec());
319 self
320 }
321
322 pub fn add_tag(&mut self, tag: &str) -> &mut Self {
324 self.tags.push(tag.to_string());
325 self
326 }
327
328 pub fn add_signature(&mut self, key: &str, signature_def: SignatureDef) -> &mut Self {
330 self.signatures.insert(key.to_string(), signature_def);
331 self
332 }
333
334 pub fn inject(self, scope: &mut Scope) -> Result<SavedModelSaver> {
337 let all_vars = self.collections.values().flatten().collect::<Vec<_>>();
338 let prefix = ops::Placeholder::new()
339 .dtype(DataType::String)
340 .build(scope)?;
341 let save_op = {
342 let tensor_names = ops::constant(
343 &all_vars
344 .iter()
345 .map(|v| v.name().to_string())
346 .collect::<Vec<_>>()[..],
347 scope,
348 )?;
349 let shape_and_slices = ops::constant(
350 &all_vars.iter().map(|_| "".to_string()).collect::<Vec<_>>()[..],
351 scope,
352 )?;
353 let tensors = all_vars
354 .iter()
355 .map(|v| v.output().clone())
356 .collect::<Vec<_>>();
357 let mut g = scope.graph_mut();
358 let mut nd = g.new_operation("SaveV2", "save")?;
359 nd.add_input(prefix.clone());
360 nd.add_input(tensor_names);
361 nd.add_input(shape_and_slices);
362 nd.add_input_list(&tensors[..]);
363 nd.set_attr_type_list(
364 "dtypes",
365 &all_vars.iter().map(|v| v.data_type()).collect::<Vec<_>>()[..],
366 )?;
367 nd.finish()?
368 };
369
370 let filename_tensor = ops::Placeholder::new()
371 .dtype(DataType::String)
372 .build(scope)?;
373 let restore_op = {
374 let all_var_names = all_vars
375 .iter()
376 .map(|v| v.name().to_string())
377 .collect::<Vec<_>>();
378 let tensor_names = ops::constant(&all_var_names[..], scope)?;
379 let shape_and_slices = ops::constant(
380 &all_vars.iter().map(|_| "".to_string()).collect::<Vec<_>>()[..],
381 scope,
382 )?;
383 let mut g = scope.graph_mut();
384 let mut nd = g.new_operation("RestoreV2", "restore")?;
385 nd.add_input(filename_tensor.clone());
386 nd.add_input(tensor_names);
387 nd.add_input(shape_and_slices);
388 nd.set_attr_type_list(
389 "dtypes",
390 &all_vars.iter().map(|v| v.data_type()).collect::<Vec<_>>()[..],
391 )?;
392 nd.finish()?
393 };
394 let really_restore_op = {
395 let mut restore_var_ops = Vec::<Operation>::new();
396 for (i, var) in all_vars.iter().enumerate() {
397 restore_var_ops.push(ops::assign(
398 var.output().clone(),
399 Output {
400 operation: restore_op.clone(),
401 index: i as i32,
402 },
403 scope,
404 )?);
405 }
406 let mut no_op = ops::NoOp::new();
407 for op in restore_var_ops {
408 no_op = no_op.add_control_input(op);
409 }
410 no_op.build(scope)?
411 };
412
413 SavedModelSaver::new(
414 filename_tensor.name()?,
415 prefix,
416 save_op,
417 really_restore_op.name()?,
418 self.collections,
419 self.tags,
420 self.signatures,
421 )
422 }
423}
424
425#[derive(Debug)]
426pub struct SavedModelSaver {
428 meta_graph: protos::meta_graph::MetaGraphDef,
429 prefix: Operation,
430 save_op: Operation,
431}
432
433impl SavedModelSaver {
434 fn new(
435 filename_tensor_name: String,
436 prefix: Operation,
437 save_op: Operation,
438 restore_op_name: String,
439 collections: HashMap<String, Vec<Variable>>,
440 tags: Vec<String>,
441 signatures: HashMap<String, SignatureDef>,
442 ) -> Result<SavedModelSaver> {
443 let mut meta_graph = protos::meta_graph::MetaGraphDef::new();
444 meta_graph
445 .mut_saver_def()
446 .set_filename_tensor_name(filename_tensor_name);
447 meta_graph
448 .mut_saver_def()
449 .set_restore_op_name(restore_op_name);
450 for (key, variables) in collections {
451 let mut trainable_variables_bytes_list =
452 protos::meta_graph::CollectionDef_BytesList::new();
453 for variable in variables {
454 let mut variable_def = protos::variable::VariableDef::new();
455 variable_def.set_variable_name(variable.name().to_string());
456 trainable_variables_bytes_list.mut_value().push(
457 match variable_def.write_to_bytes() {
458 Ok(x) => x,
459 Err(e) => {
460 return Err(Status::new_set_lossy(
461 Code::InvalidArgument,
462 &format!("Unable to encode variable definition: {}", e),
463 ));
464 }
465 },
466 );
467 }
468 let mut trainable_collection_def = protos::meta_graph::CollectionDef::new();
469 trainable_collection_def.set_bytes_list(trainable_variables_bytes_list);
470 meta_graph
471 .mut_collection_def()
472 .insert(key.to_string(), trainable_collection_def);
473 }
474 let graph_tags = meta_graph.mut_meta_info_def().mut_tags();
475 for tag in tags {
476 graph_tags.push(tag);
477 }
478 let graph_signatures = meta_graph.mut_signature_def();
479 for (key, signature) in signatures {
480 graph_signatures.insert(key, signature.into_proto());
481 }
482 Ok(SavedModelSaver {
483 meta_graph,
484 prefix,
485 save_op,
486 })
487 }
488
489 pub fn save<P: AsRef<Path>>(
491 &self,
492 session: &Session,
493 graph: &Graph,
494 save_dir: P,
495 ) -> std::result::Result<(), SaveModelError> {
496 let mut meta_graph = self.meta_graph.clone();
497 let graph_bytes = graph.graph_def()?;
498 let graph_def = protobuf::Message::parse_from_bytes(&graph_bytes).map_err(|e| {
499 SaveModelError::from(Status::new_set_lossy(
500 Code::InvalidArgument,
501 &format!("Unable to parse graph definition: {}", e),
502 ))
503 })?;
504 meta_graph.set_graph_def(graph_def);
505 let mut saved_model = protos::saved_model::SavedModel::new();
506 saved_model.set_saved_model_schema_version(1);
507 saved_model.mut_meta_graphs().push(meta_graph);
508 let saved_model_bytes = saved_model
509 .write_to_bytes()
510 .map_err(SaveModelError::from_protobuf_error)?;
511 fs::create_dir(save_dir.as_ref())?;
512 let mut file = File::create(save_dir.as_ref().join("saved_model.pb"))?;
513 file.write_all(&saved_model_bytes)?;
514 let prefix = Tensor::from(
515 save_dir
516 .as_ref()
517 .join("variables/variables")
518 .to_str()
519 .ok_or_else(|| {
520 Status::new_set(Code::OutOfRange, "Path is not valid Unicode").unwrap()
521 })?
522 .to_string(),
523 );
524
525 let mut run_args = SessionRunArgs::new();
526 run_args.add_feed(&self.prefix, 0, &prefix);
527 run_args.add_target(&self.save_op);
528 session.run(&mut run_args)?;
529 Ok(())
530 }
531}