1use std::fmt;
2use tract_core::internal::*;
3
4#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub enum DeviceTensorOrigin {
7 FromDevice,
11 FromHost,
15}
16
17#[derive(Clone, PartialEq, Eq, Hash)]
18pub struct DeviceFact {
19 pub origin: DeviceTensorOrigin,
20 pub fact: TypedFact,
21 pub state_owned: bool,
22}
23
24impl DeviceFact {
25 pub fn new(origin: DeviceTensorOrigin, fact: TypedFact) -> TractResult<Self> {
26 ensure!(fact.as_device_fact().is_none());
27 let new_fact = fact.without_value();
28 Ok(Self { origin, fact: new_fact, state_owned: false })
29 }
30
31 pub fn from_host(fact: TypedFact) -> TractResult<Self> {
32 Self::new(DeviceTensorOrigin::FromHost, fact)
33 }
34
35 pub fn is_from_device(&self) -> bool {
36 matches!(self.origin, DeviceTensorOrigin::FromDevice)
37 }
38
39 pub fn is_state_owned(&self) -> bool {
40 self.state_owned
41 }
42
43 pub fn is_from_host(&self) -> bool {
44 matches!(self.origin, DeviceTensorOrigin::FromHost)
45 }
46
47 pub fn into_typed_fact(self) -> TypedFact {
48 self.fact
49 }
50
51 pub fn into_exotic_fact(self) -> TypedFact {
52 let dt = self.fact.datum_type;
53 let shape = self.fact.shape.clone();
54 TypedFact::dt_shape(dt, shape).with_exotic_fact(self)
55 }
56}
57
58impl ExoticFact for DeviceFact {
59 fn clarify_dt_shape(&self) -> Option<(DatumType, TVec<TDim>)> {
60 Some((self.fact.datum_type, self.fact.shape.to_tvec()))
61 }
62
63 fn buffer_sizes(&self) -> TVec<TDim> {
64 let inner_fact = &self.fact;
65 let mut sizes = tvec!(inner_fact.shape.volume() * inner_fact.datum_type.size_of());
66 if let Some(of) = inner_fact.exotic_fact() {
67 sizes.extend(of.buffer_sizes());
68 }
69 sizes
70 }
71 fn compatible_with(&self, other: &dyn ExoticFact) -> bool {
72 other.is::<Self>()
73 }
74}
75
76impl fmt::Debug for DeviceFact {
77 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
78 match self.origin {
79 DeviceTensorOrigin::FromHost => write!(fmt, "FromHost({:?})", self.without_value()),
80 DeviceTensorOrigin::FromDevice => {
81 write!(fmt, "FromDevice({:?})", self.fact.without_value())
82 }
83 }
84 }
85}
86
87pub trait DeviceTypedFactExt {
88 fn to_device_fact(&self) -> TractResult<&DeviceFact>;
89 fn as_device_fact(&self) -> Option<&DeviceFact>;
90 fn as_device_fact_mut(&mut self) -> Option<&mut DeviceFact>;
91}
92
93impl DeviceTypedFactExt for TypedFact {
94 fn to_device_fact(&self) -> TractResult<&DeviceFact> {
95 self.exotic_fact
96 .as_ref()
97 .and_then(|m| m.downcast_ref::<DeviceFact>())
98 .ok_or_else(|| anyhow!("DeviceFact not found"))
99 }
100 fn as_device_fact(&self) -> Option<&DeviceFact> {
101 self.exotic_fact.as_ref().and_then(|m| m.downcast_ref::<DeviceFact>())
102 }
103 fn as_device_fact_mut(&mut self) -> Option<&mut DeviceFact> {
104 self.exotic_fact.as_mut().and_then(|m| m.downcast_mut::<DeviceFact>())
105 }
106}
107
108impl std::ops::Deref for DeviceFact {
109 type Target = TypedFact;
110 fn deref(&self) -> &Self::Target {
111 &self.fact
112 }
113}
114
115impl std::convert::AsRef<TypedFact> for DeviceFact {
116 fn as_ref(&self) -> &TypedFact {
117 &self.fact
118 }
119}