1use crate::fact::{DeviceFact, DeviceTypedFactExt};
2use crate::tensor::{DeviceTensorExt, IntoDevice};
3use derive_new::new;
4use std::collections::HashMap;
5use std::fmt;
6use std::sync::Arc;
7use tract_core::internal::*;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum DeviceSyncKind {
11 ToHost,
12 ToDevice,
13}
14
15impl fmt::Display for DeviceSyncKind {
16 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
17 write!(f, "{self:?}")
18 }
19}
20
21#[derive(Debug, Clone, new, Copy, PartialEq, Eq, Hash)]
22pub struct DeviceSync {
23 pub kind: DeviceSyncKind,
24}
25
26impl Op for DeviceSync {
27 fn name(&self) -> StaticName {
28 format!("DeviceSync{}", self.kind).into()
29 }
30
31 op_as_typed_op!();
32}
33
34impl EvalOp for DeviceSync {
35 fn is_stateless(&self) -> bool {
36 true
37 }
38
39 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
40 let input = args_1!(inputs);
41 match self.kind {
42 DeviceSyncKind::ToHost => {
43 let device_tensor = input.to_device_tensor()?;
44
45 let tensor = device_tensor
46 .to_host()
47 .with_context(|| "Error while syncing device tensor to host")?;
48 Ok(tvec![tensor.into_tvalue()])
49 }
50 DeviceSyncKind::ToDevice => {
51 let device_input = if let Some(t) = input.as_arc_tensor() {
52 Arc::clone(t).into_device()?
53 } else {
54 input.into_tensor().into_device()?
55 };
56 Ok(tvec![device_input.into_tensor().into()])
57 }
58 }
59 }
60}
61
62impl TypedOp for DeviceSync {
63 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
64 let input = inputs[0];
65 match self.kind {
66 DeviceSyncKind::ToHost => {
67 let mut typed_fact = input
68 .to_device_fact()
69 .with_context(|| {
70 "Cannot sync to Host a tensor without DeviceFact as metadata in its TypedFact"
71 })?
72 .clone()
73 .into_typed_fact();
74 if let Some(konst) = input.konst.clone() {
75 if let Some(dt) = konst.as_device_tensor() {
76 typed_fact.konst = Some(dt.to_host()?);
77 } else {
78 typed_fact.konst = Some(konst);
79 }
80 }
81 Ok(tvec!(typed_fact))
82 }
83 DeviceSyncKind::ToDevice => {
84 ensure!(
85 input.as_device_fact().is_none(),
86 "Cannot sync to Device a tensor already on Device"
87 );
88 Ok(tvec![DeviceFact::from_host(input.clone())?.into_exotic_fact()])
89 }
90 }
91 }
92
93 as_op!();
94}
95
96pub fn sync_inputs_if_required(
99 model: &mut TypedModel,
100 node: &TypedNode,
101 mapping: &HashMap<OutletId, OutletId>,
102 sync_kind: DeviceSyncKind,
103) -> TractResult<TVec<OutletId>> {
104 let mut mapped_inputs = tvec![];
105 for (i_idx, i) in node.inputs.iter().enumerate() {
106 let in_fact = model.outlet_fact_mut(mapping[i])?;
107 match sync_kind {
108 DeviceSyncKind::ToHost if in_fact.as_device_fact().is_some() => {
109 mapped_inputs.push(
110 model.wire_node(
111 format!("{}.to-cpu-{i_idx}", node.name),
112 DeviceSync::new(sync_kind),
113 &[mapping[i]],
114 )?[0],
115 );
116 }
117 DeviceSyncKind::ToDevice if in_fact.as_device_fact().is_none() => {
118 if let Some(ref konst) = in_fact.konst
119 && konst.as_device_tensor().is_none()
120 {
121 let device_konst = konst.as_ref().clone().into_device()?.into_tensor();
122 let device_fact = DeviceFact::from_host(in_fact.clone())?;
123
124 *in_fact = device_fact.into_exotic_fact();
125
126 in_fact.konst = Some(Arc::new(device_konst));
127 mapped_inputs.push(mapping[i]);
128 continue;
129 }
130 ensure!(
131 in_fact.datum_type.is_copy(),
132 "Only copy DatumType can be sync to Device: {:?}",
133 in_fact.datum_type
134 );
135
136 mapped_inputs.push(
137 model.wire_node(
138 format!("{}.to-device-{i_idx}", node.name),
139 DeviceSync::new(sync_kind),
140 &[mapping[i]],
141 )?[0],
142 );
143 }
144 _ => mapped_inputs.push(mapping[i]),
145 }
146 }
147 Ok(mapped_inputs)
148}
149
150pub fn sync_model_outputs_if_required(
152 src: &TypedModel,
153 node: &TypedNode,
154 target: &mut TypedModel,
155 target_node_outlet_ids: TVec<OutletId>,
156) -> TractResult<TVec<OutletId>> {
157 let mut outputs = tvec![];
158 for (o_idx, o) in target_node_outlet_ids.into_iter().enumerate() {
159 let is_src_output = src.outputs.contains(&OutletId::new(node.id, o_idx));
160 if target.outlet_fact(o)?.as_device_fact().is_some() && is_src_output {
161 let sync_output = target.wire_node(
162 format!("{}.to-host-{o_idx}-out", node.name),
163 DeviceSync::new(DeviceSyncKind::ToHost),
164 &[o],
165 )?[0];
166 outputs.push(sync_output);
167 } else {
168 outputs.push(o)
169 }
170 }
171 Ok(outputs)
172}