tensorflow/
session.rs

1use super::AnyTensor;
2use super::Buffer;
3use super::Code;
4use super::DataType;
5use super::Graph;
6use super::MetaGraphDef;
7use super::Operation;
8use super::Result;
9use super::SessionOptions;
10use super::Status;
11use super::Tensor;
12use super::TensorType;
13use crate::tf;
14use libc::{c_char, c_int};
15use std::ffi::CStr;
16use std::ffi::CString;
17use std::marker;
18use std::path::Path;
19use std::ptr;
20
21/// Aggregation type for a saved model bundle.
22#[derive(Debug)]
23pub struct SavedModelBundle {
24    /// The loaded session.
25    pub session: Session,
26    /// A meta graph definition as raw protocol buffer. This is deprecated in favour of the
27    /// deserialized type.
28    #[deprecated(
29        note = "Please use SavedModelBundle::meta_graph_def() instead",
30        since = "0.16.0"
31    )]
32    pub meta_graph_def: Vec<u8>,
33    /// A decoded meta-graph definition.
34    meta_graph: MetaGraphDef,
35}
36
37impl SavedModelBundle {
38    /// Loads a session from an exported model, creating a bundle
39    pub fn load<P: AsRef<Path>, Tag: AsRef<str>, Tags: IntoIterator<Item = Tag>>(
40        options: &SessionOptions,
41        tags: Tags,
42        graph: &mut Graph,
43        export_dir: P,
44    ) -> Result<SavedModelBundle> {
45        let mut status = Status::new();
46
47        let export_dir_cstr = export_dir
48            .as_ref()
49            .to_str()
50            .and_then(|s| CString::new(s.as_bytes()).ok())
51            .ok_or_else(|| invalid_arg!("Invalid export directory path"))?;
52
53        let tags_cstr: Vec<_> = tags
54            .into_iter()
55            .map(|t| CString::new(t.as_ref()))
56            .collect::<::std::result::Result<_, _>>()
57            .map_err(|_| invalid_arg!("Invalid tag name"))?;
58        // keeping tags_cstr to retain strings in memory
59        let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect();
60
61        // The empty TF_Buffer will be filled by LoadSessionFromSavedModel
62        let mut meta = unsafe { Buffer::<u8>::from_ptr(ptr::null_mut(), 0) };
63
64        let inner = unsafe {
65            tf::TF_LoadSessionFromSavedModel(
66                options.inner,
67                ptr::null(),
68                export_dir_cstr.as_ptr(),
69                tags_ptr.as_ptr(),
70                tags_ptr.len() as c_int,
71                graph.inner(),
72                meta.inner_mut(),
73                status.inner(),
74            )
75        };
76        if inner.is_null() {
77            Err(status)
78        } else {
79            let session = Session { inner };
80            #[allow(deprecated)]
81            Ok(SavedModelBundle {
82                session,
83                meta_graph_def: Vec::from(meta.as_ref()),
84                meta_graph: MetaGraphDef::from_serialized_proto(meta.as_ref())?,
85            })
86        }
87    }
88
89    /// Returns the metagraph definition for the saved model.
90    pub fn meta_graph_def(&self) -> &MetaGraphDef {
91        &self.meta_graph
92    }
93}
94
95/// Manages a single graph and execution.
96#[derive(Debug)]
97pub struct Session {
98    inner: *mut tf::TF_Session,
99}
100
101impl Session {
102    /// Creates a session.
103    /// `graph` will be be kept alive for the lifetime of the returned session.
104    /// New nodes can still be added to `graph` after this call.
105    pub fn new(options: &SessionOptions, graph: &Graph) -> Result<Self> {
106        let mut status = Status::new();
107        let inner = unsafe { tf::TF_NewSession(graph.inner(), options.inner, status.inner()) };
108        if inner.is_null() {
109            Err(status)
110        } else {
111            Ok(Session { inner })
112        }
113    }
114
115    /// Loads a session from an exported model.
116    #[deprecated(note = "Please use SavedModelBundle::load() instead", since = "0.17.0")]
117    pub fn from_saved_model<P: AsRef<Path>, Tag: AsRef<str>, Tags: IntoIterator<Item = Tag>>(
118        options: &SessionOptions,
119        tags: Tags,
120        graph: &mut Graph,
121        export_dir: P,
122    ) -> Result<Self> {
123        Ok(SavedModelBundle::load(options, tags, graph, export_dir)?.session)
124    }
125
126    /// Closes the session.
127    pub fn close(&mut self) -> Result<()> {
128        let mut status = Status::new();
129        unsafe {
130            tf::TF_CloseSession(self.inner, status.inner());
131        }
132        status.into_result()
133    }
134
135    /// Runs the graph, feeding the inputs and then fetching the outputs
136    /// requested in the step.  Note that the session has interior mutability;
137    /// this may mutate variables in the graph, and the caller is responsible
138    /// for handling race conditions.
139    pub fn run(&self, step: &mut SessionRunArgs<'_>) -> Result<()> {
140        // In case we're running it a second time and not all outputs were taken out.
141        step.drop_output_tensors();
142        // make sure run_metadata is either None or an empty TF_Buffer
143        step.maybe_reset_run_metadata();
144
145        let mut status = Status::new();
146        let maybe_tensors: Result<_> = step.input_tensors.iter().map(|t| t.inner()).collect();
147        let input_tensors: Vec<_> = maybe_tensors?;
148        let run_options_ptr = match step.run_options.as_ref() {
149            Some(buf) => buf.inner(),
150            None => ptr::null(),
151        };
152
153        let mut run_metadata_buf = if step.request_metadata {
154            Some(unsafe { Buffer::new_unallocated() })
155        } else {
156            None
157        };
158        let run_metadata_ptr = match run_metadata_buf.as_mut() {
159            Some(meta) => meta.inner_mut(),
160            None => ptr::null_mut(),
161        };
162        unsafe {
163            tf::TF_SessionRun(
164                self.inner,
165                run_options_ptr,
166                step.input_ports.as_ptr(),
167                input_tensors.as_ptr() as *const *mut tf::TF_Tensor,
168                input_tensors.len() as c_int,
169                step.output_ports.as_ptr(),
170                step.output_tensors.as_mut_ptr(),
171                step.output_tensors.len() as c_int,
172                step.target_operations.as_mut_ptr(),
173                step.target_operations.len() as c_int,
174                run_metadata_ptr,
175                status.inner(),
176            );
177            step.run_metadata = run_metadata_buf.map(Into::into);
178        }
179
180        status.into_result()
181    }
182
183    /// Lists all devices in a session.
184    pub fn device_list(&self) -> Result<Vec<Device>> {
185        let status = Status::new();
186        unsafe {
187            let list = tf::TF_SessionListDevices(self.inner, status.inner);
188            if !status.is_ok() {
189                return Err(status);
190            }
191            let result = (|| {
192                let n = tf::TF_DeviceListCount(list);
193                let mut devices = Vec::with_capacity(n as usize);
194                for i in 0..n {
195                    let c_name = tf::TF_DeviceListName(list, i, status.inner);
196                    if !status.is_ok() {
197                        return Err(status);
198                    }
199                    let c_type = tf::TF_DeviceListType(list, i, status.inner);
200                    if !status.is_ok() {
201                        return Err(status);
202                    }
203                    let bytes = tf::TF_DeviceListMemoryBytes(list, i, status.inner);
204                    if !status.is_ok() {
205                        return Err(status);
206                    }
207                    let incarnation = tf::TF_DeviceListIncarnation(list, i, status.inner);
208                    if !status.is_ok() {
209                        return Err(status);
210                    }
211                    devices.push(Device {
212                        name: CStr::from_ptr(c_name).to_str()?.to_string(),
213                        device_type: CStr::from_ptr(c_type).to_str()?.to_string(),
214                        memory_bytes: bytes,
215                        incarnation,
216                    });
217                }
218                Ok(devices)
219            })();
220            tf::TF_DeleteDeviceList(list);
221            result
222        }
223    }
224}
225
226impl Drop for Session {
227    fn drop(&mut self) {
228        let mut status = Status::new();
229        unsafe {
230            tf::TF_DeleteSession(self.inner, status.inner());
231        }
232        // TODO: What do we do with the status?
233    }
234}
235
236unsafe impl Send for Session {}
237
238unsafe impl Sync for Session {}
239
240////////////////////////
241
242/// An opaque token for retrieving an output from a computation.
243#[derive(Copy, Clone, Debug)]
244pub struct FetchToken {
245    index: usize,
246}
247
248/// Deprecated alias for FetchToken.
249#[deprecated(note = "Use FetchToken instead.", since = "0.10.0")]
250pub type OutputToken = FetchToken;
251
252/// Manages the inputs and outputs for a single execution of a graph.
253///
254/// Typical usage involves creating an instance of this struct,
255/// adding some inputs to it, requesting some outputs, passing it to `Session::run`
256/// and then taking the outputs out of it.
257///
258/// Example:
259///
260/// ```rust,ignore
261/// let mut args = SessionRunArgs::new();
262/// args.add_feed(&op1, 0, &tensor1);
263/// args.add_feed(&op2, 0, &tensor2);
264/// let result_token = args.request_fetch(&op3, 0);
265/// session.run(&mut args)?;
266/// let result_tensor = args.fetch(result_token)?;
267/// ```
268///
269/// See examples/addition.rs for a more concrete example.
270#[derive(Debug)]
271pub struct SessionRunArgs<'l> {
272    input_ports: Vec<tf::TF_Output>,
273    input_tensors: Vec<&'l dyn AnyTensor>,
274
275    output_ports: Vec<tf::TF_Output>,
276    output_tensors: Vec<*mut tf::TF_Tensor>,
277
278    target_operations: Vec<*const tf::TF_Operation>,
279
280    run_options: Option<Buffer<u8>>,
281    run_metadata: Option<Vec<u8>>,
282    request_metadata: bool,
283
284    phantom: marker::PhantomData<&'l ()>,
285}
286
287unsafe impl<'l> Send for SessionRunArgs<'l> {}
288unsafe impl<'l> Sync for SessionRunArgs<'l> {}
289
290impl<'l> Default for SessionRunArgs<'l> {
291    fn default() -> Self {
292        Self::new()
293    }
294}
295
296impl<'l> SessionRunArgs<'l> {
297    /// Creates a SessionRunArgs.
298    pub fn new() -> Self {
299        SessionRunArgs {
300            input_ports: vec![],
301            input_tensors: vec![],
302
303            output_ports: vec![],
304            output_tensors: vec![],
305
306            run_options: None,
307            run_metadata: None,
308            request_metadata: false,
309
310            target_operations: vec![],
311
312            phantom: marker::PhantomData,
313        }
314    }
315
316    /// Adds an input to be fed to the graph. The index selects which output of
317    /// the operation to feed. For most operations, there is only one output,
318    /// so the index should be 0.
319    pub fn add_feed<T: TensorType>(
320        &mut self,
321        operation: &Operation,
322        index: c_int,
323        tensor: &'l Tensor<T>,
324    ) {
325        self.input_ports.push(tf::TF_Output {
326            oper: operation.inner(),
327            index,
328        });
329        self.input_tensors.push(tensor);
330    }
331
332    /// Deprecated alias for add_feed.
333    #[deprecated(note = "Use add_feed instead.", since = "0.10.0")]
334    pub fn add_input<T: TensorType>(
335        &mut self,
336        operation: &Operation,
337        index: c_int,
338        tensor: &'l Tensor<T>,
339    ) {
340        self.add_feed(operation, index, tensor)
341    }
342
343    /// Requests that an output is fetched from the graph after running this
344    /// step. The index selects which output of the operation to return. For
345    /// most operations, there is only one output, so the index should be 0.
346    /// Returns a token that you can then use to fetch this output from the args
347    /// after running it.
348    pub fn request_fetch(&mut self, operation: &Operation, index: c_int) -> FetchToken {
349        self.output_ports.push(tf::TF_Output {
350            oper: operation.inner(),
351            index,
352        });
353        self.output_tensors.push(ptr::null_mut());
354        FetchToken {
355            index: self.output_tensors.len() - 1,
356        }
357    }
358
359    /// Deprecated alias for request_fetch.
360    #[deprecated(note = "Use request_fetch instead.", since = "0.10.0")]
361    #[allow(deprecated)]
362    pub fn request_output(&mut self, operation: &Operation, index: c_int) -> OutputToken {
363        self.request_fetch(operation, index)
364    }
365
366    /// Extracts a tensor output given a token. A given token can only be
367    /// extracted once per `Session::run`. Returns an error if the token is
368    /// invalid, output is unavailable or the requested type does not match the
369    /// type of the actual tensor.
370    pub fn fetch<T: TensorType>(&mut self, token: FetchToken) -> Result<Tensor<T>> {
371        let output_idx = token.index;
372        if output_idx >= self.output_tensors.len() {
373            return Err(Status::new_set(
374                Code::OutOfRange,
375                &format!(
376                    "Requested output index is out of range: {} vs \
377                     {}",
378                    output_idx,
379                    self.output_tensors.len()
380                ),
381            )
382            .unwrap());
383        }
384        if self.output_tensors[output_idx].is_null() {
385            return Err(Status::new_set(
386                Code::Unavailable,
387                "Output not available. Either it was already taken, or \
388                 this step has not been sucessfully run yet.",
389            )
390            .unwrap());
391        }
392        let actual_data_type = self.output_data_type(output_idx).unwrap();
393        if actual_data_type != T::data_type() {
394            return Err(invalid_arg!(
395                "Requested tensor type does not match actual tensor type: \
396                 {} vs {}",
397                actual_data_type,
398                T::data_type()
399            ));
400        }
401        let tensor = unsafe { Tensor::from_tf_tensor(self.output_tensors[output_idx]).unwrap() };
402        self.output_tensors[output_idx] = ptr::null_mut();
403        Ok(tensor)
404    }
405
406    /// Deprecated alias for fetch.
407    #[deprecated(note = "Use fetch instead.", since = "0.10.0")]
408    #[allow(deprecated)]
409    pub fn take_output<T: TensorType>(&mut self, token: OutputToken) -> Result<Tensor<T>> {
410        self.fetch(token)
411    }
412
413    /// Adds a target operation to be executed when running the graph.
414    pub fn add_target(&mut self, operation: &Operation) {
415        self.target_operations.push(operation.inner());
416    }
417
418    /// Retuns the type of the tensor given an index.
419    /// Returns `None` if the index is out of range or the output is not yet available.
420    pub fn output_data_type(&self, output_idx: usize) -> Option<DataType> {
421        if output_idx >= self.output_tensors.len() {
422            return None;
423        }
424        if self.output_tensors[output_idx].is_null() {
425            return None;
426        }
427        unsafe {
428            Some(DataType::from_c(tf::TF_TensorType(
429                self.output_tensors[output_idx],
430            )))
431        }
432    }
433
434    /// Sets the `RunOptions`. `run_options` is a serialized [`RunOptions` proto](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/config.proto).
435    pub fn set_run_options(&mut self, run_options: &[u8]) {
436        self.run_options = Some(Buffer::from(run_options))
437    }
438
439    /// Returns the serialized [`RunOptions` proto](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/config.proto)
440    /// Returns none if `RunOption` are not set.
441    pub fn get_run_options(&self) -> Option<&[u8]> {
442        self.run_options.as_ref().map(std::convert::AsRef::as_ref)
443    }
444
445    /// Returns the serialized [`RunMetadata` proto](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/config.proto)
446    /// Returns none if `self::set_request_metadata` is not set to true.
447    pub fn get_metadata(&mut self) -> Option<&[u8]> {
448        self.run_metadata.as_ref().map(std::convert::AsRef::as_ref)
449    }
450
451    /// Requests `run_metadata`. The serialized [`RunMetadata` proto](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/config.proto)
452    /// can be retrieved via `self::get_metadata` after calling `Session::run`.
453    pub fn set_request_metadata(&mut self, request: bool) {
454        self.request_metadata = request;
455    }
456
457    /// Returns whether `RunMetadata` should be stored.
458    pub fn is_request_metadata(&self) -> bool {
459        self.request_metadata
460    }
461
462    fn drop_output_tensors(&mut self) {
463        for tensor in &mut self.output_tensors {
464            // TODO: Is TF_DeleteTensor NULL safe?
465            if !tensor.is_null() {
466                unsafe {
467                    tf::TF_DeleteTensor(*tensor);
468                }
469            }
470            *tensor = ptr::null_mut();
471        }
472    }
473
474    fn maybe_reset_run_metadata(&mut self) {
475        self.run_metadata = None;
476    }
477}
478
479impl<'l> Drop for SessionRunArgs<'l> {
480    fn drop(&mut self) {
481        self.drop_output_tensors();
482    }
483}
484
485/// Deprecated alias for SessionRunArgs.
486#[deprecated(note = "Use SessionRunArgs instead.", since = "0.10.0")]
487pub type StepWithGraph<'l> = SessionRunArgs<'l>;
488
489////////////////////////
490
491/// Metadata about a device.
492#[derive(Debug, Eq, PartialEq, Clone, Hash)]
493pub struct Device {
494    /// Full name of the device (e.g. /job:worker/replica:0/...)
495    pub name: String,
496
497    /// Type of device.
498    pub device_type: String,
499
500    /// Amount of memory on the device.
501    pub memory_bytes: i64,
502
503    /// Incarnation number of the device.
504    pub incarnation: u64,
505}
506
507////////////////////////
508
509#[cfg(test)]
510mod tests {
511    use super::super::DataType;
512    use super::super::Graph;
513    use super::super::Operation;
514    use super::super::SessionOptions;
515    use super::super::Shape;
516    use super::super::Tensor;
517    use super::*;
518    use serial_test::serial;
519
520    fn create_session() -> (Session, Operation, Operation) {
521        let mut g = Graph::new();
522        let two = {
523            let mut nd = g.new_operation("Const", "two").unwrap();
524            nd.set_attr_type("dtype", DataType::Float).unwrap();
525            let mut value = Tensor::new(&[1]);
526            value[0] = 2.0f32;
527            nd.set_attr_tensor("value", value).unwrap();
528            nd.finish().unwrap()
529        };
530        let x = {
531            let mut nd = g.new_operation("Placeholder", "x").unwrap();
532            nd.set_attr_type("dtype", DataType::Float).unwrap();
533            nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
534            nd.finish().unwrap()
535        };
536        let y = {
537            let mut nd = g.new_operation("Mul", "y").unwrap();
538            nd.add_input(two);
539            nd.add_input(x.clone());
540            nd.finish().unwrap()
541        };
542        let options = SessionOptions::new();
543        match Session::new(&options, &g) {
544            Ok(session) => (session, x, y),
545            Err(status) => panic!("Creating session failed with status: {}", status),
546        }
547    }
548
549    #[test]
550    fn smoke() {
551        create_session();
552    }
553
554    #[test]
555    fn test_close() {
556        let (mut session, _, _) = create_session();
557        let status = session.close();
558        assert!(status.is_ok());
559    }
560
561    #[test]
562    fn test_run() {
563        let (session, x_operation, y_operation) = create_session();
564        let mut x = <Tensor<f32>>::new(&[2]);
565        x[0] = 2.0;
566        x[1] = 3.0;
567        let mut step = SessionRunArgs::new();
568        step.add_feed(&x_operation, 0, &x);
569        let output_token = step.request_fetch(&y_operation, 0);
570        session.run(&mut step).unwrap();
571        let output_tensor = step.fetch::<f32>(output_token).unwrap();
572        assert_eq!(output_tensor.len(), 2);
573        assert_eq!(output_tensor[0], 4.0);
574        assert_eq!(output_tensor[1], 6.0);
575    }
576
577    #[test]
578    #[serial] // Full trace enable profile session
579    fn test_run_metadata() {
580        let (session, x_operation, y_operation) = create_session();
581        let x = Tensor::<f32>::from(&[2.0, 3.0][..]);
582        let mut step = SessionRunArgs::new();
583        step.add_feed(&x_operation, 0, &x);
584        // hard coded RunOptions proto with full tracelevel
585        step.set_run_options(&[8u8, 3u8]);
586        step.set_request_metadata(true);
587        step.set_request_metadata(true);
588        let output_token = step.request_fetch(&y_operation, 0);
589        session.run(&mut step).unwrap();
590        step.get_metadata().unwrap();
591        let output_tensor = step.fetch::<f32>(output_token).unwrap();
592
593        assert_eq!(output_tensor.len(), 2);
594        assert_eq!(output_tensor[0], 4.0);
595        assert_eq!(output_tensor[1], 6.0);
596
597        // ensure multiple calls with the same SessionRunArgs work
598        session.run(&mut step).unwrap();
599        step.get_metadata().unwrap();
600        let output_tensor = step.fetch::<f32>(output_token).unwrap();
601        assert_eq!(output_tensor.len(), 2);
602        assert_eq!(output_tensor[0], 4.0);
603        assert_eq!(output_tensor[1], 6.0);
604    }
605
606    #[test]
607    #[serial] // Full_trace enable profile session
608    fn test_run_options() {
609        let (session, x_operation, y_operation) = create_session();
610        let x = Tensor::<f32>::from(&[2.0, 3.0][..]);
611        let mut step = SessionRunArgs::new();
612        step.add_feed(&x_operation, 0, &x);
613        // hard coded RunOptions proto with full tracelevel
614        step.set_run_options(&[8u8, 3u8]);
615        let output_token = step.request_fetch(&y_operation, 0);
616        session.run(&mut step).unwrap();
617        let output_tensor = step.fetch::<f32>(output_token).unwrap();
618        assert_eq!(output_tensor.len(), 2);
619        assert_eq!(output_tensor[0], 4.0);
620        assert_eq!(output_tensor[1], 6.0);
621    }
622
623    #[test]
624    fn test_run_metadata_no_run_options() {
625        let (session, x_operation, y_operation) = create_session();
626        let x = Tensor::<f32>::from(&[2.0, 3.0][..]);
627        let mut step = SessionRunArgs::new();
628        step.add_feed(&x_operation, 0, &x);
629        step.set_request_metadata(true);
630        let output_token = step.request_fetch(&y_operation, 0);
631        session.run(&mut step).unwrap();
632        step.get_metadata().unwrap();
633        let output_tensor = step.fetch::<f32>(output_token).unwrap();
634        assert_eq!(output_tensor.len(), 2);
635        assert_eq!(output_tensor[0], 4.0);
636        assert_eq!(output_tensor[1], 6.0);
637    }
638
639    #[test]
640    fn test_savedmodelbundle() {
641        let mut graph = Graph::new();
642        let bundle = SavedModelBundle::load(
643            &SessionOptions::new(),
644            &["train", "serve"],
645            &mut graph,
646            "test_resources/regression-model",
647        )
648        .unwrap();
649
650        let x_op = graph.operation_by_name_required("x").unwrap();
651        let y_op = graph.operation_by_name_required("y").unwrap();
652        let y_hat_op = graph.operation_by_name_required("y_hat").unwrap();
653        let _train_op = graph.operation_by_name_required("train").unwrap();
654
655        #[allow(deprecated)]
656        let SavedModelBundle {
657            session,
658            meta_graph_def,
659            meta_graph: _,
660        } = bundle;
661
662        assert!(!meta_graph_def.is_empty());
663
664        let mut x = <Tensor<f32>>::new(&[1]);
665        x[0] = 2.0;
666        let mut y = <Tensor<f32>>::new(&[1]);
667        y[0] = 4.0;
668        let mut step = SessionRunArgs::new();
669        step.add_feed(&x_op, 0, &x);
670        step.add_feed(&y_op, 0, &y);
671        let output_token = step.request_fetch(&y_hat_op, 0);
672        session.run(&mut step).unwrap();
673        let output_tensor = step.fetch::<f32>(output_token).unwrap();
674        assert_eq!(output_tensor.len(), 1);
675    }
676
677    #[test]
678    fn test_device_list() {
679        let (session, _, _) = create_session();
680        let devices = session.device_list().unwrap();
681        assert!(
682            devices.iter().any(|d| d.device_type == "CPU"),
683            "devices: {:?}",
684            devices
685        );
686    }
687}