tensorflow/
graph.rs

1use super::buffer::Buffer;
2use super::AnyTensor;
3use super::Code;
4use super::DataType;
5use super::Result;
6use super::Shape;
7use super::Status;
8use super::Tensor;
9use super::TensorType;
10use libc::c_char;
11use libc::c_float;
12use libc::c_int;
13use libc::c_uchar;
14use libc::c_uint;
15use libc::c_void;
16use libc::size_t;
17use std::cmp;
18use std::ffi::CStr;
19use std::ffi::CString;
20use std::ffi::NulError;
21use std::fmt;
22use std::fmt::Display;
23use std::fmt::Formatter;
24use std::mem::MaybeUninit;
25use std::os::raw::c_void as std_c_void;
26use std::ptr;
27use std::slice;
28use std::str::FromStr;
29use std::str::Utf8Error;
30use std::sync::Arc;
31#[cfg(feature = "default")]
32use tensorflow_sys as tf;
33#[cfg(feature = "tensorflow_runtime_linking")]
34use tensorflow_sys_runtime as tf;
35
36#[derive(Debug)]
37struct GraphImpl {
38    inner: *mut tf::TF_Graph,
39    owned: bool,
40}
41
42unsafe impl Send for GraphImpl {}
43unsafe impl Sync for GraphImpl {}
44
45impl Drop for GraphImpl {
46    /// Graph will be deleted once no more Sessions are referencing it.
47    fn drop(&mut self) {
48        if self.owned {
49            unsafe {
50                tf::TF_DeleteGraph(self.inner);
51            }
52        }
53    }
54}
55
56////////////////////////
57
58/// `ImportGraphDefOptions` holds options that can be passed to
59/// `Graph::import_graph_def`.
60#[derive(Debug)]
61pub struct ImportGraphDefOptions {
62    inner: *mut tf::TF_ImportGraphDefOptions,
63}
64
65impl_new!(
66    ImportGraphDefOptions,
67    TF_NewImportGraphDefOptions,
68    "Creates a default ImportGraphDefOptions."
69);
70impl_drop!(ImportGraphDefOptions, TF_DeleteImportGraphDefOptions);
71
72impl ImportGraphDefOptions {
73    /// Set the prefix to be prepended to the names of nodes in `graph_def` that will
74    /// be imported into `graph`.
75    pub fn set_prefix(&mut self, prefix: &str) -> std::result::Result<(), NulError> {
76        let s = CString::new(prefix)?;
77        unsafe {
78            tf::TF_ImportGraphDefOptionsSetPrefix(self.inner, s.as_ptr());
79        }
80        Ok(())
81    }
82
83    /// Set any imported nodes with input `src_name:src_index` to have that input
84    /// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
85    /// `dst` references a node already existing in the graph being imported into.
86    pub fn add_input_mapping(
87        &mut self,
88        src_name: &str,
89        src_index: usize,
90        dst: &Output,
91    ) -> std::result::Result<(), NulError> {
92        let s = CString::new(src_name)?;
93        unsafe {
94            tf::TF_ImportGraphDefOptionsAddInputMapping(
95                self.inner,
96                s.as_ptr(),
97                src_index as c_int,
98                dst.to_c(),
99            );
100        }
101        Ok(())
102    }
103
104    /// Set any imported nodes with control input `src_name` to have that input
105    /// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
106    /// `dst` references an operation already existing in the graph being imported
107    /// into.
108    pub fn remap_control_dependency(
109        &mut self,
110        src_name: &str,
111        dst: &Operation,
112    ) -> std::result::Result<(), NulError> {
113        let s = CString::new(src_name)?;
114        unsafe {
115            tf::TF_ImportGraphDefOptionsRemapControlDependency(self.inner, s.as_ptr(), dst.inner);
116        }
117        Ok(())
118    }
119
120    /// Cause the imported graph to have a control dependency on `oper`. `oper`
121    /// should exist in the graph being imported into.
122    pub fn add_control_dependency(&mut self, oper: &Operation) {
123        unsafe {
124            tf::TF_ImportGraphDefOptionsAddControlDependency(self.inner, oper.inner);
125        }
126    }
127
128    /// Add an output in `graph_def` to be returned via the `return_outputs` output
129    /// parameter of `import_graph_def()`. If the output is remapped via an input
130    /// mapping, the corresponding existing tensor in `graph` will be returned.
131    pub fn add_return_output(
132        &mut self,
133        oper_name: &str,
134        index: usize,
135    ) -> std::result::Result<(), NulError> {
136        let s = CString::new(oper_name)?;
137        unsafe {
138            tf::TF_ImportGraphDefOptionsAddReturnOutput(self.inner, s.as_ptr(), index as c_int);
139        }
140        Ok(())
141    }
142
143    /// Add an operation in `graph_def` to be returned via the `return_opers` output
144    /// parameter of import_graph_def().
145    pub fn add_return_operation(&mut self, oper_name: &str) -> std::result::Result<(), NulError> {
146        let s = CString::new(oper_name)?;
147        unsafe {
148            tf::TF_ImportGraphDefOptionsAddReturnOperation(self.inner, s.as_ptr());
149        }
150        Ok(())
151    }
152
153    /// Returns the number of return outputs added via `add_return_output()`.
154    pub fn num_return_outputs(&self) -> usize {
155        unsafe { tf::TF_ImportGraphDefOptionsNumReturnOutputs(self.inner) as usize }
156    }
157
158    /// Returns the number of return operations added via `add_return_operation()`.
159    pub fn num_return_operations(&self) -> usize {
160        unsafe { tf::TF_ImportGraphDefOptionsNumReturnOperations(self.inner) as usize }
161    }
162
163    /// Set whether to uniquify imported operation names. If true, imported operation
164    /// names will be modified if their name already exists in the graph. If false,
165    /// conflicting names will be treated as an error. Note that this option has no
166    /// effect if a prefix is set, since the prefix will guarantee all names are
167    /// unique. Defaults to false.
168    pub fn set_uniquify_names(&mut self, uniquify_names: bool) {
169        unsafe {
170            tf::TF_ImportGraphDefOptionsSetUniquifyNames(self.inner, u8::from(uniquify_names));
171        }
172    }
173
174    /// If true, the specified prefix will be modified if it already exists as an
175    /// operation name or prefix in the graph. If false, a conflicting prefix will be
176    /// treated as an error. This option has no effect if no prefix is specified.
177    pub fn set_uniquify_prefix(&mut self, uniquify_prefix: bool) {
178        unsafe {
179            tf::TF_ImportGraphDefOptionsSetUniquifyPrefix(self.inner, u8::from(uniquify_prefix));
180        }
181    }
182
183    /// Set the execution device for nodes.
184    /// Only applies to nodes where a device was not already explicitly specified.
185    pub fn set_default_device(&mut self, device: &str) -> std::result::Result<(), NulError> {
186        let s = CString::new(device)?;
187        unsafe {
188            tf::TF_ImportGraphDefOptionsSetDefaultDevice(self.inner, s.as_ptr());
189        }
190        Ok(())
191    }
192}
193
194////////////////////////
195
196/// ImportGraphDefResults holds results that are generated by
197/// Graph::import_graph_def_with_results().
198#[derive(Debug)]
199pub struct ImportGraphDefResults {
200    inner: *mut tf::TF_ImportGraphDefResults,
201    gimpl: Arc<GraphImpl>,
202}
203
204impl ImportGraphDefResults {
205    /// Fetches the return outputs requested via ImportGraphDefOptions::add_return_output().
206    pub fn return_outputs(&self) -> Vec<Output> {
207        unsafe {
208            let mut num_outputs: c_int = 0;
209            let mut c_outputs: *mut tf::TF_Output = ptr::null_mut();
210            tf::TF_ImportGraphDefResultsReturnOutputs(self.inner, &mut num_outputs, &mut c_outputs);
211            slice::from_raw_parts(c_outputs, num_outputs as usize)
212                .iter()
213                .map(|output| Output {
214                    operation: Operation {
215                        inner: output.oper,
216                        gimpl: self.gimpl.clone(),
217                    },
218                    index: output.index,
219                })
220                .collect()
221        }
222    }
223
224    /// Fetches the return operations requested via ImportGraphDefOptions::add_return_operation().
225    pub fn return_operations(&self) -> Vec<Operation> {
226        unsafe {
227            let mut num_operations: c_int = 0;
228            let mut c_operations: *mut *mut tf::TF_Operation = ptr::null_mut();
229            tf::TF_ImportGraphDefResultsReturnOperations(
230                self.inner,
231                &mut num_operations,
232                &mut c_operations,
233            );
234            slice::from_raw_parts(c_operations, num_operations as usize)
235                .iter()
236                .map(|operation| Operation {
237                    inner: *operation,
238                    gimpl: self.gimpl.clone(),
239                })
240                .collect()
241        }
242    }
243
244    /// Fetches any input mappings requested via
245    /// ImportGraphDefOptions::add_input_mapping() that didn't appear in the GraphDef
246    /// and weren't used as input to any node in the imported graph def.
247    pub fn missing_unused_input_mappings(
248        &self,
249    ) -> std::result::Result<Vec<(&str, c_int)>, Utf8Error> {
250        unsafe {
251            let mut n: c_int = 0;
252            let mut c_src_names: *mut *const c_char = ptr::null_mut();
253            let mut src_indexes: *mut c_int = ptr::null_mut();
254            tf::TF_ImportGraphDefResultsMissingUnusedInputMappings(
255                self.inner,
256                &mut n,
257                &mut c_src_names,
258                &mut src_indexes,
259            );
260            let c_name_slice = slice::from_raw_parts(c_src_names, n as usize);
261            let index_slice = slice::from_raw_parts(src_indexes, n as usize);
262            let mut v = Vec::new();
263            for i in 0..n as usize {
264                let s = CStr::from_ptr(c_name_slice[i]).to_str()?;
265                v.push((s, index_slice[i]));
266            }
267            Ok(v)
268        }
269    }
270}
271
272impl_drop!(ImportGraphDefResults, TF_DeleteImportGraphDefResults);
273
274////////////////////////
275
276/// Represents a computation graph.  Graphs may be shared between sessions.
277/// Graphs are thread-safe when used as directed.
278#[derive(Debug)]
279pub struct Graph {
280    gimpl: Arc<GraphImpl>,
281}
282
283impl Default for Graph {
284    fn default() -> Self {
285        Self::new()
286    }
287}
288
289impl Graph {
290    /// Creates a new graph.
291    pub fn new() -> Graph {
292        unsafe {
293            Graph {
294                gimpl: Arc::new(GraphImpl {
295                    inner: tf::TF_NewGraph(),
296                    owned: true,
297                }),
298            }
299        }
300    }
301
302    /// Operation will only be added to graph when finish_operation() is called
303    /// (assuming finish_operation() does not return an error).  graph must
304    /// not be deleted until after finish_operation() is called.
305    pub fn new_operation(
306        &mut self,
307        op_type: &str,
308        operation_name: &str,
309    ) -> std::result::Result<OperationDescription<'_>, NulError> {
310        let c_op_type = CString::new(op_type)?;
311        let c_operation_name = CString::new(operation_name)?;
312        unsafe {
313            Ok(OperationDescription {
314                inner: tf::TF_NewOperation(
315                    self.gimpl.inner,
316                    c_op_type.as_ptr(),
317                    c_operation_name.as_ptr(),
318                ),
319                graph: self,
320                finished: false,
321            })
322        }
323    }
324
325    /// Returns the operation in the graph with the given name, if it exists.
326    /// If the operation does not exist, returns `Ok(None)`.
327    pub fn operation_by_name(
328        &self,
329        operation_name: &str,
330    ) -> std::result::Result<Option<Operation>, NulError> {
331        let c_operation_name = CString::new(operation_name)?;
332        unsafe {
333            let operation =
334                tf::TF_GraphOperationByName(self.gimpl.inner, c_operation_name.as_ptr());
335            if operation.is_null() {
336                Ok(None)
337            } else {
338                Ok(Some(Operation {
339                    inner: operation,
340                    gimpl: self.gimpl.clone(),
341                }))
342            }
343        }
344    }
345
346    /// Like `operation_by_name`, except that failure to find the operation is considered an error.
347    pub fn operation_by_name_required(
348        &self,
349        operation_name: &str,
350    ) -> std::result::Result<Operation, Status> {
351        match self.operation_by_name(operation_name)? {
352            Some(operation) => Ok(operation),
353            None => Err(Status::new_set(
354                Code::Unavailable,
355                &format!("Operation {:?} not found", operation_name),
356            )
357            .unwrap()),
358        }
359    }
360
361    /// Finds a unique operation name.  The pattern must contain exactly one
362    /// '{}' placeholder to indicate where a unique ID can be inserted, e.g.
363    /// 'Add_{}' or 'while_loop_{}/Merge', and the function returns an integer
364    /// which, when inserted into the placeholder, yields an operation name
365    /// which does not appear in the graph.
366    pub(crate) fn generate_operation_name(&self, operation_name_pattern: &str) -> Result<i64> {
367        let parts: Vec<_> = operation_name_pattern.split("{}").collect();
368        if parts.len() != 2 {
369            return Err(invalid_arg!(
370                "operation_name_pattern must contain placeholder"
371            ));
372        }
373        // Can't use format! because its argument must be a string literal.
374        let mut i = 0;
375        loop {
376            let name = format!("{}{}{}", parts[0], i, parts[1]);
377            let c_name = CString::new(name)?;
378            unsafe {
379                if tf::TF_GraphOperationByName(self.gimpl.inner, c_name.as_ptr()).is_null() {
380                    return Ok(i);
381                }
382            }
383            i += 1;
384        }
385    }
386
387    /// Iterates over the operations in the graph.
388    pub fn operation_iter(&self) -> OperationIter<'_> {
389        OperationIter {
390            graph: self,
391            pos: 0,
392        }
393    }
394
395    /// Returns the graph definition as a protobuf.
396    pub fn graph_def(&self) -> Result<Vec<u8>> {
397        let mut status = Status::new();
398        unsafe {
399            let c_buffer = tf::TF_NewBuffer();
400            tf::TF_GraphToGraphDef(self.gimpl.inner, c_buffer, status.inner());
401            if status.is_ok() {
402                Ok(Buffer::from_c(c_buffer, true).into())
403            } else {
404                tf::TF_DeleteBuffer(c_buffer);
405                Err(status)
406            }
407        }
408    }
409
410    /// Returns the number of dimensions of the Tensor referenced by `output`.
411    ///
412    /// If the number of dimensions in the shape is unknown, returns -1.
413    ///
414    /// Returns an error if:
415    ///
416    ///   * `output` is not in `graph`.
417    pub fn num_dims<I: Into<Output>>(&self, output: I) -> Result<c_int> {
418        let mut status = Status::new();
419        unsafe {
420            let val = tf::TF_GraphGetTensorNumDims(
421                self.gimpl.inner,
422                output.into().to_c(),
423                status.inner(),
424            );
425            if status.is_ok() {
426                Ok(val)
427            } else {
428                Err(status)
429            }
430        }
431    }
432
433    /// Returns the shape of the Tensor referenced by `output`.
434    ///
435    /// Returns an error if:
436    ///
437    ///   * `output` is not in `graph`.
438    pub fn tensor_shape<I: Into<Output>>(&self, output: I) -> Result<Shape> {
439        let mut status = Status::new();
440        let output = output.into();
441        let n = self.num_dims(output.clone())?;
442        if n == -1 {
443            return Ok(Shape(None));
444        }
445        let mut dims = Vec::with_capacity(n as usize);
446        unsafe {
447            tf::TF_GraphGetTensorShape(
448                self.gimpl.inner,
449                output.to_c(),
450                dims.as_mut_ptr(),
451                n,
452                status.inner(),
453            );
454            if status.is_ok() {
455                dims.set_len(n as usize);
456                Ok(Shape(Some(
457                    dims.iter()
458                        .map(|x| if *x < 0 { None } else { Some(*x) })
459                        .collect(),
460                )))
461            } else {
462                Err(status)
463            }
464        }
465    }
466
467    /// Import the graph serialized in `graph_def`.
468    pub fn import_graph_def(
469        &mut self,
470        graph_def: &[u8],
471        options: &ImportGraphDefOptions,
472    ) -> Result<()> {
473        let buf = Buffer::from(graph_def);
474        let mut status = Status::new();
475        unsafe {
476            tf::TF_GraphImportGraphDef(
477                self.gimpl.inner,
478                buf.inner(),
479                options.inner,
480                status.inner(),
481            );
482            status.into_result()
483        }
484    }
485
486    /// Import the graph serialized in `graph_def`.
487    pub fn import_graph_def_with_results(
488        &mut self,
489        graph_def: &[u8],
490        options: &ImportGraphDefOptions,
491    ) -> Result<ImportGraphDefResults> {
492        let buf = Buffer::from(graph_def);
493        let mut status = Status::new();
494        unsafe {
495            let result = tf::TF_GraphImportGraphDefWithResults(
496                self.gimpl.inner,
497                buf.inner(),
498                options.inner,
499                status.inner(),
500            );
501            status.into_result().map(|()| ImportGraphDefResults {
502                inner: result,
503                gimpl: self.gimpl.clone(),
504            })
505        }
506    }
507
508    /// Import the graph serialized in `graph_def`.
509    pub fn import_graph_def_with_return_outputs(
510        &mut self,
511        graph_def: &[u8],
512        options: &ImportGraphDefOptions,
513    ) -> Result<Vec<Output>> {
514        let buf = Buffer::from(graph_def);
515        let mut status = Status::new();
516        let n = options.num_return_outputs();
517        let mut c_return_outputs: Vec<MaybeUninit<tf::TF_Output>> = Vec::with_capacity(n);
518        unsafe {
519            c_return_outputs.set_len(n);
520            tf::TF_GraphImportGraphDefWithReturnOutputs(
521                self.gimpl.inner,
522                buf.inner(),
523                options.inner,
524                c_return_outputs.as_mut_ptr() as *mut tf::TF_Output,
525                n as c_int,
526                status.inner(),
527            );
528            status.into_result()?;
529            Ok(c_return_outputs
530                .iter()
531                .map(|x| Output::from_c(self, &x.assume_init()))
532                .collect())
533        }
534    }
535
536    /// Adds a copy of function `func` and optionally its gradient function
537    /// `grad` to the graph. Once `func`/`grad` is added to the graph, it can be
538    /// called by creating an operation using the function's name. Any changes
539    /// to `func`/`grad` (including deleting it) done after this method returns,
540    /// won't affect the copy of `func`/`grad` in the graph. If `func` or `grad`
541    /// are already in the graph, `copy_function` has no effect on them, but can
542    /// establish the function->gradient relationship between them if `func`
543    /// does not already have a gradient. If `func` already has a gradient
544    /// different from `grad`, an error is returned.
545    ///
546    /// If `grad` is None and `func` is not in the graph, `func` is added
547    /// without a gradient. If `grad` is None and `func` is in the graph,
548    /// `copy_function` is a noop. `grad` must have appropriate signature as
549    /// described in the doc of GradientDef in
550    /// tensorflow/core/framework/function.proto.
551    ///
552    /// If successful, returns () and `func` and `grad` are added to the graph.
553    /// Otherwise, an error is returned and the graph is unmodified.
554    pub fn copy_function(&mut self, func: &Function, grad: Option<&Function>) -> Result<()> {
555        let mut status = Status::new();
556        unsafe {
557            tf::TF_GraphCopyFunction(
558                self.inner(),
559                func.inner,
560                match grad {
561                    None => ptr::null(),
562                    Some(g) => g.inner,
563                },
564                status.inner(),
565            );
566        }
567        status.into_result()
568    }
569
570    /// Create a `Function` from a `Graph`.
571    ///
572    /// # Arguments
573    ///
574    /// * `fn_name` - the name of the new `Function`. Should match the operation
575    ///   name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]*. If
576    ///   `append_hash_to_fn_name` is false, `fn_name` must be distinct from
577    ///   other function and operation names (at least those registered in
578    ///   graphs where this function will be used).
579    /// * `append_hash_to_fn_name` - If true, the actual name of the function
580    ///   will be `fn_name` appended with
581    ///   '_&lt;hash_of_this_function's_definition&gt;'. If false, the
582    ///   function's name will be `fn_name`.
583    /// * `opers` - Array of operations to become the body of the function or
584    ///   null.
585    ///   * If `None`, all the operations in the graph will become part of the
586    ///     function except operations referenced in `inputs`. These operations
587    ///     must have a single output (these operations are typically
588    ///     placeholders created for the sole purpose of representing an input.
589    ///     We can relax this constraint if there are compelling use cases).
590    ///   * If `Some`, all operations in it will become part of the function. In
591    ///     particular, no automatic skipping of dummy input operations is
592    ///     performed.
593    /// * `inputs` - array of `Output`s that specify the inputs to the function.
594    ///   The names used for function inputs are normalized names of the
595    ///   operations (usually placeholders) pointed to by `inputs`. These
596    ///   operation names should start with a letter. Normalization will convert
597    ///   all letters to lowercase and non-alphanumeric characters to '\_' to
598    ///   make resulting names match the "[a-z][a-z0-9_]*" pattern for operation
599    ///   argument names. `inputs` cannot contain the same tensor twice.
600    /// * `outputs` - array of `Output`s that specify the outputs of the
601    ///   function. `outputs` can contain the same tensor more than once.
602    /// * `output_names` - The names of the function's outputs. `output_names`
603    ///   array must either have the same length as `outputs` or be None. In the
604    ///   former case, the names should match the regular expression for ArgDef
605    ///   names - "[a-z][a-z0-9_]*". In the latter case, names for outputs will
606    ///   be generated automatically.
607    /// * `opts` - various options for the function, e.g. XLA's inlining control.
608    /// * `description` - optional human-readable description of this function.
609    ///
610    /// Note that when the same `Output` is listed as both an input and an
611    /// output, the corresponding function's output will equal to this input,
612    /// instead of the original node's output.
613    ///
614    /// Callers must also satisfy the following constraints:
615    ///
616    /// * `inputs` cannot refer to `Output`s within a control flow context. For
617    ///   example, one cannot use the output of "switch" node as input.
618    /// * `inputs` and `outputs` cannot have reference types. Reference types
619    ///   are not exposed through C API and are being replaced with Resources.
620    ///   We support reference types inside function's body to support legacy
621    ///   code. Do not use them in new code.
622    /// * Every node in the function's body must have all of its inputs
623    ///   (including control inputs). In other words, for every node in the
624    ///   body, each input must be either listed in `inputs` or must come from
625    ///   another node in the body. In particular, it is an error to have a
626    ///   control edge going from a node outside of the body into a node in the
627    ///   body. This applies to control edges going from nodes referenced in
628    ///   `inputs` to nodes in the body when the former nodes are not in the
629    ///   body (automatically skipped or not included in explicitly specified
630    ///   body).
631    ///
632    /// # Returns
633    ///
634    ///  A newly created `Function` instance.
635    pub fn to_function<S: AsRef<str>>(
636        &self,
637        fn_name: &str,
638        append_hash_to_fn_name: bool,
639        opers: Option<&[&Operation]>,
640        inputs: &[Output],
641        outputs: &[Output],
642        output_names: Option<&[S]>,
643        opts: &FunctionOptions,
644        description: Option<&str>,
645    ) -> Result<Function> {
646        let fn_name_cstr = CString::new(fn_name)?;
647        let num_opers: c_int = if let Some(ops) = &opers {
648            ops.len() as c_int
649        } else {
650            -1
651        };
652        #[allow(trivial_casts)]
653        let c_opers: Option<Vec<_>> =
654            opers.map(|s| s.iter().map(|op| op.inner as *const _).collect());
655        let c_opers_ptr: *const *const tf::TF_Operation = if let Some(ref ops) = &c_opers {
656            ops.as_ptr()
657        } else {
658            ptr::null()
659        };
660        let c_inputs: Vec<_> = inputs.iter().map(|x| x.to_c()).collect();
661        let c_outputs: Vec<_> = outputs.iter().map(|x| x.to_c()).collect();
662        let output_names_cstrs: Option<::std::result::Result<Vec<CString>, NulError>> =
663            output_names
664                .map(|slice: &[S]| slice.iter().map(|s: &S| CString::new(s.as_ref())).collect());
665        let output_names_cstrs: Option<Vec<CString>> = match output_names_cstrs {
666            None => None,
667            Some(r) => Some(r?),
668        };
669        let output_names_ptrs: Option<Vec<*const c_char>> = output_names_cstrs
670            .as_ref()
671            .map(|slice| slice.iter().map(|s| s.as_ptr()).collect());
672        let output_names_ptrs_ptr = match &output_names_ptrs {
673            None => ptr::null(),
674            Some(ref v) => v.as_ptr(),
675        };
676        let description_cstr = match description {
677            None => None,
678            Some(d) => Some(CString::new(d)?),
679        };
680        let description_ptr: *const c_char = if let Some(ref cstr) = &description_cstr {
681            cstr.as_ptr()
682        } else {
683            ptr::null()
684        };
685        let status = Status::new();
686        let f = unsafe {
687            tf::TF_GraphToFunction(
688                self.inner(),
689                fn_name_cstr.as_ptr(),
690                u8::from(append_hash_to_fn_name),
691                num_opers,
692                c_opers_ptr,
693                c_inputs.len() as c_int,
694                c_inputs.as_ptr(),
695                c_outputs.len() as c_int,
696                c_outputs.as_ptr(),
697                output_names_ptrs_ptr,
698                opts.inner,
699                description_ptr,
700                status.inner,
701            )
702        };
703        status.into_result()?;
704        Ok(Function { inner: f })
705    }
706
707    /// Returns the number of functions registered in the graph.
708    pub fn num_functions(&self) -> c_int {
709        unsafe { tf::TF_GraphNumFunctions(self.inner()) }
710    }
711
712    /// Returns functions registered in the graph.
713    pub fn get_functions(&self) -> Result<Vec<Function>> {
714        unsafe {
715            let num = tf::TF_GraphNumFunctions(self.inner());
716            let mut funcs = Vec::with_capacity(num as usize);
717            let status = Status::new();
718            let num = tf::TF_GraphGetFunctions(self.inner(), funcs.as_mut_ptr(), num, status.inner);
719            status.into_result()?;
720            funcs.set_len(num as usize);
721            Ok(funcs.iter().map(|f| Function { inner: *f }).collect())
722        }
723    }
724
725    /// Returns the serialized OpDef proto with name `op_name`, or a bad status if no
726    /// such op exists. This can return OpDefs of functions copied into the graph.
727    pub fn get_op_def(&self, op_name: &str) -> Result<Vec<u8>> {
728        let status = Status::new();
729        let c_op_name = CString::new(op_name)?;
730        unsafe {
731            let mut buffer = Buffer::new_unallocated();
732            tf::TF_GraphGetOpDef(
733                self.inner(),
734                c_op_name.as_ptr(),
735                buffer.inner_mut(),
736                status.inner,
737            );
738            status.into_result().map(|()| buffer.into())
739        }
740    }
741
742    /// Returns the serialized VersionDef proto for this graph.
743    pub fn versions(&self) -> Result<Vec<u8>> {
744        let status = Status::new();
745        unsafe {
746            let mut buffer = Buffer::new_unallocated();
747            tf::TF_GraphVersions(self.inner(), buffer.inner_mut(), status.inner);
748            status.into_result().map(|()| buffer.into())
749        }
750    }
751
752    /// Attempts to evaluate `output`. This will only be possible if `output`
753    /// doesn't depend on any graph inputs (this function is safe to call if
754    /// this isn't the case though).
755    ///
756    /// If the evaluation is successful, this function returns the tensor.
757    /// Otherwise returns None. An error status is returned if something is
758    /// wrong with the graph or input or the type requested doesn't match the
759    /// type of the tensor.
760    pub fn try_evaluate_constant<T: TensorType>(
761        &self,
762        output: &Output,
763    ) -> Result<Option<Tensor<T>>> {
764        let status = Status::new();
765        unsafe {
766            let mut c_tensor: *mut tf::TF_Tensor = ptr::null_mut();
767            let success = tf::TF_TryEvaluateConstant(
768                self.inner(),
769                output.to_c(),
770                &mut c_tensor,
771                status.inner,
772            );
773            status.into_result()?;
774            if success != 0 {
775                match Tensor::from_tf_tensor(c_tensor) {
776                    None => Err(invalid_arg!("Tensor types do not match")),
777                    Some(t) => Ok(Some(t)),
778                }
779            } else {
780                Ok(None)
781            }
782        }
783    }
784
785    /// Adds operations to compute the partial derivatives of sum of `y`s
786    /// w.r.t `x`s, i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...
787    ///
788    /// `dx` are used as initial gradients (which represent the symbolic partial
789    /// derivatives of some loss function `L` w.r.t. `y`).
790    /// `dx` must be None or have the same length as `y`.
791    /// If `dx` is None, the implementation will use dx of `OnesLike` for all
792    /// shapes in `y`.
793    /// `prefix` names the scope into which all gradients operations are being
794    /// added.  `prefix` must be unique within the provided graph otherwise this
795    /// operation will fail. If `prefix` is None, gradient nodes are
796    /// automatically named under the "gradients/" prefix. To guarantee name
797    /// uniqueness, subsequent calls to the same graph will append an
798    /// incremental tag to the prefix: "gradients_1/", "gradients_2/", ...
799    ///
800    /// WARNING: This function does not yet support all the gradients that
801    /// python supports. See
802    /// <https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md>
803    /// for instructions on how to add C++ more gradients.
804    pub fn add_gradients(
805        &mut self,
806        prefix: Option<&str>,
807        y: &[Output],
808        x: &[Output],
809        dx: Option<&[Output]>,
810    ) -> Result<Vec<Option<Output>>> {
811        if let Some(dx) = dx {
812            if dx.len() != y.len() {
813                return Err(invalid_arg!(
814                    "dx.len() must equal y.len() ({} vs. {})",
815                    dx.len(),
816                    y.len()
817                ));
818            }
819        }
820        let c_y: Vec<_> = y.iter().map(Output::to_c).collect();
821        let c_x: Vec<_> = x.iter().map(Output::to_c).collect();
822        let c_dx: Option<Vec<_>> = dx.map(|v| v.iter().map(Output::to_c).collect());
823        let dx_ptr = match c_dx {
824            Some(v) => v.as_ptr(),
825            None => ptr::null(),
826        };
827        let prefix_cstr = match prefix {
828            Some(s) => Some(CString::new(s)?),
829            None => None,
830        };
831        let prefix_ptr: *const c_char = if let Some(ref cstr) = &prefix_cstr {
832            cstr.as_ptr()
833        } else {
834            ptr::null()
835        };
836        let mut dy = Vec::with_capacity(x.len());
837        let mut status = Status::new();
838        unsafe {
839            tf::TF_AddGradientsWithPrefix(
840                self.inner(),
841                prefix_ptr,
842                c_y.as_ptr() as *mut _,
843                y.len() as i32,
844                c_x.as_ptr() as *mut _,
845                x.len() as i32,
846                dx_ptr as *mut _,
847                status.inner(),
848                dy.as_mut_ptr(),
849            );
850            if status.is_ok() {
851                dy.set_len(x.len());
852                Ok(dy
853                    .iter()
854                    .map(|o| Output::from_c_optional(self, o))
855                    .collect())
856            } else {
857                Err(status)
858            }
859        }
860    }
861
862    pub(crate) fn inner(&self) -> *mut tf::TF_Graph {
863        self.gimpl.inner
864    }
865
866    pub(crate) unsafe fn from_c(inner: *mut tf::TF_Graph) -> Self {
867        Graph {
868            gimpl: Arc::new(GraphImpl {
869                inner,
870                owned: false,
871            }),
872        }
873    }
874}
875
876////////////////////////
877
878/// Iterator over the operations in a `Graph`.
879#[derive(Debug)]
880pub struct OperationIter<'a> {
881    // We could just have a gimpl field, but keeping a reference to the Graph
882    // means that the graph can't be modified while iterating through it.
883    graph: &'a Graph,
884    pos: size_t,
885}
886
887impl<'a> Iterator for OperationIter<'a> {
888    type Item = Operation;
889
890    fn next(&mut self) -> Option<Self::Item> {
891        unsafe {
892            let operation = tf::TF_GraphNextOperation(self.graph.gimpl.inner, &mut self.pos);
893            if operation.is_null() {
894                None
895            } else {
896                Some(Operation {
897                    inner: operation,
898                    gimpl: self.graph.gimpl.clone(),
899                })
900            }
901        }
902    }
903}
904
905////////////////////////
906
907c_enum!(
908TF_AttrType,
909// TODO: Provide docs on variants once they are added to c_api.h.
910/// Describes the type of the value of an attribute on an operation.
911#[allow(missing_docs)]
912AttrType {
913    String = 0,
914    Int = 1,
915    Float = 2,
916    Bool = 3,
917    Type = 4,
918    Shape = 5,
919    Tensor = 6,
920    Placeholder = 7,
921    Func = 8,
922});
923
924/// AttrMetadata describes the value of an attribute on an operation.
925#[derive(Clone, Debug, Copy)]
926pub struct AttrMetadata {
927    /// Length of the list, or None if the attribute is not a list.
928    pub list_size: Option<i64>,
929
930    /// Type of elements of the list if the attribute is a list.
931    /// Type of the single value stored in the attribute if not a list.
932    pub attr_type: AttrType,
933
934    /// Total size the attribute value.
935    /// The units of total_size depend on list_size and attr_type.
936    ///
937    /// 1. If attr_type == AttrType::String and list_size == None
938    ///    then total_size is the byte size of the string valued attribute.
939    /// 2. If attr_type == AttrType::String and list_size == Some(_)
940    ///    then total_size is the cumulative byte size of all the strings in the
941    ///    list.
942    /// 3. If attr_type == AttrType::Shape and list_size == None
943    ///    then total_size is the number of dimensions of the shape valued
944    ///    attribute, or -1 if its rank is unknown.
945    /// 4. If attr_type == AttrType::SHAPE and list_size == Some(_)
946    ///    then total_size is the cumulative number of dimensions of all shapes
947    ///    in the list.
948    /// 4. Otherwise, total_size is undefined.
949    pub total_size: i64,
950}
951
952impl AttrMetadata {
953    fn from_c(metadata: tf::TF_AttrMetadata) -> Self {
954        AttrMetadata {
955            list_size: if metadata.is_list == 0 {
956                None
957            } else {
958                Some(metadata.list_size)
959            },
960            attr_type: AttrType::from_c(metadata.type_),
961            total_size: metadata.total_size,
962        }
963    }
964}
965
966////////////////////////
967
968/// An `Operation` is a node in a `Graph`.
969/// It is a computation which accepts inputs and produces outputs.
970#[derive(Debug, Clone)]
971pub struct Operation {
972    inner: *mut tf::TF_Operation,
973    gimpl: Arc<GraphImpl>,
974}
975
976unsafe impl Send for Operation {}
977unsafe impl Sync for Operation {}
978
979impl Operation {
980    /// Returns the name of the operation.
981    ///
982    /// This is the name of the specific computational step,
983    /// not an operation type, so it may look like `'add_x_and_y'` instead of `'Add'`,
984    /// although it may be a generated ID like `'Add_123'`.
985    pub fn name(&self) -> std::result::Result<String, Utf8Error> {
986        unsafe {
987            CStr::from_ptr(tf::TF_OperationName(self.inner))
988                .to_str()
989                .map(|x| x.to_string())
990        }
991    }
992
993    /// Returns the type of operation.
994    /// This will be something like `'Add'`, `'Mul'`, etc.
995    pub fn op_type(&self) -> std::result::Result<String, Utf8Error> {
996        unsafe {
997            CStr::from_ptr(tf::TF_OperationOpType(self.inner))
998                .to_str()
999                .map(|x| x.to_string())
1000        }
1001    }
1002
1003    /// Returns the device for this operation.
1004    /// The empty string means unconstrained.
1005    pub fn device(&self) -> std::result::Result<String, Utf8Error> {
1006        unsafe {
1007            CStr::from_ptr(tf::TF_OperationDevice(self.inner))
1008                .to_str()
1009                .map(|x| x.to_string())
1010        }
1011    }
1012
1013    /// Returns the number of outputs.
1014    pub fn num_outputs(&self) -> usize {
1015        unsafe { tf::TF_OperationNumOutputs(self.inner) as usize }
1016    }
1017
1018    /// Returns the type of a specific output.
1019    pub fn output_type(&self, index: usize) -> DataType {
1020        unsafe {
1021            DataType::from_c(tf::TF_OperationOutputType(tf::TF_Output {
1022                oper: self.inner,
1023                index: index as c_int,
1024            }))
1025        }
1026    }
1027
1028    /// Returns the given output edge.
1029    /// The index argument is the index into the current operation's output array,
1030    pub fn output(&self, index: usize) -> Output {
1031        crate::Output {
1032            operation: self.clone(),
1033            index: index as c_int,
1034        }
1035    }
1036
1037    // TODO: Figure out what this does and document it.
1038    #[allow(missing_docs)]
1039    pub fn output_list_length(&self, arg_name: &str) -> Result<usize> {
1040        let c_arg_name = CString::new(arg_name)?;
1041        let mut status = Status::new();
1042        let length = unsafe {
1043            tf::TF_OperationOutputListLength(self.inner, c_arg_name.as_ptr(), status.inner())
1044        };
1045        if status.is_ok() {
1046            Ok(length as usize)
1047        } else {
1048            Err(status)
1049        }
1050    }
1051
1052    /// Returns the number of inputs.
1053    pub fn num_inputs(&self) -> usize {
1054        unsafe { tf::TF_OperationNumInputs(self.inner) as usize }
1055    }
1056
1057    /// Returns the type of a specific input.
1058    pub fn input_type(&self, index: usize) -> DataType {
1059        unsafe {
1060            DataType::from_c(tf::TF_OperationInputType(tf::TF_Input {
1061                oper: self.inner,
1062                index: index as c_int,
1063            }))
1064        }
1065    }
1066
1067    // TODO: Figure out what this does and document it.
1068    #[allow(missing_docs)]
1069    pub fn input_list_length(&self, arg_name: &str) -> Result<usize> {
1070        let c_arg_name = CString::new(arg_name)?;
1071        let mut status = Status::new();
1072        let length = unsafe {
1073            tf::TF_OperationInputListLength(self.inner, c_arg_name.as_ptr(), status.inner())
1074        };
1075        if status.is_ok() {
1076            Ok(length as usize)
1077        } else {
1078            Err(status)
1079        }
1080    }
1081
1082    /// Returns the given input edge.
1083    /// The index argument is the index into the current operation's input array,
1084    /// and the return value is the source operation and the index into its output array.
1085    pub fn input(&self, index: usize) -> (Operation, usize) {
1086        unsafe {
1087            let port = tf::TF_OperationInput(tf::TF_Input {
1088                oper: self.inner,
1089                index: index as c_int,
1090            });
1091            (
1092                Operation {
1093                    inner: port.oper,
1094                    gimpl: self.gimpl.clone(),
1095                },
1096                port.index as usize,
1097            )
1098        }
1099    }
1100
1101    /// Returns the number of consumers of a specific output.
1102    pub fn output_num_consumers(&self, index: usize) -> usize {
1103        unsafe {
1104            tf::TF_OperationOutputNumConsumers(tf::TF_Output {
1105                oper: self.inner,
1106                index: index as c_int,
1107            }) as usize
1108        }
1109    }
1110
1111    /// Returns the consumers of a specific output.
1112    /// The index argument is the index into the current operation's output array,
1113    /// and the return value is a vector of the destination operation and the index
1114    /// into its input array.
1115    pub fn output_consumers(&self, index: usize) -> Vec<(Operation, usize)> {
1116        unsafe {
1117            let num_consumers = tf::TF_OperationOutputNumConsumers(tf::TF_Output {
1118                oper: self.inner,
1119                index: index as c_int,
1120            });
1121            let mut vec = <Vec<tf::TF_Input>>::with_capacity(num_consumers as usize);
1122            let len = tf::TF_OperationOutputConsumers(
1123                tf::TF_Output {
1124                    oper: self.inner,
1125                    index: index as c_int,
1126                },
1127                vec.as_mut_ptr(),
1128                num_consumers as c_int,
1129            );
1130            vec.set_len(len as usize);
1131            vec.into_iter()
1132                .map(|port| {
1133                    (
1134                        Operation {
1135                            inner: port.oper,
1136                            gimpl: self.gimpl.clone(),
1137                        },
1138                        port.index as usize,
1139                    )
1140                })
1141                .collect()
1142        }
1143    }
1144
1145    /// Returns the number of control inputs.
1146    pub fn num_control_inputs(&self) -> usize {
1147        unsafe { tf::TF_OperationNumControlInputs(self.inner) as usize }
1148    }
1149
1150    /// Returns the control inputs.
1151    pub fn control_inputs(&self) -> Vec<Operation> {
1152        unsafe {
1153            let num_consumers = tf::TF_OperationNumControlInputs(self.inner);
1154            let mut vec = <Vec<*mut tf::TF_Operation>>::with_capacity(num_consumers as usize);
1155            let len = tf::TF_OperationGetControlInputs(
1156                self.inner,
1157                vec.as_mut_ptr(),
1158                num_consumers as c_int,
1159            );
1160            vec.set_len(cmp::min(num_consumers, len) as usize);
1161            vec.into_iter()
1162                .map(|operation| Operation {
1163                    inner: operation,
1164                    gimpl: self.gimpl.clone(),
1165                })
1166                .collect()
1167        }
1168    }
1169
1170    /// Returns the number of control outputs.
1171    pub fn num_control_outputs(&self) -> usize {
1172        unsafe { tf::TF_OperationNumControlOutputs(self.inner) as usize }
1173    }
1174
1175    /// Returns the control outputs.
1176    pub fn control_outputs(&self) -> Vec<Operation> {
1177        unsafe {
1178            let num_consumers = tf::TF_OperationNumControlOutputs(self.inner);
1179            let mut vec = <Vec<*mut tf::TF_Operation>>::with_capacity(num_consumers as usize);
1180            let len =
1181                tf::TF_OperationGetControlOutputs(self.inner, vec.as_mut_ptr(), vec.len() as c_int);
1182            vec.set_len(len as usize);
1183            vec.into_iter()
1184                .map(|operation| Operation {
1185                    inner: operation,
1186                    gimpl: self.gimpl.clone(),
1187                })
1188                .collect()
1189        }
1190    }
1191
1192    /// Returns metadata about the value of the attribute `attr_name`.
1193    pub fn get_attr_metadata(&self, attr_name: &str) -> Result<AttrMetadata> {
1194        let c_attr_name = CString::new(attr_name)?;
1195        let mut status = Status::new();
1196        unsafe {
1197            let metadata =
1198                tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1199            if status.is_ok() {
1200                Ok(AttrMetadata::from_c(metadata))
1201            } else {
1202                Err(status)
1203            }
1204        }
1205    }
1206
1207    /// Returns the value of the attribute `attr_name`.
1208    pub fn get_attr_string(&self, attr_name: &str) -> Result<String> {
1209        let c_attr_name = CString::new(attr_name)?;
1210        let mut status = Status::new();
1211        unsafe {
1212            let metadata =
1213                tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1214            if !status.is_ok() {
1215                return Err(status);
1216            }
1217            let mut v: Vec<MaybeUninit<u8>> = Vec::with_capacity(metadata.total_size as usize);
1218            v.set_len(metadata.total_size as usize);
1219            tf::TF_OperationGetAttrString(
1220                self.inner,
1221                c_attr_name.as_ptr(),
1222                v.as_mut_ptr() as *mut std::os::raw::c_void,
1223                metadata.total_size as usize,
1224                status.inner(),
1225            );
1226            if !status.is_ok() {
1227                return Err(status);
1228            }
1229            Ok(CString::new(
1230                v.into_iter()
1231                    .map(|x| MaybeUninit::assume_init(x))
1232                    .collect::<Vec<_>>(),
1233            )?
1234            .into_string()?)
1235        }
1236    }
1237
1238    /// Get the list of strings in the value of the attribute `attr_name`.
1239    pub fn get_attr_string_list(&self, attr_name: &str) -> Result<Vec<String>> {
1240        let c_attr_name = CString::new(attr_name)?;
1241        let mut status = Status::new();
1242        unsafe {
1243            let metadata =
1244                tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1245            if !status.is_ok() {
1246                return Err(status);
1247            }
1248            let mut storage: Vec<MaybeUninit<u8>> =
1249                Vec::with_capacity(metadata.total_size as usize);
1250            storage.set_len(metadata.total_size as usize);
1251            let mut values: Vec<*const std::os::raw::c_char> =
1252                Vec::with_capacity(metadata.list_size as usize);
1253            let mut lengths: Vec<size_t> = Vec::with_capacity(metadata.list_size as usize);
1254            tf::TF_OperationGetAttrStringList(
1255                self.inner,
1256                c_attr_name.as_ptr(),
1257                values.as_mut_ptr() as *mut *mut std::os::raw::c_void,
1258                lengths.as_mut_ptr(),
1259                metadata.list_size as i32,
1260                storage.as_mut_ptr() as *mut std::os::raw::c_void,
1261                metadata.total_size as usize,
1262                status.inner(),
1263            );
1264            if !status.is_ok() {
1265                return Err(status);
1266            }
1267            values.set_len(metadata.list_size as usize);
1268            lengths.set_len(metadata.list_size as usize);
1269            let mut strings = Vec::with_capacity(metadata.list_size as usize);
1270            for i in 0..metadata.list_size as usize {
1271                let s = slice::from_raw_parts(values[i] as *const u8, lengths[i]);
1272                strings.push(std::str::from_utf8(s)?.to_string());
1273            }
1274            Ok(strings)
1275        }
1276    }
1277
1278    /// Returns the value of the attribute `attr_name`.
1279    pub fn get_attr_int(&self, attr_name: &str) -> Result<i64> {
1280        let c_attr_name = CString::new(attr_name)?;
1281        let mut status = Status::new();
1282        let mut value: i64 = 0;
1283        unsafe {
1284            tf::TF_OperationGetAttrInt(
1285                self.inner,
1286                c_attr_name.as_ptr(),
1287                &mut value,
1288                status.inner(),
1289            );
1290        }
1291        if !status.is_ok() {
1292            return Err(status);
1293        }
1294        Ok(value)
1295    }
1296
1297    /// Get the list of ints in the value of the attribute `attr_name`.
1298    pub fn get_attr_int_list(&self, attr_name: &str) -> Result<Vec<i64>> {
1299        let c_attr_name = CString::new(attr_name)?;
1300        let mut status = Status::new();
1301        unsafe {
1302            let metadata =
1303                tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1304            if !status.is_ok() {
1305                return Err(status);
1306            }
1307            let mut values: Vec<MaybeUninit<i64>> = Vec::with_capacity(metadata.list_size as usize);
1308            values.set_len(metadata.list_size as usize);
1309            tf::TF_OperationGetAttrIntList(
1310                self.inner,
1311                c_attr_name.as_ptr(),
1312                values.as_mut_ptr() as *mut i64,
1313                metadata.list_size as c_int,
1314                status.inner(),
1315            );
1316            if !status.is_ok() {
1317                return Err(status);
1318            }
1319            Ok(values
1320                .into_iter()
1321                .map(|x| MaybeUninit::assume_init(x))
1322                .collect())
1323        }
1324    }
1325
1326    /// Returns the value of the attribute `attr_name`.
1327    pub fn get_attr_float(&self, attr_name: &str) -> Result<f32> {
1328        let c_attr_name = CString::new(attr_name)?;
1329        let mut status = Status::new();
1330        let mut value: c_float = 0.0;
1331        unsafe {
1332            tf::TF_OperationGetAttrFloat(
1333                self.inner,
1334                c_attr_name.as_ptr(),
1335                &mut value,
1336                status.inner(),
1337            );
1338        }
1339        if !status.is_ok() {
1340            return Err(status);
1341        }
1342        #[allow(trivial_numeric_casts)]
1343        #[allow(clippy::unnecessary_cast)]
1344        Ok(value as f32)
1345    }
1346
1347    /// Get the list of floats in the value of the attribute `attr_name`.
1348    pub fn get_attr_float_list(&self, attr_name: &str) -> Result<Vec<f32>> {
1349        let c_attr_name = CString::new(attr_name)?;
1350        let mut status = Status::new();
1351        unsafe {
1352            let metadata =
1353                tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1354            if !status.is_ok() {
1355                return Err(status);
1356            }
1357            let mut values: Vec<MaybeUninit<c_float>> =
1358                Vec::with_capacity(metadata.list_size as usize);
1359            values.set_len(metadata.list_size as usize);
1360            tf::TF_OperationGetAttrFloatList(
1361                self.inner,
1362                c_attr_name.as_ptr(),
1363                values.as_mut_ptr() as *mut c_float,
1364                metadata.list_size as c_int,
1365                status.inner(),
1366            );
1367            if !status.is_ok() {
1368                return Err(status);
1369            }
1370            #[allow(trivial_numeric_casts)]
1371            #[allow(clippy::unnecessary_cast)]
1372            Ok(values.iter().map(|f| f.assume_init() as f32).collect())
1373        }
1374    }
1375
1376    /// Returns the value of the attribute `attr_name`.
1377    pub fn get_attr_bool(&self, attr_name: &str) -> Result<bool> {
1378        let c_attr_name = CString::new(attr_name)?;
1379        let mut status = Status::new();
1380        let mut value: c_uchar = 0;
1381        unsafe {
1382            tf::TF_OperationGetAttrBool(
1383                self.inner,
1384                c_attr_name.as_ptr(),
1385                &mut value,
1386                status.inner(),
1387            );
1388        }
1389        if !status.is_ok() {
1390            return Err(status);
1391        }
1392        Ok(value != 0)
1393    }
1394
1395    /// Get the list of bools in the value of the attribute `attr_name`.
1396    pub fn get_attr_bool_list(&self, attr_name: &str) -> Result<Vec<bool>> {
1397        let c_attr_name = CString::new(attr_name)?;
1398        let mut status = Status::new();
1399        unsafe {
1400            let metadata =
1401                tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1402            if !status.is_ok() {
1403                return Err(status);
1404            }
1405            let mut values: Vec<MaybeUninit<c_uchar>> =
1406                Vec::with_capacity(metadata.list_size as usize);
1407            values.set_len(metadata.list_size as usize);
1408            tf::TF_OperationGetAttrBoolList(
1409                self.inner,
1410                c_attr_name.as_ptr(),
1411                values.as_mut_ptr() as *mut c_uchar,
1412                metadata.list_size as c_int,
1413                status.inner(),
1414            );
1415            if !status.is_ok() {
1416                return Err(status);
1417            }
1418            #[allow(trivial_numeric_casts)]
1419            Ok(values.iter().map(|f| f.assume_init() != 0).collect())
1420        }
1421    }
1422
1423    /// Returns the value of the attribute `attr_name`.
1424    pub fn get_attr_type(&self, attr_name: &str) -> Result<DataType> {
1425        let c_attr_name = CString::new(attr_name)?;
1426        let mut status = Status::new();
1427        let mut value: tf::TF_DataType = tf::TF_FLOAT;
1428        unsafe {
1429            tf::TF_OperationGetAttrType(
1430                self.inner,
1431                c_attr_name.as_ptr(),
1432                &mut value,
1433                status.inner(),
1434            );
1435        }
1436        if !status.is_ok() {
1437            return Err(status);
1438        }
1439        Ok(DataType::from_c(value))
1440    }
1441
1442    /// Get the list of types in the value of the attribute `attr_name`.
1443    pub fn get_attr_type_list(&self, attr_name: &str) -> Result<Vec<DataType>> {
1444        let c_attr_name = CString::new(attr_name)?;
1445        let mut status = Status::new();
1446        unsafe {
1447            let metadata =
1448                tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1449            if !status.is_ok() {
1450                return Err(status);
1451            }
1452            let mut values: Vec<MaybeUninit<tf::TF_DataType>> =
1453                Vec::with_capacity(metadata.list_size as usize);
1454            values.set_len(metadata.list_size as usize);
1455            tf::TF_OperationGetAttrTypeList(
1456                self.inner,
1457                c_attr_name.as_ptr(),
1458                values.as_mut_ptr() as *mut tf::TF_DataType,
1459                metadata.list_size as c_int,
1460                status.inner(),
1461            );
1462            if !status.is_ok() {
1463                return Err(status);
1464            }
1465            Ok(values
1466                .iter()
1467                .map(|x| DataType::from_c(x.assume_init()))
1468                .collect())
1469        }
1470    }
1471
1472    /// Returns the value of the attribute `attr_name`.
1473    pub fn get_attr_shape(&self, attr_name: &str) -> Result<Shape> {
1474        let c_attr_name = CString::new(attr_name)?;
1475        let mut status = Status::new();
1476        unsafe {
1477            let metadata =
1478                tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1479            if !status.is_ok() {
1480                return Err(status);
1481            }
1482            if metadata.total_size == -1 {
1483                return Ok(Shape(None));
1484            }
1485            let mut v: Vec<MaybeUninit<i64>> = Vec::with_capacity(metadata.total_size as usize);
1486            v.set_len(metadata.total_size as usize);
1487            tf::TF_OperationGetAttrShape(
1488                self.inner,
1489                c_attr_name.as_ptr(),
1490                v.as_mut_ptr() as *mut i64,
1491                metadata.total_size as c_int,
1492                status.inner(),
1493            );
1494            if !status.is_ok() {
1495                return Err(status);
1496            }
1497            Ok(Shape(Some(
1498                v.iter()
1499                    .map(|x| {
1500                        let x = x.assume_init();
1501                        if x < 0 {
1502                            None
1503                        } else {
1504                            Some(x)
1505                        }
1506                    })
1507                    .collect(),
1508            )))
1509        }
1510    }
1511
1512    /// Get the list of shapes in the value of the attribute `attr_name`.
1513    pub fn get_attr_shape_list(&self, attr_name: &str) -> Result<Vec<Shape>> {
1514        let c_attr_name = CString::new(attr_name)?;
1515        let mut status = Status::new();
1516        unsafe {
1517            let metadata =
1518                tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1519            if !status.is_ok() {
1520                return Err(status);
1521            }
1522            let mut storage: Vec<MaybeUninit<i64>> =
1523                Vec::with_capacity(metadata.total_size as usize);
1524            storage.set_len(metadata.total_size as usize);
1525            let mut dims: Vec<*mut i64> = Vec::with_capacity(metadata.list_size as usize);
1526            let mut num_dims: Vec<c_int> = Vec::with_capacity(metadata.list_size as usize);
1527            tf::TF_OperationGetAttrShapeList(
1528                self.inner,
1529                c_attr_name.as_ptr(),
1530                dims.as_mut_ptr(),
1531                num_dims.as_mut_ptr(),
1532                metadata.list_size as i32,
1533                storage.as_mut_ptr() as *mut i64,
1534                metadata.total_size as c_int,
1535                status.inner(),
1536            );
1537            if !status.is_ok() {
1538                return Err(status);
1539            }
1540            dims.set_len(metadata.list_size as usize);
1541            num_dims.set_len(metadata.list_size as usize);
1542            let mut shapes = Vec::with_capacity(metadata.list_size as usize);
1543            for i in 0..metadata.list_size as usize {
1544                shapes.push(Shape(if num_dims[i] == -1 {
1545                    None
1546                } else {
1547                    let mut v = Vec::new();
1548                    for j in 0..num_dims[i] {
1549                        v.push(match *dims[i].offset(j as isize) {
1550                            -1 => None,
1551                            x => Some(x),
1552                        });
1553                    }
1554                    Some(v)
1555                }));
1556            }
1557            Ok(shapes)
1558        }
1559    }
1560
1561    /// Returns the binary-serialized TensorShapeProto value of the attribute
1562    /// `attr_name`.
1563    pub fn get_attr_tensor_shape_proto(&self, attr_name: &str) -> Result<Vec<u8>> {
1564        let c_attr_name = CString::new(attr_name)?;
1565        let mut status = Status::new();
1566        unsafe {
1567            let mut buf = Buffer::<u8>::new_unallocated();
1568            tf::TF_OperationGetAttrTensorShapeProto(
1569                self.inner,
1570                c_attr_name.as_ptr(),
1571                buf.inner_mut(),
1572                status.inner(),
1573            );
1574            if !status.is_ok() {
1575                return Err(status);
1576            }
1577            Ok(buf.into())
1578        }
1579    }
1580
1581    /// Get the list of binary-serialized TensorShapeProtos in the value of the
1582    /// attribute `attr_name`.
1583    pub fn get_attr_tensor_shape_proto_list(&self, attr_name: &str) -> Result<Vec<Vec<u8>>> {
1584        let c_attr_name = CString::new(attr_name)?;
1585        let mut status = Status::new();
1586        unsafe {
1587            let metadata =
1588                tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1589            if !status.is_ok() {
1590                return Err(status);
1591            }
1592            let mut c_buffers = Vec::with_capacity(metadata.list_size as usize);
1593            for _ in 0..metadata.list_size {
1594                c_buffers.push(ptr::null_mut());
1595            }
1596            tf::TF_OperationGetAttrTensorShapeProtoList(
1597                self.inner,
1598                c_attr_name.as_ptr(),
1599                c_buffers.as_mut_ptr(),
1600                metadata.list_size as c_int,
1601                status.inner(),
1602            );
1603            if !status.is_ok() {
1604                return Err(status);
1605            }
1606            Ok(c_buffers
1607                .iter()
1608                .map(|b| Buffer::from_c(*b, true).into())
1609                .collect())
1610        }
1611    }
1612
1613    /// Returns the value of the attribute `attr_name`. Returns an error if the
1614    /// type of the tensor value does not match the type of the generic
1615    /// argument.
1616    pub fn get_attr_tensor<T: TensorType>(&self, attr_name: &str) -> Result<Tensor<T>> {
1617        let c_attr_name = CString::new(attr_name)?;
1618        let mut status = Status::new();
1619        unsafe {
1620            let mut c_tensor: *mut tf::TF_Tensor = ptr::null_mut();
1621            tf::TF_OperationGetAttrTensor(
1622                self.inner,
1623                c_attr_name.as_ptr(),
1624                &mut c_tensor,
1625                status.inner(),
1626            );
1627            if !status.is_ok() {
1628                return Err(status);
1629            }
1630            match Tensor::from_tf_tensor(c_tensor) {
1631                None => Err(invalid_arg!("Tensor types do not match")),
1632                Some(t) => Ok(t),
1633            }
1634        }
1635    }
1636
1637    /// Get the list of tensors in the value of the attribute `attr_name`.
1638    /// Returns an error if the type of the tensor value does not match the type
1639    /// of the generic argument.
1640    pub fn get_attr_tensor_list<T: TensorType>(&self, attr_name: &str) -> Result<Vec<Tensor<T>>> {
1641        let c_attr_name = CString::new(attr_name)?;
1642        let mut status = Status::new();
1643        unsafe {
1644            let metadata =
1645                tf::TF_OperationGetAttrMetadata(self.inner, c_attr_name.as_ptr(), status.inner());
1646            if !status.is_ok() {
1647                return Err(status);
1648            }
1649            let mut c_tensors = Vec::with_capacity(metadata.list_size as usize);
1650            for _ in 0..metadata.list_size {
1651                c_tensors.push(ptr::null_mut());
1652            }
1653            tf::TF_OperationGetAttrTensorList(
1654                self.inner,
1655                c_attr_name.as_ptr(),
1656                c_tensors.as_mut_ptr(),
1657                metadata.list_size as c_int,
1658                status.inner(),
1659            );
1660            if !status.is_ok() {
1661                return Err(status);
1662            }
1663            c_tensors
1664                .iter()
1665                .map(|t| match Tensor::from_tf_tensor(*t) {
1666                    None => Err(invalid_arg!("Tensor types do not match")),
1667                    Some(t) => Ok(t),
1668                })
1669                .collect()
1670        }
1671    }
1672
1673    /// Returns the binary-serialized AttrValue proto representation of the
1674    /// value of the `attr_name` attr.
1675    pub fn get_attr_value_proto(&self, attr_name: &str) -> Result<Vec<u8>> {
1676        let status = Status::new();
1677        let attr_name_cstr = CString::new(attr_name)?;
1678        unsafe {
1679            let mut buf = Buffer::new_unallocated();
1680            tf::TF_OperationGetAttrValueProto(
1681                self.inner,
1682                attr_name_cstr.as_ptr(),
1683                buf.inner_mut(),
1684                status.inner,
1685            );
1686            status.into_result()?;
1687            Ok(buf.into())
1688        }
1689    }
1690
1691    pub(crate) fn inner(&self) -> *mut tf::TF_Operation {
1692        self.inner
1693    }
1694}
1695
1696impl From<Operation> for Output {
1697    /// Creates an Output for index 0.
1698    fn from(operation: Operation) -> Output {
1699        Output {
1700            operation,
1701            index: 0,
1702        }
1703    }
1704}
1705
1706////////////////////////
1707
1708/// A `Input` is one end of a graph edge.
1709/// It holds an operation and an index into the inputs of that operation.
1710#[derive(Debug, Copy, Clone)]
1711pub struct Input<'a> {
1712    /// Operation the edge connects to.
1713    pub operation: &'a Operation,
1714
1715    /// Index into either the inputs of the operation.
1716    pub index: c_int,
1717}
1718
1719////////////////////////
1720
1721/// A `Output` is one end of a graph edge.
1722/// It holds an operation and an index into the outputs of that operation.
1723#[derive(Debug, Clone)]
1724pub struct Output {
1725    /// Operation the edge connects to.
1726    pub operation: Operation,
1727
1728    /// Index into either the outputs of the operation.
1729    pub index: c_int,
1730}
1731
1732impl Output {
1733    pub(crate) fn to_c(&self) -> tf::TF_Output {
1734        tf::TF_Output {
1735            oper: self.operation.inner,
1736            index: self.index,
1737        }
1738    }
1739
1740    pub(crate) fn from_c(graph: &Graph, output: &tf::TF_Output) -> Self {
1741        Output {
1742            operation: Operation {
1743                inner: output.oper,
1744                gimpl: graph.gimpl.clone(),
1745            },
1746            index: output.index,
1747        }
1748    }
1749
1750    pub(crate) fn from_c_optional(graph: &Graph, output: &tf::TF_Output) -> Option<Self> {
1751        if output.oper.is_null() {
1752            None
1753        } else {
1754            Some(Output {
1755                operation: Operation {
1756                    inner: output.oper,
1757                    gimpl: graph.gimpl.clone(),
1758                },
1759                index: output.index,
1760            })
1761        }
1762    }
1763
1764    /// Returns the name of this output.
1765    pub fn name(&self) -> Result<OutputName> {
1766        Ok(OutputName {
1767            name: self.operation.name()?,
1768            index: self.index,
1769        })
1770    }
1771}
1772
1773////////////////////////
1774
1775/// Names a specific Output in the graph.
1776#[derive(Clone, PartialEq, Eq, Hash, Debug, Default)]
1777pub struct OutputName {
1778    /// Name of the operation the edge connects to.
1779    pub name: String,
1780
1781    /// Index into either the outputs of the operation.
1782    pub index: c_int,
1783}
1784
1785impl FromStr for OutputName {
1786    type Err = Status;
1787    fn from_str(s: &str) -> Result<Self> {
1788        let splits: Vec<_> = s.split(':').collect();
1789        let index = match splits.len() {
1790            2 => splits[1].parse::<c_int>()?,
1791            1 => 0,
1792            _ => {
1793                return Err(Status::new_set_lossy(
1794                    Code::InvalidArgument,
1795                    "Name contains more than one colon (':')",
1796                ))
1797            }
1798        };
1799        Ok(Self {
1800            name: splits[0].to_string(),
1801            index,
1802        })
1803    }
1804}
1805
1806impl Display for OutputName {
1807    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1808        write!(f, "{}:{}", self.name, self.index)
1809    }
1810}
1811
1812////////////////////////
1813
1814/// An `OperationDescription` is an `Operation` in the process of being built
1815/// (i.e. the builder pattern).
1816///
1817/// An `OperationDescription` is required to be finished before the graph
1818/// goes out of scope,
1819/// so `finish()` will be called on drop if it was not already called.
1820#[derive(Debug)]
1821pub struct OperationDescription<'a> {
1822    inner: *mut tf::TF_OperationDescription,
1823    // This keeps self from outliving the Graph, which is required by
1824    // the docs on TF_NewOperation.
1825    graph: &'a Graph,
1826    finished: bool,
1827}
1828
1829impl<'a> Drop for OperationDescription<'a> {
1830    fn drop(&mut self) {
1831        if !self.finished {
1832            unsafe {
1833                // TF_NewOperation requires us to make sure TF_FinishOperation is called before the
1834                // graph is deleted.  Combined with guaranteeing that OperationDescription does
1835                // not outlive Graph, this ensures that the contract is held.
1836                let status = tf::TF_NewStatus();
1837                tf::TF_FinishOperation(self.inner, status);
1838                tf::TF_DeleteStatus(status);
1839            }
1840        }
1841    }
1842}
1843
1844impl<'a> OperationDescription<'a> {
1845    /// Builds the operation and adds it to the graph.
1846    pub fn finish(mut self) -> Result<Operation> {
1847        self.finished = true; // used by the drop code
1848        let mut status = Status::new();
1849        let operation = unsafe { tf::TF_FinishOperation(self.inner, status.inner()) };
1850        if status.is_ok() {
1851            Ok(Operation {
1852                inner: operation,
1853                gimpl: self.graph.gimpl.clone(),
1854            })
1855        } else {
1856            Err(status)
1857        }
1858    }
1859
1860    /// Sets the preferred device.
1861    /// The empty string means unconstrained.
1862    pub fn set_device(&mut self, device: &str) -> std::result::Result<(), NulError> {
1863        let c_device = CString::new(device)?;
1864        unsafe {
1865            tf::TF_SetDevice(self.inner, c_device.as_ptr());
1866        }
1867        Ok(())
1868    }
1869
1870    /// Adds an input to this operation.
1871    ///
1872    /// The index in the port is an index into the source operation's output array.
1873    pub fn add_input<I: Into<Output>>(&mut self, input: I) {
1874        unsafe {
1875            tf::TF_AddInput(self.inner, input.into().to_c());
1876        }
1877    }
1878
1879    /// Adds multiple inputs to this operation.
1880    ///
1881    /// The index in the ports is an index into the source operation's output array.
1882    pub fn add_input_list(&mut self, inputs: &[Output]) {
1883        let c_inputs: Vec<tf::TF_Output> = inputs.iter().map(|x| x.to_c()).collect();
1884        unsafe {
1885            tf::TF_AddInputList(self.inner, c_inputs.as_ptr(), c_inputs.len() as c_int);
1886        }
1887    }
1888
1889    /// Adds a control input.
1890    pub fn add_control_input(&mut self, input: &Operation) {
1891        unsafe {
1892            tf::TF_AddControlInput(self.inner, input.inner);
1893        }
1894    }
1895
1896    /// Sets the value of a string attribute.
1897    #[allow(trivial_numeric_casts)]
1898    pub fn set_attr_string(
1899        &mut self,
1900        attr_name: &str,
1901        value: &str,
1902    ) -> std::result::Result<(), NulError> {
1903        let c_attr_name = CString::new(attr_name)?;
1904        let c_value = value.as_bytes();
1905        unsafe {
1906            tf::TF_SetAttrString(
1907                self.inner,
1908                c_attr_name.as_ptr(),
1909                c_value.as_ptr() as *const std_c_void,
1910                c_value.len() as size_t,
1911            );
1912        }
1913        Ok(())
1914    }
1915
1916    /// Sets the value of an attribute which holds a list of strings.
1917    #[allow(trivial_numeric_casts)]
1918    pub fn set_attr_string_list<S: AsRef<str>>(
1919        &mut self,
1920        attr_name: &str,
1921        value: &[S],
1922    ) -> std::result::Result<(), NulError> {
1923        let c_attr_name = CString::new(attr_name)?;
1924        let bytes: Vec<&[u8]> = value.iter().map(|x| x.as_ref().as_bytes()).collect();
1925        let ptrs: Vec<*const c_void> = bytes.iter().map(|x| x.as_ptr() as *const c_void).collect();
1926        let lens: Vec<size_t> = bytes.iter().map(|x| x.len() as size_t).collect();
1927        unsafe {
1928            tf::TF_SetAttrStringList(
1929                self.inner,
1930                c_attr_name.as_ptr(),
1931                ptrs.as_ptr() as *const *const std_c_void,
1932                lens.as_ptr(),
1933                ptrs.len() as c_int,
1934            );
1935        }
1936        Ok(())
1937    }
1938
1939    /// Sets the value of a function attribute.
1940    #[allow(trivial_numeric_casts)]
1941    pub fn set_attr_func_name(
1942        &mut self,
1943        attr_name: &str,
1944        value: &str,
1945    ) -> std::result::Result<(), NulError> {
1946        let c_attr_name = CString::new(attr_name)?;
1947        let c_value = value.as_bytes();
1948        unsafe {
1949            tf::TF_SetAttrFuncName(
1950                self.inner,
1951                c_attr_name.as_ptr(),
1952                c_value.as_ptr() as *const c_char,
1953                c_value.len() as size_t,
1954            );
1955        }
1956        Ok(())
1957    }
1958
1959    /// Sets an int-valued attribute.
1960    pub fn set_attr_int(
1961        &mut self,
1962        attr_name: &str,
1963        value: i64,
1964    ) -> std::result::Result<(), NulError> {
1965        let c_attr_name = CString::new(attr_name)?;
1966        unsafe {
1967            tf::TF_SetAttrInt(self.inner, c_attr_name.as_ptr(), value);
1968        }
1969        Ok(())
1970    }
1971
1972    /// Sets an attribute which holds an array of ints.
1973    pub fn set_attr_int_list(
1974        &mut self,
1975        attr_name: &str,
1976        value: &[i64],
1977    ) -> std::result::Result<(), NulError> {
1978        let c_attr_name = CString::new(attr_name)?;
1979        unsafe {
1980            tf::TF_SetAttrIntList(
1981                self.inner,
1982                c_attr_name.as_ptr(),
1983                value.as_ptr(),
1984                value.len() as i32,
1985            );
1986        }
1987        Ok(())
1988    }
1989
1990    /// Sets a float-valued attribute.
1991    pub fn set_attr_float(
1992        &mut self,
1993        attr_name: &str,
1994        value: f32,
1995    ) -> std::result::Result<(), NulError> {
1996        let c_attr_name = CString::new(attr_name)?;
1997        unsafe {
1998            tf::TF_SetAttrFloat(self.inner, c_attr_name.as_ptr(), value);
1999        }
2000        Ok(())
2001    }
2002
2003    /// Sets an attribute which holds an array of floats.
2004    #[allow(trivial_numeric_casts)]
2005    pub fn set_attr_float_list(
2006        &mut self,
2007        attr_name: &str,
2008        value: &[f32],
2009    ) -> std::result::Result<(), NulError> {
2010        let c_attr_name = CString::new(attr_name)?;
2011        // Allow trivial_numeric_casts here because f32 is not necessarily equal to c_float.
2012        let c_value: Vec<c_float> = value.iter().map(|x| *x as c_float).collect();
2013        unsafe {
2014            tf::TF_SetAttrFloatList(
2015                self.inner,
2016                c_attr_name.as_ptr(),
2017                c_value.as_ptr(),
2018                c_value.len() as i32,
2019            );
2020        }
2021        Ok(())
2022    }
2023
2024    /// Sets a boolean-valued attribute.
2025    pub fn set_attr_bool(
2026        &mut self,
2027        attr_name: &str,
2028        value: bool,
2029    ) -> std::result::Result<(), NulError> {
2030        let c_attr_name = CString::new(attr_name)?;
2031        unsafe {
2032            tf::TF_SetAttrBool(self.inner, c_attr_name.as_ptr(), u8::from(value));
2033        }
2034        Ok(())
2035    }
2036
2037    /// Sets an attribute which holds an array of booleans.
2038    pub fn set_attr_bool_list(
2039        &mut self,
2040        attr_name: &str,
2041        value: &[bool],
2042    ) -> std::result::Result<(), NulError> {
2043        let c_attr_name = CString::new(attr_name)?;
2044        let c_value: Vec<c_uchar> = value.iter().map(|x| u8::from(*x)).collect();
2045        unsafe {
2046            tf::TF_SetAttrBoolList(
2047                self.inner,
2048                c_attr_name.as_ptr(),
2049                c_value.as_ptr(),
2050                c_value.len() as c_int,
2051            );
2052        }
2053        Ok(())
2054    }
2055
2056    /// Sets a type-valued attribute.
2057    pub fn set_attr_type(
2058        &mut self,
2059        attr_name: &str,
2060        value: DataType,
2061    ) -> std::result::Result<(), NulError> {
2062        let c_attr_name = CString::new(attr_name)?;
2063        unsafe {
2064            tf::TF_SetAttrType(self.inner, c_attr_name.as_ptr(), value.to_c());
2065        }
2066        Ok(())
2067    }
2068
2069    /// Sets an attribute which holds an array of types.
2070    pub fn set_attr_type_list(
2071        &mut self,
2072        attr_name: &str,
2073        value: &[DataType],
2074    ) -> std::result::Result<(), NulError> {
2075        let c_attr_name = CString::new(attr_name)?;
2076        let c_value: Vec<tf::TF_DataType> = value.iter().map(|x| x.to_c()).collect();
2077        unsafe {
2078            tf::TF_SetAttrTypeList(
2079                self.inner,
2080                c_attr_name.as_ptr(),
2081                c_value.as_ptr(),
2082                c_value.len() as i32,
2083            );
2084        }
2085        Ok(())
2086    }
2087
2088    /// Sets a shape-valued attribute.
2089    pub fn set_attr_shape(
2090        &mut self,
2091        attr_name: &str,
2092        value: &Shape,
2093    ) -> std::result::Result<(), NulError> {
2094        let c_attr_name = CString::new(attr_name)?;
2095        unsafe {
2096            match value.0 {
2097                None => tf::TF_SetAttrShape(self.inner, c_attr_name.as_ptr(), ptr::null(), -1),
2098                Some(ref dims) => {
2099                    let c_dims: Vec<i64> = dims.iter().map(|x| (*x).unwrap_or(-1)).collect();
2100                    tf::TF_SetAttrShape(
2101                        self.inner,
2102                        c_attr_name.as_ptr(),
2103                        c_dims.as_ptr(),
2104                        c_dims.len() as i32,
2105                    );
2106                }
2107            }
2108        }
2109        Ok(())
2110    }
2111
2112    /// Sets an attribute which holds an array of shapes.
2113    pub fn set_attr_shape_list(
2114        &mut self,
2115        attr_name: &str,
2116        value: &[Shape],
2117    ) -> std::result::Result<(), NulError> {
2118        let c_attr_name = CString::new(attr_name)?;
2119        // Convert Option<i64> in each shape to i64 with None becoming -1.
2120        let c_dims: Vec<Option<Vec<i64>>> = value
2121            .iter()
2122            .map(|x| {
2123                x.0.as_ref()
2124                    .map(|dims| dims.iter().map(|x| (*x).unwrap_or(-1)).collect())
2125            })
2126            .collect();
2127        let ptrs: Vec<*const i64> = c_dims
2128            .iter()
2129            .map(|x| match *x {
2130                None => ptr::null(),
2131                Some(ref dims) => dims.as_ptr(),
2132            })
2133            .collect();
2134        let lens: Vec<c_int> = value
2135            .iter()
2136            .map(|x| match x.0 {
2137                None => -1,
2138                Some(ref dims) => dims.len() as c_int,
2139            })
2140            .collect();
2141        unsafe {
2142            tf::TF_SetAttrShapeList(
2143                self.inner,
2144                c_attr_name.as_ptr(),
2145                ptrs.as_ptr(),
2146                lens.as_ptr(),
2147                ptrs.len() as c_int,
2148            );
2149        }
2150        Ok(())
2151    }
2152
2153    /// Sets an attribute with a `TensorShapeProto` protobuf.
2154    #[allow(trivial_numeric_casts)]
2155    pub fn set_attr_tensor_shape_proto(&mut self, attr_name: &str, value: &[u8]) -> Result<()> {
2156        let c_attr_name = CString::new(attr_name)?;
2157        let mut status = Status::new();
2158        unsafe {
2159            tf::TF_SetAttrTensorShapeProto(
2160                self.inner,
2161                c_attr_name.as_ptr(),
2162                value.as_ptr() as *const std_c_void,
2163                value.len() as size_t,
2164                status.inner(),
2165            );
2166        }
2167        status.into_result()
2168    }
2169
2170    /// Sets an attribute with an array of `TensorShapeProto` protobufs.
2171    #[allow(trivial_numeric_casts)]
2172    pub fn set_attr_tensor_shape_proto_list<T: AsRef<[u8]>>(
2173        &mut self,
2174        attr_name: &str,
2175        value: &[T],
2176    ) -> Result<()> {
2177        let c_attr_name = CString::new(attr_name)?;
2178        let ptrs: Vec<*const c_void> = value
2179            .iter()
2180            .map(|x| x.as_ref().as_ptr() as *const c_void)
2181            .collect();
2182        let lens: Vec<size_t> = value.iter().map(|x| x.as_ref().len() as size_t).collect();
2183        let mut status = Status::new();
2184        unsafe {
2185            tf::TF_SetAttrTensorShapeProtoList(
2186                self.inner,
2187                c_attr_name.as_ptr(),
2188                ptrs.as_ptr() as *const *const std_c_void,
2189                lens.as_ptr(),
2190                ptrs.len() as c_int,
2191                status.inner(),
2192            );
2193        }
2194        status.into_result()
2195    }
2196
2197    /// Sets a tensor-valued attribute.
2198    pub fn set_attr_tensor<T: TensorType>(
2199        &mut self,
2200        attr_name: &str,
2201        value: Tensor<T>,
2202    ) -> Result<()> {
2203        self.set_attr_any_tensor(attr_name, &value)
2204    }
2205
2206    /// Sets a tensor-valued attribute.
2207    pub(crate) fn set_attr_any_tensor(
2208        &mut self,
2209        attr_name: &str,
2210        value: &dyn AnyTensor,
2211    ) -> Result<()> {
2212        let c_attr_name = CString::new(attr_name)?;
2213        let mut status = Status::new();
2214        unsafe {
2215            tf::TF_SetAttrTensor(
2216                self.inner,
2217                c_attr_name.as_ptr(),
2218                value.inner()?,
2219                status.inner(),
2220            );
2221        }
2222        status.into_result()
2223    }
2224
2225    /// Sets an attribute which holds an array of tensors.
2226    pub fn set_attr_tensor_list<I, T>(&mut self, attr_name: &str, value: I) -> Result<()>
2227    where
2228        I: IntoIterator<Item = Tensor<T>>,
2229        T: TensorType,
2230    {
2231        let c_attr_name = CString::new(attr_name)?;
2232        let mut status = Status::new();
2233        unsafe {
2234            // These have to stay alive durng the TF_SetAttrTensorList call.
2235            let tensors: Vec<_> = value.into_iter().collect();
2236            let maybe_ptrs: Result<_> = tensors.iter().map(|x| x.inner()).collect();
2237            let ptrs: Vec<*mut tf::TF_Tensor> = maybe_ptrs?;
2238            tf::TF_SetAttrTensorList(
2239                self.inner,
2240                c_attr_name.as_ptr(),
2241                ptrs.as_ptr() as *const *mut tf::TF_Tensor,
2242                ptrs.len() as c_int,
2243                status.inner(),
2244            );
2245        }
2246        status.into_result()
2247    }
2248
2249    /// Sets an attribute with an `AttrValue` proto.
2250    #[deprecated(since = "0.7.0", note = "Use set_attr_value_proto instead.")]
2251    pub fn set_attr_to_attr_value_proto(&mut self, attr_name: &str, value: &[u8]) -> Result<()> {
2252        self.set_attr_value_proto(attr_name, value)
2253    }
2254
2255    /// Sets an attribute with an `AttrValue` proto.
2256    #[allow(trivial_numeric_casts)]
2257    pub fn set_attr_value_proto(&mut self, attr_name: &str, value: &[u8]) -> Result<()> {
2258        let c_attr_name = CString::new(attr_name)?;
2259        let mut status = Status::new();
2260        unsafe {
2261            tf::TF_SetAttrValueProto(
2262                self.inner,
2263                c_attr_name.as_ptr(),
2264                value.as_ptr() as *const std_c_void,
2265                // Allow trivial_numeric_casts because usize is not
2266                // necessarily size_t.
2267                value.len() as size_t,
2268                status.inner(),
2269            );
2270        }
2271        status.into_result()
2272    }
2273}
2274
2275////////////////////////
2276
2277/// Options that can be passed during function creation.
2278#[derive(Debug)]
2279#[allow(missing_copy_implementations)]
2280pub struct FunctionOptions {
2281    inner: *mut tf::TF_FunctionOptions,
2282}
2283
2284impl Default for FunctionOptions {
2285    fn default() -> Self {
2286        Self::new()
2287    }
2288}
2289
2290impl FunctionOptions {
2291    /// Creates a blank set of options.
2292    pub fn new() -> Self {
2293        FunctionOptions {
2294            inner: ptr::null_mut(), // TODO: Use real options when they become available
2295        }
2296    }
2297}
2298
2299////////////////////////
2300
2301/// Function is a grouping of operations with defined inputs and outputs.
2302/// Once created and added to graphs, functions can be invoked by creating an
2303/// operation whose operation type matches the function name.
2304#[derive(Debug)]
2305pub struct Function {
2306    inner: *mut tf::TF_Function,
2307}
2308
2309impl_drop!(Function, TF_DeleteFunction);
2310
2311impl Function {
2312    /// Returns a serialized representation of the function (as a FunctionDef
2313    /// protocol message).
2314    ///
2315    /// May fail on very large graphs in the future.
2316    pub fn to_function_def(&self) -> Result<Vec<u8>> {
2317        let status = Status::new();
2318        unsafe {
2319            let mut buf = Buffer::from_ptr(ptr::null_mut(), 0);
2320            tf::TF_FunctionToFunctionDef(self.inner, buf.inner_mut(), status.inner);
2321            status.into_result()?;
2322            Ok(buf.into())
2323        }
2324    }
2325
2326    /// Construct and return the function whose FunctionDef representation is
2327    /// serialized in `proto`. Returns a newly created `Function` instance.
2328    pub fn import_function_def(proto: &[u8]) -> Result<Function> {
2329        let status = Status::new();
2330        unsafe {
2331            let inner = tf::TF_FunctionImportFunctionDef(
2332                proto.as_ptr() as *const std_c_void,
2333                proto.len(),
2334                status.inner,
2335            );
2336            status.into_result()?;
2337            Ok(Function { inner })
2338        }
2339    }
2340
2341    /// Sets function attribute named `attr_name` to value stored in `proto`. If
2342    /// this attribute is already set to another value, it is overriden. `proto`
2343    /// should be a sequence of bytes representing a binary serialization of an
2344    /// AttrValue protocol buffer.
2345    pub fn set_attr_value_proto(&mut self, attr_name: &str, proto: &[u8]) -> Result<()> {
2346        let status = Status::new();
2347        let attr_name_cstr = CString::new(attr_name)?;
2348        unsafe {
2349            tf::TF_FunctionSetAttrValueProto(
2350                self.inner,
2351                attr_name_cstr.as_ptr(),
2352                proto.as_ptr() as *const std_c_void,
2353                proto.len(),
2354                status.inner,
2355            );
2356        }
2357        status.into_result()
2358    }
2359
2360    /// Returns the binary-serialized AttrValue proto representation of the
2361    /// value of the `attr_name` attr of the function. If `attr_name` attribute
2362    /// is not present, returns an error.
2363    pub fn get_attr_value_proto(&self, attr_name: &str) -> Result<Vec<u8>> {
2364        let status = Status::new();
2365        let attr_name_cstr = CString::new(attr_name)?;
2366        unsafe {
2367            let mut buf = Buffer::from_ptr(ptr::null_mut(), 0);
2368            tf::TF_FunctionGetAttrValueProto(
2369                self.inner,
2370                attr_name_cstr.as_ptr(),
2371                buf.inner_mut(),
2372                status.inner,
2373            );
2374            status.into_result()?;
2375            Ok(buf.into())
2376        }
2377    }
2378
2379    /// Returns the name of the graph function.
2380    pub fn get_name(&self) -> std::result::Result<String, Utf8Error> {
2381        unsafe {
2382            CStr::from_ptr(tf::TF_FunctionName(self.inner))
2383                .to_str()
2384                .map(|s| s.to_string())
2385        }
2386    }
2387}
2388
2389////////////////////////
2390
2391#[cfg(test)]
2392mod tests {
2393    use super::super::DataType;
2394    use super::super::Shape;
2395    use super::*;
2396
2397    fn add_operation(g: &mut Graph) {
2398        g.new_operation("Variable", "foo").unwrap();
2399    }
2400
2401    fn add(g: &mut Graph, op1: Operation, op2: Operation, name: &str) -> Result<Operation> {
2402        let mut nd = g.new_operation("Add", name)?;
2403        nd.add_input(op1);
2404        nd.add_input(op2);
2405        nd.finish()
2406    }
2407
2408    fn multiply(g: &mut Graph, op1: Operation, op2: Operation, name: &str) -> Result<Operation> {
2409        let mut nd = g.new_operation("Mul", name)?;
2410        nd.add_input(op1);
2411        nd.add_input(op2);
2412        nd.finish()
2413    }
2414
2415    #[test]
2416    fn smoke() {
2417        let mut g = Graph::new();
2418        add_operation(&mut g);
2419        let operation = {
2420            let mut nd = g.new_operation("Variable", "foo").unwrap();
2421            nd.set_attr_type("dtype", DataType::Float).unwrap();
2422            nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2423            nd.finish().unwrap()
2424        };
2425        let mut nd2 = g.new_operation("Variable", "foo2").unwrap();
2426        nd2.set_attr_type("dtype", DataType::Float).unwrap();
2427        nd2.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2428        let operation2 = nd2.finish().unwrap();
2429        assert_eq!("foo", operation.name().unwrap());
2430        assert_eq!("foo2", operation2.name().unwrap());
2431    }
2432
2433    #[test]
2434    fn test_import_graph_def() {
2435        let mut g = Graph::new();
2436        let opts = ImportGraphDefOptions::new();
2437        // An empty array is a valid proto, since all fields are optional.
2438        let status = g.import_graph_def(&[], &opts);
2439        assert!(status.is_ok());
2440    }
2441
2442    #[test]
2443    fn test_get_tensor_shape() {
2444        fn constant<T: TensorType>(graph: &mut Graph, name: &str, value: Tensor<T>) -> Operation {
2445            let mut c = graph.new_operation("Const", name).unwrap();
2446            c.set_attr_tensor("value", value).unwrap();
2447            c.set_attr_type("dtype", T::data_type()).unwrap();
2448            c.finish().unwrap()
2449        }
2450
2451        let mut graph = Graph::new();
2452        let x_init = Tensor::<i32>::new(&[3, 3]);
2453        let x = constant(&mut graph, "x/assign_0", x_init);
2454        assert_eq!(1, x.num_outputs());
2455        assert_eq!(x.output_type(0), DataType::Int32);
2456        let dims = graph.num_dims(x.clone()).unwrap();
2457        assert_eq!(dims, 2);
2458        let shape = graph.tensor_shape(x.clone()).unwrap();
2459        assert_eq!(shape, Shape(Some(vec![Some(3_i64), Some(3_i64)])));
2460    }
2461
2462    #[test]
2463    fn graph_to_function() {
2464        let mut g = Graph::new();
2465        let x = {
2466            let mut nd = g.new_operation("Placeholder", "x").unwrap();
2467            nd.set_attr_type("dtype", DataType::Float).unwrap();
2468            nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2469            nd.finish().unwrap()
2470        };
2471        let two = {
2472            let mut nd = g.new_operation("Const", "two").unwrap();
2473            nd.set_attr_type("dtype", DataType::Float).unwrap();
2474            let mut value = Tensor::new(&[1]);
2475            value[0] = 2.0f32;
2476            nd.set_attr_tensor("value", value).unwrap();
2477            nd.finish().unwrap()
2478        };
2479        let y = multiply(&mut g, two.clone(), x.clone(), "y").unwrap();
2480        let opers = vec![&y];
2481        let inputs = vec![x.clone().into(), two.clone().into()];
2482        let outputs = vec![y.clone().into()];
2483        let output_names = vec!["result"];
2484        let description = "Multiplies by 2";
2485        let opts = FunctionOptions::new();
2486        let f = g
2487            .to_function(
2488                "times_two",
2489                false,
2490                Some(&opers),
2491                &inputs,
2492                &outputs,
2493                Some(&output_names),
2494                &opts,
2495                Some(description),
2496            )
2497            .unwrap();
2498        assert_eq!("times_two", f.get_name().unwrap());
2499        let mut g2 = Graph::new();
2500        assert_eq!(0, g2.num_functions());
2501        assert_eq!(0, g2.get_functions().unwrap().len());
2502        g2.copy_function(&f, None).unwrap();
2503        assert_eq!(1, g2.num_functions());
2504        assert_eq!(1, g2.get_functions().unwrap().len());
2505    }
2506
2507    // This test checks that Operation::get_attr_* returns the value passed in
2508    // by OperationDescription::set_attr_*.  It's long and tedious because we
2509    // need to create several different ops to cover all the different types,
2510    // and the ops have requirements that have to be set up, first.  Once we can
2511    // define our own ops, we may be able to just define a single op with
2512    // attributes for all of the types.
2513    #[test]
2514    #[allow(trivial_casts)] // so we can do assert_eq!(slice, &some_vec as &[_])
2515    fn operation_attributes() {
2516        let mut g = Graph::new();
2517
2518        let shape = Shape(Some(vec![None, Some(3)]));
2519        let variable_op = {
2520            let mut nd = g.new_operation("Variable", "Variable").unwrap();
2521            nd.set_attr_type("dtype", DataType::Int32).unwrap();
2522            nd.set_attr_shape("shape", &shape).unwrap();
2523            nd.set_attr_string("shared_name", "bar").unwrap();
2524            nd.finish().unwrap()
2525        };
2526        assert_eq!("bar", variable_op.get_attr_string("shared_name").unwrap());
2527        assert_eq!(DataType::Int32, variable_op.get_attr_type("dtype").unwrap());
2528        assert_eq!(shape, variable_op.get_attr_shape("shape").unwrap());
2529
2530        let op = {
2531            let mut nd = g
2532                .new_operation("Variable", "Variable_unknown_rank")
2533                .unwrap();
2534            nd.set_attr_type("dtype", DataType::Int32).unwrap();
2535            nd.set_attr_shape("shape", &Shape(None)).unwrap();
2536            nd.finish().unwrap()
2537        };
2538        assert_eq!(Shape(None), op.get_attr_shape("shape").unwrap());
2539
2540        let value = Tensor::<i32>::new(&[1, 3]).with_values(&[1, 2, 3]).unwrap();
2541        let const_op = {
2542            let mut nd = g.new_operation("Const", "Const").unwrap();
2543            nd.set_attr_tensor("value", value.clone()).unwrap();
2544            nd.set_attr_type("dtype", DataType::Int32).unwrap();
2545            nd.finish().unwrap()
2546        };
2547        assert_eq!(value, const_op.get_attr_tensor("value").unwrap());
2548
2549        let op = {
2550            let mut nd = g.new_operation("Assign", "Assign").unwrap();
2551            nd.add_input(variable_op.clone());
2552            nd.add_input(variable_op.clone());
2553            nd.set_attr_bool("validate_shape", true).unwrap();
2554            nd.set_attr_bool("use_locking", false).unwrap();
2555            nd.finish().unwrap()
2556        };
2557        assert_eq!(true, op.get_attr_bool("validate_shape").unwrap());
2558        assert_eq!(false, op.get_attr_bool("use_locking").unwrap());
2559
2560        let op = {
2561            let variable_op = {
2562                let mut nd = g.new_operation("Variable", "MaxPool_in1").unwrap();
2563                nd.set_attr_type("dtype", DataType::Int32).unwrap();
2564                nd.set_attr_shape(
2565                    "shape",
2566                    &Shape(Some(vec![Some(5), Some(5), Some(5), Some(5)])),
2567                )
2568                .unwrap();
2569                nd.finish().unwrap()
2570            };
2571            let mut nd = g.new_operation("MaxPool", "MaxPool").unwrap();
2572            nd.add_input(variable_op);
2573            nd.set_attr_int_list("ksize", &[1, 2, 3, 4]).unwrap();
2574            nd.set_attr_int_list("strides", &[1, 1, 1, 1]).unwrap();
2575            nd.set_attr_string("padding", "VALID").unwrap();
2576            nd.finish().unwrap()
2577        };
2578        assert_eq!(
2579            &[1, 2, 3, 4],
2580            &op.get_attr_int_list("ksize").unwrap() as &[i64]
2581        );
2582
2583        let op = {
2584            let mut nd = g.new_operation("TensorSummary", "TensorSummary").unwrap();
2585            nd.add_input(variable_op.clone());
2586            nd.set_attr_string_list("labels", &["foo", "bar"]).unwrap();
2587            nd.finish().unwrap()
2588        };
2589        assert_eq!(
2590            &["foo".to_string(), "bar".to_string()],
2591            &op.get_attr_string_list("labels").unwrap() as &[_]
2592        );
2593
2594        let op = {
2595            let mut nd = g
2596                .new_operation("ApproximateEqual", "ApproximateEqual")
2597                .unwrap();
2598            nd.add_input(variable_op.clone());
2599            nd.add_input(variable_op.clone());
2600            nd.set_attr_float("tolerance", 3.14).unwrap();
2601            nd.finish().unwrap()
2602        };
2603        assert_eq!(3.14, op.get_attr_float("tolerance").unwrap());
2604
2605        let op = {
2606            let mut nd = g.new_operation("Bucketize", "Bucketize").unwrap();
2607            nd.add_input(variable_op.clone());
2608            nd.set_attr_float_list("boundaries", &[0.1, 2.3]).unwrap();
2609            nd.finish().unwrap()
2610        };
2611        assert_eq!(
2612            &[0.1f32, 2.3],
2613            &op.get_attr_float_list("boundaries").unwrap() as &[_]
2614        );
2615
2616        let shape_list = &[
2617            Shape(None),
2618            Shape(Some(vec![])),
2619            Shape(Some(vec![None])),
2620            Shape(Some(vec![Some(1)])),
2621        ];
2622        let op = {
2623            let mut nd = g
2624                .new_operation("RandomShuffleQueue", "RandomShuffleQueue")
2625                .unwrap();
2626            nd.set_attr_shape_list("shapes", shape_list).unwrap();
2627            nd.set_attr_type_list("component_types", &[DataType::Float, DataType::Int32])
2628                .unwrap();
2629            nd.set_attr_int("seed", 42).unwrap();
2630            nd.finish().unwrap()
2631        };
2632        assert_eq!(
2633            shape_list,
2634            &op.get_attr_shape_list("shapes").unwrap() as &[_]
2635        );
2636        assert_eq!(
2637            &[DataType::Float, DataType::Int32],
2638            &op.get_attr_type_list("component_types").unwrap() as &[_]
2639        );
2640        assert_eq!(42, op.get_attr_int("seed").unwrap());
2641
2642        // TODO: Support get_attr_*/set_attr_*:
2643        // - bool_list
2644        // - tensor_list
2645        // - tensor_shape_proto
2646        // - tensor_shape_proto_list
2647        // - value_proto
2648        // - func_name
2649        // The protos are tricky because we don't currently support proto
2650        // serialization/deserialization, and bool_list and tensor_list (a.k.a.
2651        // list(bool) and list(tensor)) don't seem to be used for any standard
2652        // ops. TF_GetAttrFuncName doesn't exist yet.
2653    }
2654
2655    // Returns a serialized GraphDef proto with variables "a" and "b" and op "a_times_b".
2656    fn graph_def() -> Vec<u8> {
2657        let mut g = Graph::new();
2658        let a = {
2659            let mut nd = g.new_operation("Variable", "a").unwrap();
2660            nd.set_attr_type("dtype", DataType::Int32).unwrap();
2661            nd.set_attr_shape("shape", &Shape(None)).unwrap();
2662            nd.finish().unwrap()
2663        };
2664        let b = {
2665            let mut nd = g.new_operation("Variable", "b").unwrap();
2666            nd.set_attr_type("dtype", DataType::Int32).unwrap();
2667            nd.set_attr_shape("shape", &Shape(None)).unwrap();
2668            nd.finish().unwrap()
2669        };
2670        multiply(&mut g, a, b, "a_times_b").unwrap();
2671        g.graph_def().unwrap()
2672    }
2673
2674    #[test]
2675    fn import_graph_def_uniquify_names() {
2676        let mut g = Graph::new();
2677        let mut opts = ImportGraphDefOptions::new();
2678        g.import_graph_def(&graph_def(), &opts).unwrap();
2679        opts.set_uniquify_names(true);
2680        g.import_graph_def(&graph_def(), &opts).unwrap();
2681        g.operation_by_name_required("a_1").unwrap();
2682    }
2683
2684    #[test]
2685    fn import_graph_def_uniquify_prefix() {
2686        let mut g = Graph::new();
2687        let mut opts = ImportGraphDefOptions::new();
2688        opts.set_prefix("prefix").unwrap();
2689        g.import_graph_def(&graph_def(), &opts).unwrap();
2690        opts.set_uniquify_prefix(true);
2691        g.import_graph_def(&graph_def(), &opts).unwrap();
2692        g.operation_by_name_required("prefix_1/a").unwrap();
2693    }
2694
2695    #[test]
2696    fn import_graph_def_set_default_device() {
2697        let mut g = Graph::new();
2698        let mut opts = ImportGraphDefOptions::new();
2699        opts.set_default_device("fake_device").unwrap();
2700        g.import_graph_def(&graph_def(), &opts).unwrap();
2701        assert_eq!(
2702            g.operation_by_name_required("a").unwrap().device().unwrap(),
2703            "fake_device"
2704        );
2705    }
2706
2707    #[test]
2708    fn import_graph_def_results_return_outputs() {
2709        let mut g = Graph::new();
2710        let mut opts = ImportGraphDefOptions::new();
2711        assert_eq!(opts.num_return_outputs(), 0);
2712        opts.add_return_output("a_times_b", 0).unwrap();
2713        assert_eq!(opts.num_return_outputs(), 1);
2714        let result = g
2715            .import_graph_def_with_results(&graph_def(), &opts)
2716            .unwrap();
2717        let ops = result.return_outputs();
2718        assert_eq!(ops.len(), 1);
2719        assert_eq!(ops[0].operation.name().unwrap(), "a_times_b");
2720        assert_eq!(ops[0].index, 0);
2721    }
2722
2723    #[test]
2724    fn import_graph_def_results_return_operations() {
2725        let mut g = Graph::new();
2726        let mut opts = ImportGraphDefOptions::new();
2727        assert_eq!(opts.num_return_operations(), 0);
2728        opts.add_return_operation("a_times_b").unwrap();
2729        assert_eq!(opts.num_return_operations(), 1);
2730        let result = g
2731            .import_graph_def_with_results(&graph_def(), &opts)
2732            .unwrap();
2733        let ops = result.return_operations();
2734        assert_eq!(ops.len(), 1);
2735        assert_eq!(ops[0].name().unwrap(), "a_times_b");
2736    }
2737
2738    #[test]
2739    fn import_graph_def_results_missing_unused_input_mappings() {
2740        let mut g = Graph::new();
2741        let op = {
2742            let mut nd = g.new_operation("Variable", "foo").unwrap();
2743            nd.set_attr_type("dtype", DataType::Int32).unwrap();
2744            nd.set_attr_shape("shape", &Shape(None)).unwrap();
2745            nd.finish().unwrap()
2746        };
2747        let output = op.into();
2748        let mut opts = ImportGraphDefOptions::new();
2749        opts.add_input_mapping("bar", 3, &output).unwrap();
2750        // An empty array is a valid proto, since all fields are optional.
2751        let result = g.import_graph_def_with_results(&[], &opts).unwrap();
2752        let missing = result.missing_unused_input_mappings().unwrap();
2753        assert_eq!(missing.len(), 1);
2754        assert_eq!(missing[0].0, "bar");
2755        assert_eq!(missing[0].1, 3);
2756    }
2757
2758    #[test]
2759    fn import_graph_def_with_return_outputs() {
2760        let mut g = Graph::new();
2761        let mut opts = ImportGraphDefOptions::new();
2762        assert_eq!(opts.num_return_outputs(), 0);
2763        opts.add_return_output("a_times_b", 0).unwrap();
2764        assert_eq!(opts.num_return_outputs(), 1);
2765        let ops = g
2766            .import_graph_def_with_return_outputs(&graph_def(), &opts)
2767            .unwrap();
2768        assert_eq!(ops.len(), 1);
2769        assert_eq!(ops[0].operation.name().unwrap(), "a_times_b");
2770        assert_eq!(ops[0].index, 0);
2771    }
2772
2773    #[test]
2774    fn graph_get_op_def() {
2775        let g = Graph::new();
2776        // We don't want to compare the actual proto because it may change across releases.
2777        assert!(g.get_op_def("Const").unwrap().len() > 0);
2778    }
2779
2780    #[test]
2781    fn graph_versions() {
2782        let g = Graph::new();
2783        // We don't want to compare the actual proto because it may change across releases.
2784        assert!(g.versions().unwrap().len() > 0);
2785    }
2786
2787    #[test]
2788    fn graph_generate_operation_name() {
2789        let mut g = Graph::new();
2790        for i in 0..5 {
2791            assert_eq!(i, g.generate_operation_name("foo_{}").unwrap());
2792            let mut nd = g
2793                .new_operation("Placeholder", &format!("foo_{}", i))
2794                .unwrap();
2795            nd.set_attr_type("dtype", DataType::Float).unwrap();
2796            nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2797            nd.finish().unwrap();
2798        }
2799    }
2800
2801    #[test]
2802    fn graph_add_gradients() {
2803        // TODO: Add an integration test to verify that the gradient behaves as expected.
2804        for (prefix, expected_prefix) in &[
2805            (Some("arbitrary_prefix"), "arbitrary_prefix/"),
2806            (None, "gradients/"),
2807        ] {
2808            let mut g = Graph::new();
2809            let x = {
2810                let mut nd = g.new_operation("Placeholder", "x").unwrap();
2811                nd.set_attr_type("dtype", DataType::Float).unwrap();
2812                nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2813                nd.finish().unwrap()
2814            };
2815            let y = {
2816                let mut nd = g.new_operation("Placeholder", "y").unwrap();
2817                nd.set_attr_type("dtype", DataType::Float).unwrap();
2818                nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2819                nd.finish().unwrap()
2820            };
2821            let x_squared = multiply(&mut g, x.clone(), x.clone(), "x_squared").unwrap();
2822            let x_times_y = multiply(&mut g, x.clone(), y.clone(), "x_times_y").unwrap();
2823            let x_plus_y = add(&mut g, x.clone(), y.clone(), "x_plus_y").unwrap();
2824            // y_outs and x_outs are intentionally different lengths, so we can test that the lengths line up properly.
2825            let y_outs = vec![x_squared.into(), x_times_y.into(), x_plus_y.into()];
2826            let x_outs = vec![x.into(), y.into()];
2827            let dy = g.add_gradients(*prefix, &y_outs, &x_outs, None).unwrap();
2828            assert_eq!(dy.len(), 2);
2829            for d in dy {
2830                let d = d.unwrap();
2831                assert_eq!(d.index, 0);
2832                let name = d.operation.name().unwrap();
2833                assert!(
2834                    name.starts_with(expected_prefix),
2835                    "name = {}, expected prefix = {}",
2836                    name,
2837                    expected_prefix
2838                );
2839            }
2840        }
2841    }
2842
2843    #[test]
2844    fn graph_add_gradients_stopped_gradient() {
2845        // TODO: Add an integration test to verify that the gradient behaves as expected.
2846        for prefix in &[Some("arbitrary_prefix"), None] {
2847            let mut g = Graph::new();
2848            let zero = {
2849                let mut nd = g.new_operation("Const", "zero").unwrap();
2850                nd.set_attr_type("dtype", DataType::Int32).unwrap();
2851                nd.set_attr_tensor("value", Tensor::<i32>::from(0)).unwrap();
2852                nd.finish().unwrap()
2853            };
2854            let x = {
2855                let mut nd = g.new_operation("Placeholder", "x").unwrap();
2856                nd.set_attr_type("dtype", DataType::Float).unwrap();
2857                nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2858                nd.finish().unwrap()
2859            };
2860            let argmax_x = {
2861                let mut nd = g.new_operation("ArgMax", "argmax_x").unwrap();
2862                nd.add_input(x.clone());
2863                nd.add_input(zero);
2864                nd.finish().unwrap()
2865            };
2866            let stopped_gradient = {
2867                let mut nd = g.new_operation("StopGradient", "stopped").unwrap();
2868                nd.add_input(argmax_x.clone());
2869                nd.finish().unwrap()
2870            };
2871            let y_outs = vec![stopped_gradient.into()];
2872            let x_outs = vec![x.into()];
2873            let dy = g.add_gradients(*prefix, &y_outs, &x_outs, None).unwrap();
2874            assert_eq!(dy.len(), 1);
2875            for d in &dy {
2876                assert!(d.is_none());
2877            }
2878        }
2879    }
2880
2881    #[test]
2882    fn graph_add_gradients_no_gradient() {
2883        // TODO: Add an integration test to verify that the gradient behaves as expected.
2884        for prefix in &[Some("arbitrary_prefix"), None] {
2885            let mut g = Graph::new();
2886            let zero = {
2887                let mut nd = g.new_operation("Const", "zero").unwrap();
2888                nd.set_attr_type("dtype", DataType::Int32).unwrap();
2889                nd.set_attr_tensor("value", Tensor::<i32>::from(0)).unwrap();
2890                nd.finish().unwrap()
2891            };
2892            let x = {
2893                let mut nd = g.new_operation("Placeholder", "x").unwrap();
2894                nd.set_attr_type("dtype", DataType::Float).unwrap();
2895                nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2896                nd.finish().unwrap()
2897            };
2898            let argmax_x = {
2899                let mut nd = g.new_operation("ArgMax", "argmax_x").unwrap();
2900                nd.add_input(x.clone());
2901                nd.add_input(zero);
2902                nd.finish().unwrap()
2903            };
2904            let y_outs = vec![argmax_x.into()];
2905            let x_outs = vec![x.into()];
2906            assert!(g.add_gradients(*prefix, &y_outs, &x_outs, None).is_err());
2907        }
2908    }
2909
2910    #[test]
2911    fn output_consumers() {
2912        let mut graph = Graph::new();
2913        let x_op = {
2914            let mut nd = graph.new_operation("Placeholder", "x").unwrap();
2915            nd.set_attr_type("dtype", DataType::String).unwrap();
2916            nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2917            nd.finish().unwrap()
2918        };
2919        let _y_op = {
2920            let mut nd = graph.new_operation("EncodeBase64", "y").unwrap();
2921            nd.add_input(x_op.clone());
2922            nd.finish().unwrap()
2923        };
2924        assert_eq!(x_op.num_outputs(), 1);
2925        let consumers = x_op.output_consumers(0);
2926        assert_eq!(consumers.len(), 1);
2927        assert_eq!(consumers[0].0.name().unwrap(), "y");
2928        assert_eq!(consumers[0].1, 0);
2929    }
2930
2931    #[test]
2932    fn output_name() {
2933        assert_eq!(
2934            "foo:1".parse::<OutputName>().unwrap(),
2935            OutputName {
2936                name: "foo".to_string(),
2937                index: 1
2938            }
2939        );
2940        assert_eq!(
2941            OutputName {
2942                name: "foo".to_string(),
2943                index: 1
2944            }
2945            .to_string(),
2946            "foo:1"
2947        );
2948        assert_eq!(
2949            "foo".parse::<OutputName>().unwrap(),
2950            OutputName {
2951                name: "foo".to_string(),
2952                index: 0
2953            }
2954        );
2955        assert!("foo:bar".parse::<OutputName>().is_err());
2956        assert!("foo:0:1".parse::<OutputName>().is_err());
2957    }
2958
2959    #[test]
2960    fn device() {
2961        let mut graph = Graph::new();
2962        let op = {
2963            let mut nd = graph.new_operation("NoOp", "x").unwrap();
2964            nd.set_device("foo").unwrap();
2965            nd.finish().unwrap()
2966        };
2967        assert_eq!(op.device().unwrap(), "foo");
2968    }
2969
2970    #[test]
2971    fn control_inputs() {
2972        let mut graph = Graph::new();
2973        let x = graph.new_operation("NoOp", "x").unwrap().finish().unwrap();
2974        let y = {
2975            let mut nd = graph.new_operation("NoOp", "y").unwrap();
2976            nd.add_control_input(&x);
2977            nd.finish().unwrap()
2978        };
2979        assert_eq!(
2980            y.control_inputs()
2981                .iter()
2982                .map(|n| n.name().unwrap())
2983                .collect::<Vec<_>>(),
2984            &["x"]
2985        );
2986    }
2987}