1use std::fmt;
2use tract_core::internal::*;
3
4#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub enum DeviceTensorOrigin {
7 FromDevice,
11 FromHost,
15}
16
17#[derive(Clone, PartialEq, Eq, Hash)]
18pub struct DeviceFact {
19 pub origin: DeviceTensorOrigin,
20 pub fact: TypedFact,
21}
22
23impl DeviceFact {
24 pub fn new(origin: DeviceTensorOrigin, fact: TypedFact) -> TractResult<Self> {
25 ensure!(fact.as_device_fact().is_none());
26 let mut fact_wo_cst = fact.clone();
27 if fact.opaque_fact.is_some() {
28 fact_wo_cst.konst = None;
29 fact_wo_cst.uniform = None;
30 }
31 Ok(Self { origin, fact: fact_wo_cst })
32 }
33
34 pub fn from_host(fact: TypedFact) -> TractResult<Self> {
35 Self::new(DeviceTensorOrigin::FromHost, fact)
36 }
37
38 pub fn is_from_device(&self) -> bool {
39 matches!(self.origin, DeviceTensorOrigin::FromDevice)
40 }
41
42 pub fn is_from_host(&self) -> bool {
43 matches!(self.origin, DeviceTensorOrigin::FromHost)
44 }
45
46 pub fn into_typed_fact(self) -> TypedFact {
47 self.fact
48 }
49
50 pub fn into_opaque_fact(self) -> TypedFact {
51 TypedFact::dt_scalar(DatumType::Opaque).with_opaque_fact(self)
52 }
53}
54
55impl OpaqueFact for DeviceFact {
56 fn clarify_dt_shape(&self) -> Option<(DatumType, &[usize])> {
57 self.fact.shape.as_concrete().map(|s| (self.fact.datum_type, s))
58 }
59
60 fn mem_size(&self) -> TDim {
61 self.fact.mem_size()
62 }
63 fn same_as(&self, other: &dyn OpaqueFact) -> bool {
64 other.downcast_ref::<Self>().is_some_and(|o| o == self)
65 }
66 fn compatible_with(&self, other: &dyn OpaqueFact) -> bool {
67 other.is::<Self>()
68 }
69}
70
71impl fmt::Debug for DeviceFact {
72 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
73 match self.origin {
74 DeviceTensorOrigin::FromHost => write!(fmt, "FromHost({:?})", self.without_value()),
75 DeviceTensorOrigin::FromDevice => {
76 write!(fmt, "FromDevice({:?})", self.fact.without_value())
77 }
78 }
79 }
80}
81
82pub trait DeviceTypedFactExt {
83 fn to_device_fact(&self) -> TractResult<&DeviceFact>;
84 fn as_device_fact(&self) -> Option<&DeviceFact>;
85}
86
87impl DeviceTypedFactExt for TypedFact {
88 fn to_device_fact(&self) -> TractResult<&DeviceFact> {
89 ensure!(
90 self.datum_type == DatumType::Opaque,
91 "Cannot retrieve DeviceFact from a non Opaque Tensor"
92 );
93 self.opaque_fact
94 .as_ref()
95 .and_then(|m| m.downcast_ref::<DeviceFact>())
96 .ok_or_else(|| anyhow!("DeviceFact not found in Opaque Tensor"))
97 }
98 fn as_device_fact(&self) -> Option<&DeviceFact> {
99 self.opaque_fact.as_ref().and_then(|m| m.downcast_ref::<DeviceFact>())
100 }
101}
102
103impl std::ops::Deref for DeviceFact {
104 type Target = TypedFact;
105 fn deref(&self) -> &Self::Target {
106 &self.fact
107 }
108}
109
110impl std::convert::AsRef<TypedFact> for DeviceFact {
111 fn as_ref(&self) -> &TypedFact {
112 &self.fact
113 }
114}