tract_core/
value.rs

1use crate::internal::*;
2use std::ops::Deref;
3use std::rc::Rc;
4
5use tract_ndarray::Array;
6use TValue::*;
7
8#[derive(Clone, Eq)]
9pub enum TValue {
10    Const(Arc<Tensor>),
11    Var(Rc<Tensor>),
12}
13
14impl std::fmt::Debug for TValue {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        (**self).fmt(f)
17    }
18}
19
20impl PartialEq for TValue {
21    fn eq(&self, other: &Self) -> bool {
22        self.deref() == other.deref()
23    }
24}
25
26impl TValue {
27    pub fn is_exclusive(&self) -> bool {
28        match self {
29            Var(it) => Rc::strong_count(it) == 1,
30            Const(_) => false,
31        }
32    }
33
34    pub fn from_const(t: Arc<Tensor>) -> Self {
35        Const(t)
36    }
37
38    pub fn as_arc_tensor(&self) -> Option<&Arc<Tensor>> {
39        match self {
40            Const(t) => Some(t),
41            Var(_) => None,
42        }
43    }
44}
45
46impl From<Tensor> for TValue {
47    fn from(t: Tensor) -> Self {
48        TValue::Var(std::rc::Rc::new(t))
49    }
50}
51
52impl std::ops::Deref for TValue {
53    type Target = Tensor;
54    fn deref(&self) -> &Self::Target {
55        match self {
56            Const(it) => it,
57            Var(it) => it,
58        }
59    }
60}
61
62impl std::borrow::Borrow<Tensor> for TValue {
63    fn borrow(&self) -> &Tensor {
64        self
65    }
66}
67
68impl IntoTensor for TValue {
69    fn into_tensor(self) -> Tensor {
70        match self {
71            Var(it) => Rc::try_unwrap(it).unwrap_or_else(|t| (*t).clone()),
72            Const(it) => it.into_tensor(),
73        }
74    }
75}
76
77impl IntoArcTensor for TValue {
78    fn into_arc_tensor(self) -> Arc<Tensor> {
79        match self {
80            Var(ref _it) => self.into_tensor().into_arc_tensor(),
81            Const(t) => t,
82        }
83    }
84}
85
86pub trait IntoTValue {
87    fn into_tvalue(self) -> TValue;
88}
89
90impl IntoTValue for Tensor {
91    fn into_tvalue(self) -> TValue {
92        self.into_tensor().into()
93    }
94}
95
96impl IntoTValue for Arc<Tensor> {
97    fn into_tvalue(self) -> TValue {
98        Const(self)
99    }
100}
101
102impl<D: ::ndarray::Dimension, T: Datum> IntoTValue for Array<T, D> {
103    fn into_tvalue(self) -> TValue {
104        Tensor::from(self).into_tvalue()
105    }
106}