1use crate::internal::*;
2use std::ops::Deref;
3use std::rc::Rc;
4
5use tract_ndarray::Array;
6use TValue::*;
7
8#[derive(Clone, Eq)]
9pub enum TValue {
10 Const(Arc<Tensor>),
11 Var(Rc<Tensor>),
12}
13
14impl std::fmt::Debug for TValue {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 (**self).fmt(f)
17 }
18}
19
20impl PartialEq for TValue {
21 fn eq(&self, other: &Self) -> bool {
22 self.deref() == other.deref()
23 }
24}
25
26impl TValue {
27 pub fn is_exclusive(&self) -> bool {
28 match self {
29 Var(it) => Rc::strong_count(it) == 1,
30 Const(_) => false,
31 }
32 }
33
34 pub fn from_const(t: Arc<Tensor>) -> Self {
35 Const(t)
36 }
37
38 pub fn as_arc_tensor(&self) -> Option<&Arc<Tensor>> {
39 match self {
40 Const(t) => Some(t),
41 Var(_) => None,
42 }
43 }
44}
45
46impl From<Tensor> for TValue {
47 fn from(t: Tensor) -> Self {
48 TValue::Var(std::rc::Rc::new(t))
49 }
50}
51
52impl std::ops::Deref for TValue {
53 type Target = Tensor;
54 fn deref(&self) -> &Self::Target {
55 match self {
56 Const(it) => it,
57 Var(it) => it,
58 }
59 }
60}
61
62impl std::borrow::Borrow<Tensor> for TValue {
63 fn borrow(&self) -> &Tensor {
64 self
65 }
66}
67
68impl IntoTensor for TValue {
69 fn into_tensor(self) -> Tensor {
70 match self {
71 Var(it) => Rc::try_unwrap(it).unwrap_or_else(|t| (*t).clone()),
72 Const(it) => it.into_tensor(),
73 }
74 }
75}
76
77impl IntoArcTensor for TValue {
78 fn into_arc_tensor(self) -> Arc<Tensor> {
79 match self {
80 Var(ref _it) => self.into_tensor().into_arc_tensor(),
81 Const(t) => t,
82 }
83 }
84}
85
86pub trait IntoTValue {
87 fn into_tvalue(self) -> TValue;
88}
89
90impl IntoTValue for Tensor {
91 fn into_tvalue(self) -> TValue {
92 self.into_tensor().into()
93 }
94}
95
96impl IntoTValue for Arc<Tensor> {
97 fn into_tvalue(self) -> TValue {
98 Const(self)
99 }
100}
101
102impl<D: ::ndarray::Dimension, T: Datum> IntoTValue for Array<T, D> {
103 fn into_tvalue(self) -> TValue {
104 Tensor::from(self).into_tvalue()
105 }
106}