Skip to main content

tract_gpu/
fact.rs

1use std::fmt;
2use tract_core::internal::*;
3
4/// Origin of the GPU tensor
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub enum DeviceTensorOrigin {
7    /// Tensor outputted by a device operator
8    /// Can be either a Host or ArenaView tensor
9    /// Note: Tensors marked as Device are from asynchronous operations.
10    FromDevice,
11    /// Tensor built from a CPU tensor (CPU op output or Const)
12    /// Can be only Host tensor.
13    /// Note: Tensors marked as Host are from synchronous operations.
14    FromHost,
15}
16
17#[derive(Clone, PartialEq, Eq, Hash)]
18pub struct DeviceFact {
19    pub origin: DeviceTensorOrigin,
20    pub fact: TypedFact,
21    pub state_owned: bool,
22}
23
24impl DeviceFact {
25    pub fn new(origin: DeviceTensorOrigin, fact: TypedFact) -> TractResult<Self> {
26        ensure!(fact.as_device_fact().is_none());
27        let new_fact = fact.without_value();
28        Ok(Self { origin, fact: new_fact, state_owned: false })
29    }
30
31    pub fn from_host(fact: TypedFact) -> TractResult<Self> {
32        Self::new(DeviceTensorOrigin::FromHost, fact)
33    }
34
35    pub fn is_from_device(&self) -> bool {
36        matches!(self.origin, DeviceTensorOrigin::FromDevice)
37    }
38
39    pub fn is_state_owned(&self) -> bool {
40        self.state_owned
41    }
42
43    pub fn is_from_host(&self) -> bool {
44        matches!(self.origin, DeviceTensorOrigin::FromHost)
45    }
46
47    pub fn into_typed_fact(self) -> TypedFact {
48        self.fact
49    }
50
51    pub fn into_exotic_fact(self) -> TypedFact {
52        let dt = self.fact.datum_type;
53        let shape = self.fact.shape.clone();
54        TypedFact::dt_shape(dt, shape).with_exotic_fact(self)
55    }
56}
57
58impl ExoticFact for DeviceFact {
59    fn clarify_dt_shape(&self) -> Option<(DatumType, TVec<TDim>)> {
60        Some((self.fact.datum_type, self.fact.shape.to_tvec()))
61    }
62
63    fn buffer_sizes(&self) -> TVec<TDim> {
64        let inner_fact = &self.fact;
65        let mut sizes = tvec!(inner_fact.shape.volume() * inner_fact.datum_type.size_of());
66        if let Some(of) = inner_fact.exotic_fact() {
67            sizes.extend(of.buffer_sizes());
68        }
69        sizes
70    }
71    fn compatible_with(&self, other: &dyn ExoticFact) -> bool {
72        other.is::<Self>()
73    }
74}
75
76impl fmt::Debug for DeviceFact {
77    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
78        match self.origin {
79            DeviceTensorOrigin::FromHost => write!(fmt, "FromHost({:?})", self.without_value()),
80            DeviceTensorOrigin::FromDevice => {
81                write!(fmt, "FromDevice({:?})", self.fact.without_value())
82            }
83        }
84    }
85}
86
87pub trait DeviceTypedFactExt {
88    fn to_device_fact(&self) -> TractResult<&DeviceFact>;
89    fn as_device_fact(&self) -> Option<&DeviceFact>;
90    fn as_device_fact_mut(&mut self) -> Option<&mut DeviceFact>;
91}
92
93impl DeviceTypedFactExt for TypedFact {
94    fn to_device_fact(&self) -> TractResult<&DeviceFact> {
95        self.exotic_fact
96            .as_ref()
97            .and_then(|m| m.downcast_ref::<DeviceFact>())
98            .ok_or_else(|| anyhow!("DeviceFact not found"))
99    }
100    fn as_device_fact(&self) -> Option<&DeviceFact> {
101        self.exotic_fact.as_ref().and_then(|m| m.downcast_ref::<DeviceFact>())
102    }
103    fn as_device_fact_mut(&mut self) -> Option<&mut DeviceFact> {
104        self.exotic_fact.as_mut().and_then(|m| m.downcast_mut::<DeviceFact>())
105    }
106}
107
108impl std::ops::Deref for DeviceFact {
109    type Target = TypedFact;
110    fn deref(&self) -> &Self::Target {
111        &self.fact
112    }
113}
114
115impl std::convert::AsRef<TypedFact> for DeviceFact {
116    fn as_ref(&self) -> &TypedFact {
117        &self.fact
118    }
119}