1use std::any::type_name;
51use std::marker::PhantomData;
52
53pub use crate::error::OpError;
54use crate::ndarray_ext::{NdArrayView, NdArrayViewMut};
55use crate::smallvec::SmallVec as RawSmallVec;
56use crate::tensor::Tensor;
57use crate::{Float, NdArray};
58
59pub(crate) const DEFAULT_NUM_EDGES: usize = 2;
60
61pub(crate) type SmallVec<T> = RawSmallVec<[T; DEFAULT_NUM_EDGES]>;
62
63pub trait Op<F: Float> {
65 fn name(&self) -> &'static str {
67 type_name::<Self>()
68 }
69
70 fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError>;
72
73 fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>);
75
76 fn as_any(&self) -> Option<&dyn std::any::Any> {
78 None
79 }
80}
81
82#[allow(dead_code)]
83pub(crate) enum OpInput<'graph, F: Float> {
84 Variable(crate::variable::VariableID),
85 NonVariable(usize, &'graph Tensor<'graph, F>),
86}
87
88#[allow(dead_code)]
90pub(crate) struct OpInputGetter<'a, F: Float> {
91 f: F,
92 _marker: PhantomData<&'a ()>,
93}
94
95impl<F: Float> OpInputGetter<'_, F> {
96 #[allow(dead_code)]
97 pub fn new(_: F) -> Self {
98 Self {
99 f: F::zero(),
100 _marker: PhantomData,
101 }
102 }
103}
104
105impl<'a, 'graph, F: Float> From<&'a OpInput<'graph, F>> for OpInputGetter<'a, F> {
106 fn from(x: &'a OpInput<'graph, F>) -> Self {
107 let _ = x;
108 Self {
109 f: F::zero(),
110 _marker: PhantomData,
111 }
112 }
113}
114
115pub struct ComputeContext<F: Float> {
117 pub(crate) inputs: Vec<NdArray<F>>,
118 pub(crate) outputs: Vec<NdArray<F>>,
119}
120
121impl<F: Float> ComputeContext<F> {
122 pub fn new(inputs: &[NdArray<F>], outputs: &mut [NdArray<F>]) -> Self {
124 let input_arrays = inputs.to_vec();
126 Self {
127 inputs: input_arrays,
128 outputs: Vec::new(),
129 }
130 }
131
132 pub fn with_inputs(input_arrays: Vec<NdArray<F>>) -> Self {
134 Self {
135 inputs: input_arrays,
136 outputs: Vec::new(),
137 }
138 }
139
140 pub fn input(&self, i: usize) -> NdArrayView<F> {
144 if i >= self.inputs.len() {
145 static DUMMY_SCALAR: once_cell::sync::Lazy<NdArray<f32>> =
148 once_cell::sync::Lazy::new(|| crate::ndarray_ext::zeros::<f32>(&[]));
149
150 #[allow(clippy::transmute_ptr_to_ref)]
151 unsafe {
152 std::mem::transmute::<
153 scirs2_core::ndarray::ArrayBase<
154 scirs2_core::ndarray::ViewRepr<&f32>,
155 scirs2_core::ndarray::Dim<scirs2_core::ndarray::IxDynImpl>,
156 >,
157 scirs2_core::ndarray::ArrayBase<
158 scirs2_core::ndarray::ViewRepr<&F>,
159 scirs2_core::ndarray::Dim<scirs2_core::ndarray::IxDynImpl>,
160 >,
161 >(DUMMY_SCALAR.view())
162 }
163 } else {
164 self.inputs[i].view()
165 }
166 }
167
168 pub fn input_mut(&mut self, i: usize) -> NdArrayViewMut<'_, F> {
171 let _ = i; panic!("input_mut is not supported in the new ComputeContext implementation");
173 }
174
175 pub fn inputs(&self) -> Vec<NdArrayView<F>> {
177 self.inputs.iter().map(|arr| arr.view()).collect()
178 }
179
180 pub fn append_output<A>(&mut self, output: A)
182 where
183 A: Into<NdArray<F>>,
184 {
185 self.outputs.push(output.into());
186 }
187
188 pub fn get_outputs(&self) -> &[NdArray<F>] {
190 &self.outputs
191 }
192}
193
194pub struct GradientContext<'a, 'graph, F: Float> {
196 pub(crate) zs: &'a [&'graph Tensor<'graph, F>],
198
199 pub(crate) xs: &'a [&'graph Tensor<'graph, F>],
201
202 pub(crate) context: &'graph crate::Context<'graph, F>,
204
205 pub(crate) gzs: &'a [&'graph Tensor<'graph, F>],
207
208 pub(crate) results: &'a mut Vec<Option<Tensor<'graph, F>>>,
210
211 pub(crate) array_field_id: usize,
213
214 pub(crate) _marker: PhantomData<&'a mut &'graph F>,
216}
217
218impl<'graph, F: Float> GradientContext<'_, 'graph, F> {
219 #[doc(hidden)]
222 pub fn _new_stub() {}
223
224 pub fn compute_input_grads(&self) -> Vec<Option<Tensor<'graph, F>>> {
226 self.results.clone().into_iter().collect()
227 }
228}
229
230impl<'graph, F: Float> GradientContext<'_, 'graph, F> {
231 pub fn output(&self) -> &'graph Tensor<'graph, F> {
233 self.zs[self.array_field_id]
234 }
235
236 pub fn output_grad(&self) -> &'graph Tensor<'graph, F> {
238 self.gzs[self.array_field_id]
239 }
240
241 pub fn input(&self, i: usize) -> &'graph Tensor<'graph, F> {
243 self.xs[i]
244 }
245
246 pub fn num_inputs(&self) -> usize {
248 self.xs.len()
249 }
250
251 pub fn num_outputs(&self) -> usize {
253 self.zs.len()
254 }
255
256 pub fn graph(&self) -> &'graph crate::Context<'graph, F> {
258 self.context
259 }
260
261 pub fn append_input_grad(&mut self, i: usize, gx: Option<Tensor<'graph, F>>) {
263 for _ in self.results.len()..=i {
264 self.results.push(None);
265 }
266 self.results[i] = gx;
267 }
268
269 pub fn append_input_grad_by_ref(&mut self, gx: Option<&Tensor<'graph, F>>) {
272 self.append_input_grad(0, gx.cloned());
273 }
274
275 pub fn append_input_grad_0(&mut self, gx: Option<Tensor<'graph, F>>) {
278 self.append_input_grad(0, gx);
279 }
280
281 pub fn inputs(&self) -> &[&'graph Tensor<'graph, F>] {
283 self.xs
284 }
285}
286
287#[derive(Clone)]
289#[allow(dead_code)]
290pub struct OpOutput<F: Float> {
291 pub(crate) output: NdArray<F>,
292}
293
294impl<F: Float> OpOutput<F> {
295 #[allow(dead_code)]
296 pub(crate) fn new(output: NdArray<F>) -> Self {
297 Self { output }
298 }
299}