Skip to main content

tract_gpu/
sync.rs

1use crate::fact::{DeviceFact, DeviceTypedFactExt};
2use crate::tensor::{DeviceTensorExt, IntoDevice};
3use derive_new::new;
4use std::fmt;
5use tract_core::internal::*;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
8pub enum DeviceSyncKind {
9    ToHost,
10    ToDevice,
11}
12
13impl fmt::Display for DeviceSyncKind {
14    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
15        write!(f, "{self:?}")
16    }
17}
18
19#[derive(Debug, Clone, new, Copy, PartialEq, Eq, Hash)]
20pub struct DeviceSync {
21    pub kind: DeviceSyncKind,
22}
23
24impl Op for DeviceSync {
25    fn name(&self) -> StaticName {
26        format!("DeviceSync{}", self.kind).into()
27    }
28
29    fn same_as(&self, other: &dyn Op) -> bool {
30        let Some(other) = other.downcast_ref::<DeviceSync>() else { return false };
31        self == other
32    }
33
34    op_as_typed_op!();
35}
36
37impl EvalOp for DeviceSync {
38    fn is_stateless(&self) -> bool {
39        true
40    }
41
42    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
43        let input = args_1!(inputs);
44        match self.kind {
45            DeviceSyncKind::ToHost => {
46                let device_tensor = input.to_device_tensor()?;
47
48                let tensor = device_tensor
49                    .to_host()
50                    .with_context(|| "Error while syncing device tensor to host")?;
51                Ok(tvec![tensor.into_tvalue()])
52            }
53            DeviceSyncKind::ToDevice => {
54                let device_input = if let Some(t) = input.as_arc_tensor() {
55                    Arc::clone(t).into_device()?
56                } else {
57                    input.into_tensor().into_device()?
58                };
59                Ok(tvec![device_input.into_opaque_tensor().into()])
60            }
61        }
62    }
63}
64
65impl TypedOp for DeviceSync {
66    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
67        let input = inputs[0];
68        match self.kind {
69            DeviceSyncKind::ToHost => {
70                let mut typed_fact = input
71                    .to_device_fact()
72                    .with_context(|| {
73                        "Cannot sync to Host a tensor without DeviceFact as metadata in its TypedFact"
74                    })?
75                    .clone()
76                    .into_typed_fact();
77                if let Some(konst) = input.konst.clone() {
78                    typed_fact.konst = Some(konst.to_device_tensor()?.to_host()?);
79                }
80                Ok(tvec!(typed_fact))
81            }
82            DeviceSyncKind::ToDevice => {
83                ensure!(
84                    input.datum_type != DatumType::Opaque,
85                    "Cannot sync Opaque Tensor to Device"
86                );
87                Ok(tvec![DeviceFact::from_host(input.clone())?.into_opaque_fact()])
88            }
89        }
90    }
91
92    as_op!();
93}