1use crate::tensor::{DeviceTensor, DeviceTensorExt};
2use std::fmt;
3use tract_core::internal::*;
4use tract_core::ops::nn as core_ops_nn;
5use tract_itertools::Itertools;
6
7pub type DispatchReduceFn = fn(&Reducer, &DeviceTensor, usize, &DeviceTensor) -> TractResult<()>;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum Reducer {
11 MeanOfSquares,
12 Sum,
13 Prod,
14 Min,
15 Max,
16 All,
17 Any,
18}
19
20impl fmt::Display for Reducer {
21 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
22 match self {
23 Self::MeanOfSquares => write!(f, "mean_of_squares"),
24 Self::Sum => write!(f, "sum"),
25 Self::Prod => write!(f, "prod"),
26 Self::Min => write!(f, "min"),
27 Self::Max => write!(f, "max"),
28 Self::All => write!(f, "all"),
29 Self::Any => write!(f, "any"),
30 }
31 }
32}
33
34impl Reducer {
35 pub const ALL: [Reducer; 7] =
36 [Self::MeanOfSquares, Self::Sum, Self::Prod, Self::Min, Self::Max, Self::All, Self::Any];
37
38 pub fn is_logic(&self) -> bool {
39 *self == Reducer::All || *self == Reducer::Any
40 }
41
42 pub fn is_supported_dt(&self, dt: DatumType) -> bool {
43 if self.is_logic() { dt.is::<bool>() } else { dt.is::<f32>() || dt.is::<f16>() }
44 }
45
46 pub fn from_tract_core(reducer: &core_ops_nn::Reducer) -> TractResult<Self> {
47 match reducer {
48 core_ops_nn::Reducer::Sum => Ok(Reducer::Sum),
49 core_ops_nn::Reducer::MeanOfSquares => Ok(Reducer::MeanOfSquares),
50 core_ops_nn::Reducer::Prod => Ok(Reducer::Prod),
51 core_ops_nn::Reducer::Min => Ok(Reducer::Min),
52 core_ops_nn::Reducer::Max => Ok(Reducer::Max),
53 core_ops_nn::Reducer::All => Ok(Reducer::All),
54 core_ops_nn::Reducer::Any => Ok(Reducer::Any),
55 _ => bail!("Unsupported reducer {:?} on GPU", reducer),
56 }
57 }
58}
59
60#[derive(Clone, Debug)]
61pub struct GpuReduce {
62 pub axes: TVec<usize>,
63 pub reducer: Reducer,
64 pub backend_name: &'static str,
65 pub dispatch: DispatchReduceFn,
66}
67
68impl PartialEq for GpuReduce {
69 fn eq(&self, other: &Self) -> bool {
70 self.axes == other.axes
71 && self.reducer == other.reducer
72 && self.backend_name == other.backend_name
73 }
74}
75
76impl Eq for GpuReduce {}
77
78impl std::hash::Hash for GpuReduce {
79 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
80 self.axes.hash(state);
81 self.reducer.hash(state);
82 self.backend_name.hash(state);
83 }
84}
85
86impl GpuReduce {
87 pub fn new(
88 axes: TVec<usize>,
89 reducer: Reducer,
90 backend_name: &'static str,
91 dispatch: DispatchReduceFn,
92 ) -> TractResult<Self> {
93 ensure!(axes.len() == 1, "Only one axis of reduce is supported by {backend_name}Reduce");
94 Ok(Self { axes, reducer, backend_name, dispatch })
95 }
96
97 pub fn from_tract_core(
98 core_reduce: &core_ops_nn::Reduce,
99 backend_name: &'static str,
100 dispatch: DispatchReduceFn,
101 ) -> TractResult<Self> {
102 let reducer = Reducer::from_tract_core(&core_reduce.reducer)?;
103 Self::new(core_reduce.axes.clone(), reducer, backend_name, dispatch)
104 }
105}
106
107impl Op for GpuReduce {
108 fn name(&self) -> StaticName {
109 format!("{}Reduce<{:?}>", self.backend_name, self.reducer).into()
110 }
111 fn info(&self) -> TractResult<Vec<String>> {
112 Ok(vec![format!("axes: {:?}", self.axes)])
113 }
114 op_as_typed_op!();
115}
116
117impl EvalOp for GpuReduce {
118 fn is_stateless(&self) -> bool {
119 true
120 }
121
122 fn eval_with_session(
123 &self,
124 node_id: usize,
125 session: &TurnState,
126 inputs: TVec<TValue>,
127 ) -> TractResult<TVec<TValue>> {
128 let input_value = args_1!(inputs);
129 let input = input_value.to_device_tensor()?;
130 let mut output_shape = input.shape().to_vec();
131 output_shape[self.axes[0]] = 1;
132 let output = crate::session_handler::make_tensor_for_node(
133 session,
134 node_id,
135 input.datum_type(),
136 &output_shape,
137 )?;
138 (self.dispatch)(&self.reducer, input, self.axes[0], &output)?;
139 Ok(tvec!(output.into_tensor().into_tvalue()))
140 }
141}
142
143impl TypedOp for GpuReduce {
144 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
145 ensure!(self.axes.iter().tuple_windows().all(|(a, b)| a < b));
146 crate::utils::facts_to_device_facts(inputs, |facts| {
147 let mut shape: TVec<_> = facts[0].shape.to_tvec();
148 for &ax in &self.axes {
149 shape[ax] = 1.to_dim();
150 }
151 let dt = facts[0].datum_type;
152 Ok(tvec!(dt.fact(shape)))
153 })
154 .with_context(|| format!("Error while computing facts for {:?}", self.name()))
155 }
156
157 as_op!();
158}