1use crate::tensor::{DeviceTensorExt, IntoDevice};
2use tract_core::internal::*;
3use tract_core::ops::array::{Pad, PadMode};
4
5#[derive(Clone, Debug, PartialEq, Eq, Hash)]
9pub struct GpuPad {
10 pub pads: Vec<(usize, usize)>,
11 pub value: Arc<Tensor>,
12}
13
14impl GpuPad {
15 pub fn from_core(op: &Pad) -> Option<Self> {
17 let PadMode::Constant(value) = &op.mode else { return None };
18 Some(Self { pads: op.pads.clone(), value: value.clone() })
19 }
20
21 fn output_shape<D: DimLike>(&self, input: &[D]) -> TVec<D> {
22 input.iter().zip(&self.pads).map(|(d, (a, b))| d.clone() + *a + *b).collect()
23 }
24}
25
26impl Op for GpuPad {
27 fn name(&self) -> StaticName {
28 "GpuPad".into()
29 }
30
31 op_as_typed_op!();
32}
33
34impl EvalOp for GpuPad {
35 fn is_stateless(&self) -> bool {
36 true
37 }
38
39 fn eval_with_session(
40 &self,
41 node_id: usize,
42 session: &TurnState,
43 inputs: TVec<TValue>,
44 ) -> TractResult<TVec<TValue>> {
45 let input_value = args_1!(inputs);
46 let input = input_value.to_device_tensor()?;
47 let dt = input.datum_type();
48 let out_shape = self.output_shape(input.shape());
49
50 let output =
51 crate::session_handler::make_tensor_for_node(session, node_id, dt, &out_shape)?;
52
53 let ctx = crate::device::get_context()?;
54
55 let value = self.value.cast_to_dt(dt)?.into_owned().into_device()?;
57 let zero_strides = vec![0isize; out_shape.len()];
58 ctx.copy_nd(&value, 0, &zero_strides, &output, 0, &out_shape, output.strides())?;
59
60 if input.len() != 0 {
62 let interior: usize = self
63 .pads
64 .iter()
65 .enumerate()
66 .map(|(axis, (before, _))| before * output.strides()[axis] as usize)
67 .sum();
68 ctx.copy_nd(
69 input,
70 0,
71 input.strides(),
72 &output,
73 interior * dt.size_of(),
74 input.shape(),
75 output.strides(),
76 )?;
77 }
78 Ok(tvec![output.into_tensor().into_tvalue()])
79 }
80}
81
82impl TypedOp for GpuPad {
83 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
84 crate::utils::facts_to_device_facts(inputs, |facts| {
85 Ok(tvec!(facts[0].datum_type.fact(self.output_shape(&facts[0].shape.to_tvec()))))
86 })
87 .with_context(|| format!("Error while computing facts for {:?}", self.name()))
88 }
89
90 as_op!();
91}