1use crate::fact::{DeviceFact, DeviceTypedFactExt};
2use crate::tensor::{DeviceTensorExt, IntoDevice};
3use derive_new::new;
4use std::fmt;
5use tract_core::internal::*;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
8pub enum DeviceSyncKind {
9 ToHost,
10 ToDevice,
11}
12
13impl fmt::Display for DeviceSyncKind {
14 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
15 write!(f, "{self:?}")
16 }
17}
18
19#[derive(Debug, Clone, new, Copy, PartialEq, Eq, Hash)]
20pub struct DeviceSync {
21 pub kind: DeviceSyncKind,
22}
23
24impl Op for DeviceSync {
25 fn name(&self) -> StaticName {
26 format!("DeviceSync{}", self.kind).into()
27 }
28
29 fn same_as(&self, other: &dyn Op) -> bool {
30 let Some(other) = other.downcast_ref::<DeviceSync>() else { return false };
31 self == other
32 }
33
34 op_as_typed_op!();
35}
36
37impl EvalOp for DeviceSync {
38 fn is_stateless(&self) -> bool {
39 true
40 }
41
42 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
43 let input = args_1!(inputs);
44 match self.kind {
45 DeviceSyncKind::ToHost => {
46 let device_tensor = input.to_device_tensor()?;
47
48 let tensor = device_tensor
49 .to_host()
50 .with_context(|| "Error while syncing device tensor to host")?;
51 Ok(tvec![tensor.into_tvalue()])
52 }
53 DeviceSyncKind::ToDevice => {
54 let device_input = if let Some(t) = input.as_arc_tensor() {
55 Arc::clone(t).into_device()?
56 } else {
57 input.into_tensor().into_device()?
58 };
59 Ok(tvec![device_input.into_opaque_tensor().into()])
60 }
61 }
62 }
63}
64
65impl TypedOp for DeviceSync {
66 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
67 let input = inputs[0];
68 match self.kind {
69 DeviceSyncKind::ToHost => Ok(tvec![input
70 .to_device_fact()
71 .with_context(|| {
72 "Cannot sync to Host a tensor without DeviceFact as metadata in its TypedFact"
73 })?
74 .clone()
75 .into_typed_fact()]),
76 DeviceSyncKind::ToDevice => {
77 ensure!(
78 input.datum_type != DatumType::Opaque,
79 "Cannot sync Opaque Tensor to Device"
80 );
81 Ok(tvec![DeviceFact::from_host(input.clone())?.into_opaque_fact()])
82 }
83 }
84 }
85
86 as_op!();
87}