tensorflow/
saved_model.rs

1use super::ops;
2use super::protos;
3use super::Code;
4use super::DataType;
5use super::Graph;
6use super::Operation;
7use super::Output;
8use super::OutputName;
9use super::Result;
10use super::Scope;
11use super::Session;
12use super::SessionRunArgs;
13use super::Shape;
14use super::Status;
15use super::Tensor;
16use super::Variable;
17use protobuf::Message;
18use protobuf::ProtobufError;
19use std::borrow::Borrow;
20use std::collections::HashMap;
21use std::error::Error;
22use std::fmt;
23use std::fmt::Display;
24use std::fmt::Formatter;
25use std::fs;
26use std::fs::File;
27use std::io;
28use std::io::Write;
29use std::path::Path;
30
31/// Key in the signature def map for `default` serving signatures. The default
32/// signature is used in inference requests where a specific signature was not
33/// specified.
34pub const DEFAULT_SERVING_SIGNATURE_DEF_KEY: &str = "serving_default";
35
36/// Classification inputs.
37pub const CLASSIFY_INPUTS: &str = "inputs";
38
39/// Classification method name used in a SignatureDef.
40pub const CLASSIFY_METHOD_NAME: &str = "tensorflow/serving/classify";
41
42/// Classification classes output.
43pub const CLASSIFY_OUTPUT_CLASSES: &str = "classes";
44
45/// Classification scores output.
46pub const CLASSIFY_OUTPUT_SCORES: &str = "scores";
47
48/// Predict inputs.
49pub const PREDICT_INPUTS: &str = "inputs";
50
51/// Prediction method name used in a SignatureDef.
52pub const PREDICT_METHOD_NAME: &str = "tensorflow/serving/predict";
53
54/// Predict outputs.
55pub const PREDICT_OUTPUTS: &str = "outputs";
56
57/// Regression inputs.
58pub const REGRESS_INPUTS: &str = "inputs";
59
60/// Regression method name used in a SignatureDef.
61pub const REGRESS_METHOD_NAME: &str = "tensorflow/serving/regress";
62
63///  Regression outputs.
64pub const REGRESS_OUTPUTS: &str = "outputs";
65
66/// Error generated while saving a model.
67#[derive(Debug)]
68pub struct SaveModelError {
69    source: Box<dyn Error>,
70}
71
72impl SaveModelError {
73    // We don't use From, because we don't want this to be public API.
74    fn from_protobuf_error(e: ProtobufError) -> Self {
75        Self {
76            source: Box::new(e),
77        }
78    }
79}
80
81impl Display for SaveModelError {
82    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
83        write!(f, "SaveModelError: {}", &self.source)
84    }
85}
86
87impl Error for SaveModelError {
88    fn source(&self) -> Option<&(dyn Error + 'static)> {
89        Some(self.source.borrow())
90    }
91}
92
93impl From<Status> for SaveModelError {
94    fn from(e: Status) -> Self {
95        Self {
96            source: Box::new(e),
97        }
98    }
99}
100
101impl From<io::Error> for SaveModelError {
102    fn from(e: io::Error) -> Self {
103        Self {
104            source: Box::new(e),
105        }
106    }
107}
108
109#[derive(Debug, Clone, PartialEq, Eq, Default)]
110/// Information about a Tensor necessary for feeding or retrieval.
111pub struct TensorInfo {
112    dtype: DataType,
113    shape: Shape,
114    name: OutputName,
115}
116
117impl TensorInfo {
118    /// Creates a TensorInfo.
119    pub fn new(dtype: DataType, shape: Shape, name: OutputName) -> TensorInfo {
120        TensorInfo { dtype, shape, name }
121    }
122
123    /// Returns the name of the tensor.
124    pub fn name(&self) -> &OutputName {
125        &self.name
126    }
127
128    /// Returns the data type of the tensor.
129    pub fn dtype(&self) -> DataType {
130        self.dtype
131    }
132
133    /// Returns the shape of the tensor.
134    pub fn shape(&self) -> &Shape {
135        &self.shape
136    }
137
138    // We don't use Into, because we don't want this to be public API.
139    fn into_proto(self) -> protos::meta_graph::TensorInfo {
140        let mut info = protos::meta_graph::TensorInfo::new();
141        info.set_dtype(self.dtype.into_proto());
142        info.set_tensor_shape(self.shape.into_proto());
143        info.set_name(self.name.to_string());
144        info
145    }
146
147    // We don't use From, because we don't want this to be public API.
148    fn from_proto(proto: &protos::meta_graph::TensorInfo) -> Result<Self> {
149        Ok(Self {
150            dtype: DataType::from_proto(proto.get_dtype()),
151            shape: Shape::from_proto(proto.get_tensor_shape()),
152            name: proto.get_name().parse()?,
153        })
154    }
155}
156
157#[derive(Debug, Clone)]
158/// SignatureDef defines the signature of a computation supported by a
159/// TensorFlow graph.
160pub struct SignatureDef {
161    method_name: String,
162    inputs: HashMap<String, TensorInfo>,
163    outputs: HashMap<String, TensorInfo>,
164}
165
166impl SignatureDef {
167    /// Creates a SignatureDef with the given method name.
168    pub fn new(method_name: String) -> SignatureDef {
169        SignatureDef {
170            method_name,
171            inputs: HashMap::new(),
172            outputs: HashMap::new(),
173        }
174    }
175
176    /// Adds an input parameter.
177    pub fn add_input_info(&mut self, name: String, info: TensorInfo) {
178        self.inputs.insert(name, info);
179    }
180
181    /// Adds an output parameter.
182    pub fn add_output_info(&mut self, name: String, info: TensorInfo) {
183        self.outputs.insert(name, info);
184    }
185
186    /// Returns the method name.
187    pub fn method_name(&self) -> &str {
188        &self.method_name
189    }
190
191    /// Returns the input parameters.
192    pub fn inputs(&self) -> &HashMap<String, TensorInfo> {
193        &self.inputs
194    }
195
196    /// Returns the output parameters.
197    pub fn outputs(&self) -> &HashMap<String, TensorInfo> {
198        &self.outputs
199    }
200
201    /// Returns the given input parameter.
202    pub fn get_input(&self, name: &str) -> Result<&TensorInfo> {
203        self.inputs.get(name).ok_or_else(|| {
204            Status::new_set_lossy(
205                Code::InvalidArgument,
206                &format!("Input '{}' not found", name),
207            )
208        })
209    }
210
211    /// Returns the given output parameter.
212    pub fn get_output(&self, name: &str) -> Result<&TensorInfo> {
213        self.outputs.get(name).ok_or_else(|| {
214            Status::new_set_lossy(
215                Code::InvalidArgument,
216                &format!("Output '{}' not found", name),
217            )
218        })
219    }
220
221    // We don't use Into, because we don't want this to be public API.
222    fn into_proto(self) -> protos::meta_graph::SignatureDef {
223        let mut signature_def = protos::meta_graph::SignatureDef::new();
224        signature_def.set_method_name(self.method_name);
225        for (name, info) in self.inputs {
226            signature_def.mut_inputs().insert(name, info.into_proto());
227        }
228        for (name, info) in self.outputs {
229            signature_def.mut_outputs().insert(name, info.into_proto());
230        }
231        signature_def
232    }
233
234    // We don't use From, because we don't want this to be public API.
235    fn from_proto(proto: &protos::meta_graph::SignatureDef) -> Result<Self> {
236        let mut inputs = HashMap::new();
237        let mut outputs = HashMap::new();
238        for (key, proto) in proto.get_inputs() {
239            inputs.insert(key.clone(), TensorInfo::from_proto(proto)?);
240        }
241        for (key, proto) in proto.get_outputs() {
242            outputs.insert(key.clone(), TensorInfo::from_proto(proto)?);
243        }
244        Ok(Self {
245            method_name: proto.get_method_name().to_string(),
246            inputs,
247            outputs,
248        })
249    }
250}
251
252#[derive(Debug, Clone)]
253/// Contains data necessary to restart training, run inference. It can be used
254/// to serialize/de-serialize memory objects necessary for running computation
255/// in a graph when crossing the process boundary. It can be used for long term
256/// storage of graphs, cross-language execution of graphs, etc.
257pub struct MetaGraphDef {
258    // TODO: support all fields
259    signatures: HashMap<String, SignatureDef>,
260}
261
262impl MetaGraphDef {
263    // We don't use From, because we don't want this to be public API.
264    pub(crate) fn from_serialized_proto(data: &[u8]) -> Result<Self> {
265        let proto: protos::meta_graph::MetaGraphDef = protobuf::Message::parse_from_bytes(data)
266            .map_err(|e| {
267                Status::new_set_lossy(
268                    Code::InvalidArgument,
269                    &format!("Invalid serialized MetaGraphDef: {}", e),
270                )
271            })?;
272        let mut signatures = HashMap::new();
273        for (key, signature_proto) in proto.get_signature_def() {
274            signatures.insert(key.clone(), SignatureDef::from_proto(signature_proto)?);
275        }
276        Ok(Self { signatures })
277    }
278
279    /// Returns the defined signatures.
280    pub fn signatures(&self) -> &HashMap<String, SignatureDef> {
281        &self.signatures
282    }
283
284    /// Returns the specified signature.
285    pub fn get_signature(&self, name: &str) -> Result<&SignatureDef> {
286        self.signatures.get(name).ok_or_else(|| {
287            Status::new_set_lossy(Code::Internal, &format!("Signature '{}' not found", name))
288        })
289    }
290}
291
292/// Builds a SavedModelSaver, which can be used to save models.
293#[derive(Debug)]
294pub struct SavedModelBuilder {
295    collections: HashMap<String, Vec<Variable>>,
296    tags: Vec<String>,
297    signatures: HashMap<String, SignatureDef>,
298}
299
300impl Default for SavedModelBuilder {
301    fn default() -> Self {
302        Self::new()
303    }
304}
305
306impl SavedModelBuilder {
307    /// Creates a new SavedModelBuilder.
308    pub fn new() -> SavedModelBuilder {
309        SavedModelBuilder {
310            collections: HashMap::new(),
311            tags: Vec::new(),
312            signatures: HashMap::new(),
313        }
314    }
315
316    /// Adds a collection to be saved.
317    pub fn add_collection(&mut self, key: &str, variables: &[Variable]) -> &mut Self {
318        self.collections.insert(key.to_string(), variables.to_vec());
319        self
320    }
321
322    /// Adds a tag.
323    pub fn add_tag(&mut self, tag: &str) -> &mut Self {
324        self.tags.push(tag.to_string());
325        self
326    }
327
328    /// Adds a signature.
329    pub fn add_signature(&mut self, key: &str, signature_def: SignatureDef) -> &mut Self {
330        self.signatures.insert(key.to_string(), signature_def);
331        self
332    }
333
334    /// Adds ops to the graph necessary for saving and restoring models,
335    /// returning a SavedModelSaver which handles the actual model saving.
336    pub fn inject(self, scope: &mut Scope) -> Result<SavedModelSaver> {
337        let all_vars = self.collections.values().flatten().collect::<Vec<_>>();
338        let prefix = ops::Placeholder::new()
339            .dtype(DataType::String)
340            .build(scope)?;
341        let save_op = {
342            let tensor_names = ops::constant(
343                &all_vars
344                    .iter()
345                    .map(|v| v.name().to_string())
346                    .collect::<Vec<_>>()[..],
347                scope,
348            )?;
349            let shape_and_slices = ops::constant(
350                &all_vars.iter().map(|_| "".to_string()).collect::<Vec<_>>()[..],
351                scope,
352            )?;
353            let tensors = all_vars
354                .iter()
355                .map(|v| v.output().clone())
356                .collect::<Vec<_>>();
357            let mut g = scope.graph_mut();
358            let mut nd = g.new_operation("SaveV2", "save")?;
359            nd.add_input(prefix.clone());
360            nd.add_input(tensor_names);
361            nd.add_input(shape_and_slices);
362            nd.add_input_list(&tensors[..]);
363            nd.set_attr_type_list(
364                "dtypes",
365                &all_vars.iter().map(|v| v.data_type()).collect::<Vec<_>>()[..],
366            )?;
367            nd.finish()?
368        };
369
370        let filename_tensor = ops::Placeholder::new()
371            .dtype(DataType::String)
372            .build(scope)?;
373        let restore_op = {
374            let all_var_names = all_vars
375                .iter()
376                .map(|v| v.name().to_string())
377                .collect::<Vec<_>>();
378            let tensor_names = ops::constant(&all_var_names[..], scope)?;
379            let shape_and_slices = ops::constant(
380                &all_vars.iter().map(|_| "".to_string()).collect::<Vec<_>>()[..],
381                scope,
382            )?;
383            let mut g = scope.graph_mut();
384            let mut nd = g.new_operation("RestoreV2", "restore")?;
385            nd.add_input(filename_tensor.clone());
386            nd.add_input(tensor_names);
387            nd.add_input(shape_and_slices);
388            nd.set_attr_type_list(
389                "dtypes",
390                &all_vars.iter().map(|v| v.data_type()).collect::<Vec<_>>()[..],
391            )?;
392            nd.finish()?
393        };
394        let really_restore_op = {
395            let mut restore_var_ops = Vec::<Operation>::new();
396            for (i, var) in all_vars.iter().enumerate() {
397                restore_var_ops.push(ops::assign(
398                    var.output().clone(),
399                    Output {
400                        operation: restore_op.clone(),
401                        index: i as i32,
402                    },
403                    scope,
404                )?);
405            }
406            let mut no_op = ops::NoOp::new();
407            for op in restore_var_ops {
408                no_op = no_op.add_control_input(op);
409            }
410            no_op.build(scope)?
411        };
412
413        SavedModelSaver::new(
414            filename_tensor.name()?,
415            prefix,
416            save_op,
417            really_restore_op.name()?,
418            self.collections,
419            self.tags,
420            self.signatures,
421        )
422    }
423}
424
425#[derive(Debug)]
426/// Creates saved models. Use a SavedModelBuilder to create a SavedModelSaver.
427pub struct SavedModelSaver {
428    meta_graph: protos::meta_graph::MetaGraphDef,
429    prefix: Operation,
430    save_op: Operation,
431}
432
433impl SavedModelSaver {
434    fn new(
435        filename_tensor_name: String,
436        prefix: Operation,
437        save_op: Operation,
438        restore_op_name: String,
439        collections: HashMap<String, Vec<Variable>>,
440        tags: Vec<String>,
441        signatures: HashMap<String, SignatureDef>,
442    ) -> Result<SavedModelSaver> {
443        let mut meta_graph = protos::meta_graph::MetaGraphDef::new();
444        meta_graph
445            .mut_saver_def()
446            .set_filename_tensor_name(filename_tensor_name);
447        meta_graph
448            .mut_saver_def()
449            .set_restore_op_name(restore_op_name);
450        for (key, variables) in collections {
451            let mut trainable_variables_bytes_list =
452                protos::meta_graph::CollectionDef_BytesList::new();
453            for variable in variables {
454                let mut variable_def = protos::variable::VariableDef::new();
455                variable_def.set_variable_name(variable.name().to_string());
456                trainable_variables_bytes_list.mut_value().push(
457                    match variable_def.write_to_bytes() {
458                        Ok(x) => x,
459                        Err(e) => {
460                            return Err(Status::new_set_lossy(
461                                Code::InvalidArgument,
462                                &format!("Unable to encode variable definition: {}", e),
463                            ));
464                        }
465                    },
466                );
467            }
468            let mut trainable_collection_def = protos::meta_graph::CollectionDef::new();
469            trainable_collection_def.set_bytes_list(trainable_variables_bytes_list);
470            meta_graph
471                .mut_collection_def()
472                .insert(key.to_string(), trainable_collection_def);
473        }
474        let graph_tags = meta_graph.mut_meta_info_def().mut_tags();
475        for tag in tags {
476            graph_tags.push(tag);
477        }
478        let graph_signatures = meta_graph.mut_signature_def();
479        for (key, signature) in signatures {
480            graph_signatures.insert(key, signature.into_proto());
481        }
482        Ok(SavedModelSaver {
483            meta_graph,
484            prefix,
485            save_op,
486        })
487    }
488
489    /// Saves the graph and current variable values as a saved model.
490    pub fn save<P: AsRef<Path>>(
491        &self,
492        session: &Session,
493        graph: &Graph,
494        save_dir: P,
495    ) -> std::result::Result<(), SaveModelError> {
496        let mut meta_graph = self.meta_graph.clone();
497        let graph_bytes = graph.graph_def()?;
498        let graph_def = protobuf::Message::parse_from_bytes(&graph_bytes).map_err(|e| {
499            SaveModelError::from(Status::new_set_lossy(
500                Code::InvalidArgument,
501                &format!("Unable to parse graph definition: {}", e),
502            ))
503        })?;
504        meta_graph.set_graph_def(graph_def);
505        let mut saved_model = protos::saved_model::SavedModel::new();
506        saved_model.set_saved_model_schema_version(1);
507        saved_model.mut_meta_graphs().push(meta_graph);
508        let saved_model_bytes = saved_model
509            .write_to_bytes()
510            .map_err(SaveModelError::from_protobuf_error)?;
511        fs::create_dir(save_dir.as_ref())?;
512        let mut file = File::create(save_dir.as_ref().join("saved_model.pb"))?;
513        file.write_all(&saved_model_bytes)?;
514        let prefix = Tensor::from(
515            save_dir
516                .as_ref()
517                .join("variables/variables")
518                .to_str()
519                .ok_or_else(|| {
520                    Status::new_set(Code::OutOfRange, "Path is not valid Unicode").unwrap()
521                })?
522                .to_string(),
523        );
524
525        let mut run_args = SessionRunArgs::new();
526        run_args.add_feed(&self.prefix, 0, &prefix);
527        run_args.add_target(&self.save_op);
528        session.run(&mut run_args)?;
529        Ok(())
530    }
531}