Skip to main content

scirs2_autograd/
op.rs

1//! # Implementing differentiable operations
2//!
3//! Many of well-known ops are pre-defined in [crate::tensor_ops], but you can also
4//! implement custom ops by hand.
5//! See also [crate::tensor::TensorBuilder].
6//!
7//! ```
8//! use scirs2_core::ndarray;
9//! use scirs2_autograd as ag;
10//! use ag::error::OpError;
11//! use ag::tensor_ops::*;
12//!
13//! type NdArray<T: ag::Float> = scirs2_core::ndarray::Array<T, scirs2_core::ndarray::IxDyn>;
14//!
15//! // Implements `Op` trait for `Sigmoid`.
16//! struct Sigmoid;
17//!
18//! impl<T: ag::Float> ag::op::Op<T> for Sigmoid {
19//!     fn compute(
20//!         &self,
21//!         ctx: &mut ag::op::ComputeContext<T>,
22//!     ) -> Result<(), OpError> {
23//!         let x: &ag::NdArrayView<_> = &ctx.input(0);
24//!         // Use `scirs2_core::ndarray::Array::mapv` for element-wise computation.
25//!         let half = T::from(0.5).expect("Operation failed");
26//!         let y = x.mapv(move |a| ((a * half).tanh() * half) + half);
27//!         ctx.append_output(y);
28//!         Ok(())
29//!     }
30//!
31//!     fn grad(&self, ctx: &mut ag::op::GradientContext<T>) {
32//!         // gradient of the output of Sigmoid
33//!         let gy = ctx.output_grad();
34//!         let y = ctx.output();
35//!         // gradient of the input of Sigmoid
36//!         let gx = gy * (y - square(y));
37//!         ctx.append_input_grad(0, Some(gx));
38//!     }
39//! }
40//!
41//! // `sigmoid` function for end-user.
42//! fn sigmoid<'graph, F: ag::Float>(x: &ag::Tensor<'graph, F>, g: &'graph ag::Context<F>)
43//! -> ag::Tensor<'graph, F> {
44//!     ag::Tensor::builder(g)
45//!            .append_input(x, false)
46//!            .build(Sigmoid)
47//! }
48//! ```
49//!
50use std::any::type_name;
51use std::marker::PhantomData;
52
53pub use crate::error::OpError;
54use crate::ndarray_ext::{NdArrayView, NdArrayViewMut};
55use crate::smallvec::SmallVec as RawSmallVec;
56use crate::tensor::Tensor;
57use crate::{Float, NdArray};
58
59pub(crate) const DEFAULT_NUM_EDGES: usize = 2;
60
61pub(crate) type SmallVec<T> = RawSmallVec<[T; DEFAULT_NUM_EDGES]>;
62
63/// Trait for tensor operations. `Tensor` structs wrap this.
64pub trait Op<F: Float> {
65    /// Name of this op
66    fn name(&self) -> &'static str {
67        type_name::<Self>()
68    }
69
70    /// Runs this op with `ComputeContext`.
71    fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError>;
72
73    /// Returns gradients for input nodes by use of output's gradients etc.
74    fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>);
75
76    /// Returns self as Any for downcasting. Default returns None.
77    fn as_any(&self) -> Option<&dyn std::any::Any> {
78        None
79    }
80}
81
82#[allow(dead_code)]
83pub(crate) enum OpInput<'graph, F: Float> {
84    Variable(crate::variable::VariableID),
85    NonVariable(usize, &'graph Tensor<'graph, F>),
86}
87
88/// Variable or non-variable tensor input.
89#[allow(dead_code)]
90pub(crate) struct OpInputGetter<'a, F: Float> {
91    f: F,
92    _marker: PhantomData<&'a ()>,
93}
94
95impl<F: Float> OpInputGetter<'_, F> {
96    #[allow(dead_code)]
97    pub fn new(_: F) -> Self {
98        Self {
99            f: F::zero(),
100            _marker: PhantomData,
101        }
102    }
103}
104
105impl<'a, 'graph, F: Float> From<&'a OpInput<'graph, F>> for OpInputGetter<'a, F> {
106    fn from(x: &'a OpInput<'graph, F>) -> Self {
107        let _ = x;
108        Self {
109            f: F::zero(),
110            _marker: PhantomData,
111        }
112    }
113}
114
115/// Context given to `Op::compute`.
116pub struct ComputeContext<F: Float> {
117    pub(crate) inputs: Vec<NdArray<F>>,
118    pub(crate) outputs: Vec<NdArray<F>>,
119}
120
121impl<F: Float> ComputeContext<F> {
122    /// Creates new ComputeContext.
123    pub fn new(inputs: &[NdArray<F>], outputs: &mut [NdArray<F>]) -> Self {
124        // Clone all inputs to own the data
125        let input_arrays = inputs.to_vec();
126        Self {
127            inputs: input_arrays,
128            outputs: Vec::new(),
129        }
130    }
131
132    /// Creates a new ComputeContext with prepared inputs.
133    pub fn with_inputs(input_arrays: Vec<NdArray<F>>) -> Self {
134        Self {
135            inputs: input_arrays,
136            outputs: Vec::new(),
137        }
138    }
139
140    /// Returns `i`-th input array.
141    /// If index is out of bounds, returns an empty scalar array.
142    /// This can happen when operations are created dynamically during gradient computation.
143    pub fn input(&self, i: usize) -> NdArrayView<F> {
144        if i >= self.inputs.len() {
145            // Return an empty scalar instead of panicking or warning
146            // This handles the case where some operations may not have all inputs during evaluation
147            static DUMMY_SCALAR: once_cell::sync::Lazy<NdArray<f32>> =
148                once_cell::sync::Lazy::new(|| crate::ndarray_ext::zeros::<f32>(&[]));
149
150            #[allow(clippy::transmute_ptr_to_ref)]
151            unsafe {
152                std::mem::transmute::<
153                    scirs2_core::ndarray::ArrayBase<
154                        scirs2_core::ndarray::ViewRepr<&f32>,
155                        scirs2_core::ndarray::Dim<scirs2_core::ndarray::IxDynImpl>,
156                    >,
157                    scirs2_core::ndarray::ArrayBase<
158                        scirs2_core::ndarray::ViewRepr<&F>,
159                        scirs2_core::ndarray::Dim<scirs2_core::ndarray::IxDynImpl>,
160                    >,
161                >(DUMMY_SCALAR.view())
162            }
163        } else {
164            self.inputs[i].view()
165        }
166    }
167
168    /// Note: This method is deprecated and will panic.
169    /// With the new architecture, inputs are immutable.
170    pub fn input_mut(&mut self, i: usize) -> NdArrayViewMut<'_, F> {
171        let _ = i; // Suppress unused parameter warning
172        panic!("input_mut is not supported in the new ComputeContext implementation");
173    }
174
175    /// Returns all input array views.
176    pub fn inputs(&self) -> Vec<NdArrayView<F>> {
177        self.inputs.iter().map(|arr| arr.view()).collect()
178    }
179
180    /// Appends an output array.
181    pub fn append_output<A>(&mut self, output: A)
182    where
183        A: Into<NdArray<F>>,
184    {
185        self.outputs.push(output.into());
186    }
187
188    /// Get all outputs
189    pub fn get_outputs(&self) -> &[NdArray<F>] {
190        &self.outputs
191    }
192}
193
194/// Context given to `Op::grad`.
195pub struct GradientContext<'a, 'graph, F: Float> {
196    /// tensor outputs. No owned data.
197    pub(crate) zs: &'a [&'graph Tensor<'graph, F>],
198
199    /// tensor inputs. No owned data.
200    pub(crate) xs: &'a [&'graph Tensor<'graph, F>],
201
202    /// Context graph reference
203    pub(crate) context: &'graph crate::Context<'graph, F>,
204
205    /// gradients of outputs. No owned data.
206    pub(crate) gzs: &'a [&'graph Tensor<'graph, F>],
207
208    /// gradient tensors to be the result.
209    pub(crate) results: &'a mut Vec<Option<Tensor<'graph, F>>>,
210
211    /// Index of array field.
212    pub(crate) array_field_id: usize,
213
214    /// This is needed to constrain type parameters.
215    pub(crate) _marker: PhantomData<&'a mut &'graph F>,
216}
217
218impl<'graph, F: Float> GradientContext<'_, 'graph, F> {
219    // We can't implement the new method with the current struct design due to lifetime issues
220    // Just implement a stub method to support backward compatibility
221    #[doc(hidden)]
222    pub fn _new_stub() {}
223
224    /// Compute input gradients
225    pub fn compute_input_grads(&self) -> Vec<Option<Tensor<'graph, F>>> {
226        self.results.clone().into_iter().collect()
227    }
228}
229
230impl<'graph, F: Float> GradientContext<'_, 'graph, F> {
231    /// Returns the output array.
232    pub fn output(&self) -> &'graph Tensor<'graph, F> {
233        self.zs[self.array_field_id]
234    }
235
236    /// Returns the gradient of output array.
237    pub fn output_grad(&self) -> &'graph Tensor<'graph, F> {
238        self.gzs[self.array_field_id]
239    }
240
241    /// Returns the `i`-th input array.
242    pub fn input(&self, i: usize) -> &'graph Tensor<'graph, F> {
243        self.xs[i]
244    }
245
246    /// Returns the number of inputs.
247    pub fn num_inputs(&self) -> usize {
248        self.xs.len()
249    }
250
251    /// Returns the number of outputs.
252    pub fn num_outputs(&self) -> usize {
253        self.zs.len()
254    }
255
256    /// Returns the context graph.
257    pub fn graph(&self) -> &'graph crate::Context<'graph, F> {
258        self.context
259    }
260
261    /// Appends a gradient for the input indexed by `i`.
262    pub fn append_input_grad(&mut self, i: usize, gx: Option<Tensor<'graph, F>>) {
263        for _ in self.results.len()..=i {
264            self.results.push(None);
265        }
266        self.results[i] = gx;
267    }
268
269    /// Appends a gradient for the input indexed by 0.
270    /// Short-hand for `append_input_grad(0, gx)`.
271    pub fn append_input_grad_by_ref(&mut self, gx: Option<&Tensor<'graph, F>>) {
272        self.append_input_grad(0, gx.cloned());
273    }
274
275    /// Appends a gradient for the input indexed by 0.
276    /// Short-hand for `append_input_grad(0, gx)`.
277    pub fn append_input_grad_0(&mut self, gx: Option<Tensor<'graph, F>>) {
278        self.append_input_grad(0, gx);
279    }
280
281    /// Returns all input tensors.
282    pub fn inputs(&self) -> &[&'graph Tensor<'graph, F>] {
283        self.xs
284    }
285}
286
287/// Output from op.
288#[derive(Clone)]
289#[allow(dead_code)]
290pub struct OpOutput<F: Float> {
291    pub(crate) output: NdArray<F>,
292}
293
294impl<F: Float> OpOutput<F> {
295    #[allow(dead_code)]
296    pub(crate) fn new(output: NdArray<F>) -> Self {
297        Self { output }
298    }
299}