1use super::AnyTensor;
2use super::Buffer;
3use super::Code;
4use super::DataType;
5use super::Graph;
6use super::MetaGraphDef;
7use super::Operation;
8use super::Result;
9use super::SessionOptions;
10use super::Status;
11use super::Tensor;
12use super::TensorType;
13use crate::tf;
14use libc::{c_char, c_int};
15use std::ffi::CStr;
16use std::ffi::CString;
17use std::marker;
18use std::path::Path;
19use std::ptr;
20
21#[derive(Debug)]
23pub struct SavedModelBundle {
24 pub session: Session,
26 #[deprecated(
29 note = "Please use SavedModelBundle::meta_graph_def() instead",
30 since = "0.16.0"
31 )]
32 pub meta_graph_def: Vec<u8>,
33 meta_graph: MetaGraphDef,
35}
36
37impl SavedModelBundle {
38 pub fn load<P: AsRef<Path>, Tag: AsRef<str>, Tags: IntoIterator<Item = Tag>>(
40 options: &SessionOptions,
41 tags: Tags,
42 graph: &mut Graph,
43 export_dir: P,
44 ) -> Result<SavedModelBundle> {
45 let mut status = Status::new();
46
47 let export_dir_cstr = export_dir
48 .as_ref()
49 .to_str()
50 .and_then(|s| CString::new(s.as_bytes()).ok())
51 .ok_or_else(|| invalid_arg!("Invalid export directory path"))?;
52
53 let tags_cstr: Vec<_> = tags
54 .into_iter()
55 .map(|t| CString::new(t.as_ref()))
56 .collect::<::std::result::Result<_, _>>()
57 .map_err(|_| invalid_arg!("Invalid tag name"))?;
58 let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect();
60
61 let mut meta = unsafe { Buffer::<u8>::from_ptr(ptr::null_mut(), 0) };
63
64 let inner = unsafe {
65 tf::TF_LoadSessionFromSavedModel(
66 options.inner,
67 ptr::null(),
68 export_dir_cstr.as_ptr(),
69 tags_ptr.as_ptr(),
70 tags_ptr.len() as c_int,
71 graph.inner(),
72 meta.inner_mut(),
73 status.inner(),
74 )
75 };
76 if inner.is_null() {
77 Err(status)
78 } else {
79 let session = Session { inner };
80 #[allow(deprecated)]
81 Ok(SavedModelBundle {
82 session,
83 meta_graph_def: Vec::from(meta.as_ref()),
84 meta_graph: MetaGraphDef::from_serialized_proto(meta.as_ref())?,
85 })
86 }
87 }
88
89 pub fn meta_graph_def(&self) -> &MetaGraphDef {
91 &self.meta_graph
92 }
93}
94
95#[derive(Debug)]
97pub struct Session {
98 inner: *mut tf::TF_Session,
99}
100
101impl Session {
102 pub fn new(options: &SessionOptions, graph: &Graph) -> Result<Self> {
106 let mut status = Status::new();
107 let inner = unsafe { tf::TF_NewSession(graph.inner(), options.inner, status.inner()) };
108 if inner.is_null() {
109 Err(status)
110 } else {
111 Ok(Session { inner })
112 }
113 }
114
115 #[deprecated(note = "Please use SavedModelBundle::load() instead", since = "0.17.0")]
117 pub fn from_saved_model<P: AsRef<Path>, Tag: AsRef<str>, Tags: IntoIterator<Item = Tag>>(
118 options: &SessionOptions,
119 tags: Tags,
120 graph: &mut Graph,
121 export_dir: P,
122 ) -> Result<Self> {
123 Ok(SavedModelBundle::load(options, tags, graph, export_dir)?.session)
124 }
125
126 pub fn close(&mut self) -> Result<()> {
128 let mut status = Status::new();
129 unsafe {
130 tf::TF_CloseSession(self.inner, status.inner());
131 }
132 status.into_result()
133 }
134
135 pub fn run(&self, step: &mut SessionRunArgs<'_>) -> Result<()> {
140 step.drop_output_tensors();
142 step.maybe_reset_run_metadata();
144
145 let mut status = Status::new();
146 let maybe_tensors: Result<_> = step.input_tensors.iter().map(|t| t.inner()).collect();
147 let input_tensors: Vec<_> = maybe_tensors?;
148 let run_options_ptr = match step.run_options.as_ref() {
149 Some(buf) => buf.inner(),
150 None => ptr::null(),
151 };
152
153 let mut run_metadata_buf = if step.request_metadata {
154 Some(unsafe { Buffer::new_unallocated() })
155 } else {
156 None
157 };
158 let run_metadata_ptr = match run_metadata_buf.as_mut() {
159 Some(meta) => meta.inner_mut(),
160 None => ptr::null_mut(),
161 };
162 unsafe {
163 tf::TF_SessionRun(
164 self.inner,
165 run_options_ptr,
166 step.input_ports.as_ptr(),
167 input_tensors.as_ptr() as *const *mut tf::TF_Tensor,
168 input_tensors.len() as c_int,
169 step.output_ports.as_ptr(),
170 step.output_tensors.as_mut_ptr(),
171 step.output_tensors.len() as c_int,
172 step.target_operations.as_mut_ptr(),
173 step.target_operations.len() as c_int,
174 run_metadata_ptr,
175 status.inner(),
176 );
177 step.run_metadata = run_metadata_buf.map(Into::into);
178 }
179
180 status.into_result()
181 }
182
183 pub fn device_list(&self) -> Result<Vec<Device>> {
185 let status = Status::new();
186 unsafe {
187 let list = tf::TF_SessionListDevices(self.inner, status.inner);
188 if !status.is_ok() {
189 return Err(status);
190 }
191 let result = (|| {
192 let n = tf::TF_DeviceListCount(list);
193 let mut devices = Vec::with_capacity(n as usize);
194 for i in 0..n {
195 let c_name = tf::TF_DeviceListName(list, i, status.inner);
196 if !status.is_ok() {
197 return Err(status);
198 }
199 let c_type = tf::TF_DeviceListType(list, i, status.inner);
200 if !status.is_ok() {
201 return Err(status);
202 }
203 let bytes = tf::TF_DeviceListMemoryBytes(list, i, status.inner);
204 if !status.is_ok() {
205 return Err(status);
206 }
207 let incarnation = tf::TF_DeviceListIncarnation(list, i, status.inner);
208 if !status.is_ok() {
209 return Err(status);
210 }
211 devices.push(Device {
212 name: CStr::from_ptr(c_name).to_str()?.to_string(),
213 device_type: CStr::from_ptr(c_type).to_str()?.to_string(),
214 memory_bytes: bytes,
215 incarnation,
216 });
217 }
218 Ok(devices)
219 })();
220 tf::TF_DeleteDeviceList(list);
221 result
222 }
223 }
224}
225
226impl Drop for Session {
227 fn drop(&mut self) {
228 let mut status = Status::new();
229 unsafe {
230 tf::TF_DeleteSession(self.inner, status.inner());
231 }
232 }
234}
235
236unsafe impl Send for Session {}
237
238unsafe impl Sync for Session {}
239
240#[derive(Copy, Clone, Debug)]
244pub struct FetchToken {
245 index: usize,
246}
247
248#[deprecated(note = "Use FetchToken instead.", since = "0.10.0")]
250pub type OutputToken = FetchToken;
251
252#[derive(Debug)]
271pub struct SessionRunArgs<'l> {
272 input_ports: Vec<tf::TF_Output>,
273 input_tensors: Vec<&'l dyn AnyTensor>,
274
275 output_ports: Vec<tf::TF_Output>,
276 output_tensors: Vec<*mut tf::TF_Tensor>,
277
278 target_operations: Vec<*const tf::TF_Operation>,
279
280 run_options: Option<Buffer<u8>>,
281 run_metadata: Option<Vec<u8>>,
282 request_metadata: bool,
283
284 phantom: marker::PhantomData<&'l ()>,
285}
286
287unsafe impl<'l> Send for SessionRunArgs<'l> {}
288unsafe impl<'l> Sync for SessionRunArgs<'l> {}
289
290impl<'l> Default for SessionRunArgs<'l> {
291 fn default() -> Self {
292 Self::new()
293 }
294}
295
296impl<'l> SessionRunArgs<'l> {
297 pub fn new() -> Self {
299 SessionRunArgs {
300 input_ports: vec![],
301 input_tensors: vec![],
302
303 output_ports: vec![],
304 output_tensors: vec![],
305
306 run_options: None,
307 run_metadata: None,
308 request_metadata: false,
309
310 target_operations: vec![],
311
312 phantom: marker::PhantomData,
313 }
314 }
315
316 pub fn add_feed<T: TensorType>(
320 &mut self,
321 operation: &Operation,
322 index: c_int,
323 tensor: &'l Tensor<T>,
324 ) {
325 self.input_ports.push(tf::TF_Output {
326 oper: operation.inner(),
327 index,
328 });
329 self.input_tensors.push(tensor);
330 }
331
332 #[deprecated(note = "Use add_feed instead.", since = "0.10.0")]
334 pub fn add_input<T: TensorType>(
335 &mut self,
336 operation: &Operation,
337 index: c_int,
338 tensor: &'l Tensor<T>,
339 ) {
340 self.add_feed(operation, index, tensor)
341 }
342
343 pub fn request_fetch(&mut self, operation: &Operation, index: c_int) -> FetchToken {
349 self.output_ports.push(tf::TF_Output {
350 oper: operation.inner(),
351 index,
352 });
353 self.output_tensors.push(ptr::null_mut());
354 FetchToken {
355 index: self.output_tensors.len() - 1,
356 }
357 }
358
359 #[deprecated(note = "Use request_fetch instead.", since = "0.10.0")]
361 #[allow(deprecated)]
362 pub fn request_output(&mut self, operation: &Operation, index: c_int) -> OutputToken {
363 self.request_fetch(operation, index)
364 }
365
366 pub fn fetch<T: TensorType>(&mut self, token: FetchToken) -> Result<Tensor<T>> {
371 let output_idx = token.index;
372 if output_idx >= self.output_tensors.len() {
373 return Err(Status::new_set(
374 Code::OutOfRange,
375 &format!(
376 "Requested output index is out of range: {} vs \
377 {}",
378 output_idx,
379 self.output_tensors.len()
380 ),
381 )
382 .unwrap());
383 }
384 if self.output_tensors[output_idx].is_null() {
385 return Err(Status::new_set(
386 Code::Unavailable,
387 "Output not available. Either it was already taken, or \
388 this step has not been sucessfully run yet.",
389 )
390 .unwrap());
391 }
392 let actual_data_type = self.output_data_type(output_idx).unwrap();
393 if actual_data_type != T::data_type() {
394 return Err(invalid_arg!(
395 "Requested tensor type does not match actual tensor type: \
396 {} vs {}",
397 actual_data_type,
398 T::data_type()
399 ));
400 }
401 let tensor = unsafe { Tensor::from_tf_tensor(self.output_tensors[output_idx]).unwrap() };
402 self.output_tensors[output_idx] = ptr::null_mut();
403 Ok(tensor)
404 }
405
406 #[deprecated(note = "Use fetch instead.", since = "0.10.0")]
408 #[allow(deprecated)]
409 pub fn take_output<T: TensorType>(&mut self, token: OutputToken) -> Result<Tensor<T>> {
410 self.fetch(token)
411 }
412
413 pub fn add_target(&mut self, operation: &Operation) {
415 self.target_operations.push(operation.inner());
416 }
417
418 pub fn output_data_type(&self, output_idx: usize) -> Option<DataType> {
421 if output_idx >= self.output_tensors.len() {
422 return None;
423 }
424 if self.output_tensors[output_idx].is_null() {
425 return None;
426 }
427 unsafe {
428 Some(DataType::from_c(tf::TF_TensorType(
429 self.output_tensors[output_idx],
430 )))
431 }
432 }
433
434 pub fn set_run_options(&mut self, run_options: &[u8]) {
436 self.run_options = Some(Buffer::from(run_options))
437 }
438
439 pub fn get_run_options(&self) -> Option<&[u8]> {
442 self.run_options.as_ref().map(std::convert::AsRef::as_ref)
443 }
444
445 pub fn get_metadata(&mut self) -> Option<&[u8]> {
448 self.run_metadata.as_ref().map(std::convert::AsRef::as_ref)
449 }
450
451 pub fn set_request_metadata(&mut self, request: bool) {
454 self.request_metadata = request;
455 }
456
457 pub fn is_request_metadata(&self) -> bool {
459 self.request_metadata
460 }
461
462 fn drop_output_tensors(&mut self) {
463 for tensor in &mut self.output_tensors {
464 if !tensor.is_null() {
466 unsafe {
467 tf::TF_DeleteTensor(*tensor);
468 }
469 }
470 *tensor = ptr::null_mut();
471 }
472 }
473
474 fn maybe_reset_run_metadata(&mut self) {
475 self.run_metadata = None;
476 }
477}
478
479impl<'l> Drop for SessionRunArgs<'l> {
480 fn drop(&mut self) {
481 self.drop_output_tensors();
482 }
483}
484
485#[deprecated(note = "Use SessionRunArgs instead.", since = "0.10.0")]
487pub type StepWithGraph<'l> = SessionRunArgs<'l>;
488
489#[derive(Debug, Eq, PartialEq, Clone, Hash)]
493pub struct Device {
494 pub name: String,
496
497 pub device_type: String,
499
500 pub memory_bytes: i64,
502
503 pub incarnation: u64,
505}
506
507#[cfg(test)]
510mod tests {
511 use super::super::DataType;
512 use super::super::Graph;
513 use super::super::Operation;
514 use super::super::SessionOptions;
515 use super::super::Shape;
516 use super::super::Tensor;
517 use super::*;
518 use serial_test::serial;
519
520 fn create_session() -> (Session, Operation, Operation) {
521 let mut g = Graph::new();
522 let two = {
523 let mut nd = g.new_operation("Const", "two").unwrap();
524 nd.set_attr_type("dtype", DataType::Float).unwrap();
525 let mut value = Tensor::new(&[1]);
526 value[0] = 2.0f32;
527 nd.set_attr_tensor("value", value).unwrap();
528 nd.finish().unwrap()
529 };
530 let x = {
531 let mut nd = g.new_operation("Placeholder", "x").unwrap();
532 nd.set_attr_type("dtype", DataType::Float).unwrap();
533 nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
534 nd.finish().unwrap()
535 };
536 let y = {
537 let mut nd = g.new_operation("Mul", "y").unwrap();
538 nd.add_input(two);
539 nd.add_input(x.clone());
540 nd.finish().unwrap()
541 };
542 let options = SessionOptions::new();
543 match Session::new(&options, &g) {
544 Ok(session) => (session, x, y),
545 Err(status) => panic!("Creating session failed with status: {}", status),
546 }
547 }
548
549 #[test]
550 fn smoke() {
551 create_session();
552 }
553
554 #[test]
555 fn test_close() {
556 let (mut session, _, _) = create_session();
557 let status = session.close();
558 assert!(status.is_ok());
559 }
560
561 #[test]
562 fn test_run() {
563 let (session, x_operation, y_operation) = create_session();
564 let mut x = <Tensor<f32>>::new(&[2]);
565 x[0] = 2.0;
566 x[1] = 3.0;
567 let mut step = SessionRunArgs::new();
568 step.add_feed(&x_operation, 0, &x);
569 let output_token = step.request_fetch(&y_operation, 0);
570 session.run(&mut step).unwrap();
571 let output_tensor = step.fetch::<f32>(output_token).unwrap();
572 assert_eq!(output_tensor.len(), 2);
573 assert_eq!(output_tensor[0], 4.0);
574 assert_eq!(output_tensor[1], 6.0);
575 }
576
577 #[test]
578 #[serial] fn test_run_metadata() {
580 let (session, x_operation, y_operation) = create_session();
581 let x = Tensor::<f32>::from(&[2.0, 3.0][..]);
582 let mut step = SessionRunArgs::new();
583 step.add_feed(&x_operation, 0, &x);
584 step.set_run_options(&[8u8, 3u8]);
586 step.set_request_metadata(true);
587 step.set_request_metadata(true);
588 let output_token = step.request_fetch(&y_operation, 0);
589 session.run(&mut step).unwrap();
590 step.get_metadata().unwrap();
591 let output_tensor = step.fetch::<f32>(output_token).unwrap();
592
593 assert_eq!(output_tensor.len(), 2);
594 assert_eq!(output_tensor[0], 4.0);
595 assert_eq!(output_tensor[1], 6.0);
596
597 session.run(&mut step).unwrap();
599 step.get_metadata().unwrap();
600 let output_tensor = step.fetch::<f32>(output_token).unwrap();
601 assert_eq!(output_tensor.len(), 2);
602 assert_eq!(output_tensor[0], 4.0);
603 assert_eq!(output_tensor[1], 6.0);
604 }
605
606 #[test]
607 #[serial] fn test_run_options() {
609 let (session, x_operation, y_operation) = create_session();
610 let x = Tensor::<f32>::from(&[2.0, 3.0][..]);
611 let mut step = SessionRunArgs::new();
612 step.add_feed(&x_operation, 0, &x);
613 step.set_run_options(&[8u8, 3u8]);
615 let output_token = step.request_fetch(&y_operation, 0);
616 session.run(&mut step).unwrap();
617 let output_tensor = step.fetch::<f32>(output_token).unwrap();
618 assert_eq!(output_tensor.len(), 2);
619 assert_eq!(output_tensor[0], 4.0);
620 assert_eq!(output_tensor[1], 6.0);
621 }
622
623 #[test]
624 fn test_run_metadata_no_run_options() {
625 let (session, x_operation, y_operation) = create_session();
626 let x = Tensor::<f32>::from(&[2.0, 3.0][..]);
627 let mut step = SessionRunArgs::new();
628 step.add_feed(&x_operation, 0, &x);
629 step.set_request_metadata(true);
630 let output_token = step.request_fetch(&y_operation, 0);
631 session.run(&mut step).unwrap();
632 step.get_metadata().unwrap();
633 let output_tensor = step.fetch::<f32>(output_token).unwrap();
634 assert_eq!(output_tensor.len(), 2);
635 assert_eq!(output_tensor[0], 4.0);
636 assert_eq!(output_tensor[1], 6.0);
637 }
638
639 #[test]
640 fn test_savedmodelbundle() {
641 let mut graph = Graph::new();
642 let bundle = SavedModelBundle::load(
643 &SessionOptions::new(),
644 &["train", "serve"],
645 &mut graph,
646 "test_resources/regression-model",
647 )
648 .unwrap();
649
650 let x_op = graph.operation_by_name_required("x").unwrap();
651 let y_op = graph.operation_by_name_required("y").unwrap();
652 let y_hat_op = graph.operation_by_name_required("y_hat").unwrap();
653 let _train_op = graph.operation_by_name_required("train").unwrap();
654
655 #[allow(deprecated)]
656 let SavedModelBundle {
657 session,
658 meta_graph_def,
659 meta_graph: _,
660 } = bundle;
661
662 assert!(!meta_graph_def.is_empty());
663
664 let mut x = <Tensor<f32>>::new(&[1]);
665 x[0] = 2.0;
666 let mut y = <Tensor<f32>>::new(&[1]);
667 y[0] = 4.0;
668 let mut step = SessionRunArgs::new();
669 step.add_feed(&x_op, 0, &x);
670 step.add_feed(&y_op, 0, &y);
671 let output_token = step.request_fetch(&y_hat_op, 0);
672 session.run(&mut step).unwrap();
673 let output_tensor = step.fetch::<f32>(output_token).unwrap();
674 assert_eq!(output_tensor.len(), 1);
675 }
676
677 #[test]
678 fn test_device_list() {
679 let (session, _, _) = create_session();
680 let devices = session.device_list().unwrap();
681 assert!(
682 devices.iter().any(|d| d.device_type == "CPU"),
683 "devices: {:?}",
684 devices
685 );
686 }
687}