1use crate::internal::*;
2use std::ops::Deref;
3use std::rc::Rc;
4
5use tract_ndarray::Array;
6use TValue::*;
7
8#[derive(Clone, Eq)]
9pub enum TValue {
10 Const(Arc<Tensor>),
11 Var(Rc<Tensor>),
12}
13
14impl std::fmt::Debug for TValue {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 (**self).fmt(f)
17 }
18}
19
20impl PartialEq for TValue {
21 fn eq(&self, other: &Self) -> bool {
22 self.deref() == other.deref()
23 }
24}
25
26impl TValue {
27 pub fn is_exclusive(&self) -> bool {
28 match self {
29 Var(it) => Rc::strong_count(it) == 1,
30 Const(_) => false,
31 }
32 }
33
34 pub fn from_const(t: Arc<Tensor>) -> Self {
35 Const(t)
36 }
37}
38
39impl From<Tensor> for TValue {
40 fn from(t: Tensor) -> Self {
41 TValue::Var(std::rc::Rc::new(t))
42 }
43}
44
45impl std::ops::Deref for TValue {
46 type Target = Tensor;
47 fn deref(&self) -> &Self::Target {
48 match self {
49 Const(it) => it,
50 Var(it) => it,
51 }
52 }
53}
54
55impl std::borrow::Borrow<Tensor> for TValue {
56 fn borrow(&self) -> &Tensor {
57 self
58 }
59}
60
61impl IntoTensor for TValue {
62 fn into_tensor(self) -> Tensor {
63 match self {
64 Var(it) => Rc::try_unwrap(it).unwrap_or_else(|t| (*t).clone()),
65 Const(it) => it.into_tensor(),
66 }
67 }
68}
69
70impl IntoArcTensor for TValue {
71 fn into_arc_tensor(self) -> Arc<Tensor> {
72 match self {
73 Var(ref _it) => self.into_tensor().into_arc_tensor(),
74 Const(t) => t,
75 }
76 }
77}
78
79pub trait IntoTValue {
80 fn into_tvalue(self) -> TValue;
81}
82
83impl IntoTValue for Tensor {
84 fn into_tvalue(self) -> TValue {
85 self.into_tensor().into()
86 }
87}
88
89impl IntoTValue for Arc<Tensor> {
90 fn into_tvalue(self) -> TValue {
91 Const(self)
92 }
93}
94
95impl<D: ::ndarray::Dimension, T: Datum> IntoTValue for Array<T, D> {
96 fn into_tvalue(self) -> TValue {
97 Tensor::from(self).into_tvalue()
98 }
99}