tract_core/
value.rs

1use crate::internal::*;
2use std::ops::Deref;
3use std::rc::Rc;
4
5use tract_ndarray::Array;
6use TValue::*;
7
8#[derive(Clone, Eq)]
9pub enum TValue {
10    Const(Arc<Tensor>),
11    Var(Rc<Tensor>),
12}
13
14impl std::fmt::Debug for TValue {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        (**self).fmt(f)
17    }
18}
19
20impl PartialEq for TValue {
21    fn eq(&self, other: &Self) -> bool {
22        self.deref() == other.deref()
23    }
24}
25
26impl TValue {
27    pub fn is_exclusive(&self) -> bool {
28        match self {
29            Var(it) => Rc::strong_count(it) == 1,
30            Const(_) => false,
31        }
32    }
33
34    pub fn from_const(t: Arc<Tensor>) -> Self {
35        Const(t)
36    }
37}
38
39impl From<Tensor> for TValue {
40    fn from(t: Tensor) -> Self {
41        TValue::Var(std::rc::Rc::new(t))
42    }
43}
44
45impl std::ops::Deref for TValue {
46    type Target = Tensor;
47    fn deref(&self) -> &Self::Target {
48        match self {
49            Const(it) => it,
50            Var(it) => it,
51        }
52    }
53}
54
55impl std::borrow::Borrow<Tensor> for TValue {
56    fn borrow(&self) -> &Tensor {
57        self
58    }
59}
60
61impl IntoTensor for TValue {
62    fn into_tensor(self) -> Tensor {
63        match self {
64            Var(it) => Rc::try_unwrap(it).unwrap_or_else(|t| (*t).clone()),
65            Const(it) => it.into_tensor(),
66        }
67    }
68}
69
70impl IntoArcTensor for TValue {
71    fn into_arc_tensor(self) -> Arc<Tensor> {
72        match self {
73            Var(ref _it) => self.into_tensor().into_arc_tensor(),
74            Const(t) => t,
75        }
76    }
77}
78
79pub trait IntoTValue {
80    fn into_tvalue(self) -> TValue;
81}
82
83impl IntoTValue for Tensor {
84    fn into_tvalue(self) -> TValue {
85        self.into_tensor().into()
86    }
87}
88
89impl IntoTValue for Arc<Tensor> {
90    fn into_tvalue(self) -> TValue {
91        Const(self)
92    }
93}
94
95impl<D: ::ndarray::Dimension, T: Datum> IntoTValue for Array<T, D> {
96    fn into_tvalue(self) -> TValue {
97        Tensor::from(self).into_tvalue()
98    }
99}