1use super::buffer::Buffer;
2use super::AnyTensor;
3use super::Code;
4use super::DataType;
5use super::Result;
6use super::Shape;
7use super::Status;
8use super::Tensor;
9use super::TensorType;
10use libc::c_char;
11use libc::c_float;
12use libc::c_int;
13use libc::c_uchar;
14use libc::c_uint;
15use libc::c_void;
16use libc::size_t;
17use std::cmp;
18use std::ffi::CStr;
19use std::ffi::CString;
20use std::ffi::NulError;
21use std::fmt;
22use std::fmt::Display;
23use std::fmt::Formatter;
24use std::mem::MaybeUninit;
25use std::os::raw::c_void as std_c_void;
26use std::ptr;
27use std::slice;
28use std::str::FromStr;
29use std::str::Utf8Error;
30use std::sync::Arc;
31#[cfg(feature = "default")]
32use tensorflow_sys as tf;
33#[cfg(feature = "tensorflow_runtime_linking")]
34use tensorflow_sys_runtime as tf;
35
36#[derive(Debug)]
37struct GraphImpl {
38 inner: *mut tf::TF_Graph,
39 owned: bool,
40}
41
42unsafe impl Send for GraphImpl {}
43unsafe impl Sync for GraphImpl {}
44
45impl Drop for GraphImpl {
46 fn drop(&mut self) {
48 if self.owned {
49 unsafe {
50 tf::TF_DeleteGraph(self.inner);
51 }
52 }
53 }
54}
55
56#[derive(Debug)]
61pub struct ImportGraphDefOptions {
62 inner: *mut tf::TF_ImportGraphDefOptions,
63}
64
65impl_new!(
66 ImportGraphDefOptions,
67 TF_NewImportGraphDefOptions,
68 "Creates a default ImportGraphDefOptions."
69);
70impl_drop!(ImportGraphDefOptions, TF_DeleteImportGraphDefOptions);
71
72impl ImportGraphDefOptions {
73 pub fn set_prefix(&mut self, prefix: &str) -> std::result::Result<(), NulError> {
76 let s = CString::new(prefix)?;
77 unsafe {
78 tf::TF_ImportGraphDefOptionsSetPrefix(self.inner, s.as_ptr());
79 }
80 Ok(())
81 }
82
83 pub fn add_input_mapping(
87 &mut self,
88 src_name: &str,
89 src_index: usize,
90 dst: &Output,
91 ) -> std::result::Result<(), NulError> {
92 let s = CString::new(src_name)?;
93 unsafe {
94 tf::TF_ImportGraphDefOptionsAddInputMapping(
95 self.inner,
96 s.as_ptr(),
97 src_index as c_int,
98 dst.to_c(),
99 );
100 }
101 Ok(())
102 }
103
104 pub fn remap_control_dependency(
109 &mut self,
110 src_name: &str,
111 dst: &Operation,
112 ) -> std::result::Result<(), NulError> {
113 let s = CString::new(src_name)?;
114 unsafe {
115 tf::TF_ImportGraphDefOptionsRemapControlDependency(self.inner, s.as_ptr(), dst.inner);
116 }
117 Ok(())
118 }
119
120 pub fn add_control_dependency(&mut self, oper: &Operation) {
123 unsafe {
124 tf::TF_ImportGraphDefOptionsAddControlDependency(self.inner, oper.inner);
125 }
126 }
127
128 pub fn add_return_output(
132 &mut self,
133 oper_name: &str,
134 index: usize,
135 ) -> std::result::Result<(), NulError> {
136 let s = CString::new(oper_name)?;
137 unsafe {
138 tf::TF_ImportGraphDefOptionsAddReturnOutput(self.inner, s.as_ptr(), index as c_int);
139 }
140 Ok(())
141 }
142
143 pub fn add_return_operation(&mut self, oper_name: &str) -> std::result::Result<(), NulError> {
146 let s = CString::new(oper_name)?;
147 unsafe {
148 tf::TF_ImportGraphDefOptionsAddReturnOperation(self.inner, s.as_ptr());
149 }
150 Ok(())
151 }
152
153 pub fn num_return_outputs(&self) -> usize {
155 unsafe { tf::TF_ImportGraphDefOptionsNumReturnOutputs(self.inner) as usize }
156 }
157
158 pub fn num_return_operations(&self) -> usize {
160 unsafe { tf::TF_ImportGraphDefOptionsNumReturnOperations(self.inner) as usize }
161 }
162
163 pub fn set_uniquify_names(&mut self, uniquify_names: bool) {
169 unsafe {
170 tf::TF_ImportGraphDefOptionsSetUniquifyNames(self.inner, u8::from(uniquify_names));
171 }
172 }
173
174 pub fn set_uniquify_prefix(&mut self, uniquify_prefix: bool) {
178 unsafe {
179 tf::TF_ImportGraphDefOptionsSetUniquifyPrefix(self.inner, u8::from(uniquify_prefix));
180 }
181 }
182
183 pub fn set_default_device(&mut self, device: &str) -> std::result::Result<(), NulError> {
186 let s = CString::new(device)?;
187 unsafe {
188 tf::TF_ImportGraphDefOptionsSetDefaultDevice(self.inner, s.as_ptr());
189 }
190 Ok(())
191 }
192}
193
194#[derive(Debug)]
199pub struct ImportGraphDefResults {
200 inner: *mut tf::TF_ImportGraphDefResults,
201 gimpl: Arc<GraphImpl>,
202}
203
204impl ImportGraphDefResults {
205 pub fn return_outputs(&self) -> Vec<Output> {
207 unsafe {
208 let mut num_outputs: c_int = 0;
209 let mut c_outputs: *mut tf::TF_Output = ptr::null_mut();
210 tf::TF_ImportGraphDefResultsReturnOutputs(self.inner, &mut num_outputs, &mut c_outputs);
211 slice::from_raw_parts(c_outputs, num_outputs as usize)
212 .iter()
213 .map(|output| Output {
214 operation: Operation {
215 inner: output.oper,
216 gimpl: self.gimpl.clone(),
217 },
218 index: output.index,
219 })
220 .collect()
221 }
222 }
223
224 pub fn return_operations(&self) -> Vec<Operation> {
226 unsafe {
227 let mut num_operations: c_int = 0;
228 let mut c_operations: *mut *mut tf::TF_Operation = ptr::null_mut();
229 tf::TF_ImportGraphDefResultsReturnOperations(
230 self.inner,
231 &mut num_operations,
232 &mut c_operations,
233 );
234 slice::from_raw_parts(c_operations, num_operations as usize)
235 .iter()
236 .map(|operation| Operation {
237 inner: *operation,
238 gimpl: self.gimpl.clone(),
239 })
240 .collect()
241 }
242 }
243
244 pub fn missing_unused_input_mappings(
248 &self,
249 ) -> std::result::Result<Vec<(&str, c_int)>, Utf8Error> {
250 unsafe {
251 let mut n: c_int = 0;
252 let mut c_src_names: *mut *const c_char = ptr::null_mut();
253 let mut src_indexes: *mut c_int = ptr::null_mut();
254 tf::TF_ImportGraphDefResultsMissingUnusedInputMappings(
255 self.inner,
256 &mut n,
257 &mut c_src_names,
258 &mut src_indexes,
259 );
260 let c_name_slice = slice::from_raw_parts(c_src_names, n as usize);
261 let index_slice = slice::from_raw_parts(src_indexes, n as usize);
262 let mut v = Vec::new();
263 for i in 0..n as usize {
264 let s = CStr::from_ptr(c_name_slice[i]).to_str()?;
265 v.push((s, index_slice[i]));
266 }
267 Ok(v)
268 }
269 }
270}
271
272impl_drop!(ImportGraphDefResults, TF_DeleteImportGraphDefResults);
273
274#[derive(Debug)]
279pub struct Graph {
280 gimpl: Arc<GraphImpl>,
281}
282
283impl Default for Graph {
284 fn default() -> Self {
285 Self::new()
286 }
287}
288
289impl Graph {
290 pub fn new() -> Graph {
292 unsafe {
293 Graph {
294 gimpl: Arc::new(GraphImpl {
295 inner: tf::TF_NewGraph(),
296 owned: true,
297 }),
298 }
299 }
300 }
301
302 pub fn new_operation(
306 &mut self,
307 op_type: &str,
308 operation_name: &str,
309 ) -> std::result::Result<OperationDescription<'_>, NulError> {
310 let c_op_type = CString::new(op_type)?;
311 let c_operation_name = CString::new(operation_name)?;
312 unsafe {
313 Ok(OperationDescription {
314 inner: tf::TF_NewOperation(
315 self.gimpl.inner,
316 c_op_type.as_ptr(),
317 c_operation_name.as_ptr(),
318 ),
319 graph: self,
320 finished: false,
321 })
322 }
323 }
324
325 pub fn operation_by_name(
328 &self,
329 operation_name: &str,
330 ) -> std::result::Result<Option<Operation>, NulError> {
331 let c_operation_name = CString::new(operation_name)?;
332 unsafe {
333 let operation =
334 tf::TF_GraphOperationByName(self.gimpl.inner, c_operation_name.as_ptr());
335 if operation.is_null() {
336 Ok(None)
337 } else {
338 Ok(Some(Operation {
339 inner: operation,
340 gimpl: self.gimpl.clone(),
341 }))
342 }
343 }
344 }
345
346 pub fn operation_by_name_required(
348 &self,
349 operation_name: &str,
350 ) -> std::result::Result<Operation, Status> {
351 match self.operation_by_name(operation_name)? {
352 Some(operation) => Ok(operation),
353 None => Err(Status::new_set(
354 Code::Unavailable,
355 &format!("Operation {:?} not found", operation_name),
356 )
357 .unwrap()),
358 }
359 }
360
361 pub(crate) fn generate_operation_name(&self, operation_name_pattern: &str) -> Result<i64> {
367 let parts: Vec<_> = operation_name_pattern.split("{}").collect();
368 if parts.len() != 2 {
369 return Err(invalid_arg!(
370 "operation_name_pattern must contain placeholder"
371 ));
372 }
373 let mut i = 0;
375 loop {
376 let name = format!("{}{}{}", parts[0], i, parts[1]);
377 let c_name = CString::new(name)?;
378 unsafe {
379 if tf::TF_GraphOperationByName(self.gimpl.inner, c_name.as_ptr()).is_null() {
380 return Ok(i);
381 }
382 }
383 i += 1;
384 }
385 }
386
387 pub fn operation_iter(&self) -> OperationIter<'_> {
389 OperationIter {
390 graph: self,
391 pos: 0,
392 }
393 }
394
395 pub fn graph_def(&self) -> Result<Vec<u8>> {
397 let mut status = Status::new();
398 unsafe {
399 let c_buffer = tf::TF_NewBuffer();
400 tf::TF_GraphToGraphDef(self.gimpl.inner, c_buffer, status.inner());
401 if status.is_ok() {
402 Ok(Buffer::from_c(c_buffer, true).into())
403 } else {
404 tf::TF_DeleteBuffer(c_buffer);
405 Err(status)
406 }
407 }
408 }
409
410 pub fn num_dims<I: Into<Output>>(&self, output: I) -> Result<c_int> {
418 let mut status = Status::new();
419 unsafe {
420 let val = tf::TF_GraphGetTensorNumDims(
421 self.gimpl.inner,
422 output.into().to_c(),
423 status.inner(),
424 );
425 if status.is_ok() {
426 Ok(val)
427 } else {
428 Err(status)
429 }
430 }
431 }
432
433 pub fn tensor_shape<I: Into<Output>>(&self, output: I) -> Result<Shape> {
439 let mut status = Status::new();
440 let output = output.into();
441 let n = self.num_dims(output.clone())?;
442 if n == -1 {
443 return Ok(Shape(None));
444 }
445 let mut dims = Vec::with_capacity(n as usize);
446 unsafe {
447 tf::TF_GraphGetTensorShape(
448 self.gimpl.inner,
449 output.to_c(),
450 dims.as_mut_ptr(),
451 n,
452 status.inner(),
453 );
454 if status.is_ok() {
455 dims.set_len(n as usize);
456 Ok(Shape(Some(
457 dims.iter()
458 .map(|x| if *x < 0 { None } else { Some(*x) })
459 .collect(),
460 )))
461 } else {
462 Err(status)
463 }
464 }
465 }
466
467 pub fn import_graph_def(
469 &mut self,
470 graph_def: &[u8],
471 options: &ImportGraphDefOptions,
472 ) -> Result<()> {
473 let buf = Buffer::from(graph_def);
474 let mut status = Status::new();
475 unsafe {
476 tf::TF_GraphImportGraphDef(
477 self.gimpl.inner,
478 buf.inner(),
479 options.inner,
480 status.inner(),
481 );
482 status.into_result()
483 }
484 }
485
486 pub fn import_graph_def_with_results(
488 &mut self,
489 graph_def: &[u8],
490 options: &ImportGraphDefOptions,
491 ) -> Result<ImportGraphDefResults> {
492 let buf = Buffer::from(graph_def);
493 let mut status = Status::new();
494 unsafe {
495 let result = tf::TF_GraphImportGraphDefWithResults(
496 self.gimpl.inner,
497 buf.inner(),
498 options.inner,
499 status.inner(),
500 );
501 status.into_result().map(|()| ImportGraphDefResults {
502 inner: result,
503 gimpl: self.gimpl.clone(),
504 })
505 }
506 }
507
508 pub fn import_graph_def_with_return_outputs(
510 &mut self,
511 graph_def: &[u8],
512 options: &ImportGraphDefOptions,
513 ) -> Result<Vec<Output>> {
514 let buf = Buffer::from(graph_def);
515 let mut status = Status::new();
516 let n = options.num_return_outputs();
517 let mut c_return_outputs: Vec<MaybeUninit<tf::TF_Output>> = Vec::with_capacity(n);
518 unsafe {
519 c_return_outputs.set_len(n);
520 tf::TF_GraphImportGraphDefWithReturnOutputs(
521 self.gimpl.inner,
522 buf.inner(),
523 options.inner,
524 c_return_outputs.as_mut_ptr() as *mut tf::TF_Output,
525 n as c_int,
526 status.inner(),
527 );
528 status.into_result()?;
529 Ok(c_return_outputs
530 .iter()
531 .map(|x| Output::from_c(self, &x.assume_init()))
532 .collect())
533 }
534 }
535
536 pub fn copy_function(&mut self, func: &Function, grad: Option<&Function>) -> Result<()> {
555 let mut status = Status::new();
556 unsafe {
557 tf::TF_GraphCopyFunction(
558 self.inner(),
559 func.inner,
560 match grad {
561 None => ptr::null(),
562 Some(g) => g.inner,
563 },
564 status.inner(),
565 );
566 }
567 status.into_result()
568 }
569
570 pub fn to_function<S: AsRef<str>>(
636 &self,
637 fn_name: &str,
638 append_hash_to_fn_name: bool,
639 opers: Option<&[&Operation]>,
640 inputs: &[Output],
641 outputs: &[Output],
642 output_names: Option<&[S]>,
643 opts: &FunctionOptions,
644 description: Option<&str>,
645 ) -> Result<Function> {
646 let fn_name_cstr = CString::new(fn_name)?;
647 let num_opers: c_int = if let Some(ops) = &opers {
648 ops.len() as c_int
649 } else {
650 -1
651 };
652 #[allow(trivial_casts)]
653 let c_opers: Option<Vec<_>> =
654 opers.map(|s| s.iter().map(|op| op.inner as *const _).collect());
655 let c_opers_ptr: *const *const tf::TF_Operation = if let Some(ref ops) = &c_opers {
656 ops.as_ptr()
657 } else {
658 ptr::null()
659 };
660 let c_inputs: Vec<_> = inputs.iter().map(|x| x.to_c()).collect();
661 let c_outputs: Vec<_> = outputs.iter().map(|x| x.to_c()).collect();
662 let output_names_cstrs: Option<::std::result::Result<Vec<CString>, NulError>> =
663 output_names
664 .map(|slice: &[S]| slice.iter().map(|s: &S| CString::new(s.as_ref())).collect());
665 let output_names_cstrs: Option<Vec<CString>> = match output_names_cstrs {
666 None => None,
667 Some(r) => Some(r?),
668 };
669 let output_names_ptrs: Option<Vec<*const c_char>> = output_names_cstrs
670 .as_ref()
671 .map(|slice| slice.iter().map(|s| s.as_ptr()).collect());
672 let output_names_ptrs_ptr = match &output_names_ptrs {
673 None => ptr::null(),
674 Some(ref v) => v.as_ptr(),
675 };
676 let description_cstr = match description {
677 None => None,
678 Some(d) => Some(CString::new(d)?),
679 };
680 let description_ptr: *const c_char = if let Some(ref cstr) = &description_cstr {
681 cstr.as_ptr()
682 } else {
683 ptr::null()
684 };
685 let status = Status::new();
686 let f = unsafe {
687 tf::TF_GraphToFunction(
688 self.inner(),
689 fn_name_cstr.as_ptr(),
690 u8::from(append_hash_to_fn_name),
691 num_opers,
692 c_opers_ptr,
693 c_inputs.len() as c_int,
694 c_inputs.as_ptr(),
695 c_outputs.len() as c_int,
696 c_outputs.as_ptr(),
697 output_names_ptrs_ptr,
698 opts.inner,
699 description_ptr,
700 status.inner,
701 )
702 };
703 status.into_result()?;
704 Ok(Function { inner: f })
705 }
706
707 pub fn num_functions(&self) -> c_int {
709 unsafe { tf::TF_GraphNumFunctions(self.inner()) }
710 }
711
712 pub fn get_functions(&self) -> Result<Vec<Function>> {
714 unsafe {
715 let num = tf::TF_GraphNumFunctions(self.inner());
716 let mut funcs = Vec::with_capacity(num as usize);
717 let status = Status::new();
718 let num = tf::TF_GraphGetFunctions(self.inner(), funcs.as_mut_ptr(), num, status.inner);
719 status.into_result()?;
720 funcs.set_len(num as usize);
721 Ok(funcs.iter().map(|f| Function { inner: *f }).collect())
722 }
723 }
724
725 pub fn get_op_def(&self, op_name: &str) -> Result<Vec<u8>> {
728 let status = Status::new();
729 let c_op_name = CString::new(op_name)?;
730 unsafe {
731 let mut buffer = Buffer::new_unallocated();
732 tf::TF_GraphGetOpDef(
733 self.inner(),
734 c_op_name.as_ptr(),
735 buffer.inner_mut(),
736 status.inner,
737 );
738 status.into_result().map(|()| buffer.into())
739 }
740 }
741
742 pub fn versions(&self) -> Result<Vec<u8>> {
744 let status = Status::new();
745 unsafe {
746 let mut buffer = Buffer::new_unallocated();
747 tf::TF_GraphVersions(self.inner(), buffer.inner_mut(), status.inner);
748 status.into_result().map(|()| buffer.into())
749 }
750 }
751
752 pub fn try_evaluate_constant<T: TensorType>(
761 &self,
762 output: &Output,
763 ) -> Result<Option<Tensor<T>>> {
764 let status = Status::new();
765 unsafe {
766 let mut c_tensor: *mut tf::TF_Tensor = ptr::null_mut();
767 let success = tf::TF_TryEvaluateConstant(
768 self.inner(),
769 output.to_c(),
770 &mut c_tensor,
771 status.inner,
772 );
773 status.into_result()?;
774 if success != 0 {
775 match Tensor::from_tf_tensor(c_tensor) {
776 None => Err(invalid_arg!("Tensor types do not match")),
777 Some(t) => Ok(Some(t)),
778 }
779 } else {
780 Ok(None)
781 }
782 }
783 }
784
785 pub fn add_gradients(
805 &mut self,
806 prefix: Option<&str>,
807 y: &[Output],
808 x: &[Output],
809 dx: Option<&[Output]>,
810 ) -> Result<Vec<Option<Output>>> {
811 if let Some(dx) = dx {
812 if dx.len() != y.len() {
813 return Err(invalid_arg!(
814 "dx.len() must equal y.len() ({} vs. {})",
815 dx.len(),
816 y.len()
817 ));
818 }
819 }
820 let c_y: Vec<_> = y.iter().map(Output::to_c).collect();
821 let c_x: Vec<_> = x.iter().map(Output::to_c).collect();
822 let c_dx: Option<Vec<_>> = dx.map(|v| v.iter().map(Output::to_c).collect());
823 let dx_ptr = match c_dx {
824 Some(v) => v.as_ptr(),
825 None => ptr::null(),
826 };
827 let prefix_cstr = match prefix {
828 Some(s) => Some(CString::new(s)?),
829 None => None,
830 };
831 let prefix_ptr: *const c_char = if let Some(ref cstr) = &prefix_cstr {
832 cstr.as_ptr()
833 } else {
834 ptr::null()
835 };
836 let mut dy = Vec::with_capacity(x.len());
837 let mut status = Status::new();
838 unsafe {
839 tf::TF_AddGradientsWithPrefix(
840 self.inner(),
841 prefix_ptr,
842 c_y.as_ptr() as *mut _,
843 y.len() as i32,
844 c_x.as_ptr() as *mut _,
845 x.len() as i32,
846 dx_ptr as *mut _,
847 status.inner(),
848 dy.as_mut_ptr(),
849 );
850 if status.is_ok() {
851 dy.set_len(x.len());
852 Ok(dy
853 .iter()
854 .map(|o| Output::from_c_optional(self, o))
855 .collect())
856 } else {
857 Err(status)
858 }
859 }
860 }
861
862 pub(crate) fn inner(&self) -> *mut tf::TF_Graph {
863 self.gimpl.inner
864 }
865
866 pub(crate) unsafe fn from_c(inner: *mut tf::TF_Graph) -> Self {
867 Graph {
868 gimpl: Arc::new(GraphImpl {
869 inner,
870 owned: false,
871 }),
872 }
873 }
874}
875
876#[derive(Debug)]
880pub struct OperationIter<'a> {
881 graph: &'a Graph,
884 pos: size_t,
885}
886
887impl<'a> Iterator for OperationIter<'a> {
888 type Item = Operation;
889
890 fn next(&mut self) -> Option<Self::Item> {
891 unsafe {
892 let operation = tf::TF_GraphNextOperation(self.graph.gimpl.inner, &mut self.pos);
893 if operation.is_null() {
894 None
895 } else {
896 Some(Operation {
897 inner: operation,
898 gimpl: self.graph.gimpl.clone(),
899 })
900 }
901 }
902 }
903}
904
905c_enum!(
908TF_AttrType,
909#[allow(missing_docs)]
912AttrType {
913 String = 0,
914 Int = 1,
915 Float = 2,
916 Bool = 3,
917 Type = 4,
918 Shape = 5,
919 Tensor = 6,
920 Placeholder = 7,
921 Func = 8,
922});
923
924#[derive(Clone, Debug, Copy)]
926pub struct AttrMetadata {
927 pub list_size: Option<i64>,
929
930 pub attr_type: AttrType,
933
934 pub total_size: i64,
950}
951
952impl AttrMetadata {
953 fn from_c(metadata: tf::TF_AttrMetadata) -> Self {
954 AttrMetadata {
955 list_size: if metadata.is_list == 0 {
956 None
957 } else {
958 Some(metadata.list_size)
959 },
960 attr_type: AttrType::from_c(metadata.type_),
961 total_size: metadata.total_size,
962 }
963 }
964}
965
966#[derive(Debug, Clone)]
971pub struct Operation {
972 inner: *mut tf::TF_Operation,
973 gimpl: Arc<GraphImpl>,
974}
975
976unsafe impl Send for Operation {}
977unsafe impl Sync for Operation {}
978
979impl Operation {
980 pub fn name(&self) -> std::result::Result<String, Utf8Error> {
986 unsafe {
987 CStr::from_ptr(tf::TF_OperationName(self.inner))
988 .to_str()
989 .map(|x| x.to_string())
990 }
991 }
992
993 pub fn op_type(&self) -> std::result::Result<String, Utf8Error> {
996 unsafe {
997 CStr::from_ptr(tf::TF_OperationOpType(self.inner))
998 .to_str()
999 .map(|x| x.to_string())
1000 }
1001 }
1002
1003 pub fn device(&self) -> std::result::Result<String, Utf8Error> {
1006 unsafe {
1007 CStr::from_ptr(tf::TF_OperationDevice(self.inner))
1008 .to_str()
1009 .map(|x| x.to_string())
1010 }
1011 }
1012
1013 pub fn num_outputs(&self) -> usize {
1015 unsafe { tf::TF_OperationNumOutputs(self.inner) as usize }
1016 }
1017
1018 pub fn output_type(&self, index: usize) -> DataType {
1020 unsafe {
1021 DataType::from_c(tf::TF_OperationOutputType(tf::TF_Output {
1022 oper: self.inner,
1023 index: index as c_int,
1024 }))
1025 }
1026 }
1027
1028 pub fn output(&self, index: usize) -> Output {
1031 crate::Output {
1032 operation: self.clone(),
1033 index: index as c_int,
1034 }
1035 }
1036
1037 #[allow(missing_docs)]
1039 pub fn output_list_length(&self, arg_name: &str) -> Result<usize> {
1040 let c_arg_name = CString::new(arg_name)?;
1041 let mut status = Status::new();
1042 let length = unsafe {
1043 tf::TF_OperationOutputListLength(self.inner, c_arg_name.as_ptr(), status.inner())
1044 };
1045 if status.is_ok() {
1046 Ok(length as usize)
1047 } else {
1048 Err(status)
1049 }
1050 }
1051
1052 pub fn num_inputs(&self) -> usize {
1054 unsafe { tf::TF_OperationNumInputs(self.inner) as usize }
1055 }
1056
1057 pub fn input_type(&self, index: usize) -> DataType {
1059 unsafe {
1060 DataType::from_c(tf::TF_OperationInputType(tf::TF_Input {
1061 oper: self.inner,
1062 index: index as c_int,
1063 }))
1064 }
1065 }
1066
1067 #[allow(missing_docs)]
1069 pub fn input_list_length(&self, arg_name: &str) -> Result<usize> {
1070 let c_arg_name = CString::new(arg_name)?;
1071 let mut status = Status::new();
1072 let length = unsafe {
1073 tf::TF_OperationInputListLength(self.inner, c_arg_name.as_ptr(), status.inner())
1074 };
1075 if status.is_ok() {
1076 Ok(length as usize)
1077 } else {
1078 Err(status)
1079 }
1080 }
1081
1082 pub fn input(&self, index: usize) -> (Operation, usize) {
1086 unsafe {
1087 let port = tf::TF_OperationInput(tf::TF_Input {
1088 oper: self.inner,
1089 index: index as c_int,
1090 });
1091 (
1092 Operation {
1093 inner: port.oper,
1094 gimpl: self.gimpl.clone(),
1095 },
1096 port.index as usize,
1097 )
1098 }
1099 }
1100
1101 pub fn output_num_consumers(&self, index: usize) -> usize {
1103 unsafe {
1104 tf::TF_OperationOutputNumConsumers(tf::TF_Output {
1105 oper: self.inner,
1106 index: index as c_int,
1107 }) as usize
1108 }
1109 }
1110
1111 pub fn output_consumers(&self, index: usize) -> Vec<(Operation, usize)> {
1116 unsafe {
1117 let num_consumers = tf::TF_OperationOutputNumConsumers(tf::TF_Output {
1118 oper: self.inner,
1119 index: index as c_int,
1120 });
1121 let mut vec = <Vec<tf::TF_Input>>::with_capacity(num_consumers as usize);
1122 let len = tf::TF_OperationOutputConsumers(
1123 tf::TF_Output {
1124 oper: self.inner,
1125 index: index as c_int,
1126 },
1127 vec.as_mut_ptr(),
1128 num_consumers as c_int,
1129 );
1130 vec.set_len(len as usize);
1131 vec.into_iter()
1132 .map(|port| {
1133 (
1134 Operation {
1135 inner: port.oper,
1136 gimpl: self.gimpl.clone(),
1137 },
1138 port.index as usize,
1139 )
1140 })
1141 .collect()
1142 }
1143 }
1144
1145 pub fn num_control_inputs(&self) -> usize {
1147 unsafe { tf::TF_OperationNumControlInputs(self.inner) as usize }
1148 }
1149
1150 pub fn control_inputs(&self) -> Vec<Operation> {
1152 unsafe {
1153 let num_consumers = tf::TF_OperationNumControlInputs(self.inner);
1154 let mut vec = <Vec<*mut tf::TF_Operation>>::with_capacity(num_consumers as usize);
1155 let len = tf::TF_OperationGetControlInputs(
1156 self.inner,
1157 vec.as_mut_ptr(),
1158 num_consumers as c_int,
1159 );
1160 vec.set_len(cmp::min(num_consumers, len) as usize);
1161 vec.into_iter()
1162 .map(|operation| Operation {
1163 inner: operation,
1164 gimpl: self.gimpl.clone(),
1165 })
1166 .collect()
1167 }
1168 }
1169
1170 pub fn num_control_outputs(&self) -> usize {
1172 unsafe { tf::TF_OperationNumControlOutputs(self.inner) as usize }
1173 }
1174
1175 pub fn control_outputs(&self) -> Vec<Operation> {
1177 unsafe {
1178 let num_consumers = tf::TF_OperationNumControlOutputs(self.inner);
1179 let mut vec = <Vec<*mut tf::TF_Operation>>::with_capacity(num_consumers as usize);
1180 let len =
1181 tf::TF_OperationGetControlOutputs(self.inner, vec.as_mut_ptr(), vec.len() as c_int);
1182 vec.set_len(len as usize);
1183 vec.into_iter()
1184 .map(|operation| Operation {
1185 inner: operation,
1186 gimpl: self.gimpl.clone(),
1187 })
1188 .collect()
1189 }
1190 }
1191
1192 pub fn get_attr_metadata(&self, attr_name: &str) -> Result<AttrMetadata> {
1194 let c_attr_name = CString::new(attr_name)?;
1195 let mut status = Status::new();
1196 unsafe {
1197 let metadata =
1198 tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1199 if status.is_ok() {
1200 Ok(AttrMetadata::from_c(metadata))
1201 } else {
1202 Err(status)
1203 }
1204 }
1205 }
1206
1207 pub fn get_attr_string(&self, attr_name: &str) -> Result<String> {
1209 let c_attr_name = CString::new(attr_name)?;
1210 let mut status = Status::new();
1211 unsafe {
1212 let metadata =
1213 tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1214 if !status.is_ok() {
1215 return Err(status);
1216 }
1217 let mut v: Vec<MaybeUninit<u8>> = Vec::with_capacity(metadata.total_size as usize);
1218 v.set_len(metadata.total_size as usize);
1219 tf::TF_OperationGetAttrString(
1220 self.inner,
1221 c_attr_name.as_ptr(),
1222 v.as_mut_ptr() as *mut std::os::raw::c_void,
1223 metadata.total_size as usize,
1224 status.inner(),
1225 );
1226 if !status.is_ok() {
1227 return Err(status);
1228 }
1229 Ok(CString::new(
1230 v.into_iter()
1231 .map(|x| MaybeUninit::assume_init(x))
1232 .collect::<Vec<_>>(),
1233 )?
1234 .into_string()?)
1235 }
1236 }
1237
1238 pub fn get_attr_string_list(&self, attr_name: &str) -> Result<Vec<String>> {
1240 let c_attr_name = CString::new(attr_name)?;
1241 let mut status = Status::new();
1242 unsafe {
1243 let metadata =
1244 tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1245 if !status.is_ok() {
1246 return Err(status);
1247 }
1248 let mut storage: Vec<MaybeUninit<u8>> =
1249 Vec::with_capacity(metadata.total_size as usize);
1250 storage.set_len(metadata.total_size as usize);
1251 let mut values: Vec<*const std::os::raw::c_char> =
1252 Vec::with_capacity(metadata.list_size as usize);
1253 let mut lengths: Vec<size_t> = Vec::with_capacity(metadata.list_size as usize);
1254 tf::TF_OperationGetAttrStringList(
1255 self.inner,
1256 c_attr_name.as_ptr(),
1257 values.as_mut_ptr() as *mut *mut std::os::raw::c_void,
1258 lengths.as_mut_ptr(),
1259 metadata.list_size as i32,
1260 storage.as_mut_ptr() as *mut std::os::raw::c_void,
1261 metadata.total_size as usize,
1262 status.inner(),
1263 );
1264 if !status.is_ok() {
1265 return Err(status);
1266 }
1267 values.set_len(metadata.list_size as usize);
1268 lengths.set_len(metadata.list_size as usize);
1269 let mut strings = Vec::with_capacity(metadata.list_size as usize);
1270 for i in 0..metadata.list_size as usize {
1271 let s = slice::from_raw_parts(values[i] as *const u8, lengths[i]);
1272 strings.push(std::str::from_utf8(s)?.to_string());
1273 }
1274 Ok(strings)
1275 }
1276 }
1277
1278 pub fn get_attr_int(&self, attr_name: &str) -> Result<i64> {
1280 let c_attr_name = CString::new(attr_name)?;
1281 let mut status = Status::new();
1282 let mut value: i64 = 0;
1283 unsafe {
1284 tf::TF_OperationGetAttrInt(
1285 self.inner,
1286 c_attr_name.as_ptr(),
1287 &mut value,
1288 status.inner(),
1289 );
1290 }
1291 if !status.is_ok() {
1292 return Err(status);
1293 }
1294 Ok(value)
1295 }
1296
1297 pub fn get_attr_int_list(&self, attr_name: &str) -> Result<Vec<i64>> {
1299 let c_attr_name = CString::new(attr_name)?;
1300 let mut status = Status::new();
1301 unsafe {
1302 let metadata =
1303 tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1304 if !status.is_ok() {
1305 return Err(status);
1306 }
1307 let mut values: Vec<MaybeUninit<i64>> = Vec::with_capacity(metadata.list_size as usize);
1308 values.set_len(metadata.list_size as usize);
1309 tf::TF_OperationGetAttrIntList(
1310 self.inner,
1311 c_attr_name.as_ptr(),
1312 values.as_mut_ptr() as *mut i64,
1313 metadata.list_size as c_int,
1314 status.inner(),
1315 );
1316 if !status.is_ok() {
1317 return Err(status);
1318 }
1319 Ok(values
1320 .into_iter()
1321 .map(|x| MaybeUninit::assume_init(x))
1322 .collect())
1323 }
1324 }
1325
1326 pub fn get_attr_float(&self, attr_name: &str) -> Result<f32> {
1328 let c_attr_name = CString::new(attr_name)?;
1329 let mut status = Status::new();
1330 let mut value: c_float = 0.0;
1331 unsafe {
1332 tf::TF_OperationGetAttrFloat(
1333 self.inner,
1334 c_attr_name.as_ptr(),
1335 &mut value,
1336 status.inner(),
1337 );
1338 }
1339 if !status.is_ok() {
1340 return Err(status);
1341 }
1342 #[allow(trivial_numeric_casts)]
1343 #[allow(clippy::unnecessary_cast)]
1344 Ok(value as f32)
1345 }
1346
1347 pub fn get_attr_float_list(&self, attr_name: &str) -> Result<Vec<f32>> {
1349 let c_attr_name = CString::new(attr_name)?;
1350 let mut status = Status::new();
1351 unsafe {
1352 let metadata =
1353 tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1354 if !status.is_ok() {
1355 return Err(status);
1356 }
1357 let mut values: Vec<MaybeUninit<c_float>> =
1358 Vec::with_capacity(metadata.list_size as usize);
1359 values.set_len(metadata.list_size as usize);
1360 tf::TF_OperationGetAttrFloatList(
1361 self.inner,
1362 c_attr_name.as_ptr(),
1363 values.as_mut_ptr() as *mut c_float,
1364 metadata.list_size as c_int,
1365 status.inner(),
1366 );
1367 if !status.is_ok() {
1368 return Err(status);
1369 }
1370 #[allow(trivial_numeric_casts)]
1371 #[allow(clippy::unnecessary_cast)]
1372 Ok(values.iter().map(|f| f.assume_init() as f32).collect())
1373 }
1374 }
1375
1376 pub fn get_attr_bool(&self, attr_name: &str) -> Result<bool> {
1378 let c_attr_name = CString::new(attr_name)?;
1379 let mut status = Status::new();
1380 let mut value: c_uchar = 0;
1381 unsafe {
1382 tf::TF_OperationGetAttrBool(
1383 self.inner,
1384 c_attr_name.as_ptr(),
1385 &mut value,
1386 status.inner(),
1387 );
1388 }
1389 if !status.is_ok() {
1390 return Err(status);
1391 }
1392 Ok(value != 0)
1393 }
1394
1395 pub fn get_attr_bool_list(&self, attr_name: &str) -> Result<Vec<bool>> {
1397 let c_attr_name = CString::new(attr_name)?;
1398 let mut status = Status::new();
1399 unsafe {
1400 let metadata =
1401 tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1402 if !status.is_ok() {
1403 return Err(status);
1404 }
1405 let mut values: Vec<MaybeUninit<c_uchar>> =
1406 Vec::with_capacity(metadata.list_size as usize);
1407 values.set_len(metadata.list_size as usize);
1408 tf::TF_OperationGetAttrBoolList(
1409 self.inner,
1410 c_attr_name.as_ptr(),
1411 values.as_mut_ptr() as *mut c_uchar,
1412 metadata.list_size as c_int,
1413 status.inner(),
1414 );
1415 if !status.is_ok() {
1416 return Err(status);
1417 }
1418 #[allow(trivial_numeric_casts)]
1419 Ok(values.iter().map(|f| f.assume_init() != 0).collect())
1420 }
1421 }
1422
1423 pub fn get_attr_type(&self, attr_name: &str) -> Result<DataType> {
1425 let c_attr_name = CString::new(attr_name)?;
1426 let mut status = Status::new();
1427 let mut value: tf::TF_DataType = tf::TF_FLOAT;
1428 unsafe {
1429 tf::TF_OperationGetAttrType(
1430 self.inner,
1431 c_attr_name.as_ptr(),
1432 &mut value,
1433 status.inner(),
1434 );
1435 }
1436 if !status.is_ok() {
1437 return Err(status);
1438 }
1439 Ok(DataType::from_c(value))
1440 }
1441
1442 pub fn get_attr_type_list(&self, attr_name: &str) -> Result<Vec<DataType>> {
1444 let c_attr_name = CString::new(attr_name)?;
1445 let mut status = Status::new();
1446 unsafe {
1447 let metadata =
1448 tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1449 if !status.is_ok() {
1450 return Err(status);
1451 }
1452 let mut values: Vec<MaybeUninit<tf::TF_DataType>> =
1453 Vec::with_capacity(metadata.list_size as usize);
1454 values.set_len(metadata.list_size as usize);
1455 tf::TF_OperationGetAttrTypeList(
1456 self.inner,
1457 c_attr_name.as_ptr(),
1458 values.as_mut_ptr() as *mut tf::TF_DataType,
1459 metadata.list_size as c_int,
1460 status.inner(),
1461 );
1462 if !status.is_ok() {
1463 return Err(status);
1464 }
1465 Ok(values
1466 .iter()
1467 .map(|x| DataType::from_c(x.assume_init()))
1468 .collect())
1469 }
1470 }
1471
1472 pub fn get_attr_shape(&self, attr_name: &str) -> Result<Shape> {
1474 let c_attr_name = CString::new(attr_name)?;
1475 let mut status = Status::new();
1476 unsafe {
1477 let metadata =
1478 tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1479 if !status.is_ok() {
1480 return Err(status);
1481 }
1482 if metadata.total_size == -1 {
1483 return Ok(Shape(None));
1484 }
1485 let mut v: Vec<MaybeUninit<i64>> = Vec::with_capacity(metadata.total_size as usize);
1486 v.set_len(metadata.total_size as usize);
1487 tf::TF_OperationGetAttrShape(
1488 self.inner,
1489 c_attr_name.as_ptr(),
1490 v.as_mut_ptr() as *mut i64,
1491 metadata.total_size as c_int,
1492 status.inner(),
1493 );
1494 if !status.is_ok() {
1495 return Err(status);
1496 }
1497 Ok(Shape(Some(
1498 v.iter()
1499 .map(|x| {
1500 let x = x.assume_init();
1501 if x < 0 {
1502 None
1503 } else {
1504 Some(x)
1505 }
1506 })
1507 .collect(),
1508 )))
1509 }
1510 }
1511
1512 pub fn get_attr_shape_list(&self, attr_name: &str) -> Result<Vec<Shape>> {
1514 let c_attr_name = CString::new(attr_name)?;
1515 let mut status = Status::new();
1516 unsafe {
1517 let metadata =
1518 tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1519 if !status.is_ok() {
1520 return Err(status);
1521 }
1522 let mut storage: Vec<MaybeUninit<i64>> =
1523 Vec::with_capacity(metadata.total_size as usize);
1524 storage.set_len(metadata.total_size as usize);
1525 let mut dims: Vec<*mut i64> = Vec::with_capacity(metadata.list_size as usize);
1526 let mut num_dims: Vec<c_int> = Vec::with_capacity(metadata.list_size as usize);
1527 tf::TF_OperationGetAttrShapeList(
1528 self.inner,
1529 c_attr_name.as_ptr(),
1530 dims.as_mut_ptr(),
1531 num_dims.as_mut_ptr(),
1532 metadata.list_size as i32,
1533 storage.as_mut_ptr() as *mut i64,
1534 metadata.total_size as c_int,
1535 status.inner(),
1536 );
1537 if !status.is_ok() {
1538 return Err(status);
1539 }
1540 dims.set_len(metadata.list_size as usize);
1541 num_dims.set_len(metadata.list_size as usize);
1542 let mut shapes = Vec::with_capacity(metadata.list_size as usize);
1543 for i in 0..metadata.list_size as usize {
1544 shapes.push(Shape(if num_dims[i] == -1 {
1545 None
1546 } else {
1547 let mut v = Vec::new();
1548 for j in 0..num_dims[i] {
1549 v.push(match *dims[i].offset(j as isize) {
1550 -1 => None,
1551 x => Some(x),
1552 });
1553 }
1554 Some(v)
1555 }));
1556 }
1557 Ok(shapes)
1558 }
1559 }
1560
1561 pub fn get_attr_tensor_shape_proto(&self, attr_name: &str) -> Result<Vec<u8>> {
1564 let c_attr_name = CString::new(attr_name)?;
1565 let mut status = Status::new();
1566 unsafe {
1567 let mut buf = Buffer::<u8>::new_unallocated();
1568 tf::TF_OperationGetAttrTensorShapeProto(
1569 self.inner,
1570 c_attr_name.as_ptr(),
1571 buf.inner_mut(),
1572 status.inner(),
1573 );
1574 if !status.is_ok() {
1575 return Err(status);
1576 }
1577 Ok(buf.into())
1578 }
1579 }
1580
1581 pub fn get_attr_tensor_shape_proto_list(&self, attr_name: &str) -> Result<Vec<Vec<u8>>> {
1584 let c_attr_name = CString::new(attr_name)?;
1585 let mut status = Status::new();
1586 unsafe {
1587 let metadata =
1588 tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1589 if !status.is_ok() {
1590 return Err(status);
1591 }
1592 let mut c_buffers = Vec::with_capacity(metadata.list_size as usize);
1593 for _ in 0..metadata.list_size {
1594 c_buffers.push(ptr::null_mut());
1595 }
1596 tf::TF_OperationGetAttrTensorShapeProtoList(
1597 self.inner,
1598 c_attr_name.as_ptr(),
1599 c_buffers.as_mut_ptr(),
1600 metadata.list_size as c_int,
1601 status.inner(),
1602 );
1603 if !status.is_ok() {
1604 return Err(status);
1605 }
1606 Ok(c_buffers
1607 .iter()
1608 .map(|b| Buffer::from_c(*b, true).into())
1609 .collect())
1610 }
1611 }
1612
1613 pub fn get_attr_tensor<T: TensorType>(&self, attr_name: &str) -> Result<Tensor<T>> {
1617 let c_attr_name = CString::new(attr_name)?;
1618 let mut status = Status::new();
1619 unsafe {
1620 let mut c_tensor: *mut tf::TF_Tensor = ptr::null_mut();
1621 tf::TF_OperationGetAttrTensor(
1622 self.inner,
1623 c_attr_name.as_ptr(),
1624 &mut c_tensor,
1625 status.inner(),
1626 );
1627 if !status.is_ok() {
1628 return Err(status);
1629 }
1630 match Tensor::from_tf_tensor(c_tensor) {
1631 None => Err(invalid_arg!("Tensor types do not match")),
1632 Some(t) => Ok(t),
1633 }
1634 }
1635 }
1636
1637 pub fn get_attr_tensor_list<T: TensorType>(&self, attr_name: &str) -> Result<Vec<Tensor<T>>> {
1641 let c_attr_name = CString::new(attr_name)?;
1642 let mut status = Status::new();
1643 unsafe {
1644 let metadata =
1645 tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1646 if !status.is_ok() {
1647 return Err(status);
1648 }
1649 let mut c_tensors = Vec::with_capacity(metadata.list_size as usize);
1650 for _ in 0..metadata.list_size {
1651 c_tensors.push(ptr::null_mut());
1652 }
1653 tf::TF_OperationGetAttrTensorList(
1654 self.inner,
1655 c_attr_name.as_ptr(),
1656 c_tensors.as_mut_ptr(),
1657 metadata.list_size as c_int,
1658 status.inner(),
1659 );
1660 if !status.is_ok() {
1661 return Err(status);
1662 }
1663 c_tensors
1664 .iter()
1665 .map(|t| match Tensor::from_tf_tensor(*t) {
1666 None => Err(invalid_arg!("Tensor types do not match")),
1667 Some(t) => Ok(t),
1668 })
1669 .collect()
1670 }
1671 }
1672
1673 pub fn get_attr_value_proto(&self, attr_name: &str) -> Result<Vec<u8>> {
1676 let status = Status::new();
1677 let attr_name_cstr = CString::new(attr_name)?;
1678 unsafe {
1679 let mut buf = Buffer::new_unallocated();
1680 tf::TF_OperationGetAttrValueProto(
1681 self.inner,
1682 attr_name_cstr.as_ptr(),
1683 buf.inner_mut(),
1684 status.inner,
1685 );
1686 status.into_result()?;
1687 Ok(buf.into())
1688 }
1689 }
1690
1691 pub(crate) fn inner(&self) -> *mut tf::TF_Operation {
1692 self.inner
1693 }
1694}
1695
1696impl From<Operation> for Output {
1697 fn from(operation: Operation) -> Output {
1699 Output {
1700 operation,
1701 index: 0,
1702 }
1703 }
1704}
1705
1706#[derive(Debug, Copy, Clone)]
1711pub struct Input<'a> {
1712 pub operation: &'a Operation,
1714
1715 pub index: c_int,
1717}
1718
1719#[derive(Debug, Clone)]
1724pub struct Output {
1725 pub operation: Operation,
1727
1728 pub index: c_int,
1730}
1731
1732impl Output {
1733 pub(crate) fn to_c(&self) -> tf::TF_Output {
1734 tf::TF_Output {
1735 oper: self.operation.inner,
1736 index: self.index,
1737 }
1738 }
1739
1740 pub(crate) fn from_c(graph: &Graph, output: &tf::TF_Output) -> Self {
1741 Output {
1742 operation: Operation {
1743 inner: output.oper,
1744 gimpl: graph.gimpl.clone(),
1745 },
1746 index: output.index,
1747 }
1748 }
1749
1750 pub(crate) fn from_c_optional(graph: &Graph, output: &tf::TF_Output) -> Option<Self> {
1751 if output.oper.is_null() {
1752 None
1753 } else {
1754 Some(Output {
1755 operation: Operation {
1756 inner: output.oper,
1757 gimpl: graph.gimpl.clone(),
1758 },
1759 index: output.index,
1760 })
1761 }
1762 }
1763
1764 pub fn name(&self) -> Result<OutputName> {
1766 Ok(OutputName {
1767 name: self.operation.name()?,
1768 index: self.index,
1769 })
1770 }
1771}
1772
1773#[derive(Clone, PartialEq, Eq, Hash, Debug, Default)]
1777pub struct OutputName {
1778 pub name: String,
1780
1781 pub index: c_int,
1783}
1784
1785impl FromStr for OutputName {
1786 type Err = Status;
1787 fn from_str(s: &str) -> Result<Self> {
1788 let splits: Vec<_> = s.split(':').collect();
1789 let index = match splits.len() {
1790 2 => splits[1].parse::<c_int>()?,
1791 1 => 0,
1792 _ => {
1793 return Err(Status::new_set_lossy(
1794 Code::InvalidArgument,
1795 "Name contains more than one colon (':')",
1796 ))
1797 }
1798 };
1799 Ok(Self {
1800 name: splits[0].to_string(),
1801 index,
1802 })
1803 }
1804}
1805
1806impl Display for OutputName {
1807 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1808 write!(f, "{}:{}", self.name, self.index)
1809 }
1810}
1811
1812#[derive(Debug)]
1821pub struct OperationDescription<'a> {
1822 inner: *mut tf::TF_OperationDescription,
1823 graph: &'a Graph,
1826 finished: bool,
1827}
1828
1829impl<'a> Drop for OperationDescription<'a> {
1830 fn drop(&mut self) {
1831 if !self.finished {
1832 unsafe {
1833 let status = tf::TF_NewStatus();
1837 tf::TF_FinishOperation(self.inner, status);
1838 tf::TF_DeleteStatus(status);
1839 }
1840 }
1841 }
1842}
1843
1844impl<'a> OperationDescription<'a> {
1845 pub fn finish(mut self) -> Result<Operation> {
1847 self.finished = true; let mut status = Status::new();
1849 let operation = unsafe { tf::TF_FinishOperation(self.inner, status.inner()) };
1850 if status.is_ok() {
1851 Ok(Operation {
1852 inner: operation,
1853 gimpl: self.graph.gimpl.clone(),
1854 })
1855 } else {
1856 Err(status)
1857 }
1858 }
1859
1860 pub fn set_device(&mut self, device: &str) -> std::result::Result<(), NulError> {
1863 let c_device = CString::new(device)?;
1864 unsafe {
1865 tf::TF_SetDevice(self.inner, c_device.as_ptr());
1866 }
1867 Ok(())
1868 }
1869
1870 pub fn add_input<I: Into<Output>>(&mut self, input: I) {
1874 unsafe {
1875 tf::TF_AddInput(self.inner, input.into().to_c());
1876 }
1877 }
1878
1879 pub fn add_input_list(&mut self, inputs: &[Output]) {
1883 let c_inputs: Vec<tf::TF_Output> = inputs.iter().map(|x| x.to_c()).collect();
1884 unsafe {
1885 tf::TF_AddInputList(self.inner, c_inputs.as_ptr(), c_inputs.len() as c_int);
1886 }
1887 }
1888
1889 pub fn add_control_input(&mut self, input: &Operation) {
1891 unsafe {
1892 tf::TF_AddControlInput(self.inner, input.inner);
1893 }
1894 }
1895
1896 #[allow(trivial_numeric_casts)]
1898 pub fn set_attr_string(
1899 &mut self,
1900 attr_name: &str,
1901 value: &str,
1902 ) -> std::result::Result<(), NulError> {
1903 let c_attr_name = CString::new(attr_name)?;
1904 let c_value = value.as_bytes();
1905 unsafe {
1906 tf::TF_SetAttrString(
1907 self.inner,
1908 c_attr_name.as_ptr(),
1909 c_value.as_ptr() as *const std_c_void,
1910 c_value.len() as size_t,
1911 );
1912 }
1913 Ok(())
1914 }
1915
1916 #[allow(trivial_numeric_casts)]
1918 pub fn set_attr_string_list<S: AsRef<str>>(
1919 &mut self,
1920 attr_name: &str,
1921 value: &[S],
1922 ) -> std::result::Result<(), NulError> {
1923 let c_attr_name = CString::new(attr_name)?;
1924 let bytes: Vec<&[u8]> = value.iter().map(|x| x.as_ref().as_bytes()).collect();
1925 let ptrs: Vec<*const c_void> = bytes.iter().map(|x| x.as_ptr() as *const c_void).collect();
1926 let lens: Vec<size_t> = bytes.iter().map(|x| x.len() as size_t).collect();
1927 unsafe {
1928 tf::TF_SetAttrStringList(
1929 self.inner,
1930 c_attr_name.as_ptr(),
1931 ptrs.as_ptr() as *const *const std_c_void,
1932 lens.as_ptr(),
1933 ptrs.len() as c_int,
1934 );
1935 }
1936 Ok(())
1937 }
1938
1939 #[allow(trivial_numeric_casts)]
1941 pub fn set_attr_func_name(
1942 &mut self,
1943 attr_name: &str,
1944 value: &str,
1945 ) -> std::result::Result<(), NulError> {
1946 let c_attr_name = CString::new(attr_name)?;
1947 let c_value = value.as_bytes();
1948 unsafe {
1949 tf::TF_SetAttrFuncName(
1950 self.inner,
1951 c_attr_name.as_ptr(),
1952 c_value.as_ptr() as *const c_char,
1953 c_value.len() as size_t,
1954 );
1955 }
1956 Ok(())
1957 }
1958
1959 pub fn set_attr_int(
1961 &mut self,
1962 attr_name: &str,
1963 value: i64,
1964 ) -> std::result::Result<(), NulError> {
1965 let c_attr_name = CString::new(attr_name)?;
1966 unsafe {
1967 tf::TF_SetAttrInt(self.inner, c_attr_name.as_ptr(), value);
1968 }
1969 Ok(())
1970 }
1971
1972 pub fn set_attr_int_list(
1974 &mut self,
1975 attr_name: &str,
1976 value: &[i64],
1977 ) -> std::result::Result<(), NulError> {
1978 let c_attr_name = CString::new(attr_name)?;
1979 unsafe {
1980 tf::TF_SetAttrIntList(
1981 self.inner,
1982 c_attr_name.as_ptr(),
1983 value.as_ptr(),
1984 value.len() as i32,
1985 );
1986 }
1987 Ok(())
1988 }
1989
1990 pub fn set_attr_float(
1992 &mut self,
1993 attr_name: &str,
1994 value: f32,
1995 ) -> std::result::Result<(), NulError> {
1996 let c_attr_name = CString::new(attr_name)?;
1997 unsafe {
1998 tf::TF_SetAttrFloat(self.inner, c_attr_name.as_ptr(), value);
1999 }
2000 Ok(())
2001 }
2002
2003 #[allow(trivial_numeric_casts)]
2005 pub fn set_attr_float_list(
2006 &mut self,
2007 attr_name: &str,
2008 value: &[f32],
2009 ) -> std::result::Result<(), NulError> {
2010 let c_attr_name = CString::new(attr_name)?;
2011 let c_value: Vec<c_float> = value.iter().map(|x| *x as c_float).collect();
2013 unsafe {
2014 tf::TF_SetAttrFloatList(
2015 self.inner,
2016 c_attr_name.as_ptr(),
2017 c_value.as_ptr(),
2018 c_value.len() as i32,
2019 );
2020 }
2021 Ok(())
2022 }
2023
2024 pub fn set_attr_bool(
2026 &mut self,
2027 attr_name: &str,
2028 value: bool,
2029 ) -> std::result::Result<(), NulError> {
2030 let c_attr_name = CString::new(attr_name)?;
2031 unsafe {
2032 tf::TF_SetAttrBool(self.inner, c_attr_name.as_ptr(), u8::from(value));
2033 }
2034 Ok(())
2035 }
2036
2037 pub fn set_attr_bool_list(
2039 &mut self,
2040 attr_name: &str,
2041 value: &[bool],
2042 ) -> std::result::Result<(), NulError> {
2043 let c_attr_name = CString::new(attr_name)?;
2044 let c_value: Vec<c_uchar> = value.iter().map(|x| u8::from(*x)).collect();
2045 unsafe {
2046 tf::TF_SetAttrBoolList(
2047 self.inner,
2048 c_attr_name.as_ptr(),
2049 c_value.as_ptr(),
2050 c_value.len() as c_int,
2051 );
2052 }
2053 Ok(())
2054 }
2055
2056 pub fn set_attr_type(
2058 &mut self,
2059 attr_name: &str,
2060 value: DataType,
2061 ) -> std::result::Result<(), NulError> {
2062 let c_attr_name = CString::new(attr_name)?;
2063 unsafe {
2064 tf::TF_SetAttrType(self.inner, c_attr_name.as_ptr(), value.to_c());
2065 }
2066 Ok(())
2067 }
2068
2069 pub fn set_attr_type_list(
2071 &mut self,
2072 attr_name: &str,
2073 value: &[DataType],
2074 ) -> std::result::Result<(), NulError> {
2075 let c_attr_name = CString::new(attr_name)?;
2076 let c_value: Vec<tf::TF_DataType> = value.iter().map(|x| x.to_c()).collect();
2077 unsafe {
2078 tf::TF_SetAttrTypeList(
2079 self.inner,
2080 c_attr_name.as_ptr(),
2081 c_value.as_ptr(),
2082 c_value.len() as i32,
2083 );
2084 }
2085 Ok(())
2086 }
2087
2088 pub fn set_attr_shape(
2090 &mut self,
2091 attr_name: &str,
2092 value: &Shape,
2093 ) -> std::result::Result<(), NulError> {
2094 let c_attr_name = CString::new(attr_name)?;
2095 unsafe {
2096 match value.0 {
2097 None => tf::TF_SetAttrShape(self.inner, c_attr_name.as_ptr(), ptr::null(), -1),
2098 Some(ref dims) => {
2099 let c_dims: Vec<i64> = dims.iter().map(|x| (*x).unwrap_or(-1)).collect();
2100 tf::TF_SetAttrShape(
2101 self.inner,
2102 c_attr_name.as_ptr(),
2103 c_dims.as_ptr(),
2104 c_dims.len() as i32,
2105 );
2106 }
2107 }
2108 }
2109 Ok(())
2110 }
2111
2112 pub fn set_attr_shape_list(
2114 &mut self,
2115 attr_name: &str,
2116 value: &[Shape],
2117 ) -> std::result::Result<(), NulError> {
2118 let c_attr_name = CString::new(attr_name)?;
2119 let c_dims: Vec<Option<Vec<i64>>> = value
2121 .iter()
2122 .map(|x| {
2123 x.0.as_ref()
2124 .map(|dims| dims.iter().map(|x| (*x).unwrap_or(-1)).collect())
2125 })
2126 .collect();
2127 let ptrs: Vec<*const i64> = c_dims
2128 .iter()
2129 .map(|x| match *x {
2130 None => ptr::null(),
2131 Some(ref dims) => dims.as_ptr(),
2132 })
2133 .collect();
2134 let lens: Vec<c_int> = value
2135 .iter()
2136 .map(|x| match x.0 {
2137 None => -1,
2138 Some(ref dims) => dims.len() as c_int,
2139 })
2140 .collect();
2141 unsafe {
2142 tf::TF_SetAttrShapeList(
2143 self.inner,
2144 c_attr_name.as_ptr(),
2145 ptrs.as_ptr(),
2146 lens.as_ptr(),
2147 ptrs.len() as c_int,
2148 );
2149 }
2150 Ok(())
2151 }
2152
2153 #[allow(trivial_numeric_casts)]
2155 pub fn set_attr_tensor_shape_proto(&mut self, attr_name: &str, value: &[u8]) -> Result<()> {
2156 let c_attr_name = CString::new(attr_name)?;
2157 let mut status = Status::new();
2158 unsafe {
2159 tf::TF_SetAttrTensorShapeProto(
2160 self.inner,
2161 c_attr_name.as_ptr(),
2162 value.as_ptr() as *const std_c_void,
2163 value.len() as size_t,
2164 status.inner(),
2165 );
2166 }
2167 status.into_result()
2168 }
2169
2170 #[allow(trivial_numeric_casts)]
2172 pub fn set_attr_tensor_shape_proto_list<T: AsRef<[u8]>>(
2173 &mut self,
2174 attr_name: &str,
2175 value: &[T],
2176 ) -> Result<()> {
2177 let c_attr_name = CString::new(attr_name)?;
2178 let ptrs: Vec<*const c_void> = value
2179 .iter()
2180 .map(|x| x.as_ref().as_ptr() as *const c_void)
2181 .collect();
2182 let lens: Vec<size_t> = value.iter().map(|x| x.as_ref().len() as size_t).collect();
2183 let mut status = Status::new();
2184 unsafe {
2185 tf::TF_SetAttrTensorShapeProtoList(
2186 self.inner,
2187 c_attr_name.as_ptr(),
2188 ptrs.as_ptr() as *const *const std_c_void,
2189 lens.as_ptr(),
2190 ptrs.len() as c_int,
2191 status.inner(),
2192 );
2193 }
2194 status.into_result()
2195 }
2196
2197 pub fn set_attr_tensor<T: TensorType>(
2199 &mut self,
2200 attr_name: &str,
2201 value: Tensor<T>,
2202 ) -> Result<()> {
2203 self.set_attr_any_tensor(attr_name, &value)
2204 }
2205
2206 pub(crate) fn set_attr_any_tensor(
2208 &mut self,
2209 attr_name: &str,
2210 value: &dyn AnyTensor,
2211 ) -> Result<()> {
2212 let c_attr_name = CString::new(attr_name)?;
2213 let mut status = Status::new();
2214 unsafe {
2215 tf::TF_SetAttrTensor(
2216 self.inner,
2217 c_attr_name.as_ptr(),
2218 value.inner()?,
2219 status.inner(),
2220 );
2221 }
2222 status.into_result()
2223 }
2224
2225 pub fn set_attr_tensor_list<I, T>(&mut self, attr_name: &str, value: I) -> Result<()>
2227 where
2228 I: IntoIterator<Item = Tensor<T>>,
2229 T: TensorType,
2230 {
2231 let c_attr_name = CString::new(attr_name)?;
2232 let mut status = Status::new();
2233 unsafe {
2234 let tensors: Vec<_> = value.into_iter().collect();
2236 let maybe_ptrs: Result<_> = tensors.iter().map(|x| x.inner()).collect();
2237 let ptrs: Vec<*mut tf::TF_Tensor> = maybe_ptrs?;
2238 tf::TF_SetAttrTensorList(
2239 self.inner,
2240 c_attr_name.as_ptr(),
2241 ptrs.as_ptr() as *const *mut tf::TF_Tensor,
2242 ptrs.len() as c_int,
2243 status.inner(),
2244 );
2245 }
2246 status.into_result()
2247 }
2248
2249 #[deprecated(since = "0.7.0", note = "Use set_attr_value_proto instead.")]
2251 pub fn set_attr_to_attr_value_proto(&mut self, attr_name: &str, value: &[u8]) -> Result<()> {
2252 self.set_attr_value_proto(attr_name, value)
2253 }
2254
2255 #[allow(trivial_numeric_casts)]
2257 pub fn set_attr_value_proto(&mut self, attr_name: &str, value: &[u8]) -> Result<()> {
2258 let c_attr_name = CString::new(attr_name)?;
2259 let mut status = Status::new();
2260 unsafe {
2261 tf::TF_SetAttrValueProto(
2262 self.inner,
2263 c_attr_name.as_ptr(),
2264 value.as_ptr() as *const std_c_void,
2265 value.len() as size_t,
2268 status.inner(),
2269 );
2270 }
2271 status.into_result()
2272 }
2273}
2274
2275#[derive(Debug)]
2279#[allow(missing_copy_implementations)]
2280pub struct FunctionOptions {
2281 inner: *mut tf::TF_FunctionOptions,
2282}
2283
2284impl Default for FunctionOptions {
2285 fn default() -> Self {
2286 Self::new()
2287 }
2288}
2289
2290impl FunctionOptions {
2291 pub fn new() -> Self {
2293 FunctionOptions {
2294 inner: ptr::null_mut(), }
2296 }
2297}
2298
2299#[derive(Debug)]
2305pub struct Function {
2306 inner: *mut tf::TF_Function,
2307}
2308
2309impl_drop!(Function, TF_DeleteFunction);
2310
2311impl Function {
2312 pub fn to_function_def(&self) -> Result<Vec<u8>> {
2317 let status = Status::new();
2318 unsafe {
2319 let mut buf = Buffer::from_ptr(ptr::null_mut(), 0);
2320 tf::TF_FunctionToFunctionDef(self.inner, buf.inner_mut(), status.inner);
2321 status.into_result()?;
2322 Ok(buf.into())
2323 }
2324 }
2325
2326 pub fn import_function_def(proto: &[u8]) -> Result<Function> {
2329 let status = Status::new();
2330 unsafe {
2331 let inner = tf::TF_FunctionImportFunctionDef(
2332 proto.as_ptr() as *const std_c_void,
2333 proto.len(),
2334 status.inner,
2335 );
2336 status.into_result()?;
2337 Ok(Function { inner })
2338 }
2339 }
2340
2341 pub fn set_attr_value_proto(&mut self, attr_name: &str, proto: &[u8]) -> Result<()> {
2346 let status = Status::new();
2347 let attr_name_cstr = CString::new(attr_name)?;
2348 unsafe {
2349 tf::TF_FunctionSetAttrValueProto(
2350 self.inner,
2351 attr_name_cstr.as_ptr(),
2352 proto.as_ptr() as *const std_c_void,
2353 proto.len(),
2354 status.inner,
2355 );
2356 }
2357 status.into_result()
2358 }
2359
2360 pub fn get_attr_value_proto(&self, attr_name: &str) -> Result<Vec<u8>> {
2364 let status = Status::new();
2365 let attr_name_cstr = CString::new(attr_name)?;
2366 unsafe {
2367 let mut buf = Buffer::from_ptr(ptr::null_mut(), 0);
2368 tf::TF_FunctionGetAttrValueProto(
2369 self.inner,
2370 attr_name_cstr.as_ptr(),
2371 buf.inner_mut(),
2372 status.inner,
2373 );
2374 status.into_result()?;
2375 Ok(buf.into())
2376 }
2377 }
2378
2379 pub fn get_name(&self) -> std::result::Result<String, Utf8Error> {
2381 unsafe {
2382 CStr::from_ptr(tf::TF_FunctionName(self.inner))
2383 .to_str()
2384 .map(|s| s.to_string())
2385 }
2386 }
2387}
2388
2389#[cfg(test)]
2392mod tests {
2393 use super::super::DataType;
2394 use super::super::Shape;
2395 use super::*;
2396
2397 fn add_operation(g: &mut Graph) {
2398 g.new_operation("Variable", "foo").unwrap();
2399 }
2400
2401 fn add(g: &mut Graph, op1: Operation, op2: Operation, name: &str) -> Result<Operation> {
2402 let mut nd = g.new_operation("Add", name)?;
2403 nd.add_input(op1);
2404 nd.add_input(op2);
2405 nd.finish()
2406 }
2407
2408 fn multiply(g: &mut Graph, op1: Operation, op2: Operation, name: &str) -> Result<Operation> {
2409 let mut nd = g.new_operation("Mul", name)?;
2410 nd.add_input(op1);
2411 nd.add_input(op2);
2412 nd.finish()
2413 }
2414
2415 #[test]
2416 fn smoke() {
2417 let mut g = Graph::new();
2418 add_operation(&mut g);
2419 let operation = {
2420 let mut nd = g.new_operation("Variable", "foo").unwrap();
2421 nd.set_attr_type("dtype", DataType::Float).unwrap();
2422 nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2423 nd.finish().unwrap()
2424 };
2425 let mut nd2 = g.new_operation("Variable", "foo2").unwrap();
2426 nd2.set_attr_type("dtype", DataType::Float).unwrap();
2427 nd2.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2428 let operation2 = nd2.finish().unwrap();
2429 assert_eq!("foo", operation.name().unwrap());
2430 assert_eq!("foo2", operation2.name().unwrap());
2431 }
2432
2433 #[test]
2434 fn test_import_graph_def() {
2435 let mut g = Graph::new();
2436 let opts = ImportGraphDefOptions::new();
2437 let status = g.import_graph_def(&[], &opts);
2439 assert!(status.is_ok());
2440 }
2441
2442 #[test]
2443 fn test_get_tensor_shape() {
2444 fn constant<T: TensorType>(graph: &mut Graph, name: &str, value: Tensor<T>) -> Operation {
2445 let mut c = graph.new_operation("Const", name).unwrap();
2446 c.set_attr_tensor("value", value).unwrap();
2447 c.set_attr_type("dtype", T::data_type()).unwrap();
2448 c.finish().unwrap()
2449 }
2450
2451 let mut graph = Graph::new();
2452 let x_init = Tensor::<i32>::new(&[3, 3]);
2453 let x = constant(&mut graph, "x/assign_0", x_init);
2454 assert_eq!(1, x.num_outputs());
2455 assert_eq!(x.output_type(0), DataType::Int32);
2456 let dims = graph.num_dims(x.clone()).unwrap();
2457 assert_eq!(dims, 2);
2458 let shape = graph.tensor_shape(x.clone()).unwrap();
2459 assert_eq!(shape, Shape(Some(vec![Some(3_i64), Some(3_i64)])));
2460 }
2461
2462 #[test]
2463 fn graph_to_function() {
2464 let mut g = Graph::new();
2465 let x = {
2466 let mut nd = g.new_operation("Placeholder", "x").unwrap();
2467 nd.set_attr_type("dtype", DataType::Float).unwrap();
2468 nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2469 nd.finish().unwrap()
2470 };
2471 let two = {
2472 let mut nd = g.new_operation("Const", "two").unwrap();
2473 nd.set_attr_type("dtype", DataType::Float).unwrap();
2474 let mut value = Tensor::new(&[1]);
2475 value[0] = 2.0f32;
2476 nd.set_attr_tensor("value", value).unwrap();
2477 nd.finish().unwrap()
2478 };
2479 let y = multiply(&mut g, two.clone(), x.clone(), "y").unwrap();
2480 let opers = vec![&y];
2481 let inputs = vec![x.clone().into(), two.clone().into()];
2482 let outputs = vec![y.clone().into()];
2483 let output_names = vec!["result"];
2484 let description = "Multiplies by 2";
2485 let opts = FunctionOptions::new();
2486 let f = g
2487 .to_function(
2488 "times_two",
2489 false,
2490 Some(&opers),
2491 &inputs,
2492 &outputs,
2493 Some(&output_names),
2494 &opts,
2495 Some(description),
2496 )
2497 .unwrap();
2498 assert_eq!("times_two", f.get_name().unwrap());
2499 let mut g2 = Graph::new();
2500 assert_eq!(0, g2.num_functions());
2501 assert_eq!(0, g2.get_functions().unwrap().len());
2502 g2.copy_function(&f, None).unwrap();
2503 assert_eq!(1, g2.num_functions());
2504 assert_eq!(1, g2.get_functions().unwrap().len());
2505 }
2506
2507 #[test]
2514 #[allow(trivial_casts)] fn operation_attributes() {
2516 let mut g = Graph::new();
2517
2518 let shape = Shape(Some(vec![None, Some(3)]));
2519 let variable_op = {
2520 let mut nd = g.new_operation("Variable", "Variable").unwrap();
2521 nd.set_attr_type("dtype", DataType::Int32).unwrap();
2522 nd.set_attr_shape("shape", &shape).unwrap();
2523 nd.set_attr_string("shared_name", "bar").unwrap();
2524 nd.finish().unwrap()
2525 };
2526 assert_eq!("bar", variable_op.get_attr_string("shared_name").unwrap());
2527 assert_eq!(DataType::Int32, variable_op.get_attr_type("dtype").unwrap());
2528 assert_eq!(shape, variable_op.get_attr_shape("shape").unwrap());
2529
2530 let op = {
2531 let mut nd = g
2532 .new_operation("Variable", "Variable_unknown_rank")
2533 .unwrap();
2534 nd.set_attr_type("dtype", DataType::Int32).unwrap();
2535 nd.set_attr_shape("shape", &Shape(None)).unwrap();
2536 nd.finish().unwrap()
2537 };
2538 assert_eq!(Shape(None), op.get_attr_shape("shape").unwrap());
2539
2540 let value = Tensor::<i32>::new(&[1, 3]).with_values(&[1, 2, 3]).unwrap();
2541 let const_op = {
2542 let mut nd = g.new_operation("Const", "Const").unwrap();
2543 nd.set_attr_tensor("value", value.clone()).unwrap();
2544 nd.set_attr_type("dtype", DataType::Int32).unwrap();
2545 nd.finish().unwrap()
2546 };
2547 assert_eq!(value, const_op.get_attr_tensor("value").unwrap());
2548
2549 let op = {
2550 let mut nd = g.new_operation("Assign", "Assign").unwrap();
2551 nd.add_input(variable_op.clone());
2552 nd.add_input(variable_op.clone());
2553 nd.set_attr_bool("validate_shape", true).unwrap();
2554 nd.set_attr_bool("use_locking", false).unwrap();
2555 nd.finish().unwrap()
2556 };
2557 assert_eq!(true, op.get_attr_bool("validate_shape").unwrap());
2558 assert_eq!(false, op.get_attr_bool("use_locking").unwrap());
2559
2560 let op = {
2561 let variable_op = {
2562 let mut nd = g.new_operation("Variable", "MaxPool_in1").unwrap();
2563 nd.set_attr_type("dtype", DataType::Int32).unwrap();
2564 nd.set_attr_shape(
2565 "shape",
2566 &Shape(Some(vec![Some(5), Some(5), Some(5), Some(5)])),
2567 )
2568 .unwrap();
2569 nd.finish().unwrap()
2570 };
2571 let mut nd = g.new_operation("MaxPool", "MaxPool").unwrap();
2572 nd.add_input(variable_op);
2573 nd.set_attr_int_list("ksize", &[1, 2, 3, 4]).unwrap();
2574 nd.set_attr_int_list("strides", &[1, 1, 1, 1]).unwrap();
2575 nd.set_attr_string("padding", "VALID").unwrap();
2576 nd.finish().unwrap()
2577 };
2578 assert_eq!(
2579 &[1, 2, 3, 4],
2580 &op.get_attr_int_list("ksize").unwrap() as &[i64]
2581 );
2582
2583 let op = {
2584 let mut nd = g.new_operation("TensorSummary", "TensorSummary").unwrap();
2585 nd.add_input(variable_op.clone());
2586 nd.set_attr_string_list("labels", &["foo", "bar"]).unwrap();
2587 nd.finish().unwrap()
2588 };
2589 assert_eq!(
2590 &["foo".to_string(), "bar".to_string()],
2591 &op.get_attr_string_list("labels").unwrap() as &[_]
2592 );
2593
2594 let op = {
2595 let mut nd = g
2596 .new_operation("ApproximateEqual", "ApproximateEqual")
2597 .unwrap();
2598 nd.add_input(variable_op.clone());
2599 nd.add_input(variable_op.clone());
2600 nd.set_attr_float("tolerance", 3.14).unwrap();
2601 nd.finish().unwrap()
2602 };
2603 assert_eq!(3.14, op.get_attr_float("tolerance").unwrap());
2604
2605 let op = {
2606 let mut nd = g.new_operation("Bucketize", "Bucketize").unwrap();
2607 nd.add_input(variable_op.clone());
2608 nd.set_attr_float_list("boundaries", &[0.1, 2.3]).unwrap();
2609 nd.finish().unwrap()
2610 };
2611 assert_eq!(
2612 &[0.1f32, 2.3],
2613 &op.get_attr_float_list("boundaries").unwrap() as &[_]
2614 );
2615
2616 let shape_list = &[
2617 Shape(None),
2618 Shape(Some(vec![])),
2619 Shape(Some(vec![None])),
2620 Shape(Some(vec![Some(1)])),
2621 ];
2622 let op = {
2623 let mut nd = g
2624 .new_operation("RandomShuffleQueue", "RandomShuffleQueue")
2625 .unwrap();
2626 nd.set_attr_shape_list("shapes", shape_list).unwrap();
2627 nd.set_attr_type_list("component_types", &[DataType::Float, DataType::Int32])
2628 .unwrap();
2629 nd.set_attr_int("seed", 42).unwrap();
2630 nd.finish().unwrap()
2631 };
2632 assert_eq!(
2633 shape_list,
2634 &op.get_attr_shape_list("shapes").unwrap() as &[_]
2635 );
2636 assert_eq!(
2637 &[DataType::Float, DataType::Int32],
2638 &op.get_attr_type_list("component_types").unwrap() as &[_]
2639 );
2640 assert_eq!(42, op.get_attr_int("seed").unwrap());
2641
2642 }
2654
2655 fn graph_def() -> Vec<u8> {
2657 let mut g = Graph::new();
2658 let a = {
2659 let mut nd = g.new_operation("Variable", "a").unwrap();
2660 nd.set_attr_type("dtype", DataType::Int32).unwrap();
2661 nd.set_attr_shape("shape", &Shape(None)).unwrap();
2662 nd.finish().unwrap()
2663 };
2664 let b = {
2665 let mut nd = g.new_operation("Variable", "b").unwrap();
2666 nd.set_attr_type("dtype", DataType::Int32).unwrap();
2667 nd.set_attr_shape("shape", &Shape(None)).unwrap();
2668 nd.finish().unwrap()
2669 };
2670 multiply(&mut g, a, b, "a_times_b").unwrap();
2671 g.graph_def().unwrap()
2672 }
2673
2674 #[test]
2675 fn import_graph_def_uniquify_names() {
2676 let mut g = Graph::new();
2677 let mut opts = ImportGraphDefOptions::new();
2678 g.import_graph_def(&graph_def(), &opts).unwrap();
2679 opts.set_uniquify_names(true);
2680 g.import_graph_def(&graph_def(), &opts).unwrap();
2681 g.operation_by_name_required("a_1").unwrap();
2682 }
2683
2684 #[test]
2685 fn import_graph_def_uniquify_prefix() {
2686 let mut g = Graph::new();
2687 let mut opts = ImportGraphDefOptions::new();
2688 opts.set_prefix("prefix").unwrap();
2689 g.import_graph_def(&graph_def(), &opts).unwrap();
2690 opts.set_uniquify_prefix(true);
2691 g.import_graph_def(&graph_def(), &opts).unwrap();
2692 g.operation_by_name_required("prefix_1/a").unwrap();
2693 }
2694
2695 #[test]
2696 fn import_graph_def_set_default_device() {
2697 let mut g = Graph::new();
2698 let mut opts = ImportGraphDefOptions::new();
2699 opts.set_default_device("fake_device").unwrap();
2700 g.import_graph_def(&graph_def(), &opts).unwrap();
2701 assert_eq!(
2702 g.operation_by_name_required("a").unwrap().device().unwrap(),
2703 "fake_device"
2704 );
2705 }
2706
2707 #[test]
2708 fn import_graph_def_results_return_outputs() {
2709 let mut g = Graph::new();
2710 let mut opts = ImportGraphDefOptions::new();
2711 assert_eq!(opts.num_return_outputs(), 0);
2712 opts.add_return_output("a_times_b", 0).unwrap();
2713 assert_eq!(opts.num_return_outputs(), 1);
2714 let result = g
2715 .import_graph_def_with_results(&graph_def(), &opts)
2716 .unwrap();
2717 let ops = result.return_outputs();
2718 assert_eq!(ops.len(), 1);
2719 assert_eq!(ops[0].operation.name().unwrap(), "a_times_b");
2720 assert_eq!(ops[0].index, 0);
2721 }
2722
2723 #[test]
2724 fn import_graph_def_results_return_operations() {
2725 let mut g = Graph::new();
2726 let mut opts = ImportGraphDefOptions::new();
2727 assert_eq!(opts.num_return_operations(), 0);
2728 opts.add_return_operation("a_times_b").unwrap();
2729 assert_eq!(opts.num_return_operations(), 1);
2730 let result = g
2731 .import_graph_def_with_results(&graph_def(), &opts)
2732 .unwrap();
2733 let ops = result.return_operations();
2734 assert_eq!(ops.len(), 1);
2735 assert_eq!(ops[0].name().unwrap(), "a_times_b");
2736 }
2737
2738 #[test]
2739 fn import_graph_def_results_missing_unused_input_mappings() {
2740 let mut g = Graph::new();
2741 let op = {
2742 let mut nd = g.new_operation("Variable", "foo").unwrap();
2743 nd.set_attr_type("dtype", DataType::Int32).unwrap();
2744 nd.set_attr_shape("shape", &Shape(None)).unwrap();
2745 nd.finish().unwrap()
2746 };
2747 let output = op.into();
2748 let mut opts = ImportGraphDefOptions::new();
2749 opts.add_input_mapping("bar", 3, &output).unwrap();
2750 let result = g.import_graph_def_with_results(&[], &opts).unwrap();
2752 let missing = result.missing_unused_input_mappings().unwrap();
2753 assert_eq!(missing.len(), 1);
2754 assert_eq!(missing[0].0, "bar");
2755 assert_eq!(missing[0].1, 3);
2756 }
2757
2758 #[test]
2759 fn import_graph_def_with_return_outputs() {
2760 let mut g = Graph::new();
2761 let mut opts = ImportGraphDefOptions::new();
2762 assert_eq!(opts.num_return_outputs(), 0);
2763 opts.add_return_output("a_times_b", 0).unwrap();
2764 assert_eq!(opts.num_return_outputs(), 1);
2765 let ops = g
2766 .import_graph_def_with_return_outputs(&graph_def(), &opts)
2767 .unwrap();
2768 assert_eq!(ops.len(), 1);
2769 assert_eq!(ops[0].operation.name().unwrap(), "a_times_b");
2770 assert_eq!(ops[0].index, 0);
2771 }
2772
2773 #[test]
2774 fn graph_get_op_def() {
2775 let g = Graph::new();
2776 assert!(g.get_op_def("Const").unwrap().len() > 0);
2778 }
2779
2780 #[test]
2781 fn graph_versions() {
2782 let g = Graph::new();
2783 assert!(g.versions().unwrap().len() > 0);
2785 }
2786
2787 #[test]
2788 fn graph_generate_operation_name() {
2789 let mut g = Graph::new();
2790 for i in 0..5 {
2791 assert_eq!(i, g.generate_operation_name("foo_{}").unwrap());
2792 let mut nd = g
2793 .new_operation("Placeholder", &format!("foo_{}", i))
2794 .unwrap();
2795 nd.set_attr_type("dtype", DataType::Float).unwrap();
2796 nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2797 nd.finish().unwrap();
2798 }
2799 }
2800
2801 #[test]
2802 fn graph_add_gradients() {
2803 for (prefix, expected_prefix) in &[
2805 (Some("arbitrary_prefix"), "arbitrary_prefix/"),
2806 (None, "gradients/"),
2807 ] {
2808 let mut g = Graph::new();
2809 let x = {
2810 let mut nd = g.new_operation("Placeholder", "x").unwrap();
2811 nd.set_attr_type("dtype", DataType::Float).unwrap();
2812 nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2813 nd.finish().unwrap()
2814 };
2815 let y = {
2816 let mut nd = g.new_operation("Placeholder", "y").unwrap();
2817 nd.set_attr_type("dtype", DataType::Float).unwrap();
2818 nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2819 nd.finish().unwrap()
2820 };
2821 let x_squared = multiply(&mut g, x.clone(), x.clone(), "x_squared").unwrap();
2822 let x_times_y = multiply(&mut g, x.clone(), y.clone(), "x_times_y").unwrap();
2823 let x_plus_y = add(&mut g, x.clone(), y.clone(), "x_plus_y").unwrap();
2824 let y_outs = vec![x_squared.into(), x_times_y.into(), x_plus_y.into()];
2826 let x_outs = vec![x.into(), y.into()];
2827 let dy = g.add_gradients(*prefix, &y_outs, &x_outs, None).unwrap();
2828 assert_eq!(dy.len(), 2);
2829 for d in dy {
2830 let d = d.unwrap();
2831 assert_eq!(d.index, 0);
2832 let name = d.operation.name().unwrap();
2833 assert!(
2834 name.starts_with(expected_prefix),
2835 "name = {}, expected prefix = {}",
2836 name,
2837 expected_prefix
2838 );
2839 }
2840 }
2841 }
2842
2843 #[test]
2844 fn graph_add_gradients_stopped_gradient() {
2845 for prefix in &[Some("arbitrary_prefix"), None] {
2847 let mut g = Graph::new();
2848 let zero = {
2849 let mut nd = g.new_operation("Const", "zero").unwrap();
2850 nd.set_attr_type("dtype", DataType::Int32).unwrap();
2851 nd.set_attr_tensor("value", Tensor::<i32>::from(0)).unwrap();
2852 nd.finish().unwrap()
2853 };
2854 let x = {
2855 let mut nd = g.new_operation("Placeholder", "x").unwrap();
2856 nd.set_attr_type("dtype", DataType::Float).unwrap();
2857 nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2858 nd.finish().unwrap()
2859 };
2860 let argmax_x = {
2861 let mut nd = g.new_operation("ArgMax", "argmax_x").unwrap();
2862 nd.add_input(x.clone());
2863 nd.add_input(zero);
2864 nd.finish().unwrap()
2865 };
2866 let stopped_gradient = {
2867 let mut nd = g.new_operation("StopGradient", "stopped").unwrap();
2868 nd.add_input(argmax_x.clone());
2869 nd.finish().unwrap()
2870 };
2871 let y_outs = vec![stopped_gradient.into()];
2872 let x_outs = vec![x.into()];
2873 let dy = g.add_gradients(*prefix, &y_outs, &x_outs, None).unwrap();
2874 assert_eq!(dy.len(), 1);
2875 for d in &dy {
2876 assert!(d.is_none());
2877 }
2878 }
2879 }
2880
2881 #[test]
2882 fn graph_add_gradients_no_gradient() {
2883 for prefix in &[Some("arbitrary_prefix"), None] {
2885 let mut g = Graph::new();
2886 let zero = {
2887 let mut nd = g.new_operation("Const", "zero").unwrap();
2888 nd.set_attr_type("dtype", DataType::Int32).unwrap();
2889 nd.set_attr_tensor("value", Tensor::<i32>::from(0)).unwrap();
2890 nd.finish().unwrap()
2891 };
2892 let x = {
2893 let mut nd = g.new_operation("Placeholder", "x").unwrap();
2894 nd.set_attr_type("dtype", DataType::Float).unwrap();
2895 nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2896 nd.finish().unwrap()
2897 };
2898 let argmax_x = {
2899 let mut nd = g.new_operation("ArgMax", "argmax_x").unwrap();
2900 nd.add_input(x.clone());
2901 nd.add_input(zero);
2902 nd.finish().unwrap()
2903 };
2904 let y_outs = vec![argmax_x.into()];
2905 let x_outs = vec![x.into()];
2906 assert!(g.add_gradients(*prefix, &y_outs, &x_outs, None).is_err());
2907 }
2908 }
2909
2910 #[test]
2911 fn output_consumers() {
2912 let mut graph = Graph::new();
2913 let x_op = {
2914 let mut nd = graph.new_operation("Placeholder", "x").unwrap();
2915 nd.set_attr_type("dtype", DataType::String).unwrap();
2916 nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2917 nd.finish().unwrap()
2918 };
2919 let _y_op = {
2920 let mut nd = graph.new_operation("EncodeBase64", "y").unwrap();
2921 nd.add_input(x_op.clone());
2922 nd.finish().unwrap()
2923 };
2924 assert_eq!(x_op.num_outputs(), 1);
2925 let consumers = x_op.output_consumers(0);
2926 assert_eq!(consumers.len(), 1);
2927 assert_eq!(consumers[0].0.name().unwrap(), "y");
2928 assert_eq!(consumers[0].1, 0);
2929 }
2930
2931 #[test]
2932 fn output_name() {
2933 assert_eq!(
2934 "foo:1".parse::<OutputName>().unwrap(),
2935 OutputName {
2936 name: "foo".to_string(),
2937 index: 1
2938 }
2939 );
2940 assert_eq!(
2941 OutputName {
2942 name: "foo".to_string(),
2943 index: 1
2944 }
2945 .to_string(),
2946 "foo:1"
2947 );
2948 assert_eq!(
2949 "foo".parse::<OutputName>().unwrap(),
2950 OutputName {
2951 name: "foo".to_string(),
2952 index: 0
2953 }
2954 );
2955 assert!("foo:bar".parse::<OutputName>().is_err());
2956 assert!("foo:0:1".parse::<OutputName>().is_err());
2957 }
2958
2959 #[test]
2960 fn device() {
2961 let mut graph = Graph::new();
2962 let op = {
2963 let mut nd = graph.new_operation("NoOp", "x").unwrap();
2964 nd.set_device("foo").unwrap();
2965 nd.finish().unwrap()
2966 };
2967 assert_eq!(op.device().unwrap(), "foo");
2968 }
2969
2970 #[test]
2971 fn control_inputs() {
2972 let mut graph = Graph::new();
2973 let x = graph.new_operation("NoOp", "x").unwrap().finish().unwrap();
2974 let y = {
2975 let mut nd = graph.new_operation("NoOp", "y").unwrap();
2976 nd.add_control_input(&x);
2977 nd.finish().unwrap()
2978 };
2979 assert_eq!(
2980 y.control_inputs()
2981 .iter()
2982 .map(|n| n.name().unwrap())
2983 .collect::<Vec<_>>(),
2984 &["x"]
2985 );
2986 }
2987}