Skip to main content

shrew_core/
backprop.rs

1// Backpropagation — Reverse-mode automatic differentiation
2//
3// This module implements the backward pass (backpropagation), computing
4// gradients of a scalar loss with respect to all tensors in the computation
5// graph. This is the core of training any neural network.
6//
7// HOW IT WORKS:
8//
9//   1. Forward pass: tensor operations build a DAG (directed acyclic graph)
10//      where each tensor stores its Op (the operation that created it).
11//
12//   2. backward() topologically sorts the DAG from the loss tensor to the
13//      leaves (input data and trainable parameters).
14//
15//   3. Starting with grad(loss) = 1.0, we walk the graph in reverse order.
16//      For each tensor, we apply the chain rule to compute gradients for
17//      its inputs and accumulate them.
18//
19// AUTOGRAD COVERAGE (all 20 Op variants):
20//
21//   Leaf:        Op::None        — no gradient propagation needed
22//   Contiguous:  Op::Contiguous  — pass-through (identity gradient)
23//
24//   Binary ops:  Op::Binary      — Add, Sub, Mul, Div (with broadcast reduction)
25//   Unary ops:   Op::Unary       — Neg, Abs, Exp, Log, Sqrt, Square, Relu,
26//                                   Sigmoid, Tanh, Gelu, Silu, Sin, Cos,
27//                                   Floor, Ceil, Round
28//   Reduce ops:  Op::Reduce      — Sum, Mean, Max, Min (with keepdim support)
29//
30//   Linear algebra:              — Matmul (with batched support)
31//   Shape ops:                   — Reshape, Transpose, Narrow
32//   Affine:                      — scale + bias (LayerNorm/BatchNorm support)
33//
34//   Convolutions:                — Conv2d (input + weight + bias grads)
35//                                — Conv1d (input + weight + bias grads)
36//   Pooling:                     — MaxPool2d (via saved indices)
37//                                — AvgPool2d (uniform gradient distribution)
38//
39//   Composite:                   — Cat (split gradient by sizes)
40//                                — Powf (power rule)
41//                                — Clamp (zero grad outside bounds)
42//                                — WhereCond (route grad through mask)
43//                                — Gather (scatter gradient to source)
44//                                — Pad (narrow gradient to unpadded region)
45//
46// GRADIENT CHECKPOINTING:
47//   checkpoint()            — wrap a forward fn for recomputation in backward
48//   checkpoint_sequential() — split sequential layers into segments
49//   is_checkpointing()      — check if inside a recomputation pass
50//
51// GRADIENT RULES (chain rule applied for each Op):
52//
53//   Binary Add:  grad_a += grad_out, grad_b += grad_out
54//   Binary Sub:  grad_a += grad_out, grad_b += -grad_out
55//   Binary Mul:  grad_a += grad_out * b, grad_b += grad_out * a
56//   Binary Div:  grad_a += grad_out / b, grad_b += -grad_out * a / b²
57//   Unary Neg:   grad_in += -grad_out
58//   Unary Exp:   grad_in += grad_out * exp(input)
59//   Unary Log:   grad_in += grad_out / input
60//   Matmul:      grad_A += grad_C @ B^T, grad_B += A^T @ grad_C
61//   Sum:         grad_in += broadcast(grad_out)
62//   Reshape:     grad_in += reshape(grad_out, original_shape)
63//   Transpose:   grad_in += transpose(grad_out)
64//   ... and many more (see compute_* functions below)
65//
66// ACCUMULATION: If a tensor is used in multiple operations, its gradient
67// is the SUM of contributions from each use (multivariate chain rule).
68//
69// For example: c = a * a, then grad_a = grad_c * a + grad_c * a = 2 * a * grad_c
70
71use std::collections::{HashMap, HashSet};
72
73use crate::backend::{Backend, BinaryOp, ReduceOp, UnaryOp};
74use crate::error::Result;
75use crate::op::{Op, TensorId};
76use crate::shape::Shape;
77use crate::tensor::Tensor;
78
79/// Stores gradients for all tensors in a computation graph.
80///
81/// After calling `tensor.backward()`, you receive a GradStore.
82/// Use `grads.get(&tensor)` to retrieve the gradient for any tensor.
83pub struct GradStore<B: Backend> {
84    grads: HashMap<TensorId, Tensor<B>>,
85}
86
87impl<B: Backend> Clone for GradStore<B> {
88    fn clone(&self) -> Self {
89        GradStore {
90            grads: self.grads.clone(),
91        }
92    }
93}
94
95impl<B: Backend> Default for GradStore<B> {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101impl<B: Backend> GradStore<B> {
102    /// Create a new empty GradStore.
103    pub fn new() -> Self {
104        GradStore {
105            grads: HashMap::new(),
106        }
107    }
108
109    /// Get the gradient of a tensor (if it exists).
110    pub fn get(&self, tensor: &Tensor<B>) -> Option<&Tensor<B>> {
111        self.grads.get(&tensor.id())
112    }
113
114    fn get_by_id(&self, id: &TensorId) -> Option<&Tensor<B>> {
115        self.grads.get(id)
116    }
117
118    /// Accumulate gradient for a tensor.
119    /// If a gradient already exists for this tensor, add the new one to it.
120    /// This handles the case where a tensor is used in multiple operations.
121    pub fn accumulate(&mut self, id: TensorId, grad: Tensor<B>) -> Result<()> {
122        if let Some(existing) = self.grads.get(&id) {
123            let new_grad = existing.add(&grad)?;
124            self.grads.insert(id, new_grad);
125        } else {
126            self.grads.insert(id, grad);
127        }
128        Ok(())
129    }
130}
131
132/// Build a topological ordering of the computation graph.
133///
134/// Uses depth-first search from the root tensor. Returns tensors in
135/// order such that every tensor appears AFTER all its inputs.
136/// (Leaves first, root last.)
137fn build_topo<B: Backend>(root: &Tensor<B>) -> Vec<Tensor<B>> {
138    let mut visited = HashSet::new();
139    let mut order = Vec::new();
140
141    fn visit<B: Backend>(
142        t: &Tensor<B>,
143        visited: &mut HashSet<TensorId>,
144        order: &mut Vec<Tensor<B>>,
145    ) {
146        if visited.contains(&t.id()) {
147            return;
148        }
149        visited.insert(t.id());
150        // Visit inputs first (depth-first)
151        for input in t.op().inputs() {
152            visit(input, visited, order);
153        }
154        // Then add this tensor (post-order)
155        order.push(t.clone());
156    }
157
158    visit(root, &mut visited, &mut order);
159    order
160}
161
162/// Compute gradients of `root` with respect to all tensors in the graph.
163///
164/// `root` must be a scalar tensor (single element). This is the main entry
165/// point for backpropagation, called by `tensor.backward()`.
166#[allow(clippy::needless_range_loop)]
167pub fn backward<B: Backend>(root: &Tensor<B>) -> Result<GradStore<B>> {
168    // Backward only works from a scalar (otherwise, which output element?)
169    if root.elem_count() != 1 {
170        return Err(crate::Error::msg(
171            "backward() requires a scalar tensor (single element). \
172             Use .sum_all() or .mean_all() to reduce to a scalar first.",
173        ));
174    }
175
176    // Step 1: Topological sort (leaves first, root last)
177    let topo = build_topo(root);
178
179    // Step 2: Initialize — grad(root) = 1.0 (dL/dL = 1)
180    let mut grads = GradStore::new();
181    let ones = Tensor::<B>::ones(root.shape().clone(), root.dtype(), root.device())?;
182    grads.grads.insert(root.id(), ones);
183
184    // Step 3: Walk in reverse topological order (root first, leaves last)
185    for tensor in topo.iter().rev() {
186        let grad_output = match grads.get_by_id(&tensor.id()) {
187            Some(g) => g.clone(),
188            None => continue, // No gradient flows to this tensor
189        };
190
191        match tensor.op() {
192            Op::None => {
193                // Leaf — nothing to propagate
194            }
195
196            Op::Contiguous { input } => {
197                // Identity-like: gradient passes through unchanged
198                grads.accumulate(input.id(), grad_output)?;
199            }
200
201            Op::Binary { lhs, rhs, op } => {
202                compute_binary_grad(*op, &grad_output, lhs, rhs, &mut grads)?;
203            }
204
205            Op::Unary { input, op } => {
206                compute_unary_grad(*op, &grad_output, input, &mut grads)?;
207            }
208
209            Op::Reduce {
210                input,
211                op,
212                dims,
213                keep_dim,
214            } => {
215                compute_reduce_grad(*op, &grad_output, input, dims, *keep_dim, &mut grads)?;
216            }
217
218            Op::Matmul { lhs, rhs } => {
219                compute_matmul_grad(&grad_output, lhs, rhs, &mut grads)?;
220            }
221
222            Op::Reshape { input, src_shape } => {
223                // Reshape gradient = reshape grad_output back to original shape
224                let grad = grad_output.reshape(src_shape.clone())?;
225                grads.accumulate(input.id(), grad)?;
226            }
227
228            Op::Transpose { input, dim0, dim1 } => {
229                // Transpose is its own inverse: transpose back with same dims
230                let grad = grad_output.transpose(*dim0, *dim1)?;
231                grads.accumulate(input.id(), grad)?;
232            }
233
234            Op::Narrow {
235                input,
236                dim,
237                start,
238                len,
239            } => {
240                // Scatter gradient into zero tensor at the original position
241                compute_narrow_grad(&grad_output, input, *dim, *start, *len, &mut grads)?;
242            }
243
244            Op::Affine { input, mul, .. } => {
245                // d(x * mul + add)/dx = mul
246                let grad = grad_output.affine(*mul, 0.0)?;
247                grads.accumulate(input.id(), grad)?;
248            }
249
250            Op::Conv2d {
251                input,
252                weight,
253                bias,
254                stride,
255                padding,
256            } => {
257                compute_conv2d_grad(
258                    &grad_output,
259                    input,
260                    weight,
261                    bias.as_ref(),
262                    *stride,
263                    *padding,
264                    &mut grads,
265                )?;
266            }
267
268            Op::MaxPool2d { input, indices, .. } => {
269                compute_maxpool2d_grad(&grad_output, input, indices, &mut grads)?;
270            }
271
272            Op::Cat { inputs, dim, sizes } => {
273                // Backward of cat: slice the gradient into pieces for each input
274                let mut offset = 0usize;
275                for (inp, &sz) in inputs.iter().zip(sizes.iter()) {
276                    let grad_slice = grad_output.narrow(*dim, offset, sz)?;
277                    grads.accumulate(inp.id(), grad_slice)?;
278                    offset += sz;
279                }
280            }
281
282            Op::Powf { input, exponent } => {
283                // d(x^n)/dx = n * x^(n-1)
284                let n = *exponent;
285                let x_pow_nm1 = input.powf(n - 1.0)?;
286                let n_tensor =
287                    Tensor::<B>::full(input.shape().clone(), n, input.dtype(), input.device())?;
288                let grad = grad_output.mul(&n_tensor)?.mul(&x_pow_nm1)?;
289                grads.accumulate(input.id(), grad)?;
290            }
291
292            Op::Clamp { input, min, max } => {
293                // Gradient = 1 where min < input < max, 0 at boundaries
294                let input_data = input.to_f64_vec()?;
295                let grad_data = grad_output.to_f64_vec()?;
296                let mask: Vec<f64> = input_data
297                    .iter()
298                    .zip(grad_data.iter())
299                    .map(|(&x, &g)| if x > *min && x < *max { g } else { 0.0 })
300                    .collect();
301                let grad = Tensor::<B>::from_f64_slice(
302                    &mask,
303                    input.shape().clone(),
304                    input.dtype(),
305                    input.device(),
306                )?;
307                grads.accumulate(input.id(), grad)?;
308            }
309
310            Op::WhereCond {
311                mask,
312                on_true,
313                on_false,
314            } => {
315                // Gradient flows to on_true where mask is true, on_false where mask is false
316                let mask_data = mask.to_f64_vec()?;
317                let grad_data = grad_output.to_f64_vec()?;
318                let n = mask_data.len();
319
320                let grad_true_data: Vec<f64> = (0..n)
321                    .map(|i| {
322                        if mask_data[i] != 0.0 {
323                            grad_data[i]
324                        } else {
325                            0.0
326                        }
327                    })
328                    .collect();
329                let grad_false_data: Vec<f64> = (0..n)
330                    .map(|i| {
331                        if mask_data[i] == 0.0 {
332                            grad_data[i]
333                        } else {
334                            0.0
335                        }
336                    })
337                    .collect();
338
339                let grad_true = Tensor::<B>::from_f64_slice(
340                    &grad_true_data,
341                    on_true.shape().clone(),
342                    on_true.dtype(),
343                    on_true.device(),
344                )?;
345                let grad_false = Tensor::<B>::from_f64_slice(
346                    &grad_false_data,
347                    on_false.shape().clone(),
348                    on_false.dtype(),
349                    on_false.device(),
350                )?;
351                grads.accumulate(on_true.id(), grad_true)?;
352                grads.accumulate(on_false.id(), grad_false)?;
353                // mask is non-differentiable — no gradient
354            }
355
356            Op::Gather { input, index, dim } => {
357                // Gather backward = scatter-add:
358                // grad_input = zeros_like(input);
359                // for each position p in index: grad_input[...dim=index[p]...] += grad_output[p]
360                let dim = *dim;
361                let input_dims = input.dims();
362                let rank = input_dims.len();
363
364                // Read grad_output and index data
365                let grad_data = grad_output.to_f64_vec()?;
366                let index_data = index.to_f64_vec()?;
367
368                // Create zero grad for input
369                let mut grad_input_data = vec![0.0f64; input.elem_count()];
370                let input_strides = input.shape().stride_contiguous();
371
372                // Compute strides for index shape (to decompose flat positions)
373                let index_strides = index.shape().stride_contiguous();
374
375                let n = index_data.len();
376                for flat_idx in 0..n {
377                    // Decompose flat_idx into multi-dim coords in index shape
378                    let mut coords = vec![0usize; rank];
379                    let mut remainder = flat_idx;
380                    for d in 0..rank {
381                        coords[d] = remainder / index_strides[d];
382                        remainder %= index_strides[d];
383                    }
384
385                    // Replace coord at `dim` with the index value
386                    let idx_val = index_data[flat_idx] as usize;
387                    coords[dim] = idx_val;
388
389                    // Compute flat position in input
390                    let mut input_flat = 0;
391                    for d in 0..rank {
392                        input_flat += coords[d] * input_strides[d];
393                    }
394
395                    // Scatter-add: accumulate the gradient
396                    grad_input_data[input_flat] += grad_data[flat_idx];
397                }
398
399                let grad_input = Tensor::<B>::from_f64_slice(
400                    &grad_input_data,
401                    input.shape().clone(),
402                    input.dtype(),
403                    input.device(),
404                )?;
405                grads.accumulate(input.id(), grad_input)?;
406                // index is non-differentiable — no gradient
407            }
408
409            Op::Pad { input, padding } => {
410                // Backward of pad: narrow the gradient to remove the padding
411                let mut grad = grad_output.clone();
412                let input_dims = input.dims();
413                for d in 0..input_dims.len() {
414                    let [before, _after] = padding[d];
415                    if before > 0 || _after > 0 {
416                        grad = grad.narrow(d, before, input_dims[d])?;
417                    }
418                }
419                grads.accumulate(input.id(), grad)?;
420            }
421
422            Op::AvgPool2d {
423                input,
424                kernel_size,
425                stride,
426                padding,
427            } => {
428                compute_avgpool2d_grad(
429                    &grad_output,
430                    input,
431                    *kernel_size,
432                    *stride,
433                    *padding,
434                    &mut grads,
435                )?;
436            }
437
438            Op::Conv1d {
439                input,
440                weight,
441                bias,
442                stride,
443                padding,
444            } => {
445                compute_conv1d_grad(
446                    &grad_output,
447                    input,
448                    weight,
449                    bias.as_ref(),
450                    *stride,
451                    *padding,
452                    &mut grads,
453                )?;
454            }
455
456            Op::IndexSelect {
457                input,
458                indices,
459                dim,
460            } => {
461                // IndexSelect backward = scatter-add:
462                // grad_input = zeros_like(input)
463                // For each position in output, add grad_output to grad_input
464                // at the corresponding input position (determined by indices).
465                let dim = *dim;
466                let input_dims = input.dims();
467                let rank = input_dims.len();
468
469                let grad_data = grad_output.to_f64_vec()?;
470                let index_data = indices.to_f64_vec()?; // index values as f64
471                let _num_indices = index_data.len();
472
473                let mut grad_input_data = vec![0.0f64; input.elem_count()];
474                let input_strides = input.shape().stride_contiguous();
475                let _output_dims = grad_output.dims();
476                let output_strides = grad_output.shape().stride_contiguous();
477
478                let total = grad_data.len();
479                for flat_idx in 0..total {
480                    // Decompose flat_idx into multi-dim coords in output shape
481                    let mut coords = vec![0usize; rank];
482                    let mut remainder = flat_idx;
483                    for d in 0..rank {
484                        coords[d] = remainder / output_strides[d];
485                        remainder %= output_strides[d];
486                    }
487
488                    // Replace coord at `dim` with the original source index
489                    let out_dim_coord = coords[dim];
490                    let src_idx = index_data[out_dim_coord] as usize;
491                    coords[dim] = src_idx;
492
493                    // Compute flat position in input
494                    let mut input_flat = 0;
495                    for d in 0..rank {
496                        input_flat += coords[d] * input_strides[d];
497                    }
498
499                    // Scatter-add
500                    grad_input_data[input_flat] += grad_data[flat_idx];
501                }
502
503                let grad_input = Tensor::<B>::from_f64_slice(
504                    &grad_input_data,
505                    input.shape().clone(),
506                    input.dtype(),
507                    input.device(),
508                )?;
509                grads.accumulate(input.id(), grad_input)?;
510                // indices are non-differentiable
511            }
512
513            Op::ToDtype { input, src_dtype } => {
514                // Cast gradient back to the original input dtype.
515                let grad_in = grad_output.to_dtype(*src_dtype)?;
516                grads.accumulate(input.id(), grad_in)?;
517            }
518        }
519    }
520
521    Ok(grads)
522}
523
524// Gradient rules for binary operations
525
526fn compute_binary_grad<B: Backend>(
527    op: BinaryOp,
528    grad_output: &Tensor<B>,
529    lhs: &Tensor<B>,
530    rhs: &Tensor<B>,
531    grads: &mut GradStore<B>,
532) -> Result<()> {
533    match op {
534        BinaryOp::Add => {
535            // d(a + b)/da = 1, d(a + b)/db = 1
536            let grad_lhs = reduce_broadcast_grad(grad_output, lhs.shape())?;
537            let grad_rhs = reduce_broadcast_grad(grad_output, rhs.shape())?;
538            grads.accumulate(lhs.id(), grad_lhs)?;
539            grads.accumulate(rhs.id(), grad_rhs)?;
540        }
541        BinaryOp::Sub => {
542            // d(a - b)/da = 1, d(a - b)/db = -1
543            let grad_lhs = reduce_broadcast_grad(grad_output, lhs.shape())?;
544            let neg = grad_output.neg()?;
545            let grad_rhs = reduce_broadcast_grad(&neg, rhs.shape())?;
546            grads.accumulate(lhs.id(), grad_lhs)?;
547            grads.accumulate(rhs.id(), grad_rhs)?;
548        }
549        BinaryOp::Mul => {
550            // d(a * b)/da = b, d(a * b)/db = a
551            let raw_lhs = grad_output.mul(rhs)?;
552            let raw_rhs = grad_output.mul(lhs)?;
553            grads.accumulate(lhs.id(), reduce_broadcast_grad(&raw_lhs, lhs.shape())?)?;
554            grads.accumulate(rhs.id(), reduce_broadcast_grad(&raw_rhs, rhs.shape())?)?;
555        }
556        BinaryOp::Div => {
557            // d(a / b)/da = 1/b
558            // d(a / b)/db = -a / b²
559            let raw_lhs = grad_output.div(rhs)?;
560            grads.accumulate(lhs.id(), reduce_broadcast_grad(&raw_lhs, lhs.shape())?)?;
561            let neg_grad = grad_output.neg()?;
562            let b_sq = rhs.mul(rhs)?;
563            let raw_rhs = neg_grad.mul(lhs)?.div(&b_sq)?;
564            grads.accumulate(rhs.id(), reduce_broadcast_grad(&raw_rhs, rhs.shape())?)?;
565        }
566    }
567    Ok(())
568}
569
570/// When broadcasting expands a tensor's shape, the backward pass must sum
571/// the gradient over the broadcast dimensions to match the original shape.
572///
573/// For example, if lhs was [1, 4] broadcast to [3, 4]:
574///   grad_output is [3, 4], but grad_lhs must be [1, 4] → sum over dim 0
575///
576/// If lhs was [4] broadcast to [3, 4]:
577///   grad_output is [3, 4], grad_lhs must be [4] → sum over dim 0, squeeze
578fn reduce_broadcast_grad<B: Backend>(
579    grad: &Tensor<B>,
580    target_shape: &crate::Shape,
581) -> Result<Tensor<B>> {
582    let grad_shape = grad.dims();
583    let target_dims = target_shape.dims();
584
585    // If shapes already match, no reduction needed
586    if grad_shape == target_dims {
587        return Ok(grad.clone());
588    }
589
590    // Pad target dims with leading 1s to match grad rank
591    let grad_rank = grad_shape.len();
592    let target_rank = target_dims.len();
593    let mut padded_target = vec![1usize; grad_rank];
594    let offset = grad_rank - target_rank;
595    padded_target[offset..offset + target_rank].copy_from_slice(target_dims);
596
597    // Sum over dimensions where padded_target[d] == 1 and grad[d] > 1
598    let mut result = grad.clone();
599    // Sum from left to right, adjusting for removed dimensions
600    let mut dims_to_sum: Vec<usize> = Vec::new();
601    for d in 0..grad_rank {
602        if padded_target[d] == 1 && grad_shape[d] > 1 {
603            dims_to_sum.push(d);
604        }
605    }
606
607    // Sum all broadcast dimensions at once by processing from highest to lowest
608    // to keep dimension indices stable
609    for &d in dims_to_sum.iter().rev() {
610        result = result.sum(d, true)?;
611    }
612
613    // Now reshape to target shape (removing the extra size-1 dims)
614    result = result.reshape(target_shape.clone())?;
615
616    Ok(result)
617}
618
619// Gradient rules for unary operations
620
621fn compute_unary_grad<B: Backend>(
622    op: UnaryOp,
623    grad_output: &Tensor<B>,
624    input: &Tensor<B>,
625    grads: &mut GradStore<B>,
626) -> Result<()> {
627    let grad_input = match op {
628        // d(-x)/dx = -1
629        UnaryOp::Neg => grad_output.neg()?,
630
631        // d|x|/dx = sign(x)
632        UnaryOp::Abs => {
633            let input_data = input.to_f64_vec()?;
634            let sign_data: Vec<f64> = input_data
635                .iter()
636                .map(|&v| {
637                    if v > 0.0 {
638                        1.0
639                    } else if v < 0.0 {
640                        -1.0
641                    } else {
642                        0.0
643                    }
644                })
645                .collect();
646            let sign = Tensor::<B>::from_f64_slice(
647                &sign_data,
648                input.shape().clone(),
649                input.dtype(),
650                input.device(),
651            )?;
652            grad_output.mul(&sign)?
653        }
654
655        // d(e^x)/dx = e^x
656        UnaryOp::Exp => {
657            let exp_x = input.exp()?;
658            grad_output.mul(&exp_x)?
659        }
660
661        // d(ln x)/dx = 1/x
662        UnaryOp::Log => grad_output.div(input)?,
663
664        // d(√x)/dx = 1 / (2√x)
665        UnaryOp::Sqrt => {
666            let sqrt_x = input.sqrt()?;
667            let two_sqrt = sqrt_x.affine(2.0, 0.0)?;
668            grad_output.div(&two_sqrt)?
669        }
670
671        // d(x²)/dx = 2x
672        UnaryOp::Square => {
673            let two_x = input.affine(2.0, 0.0)?;
674            grad_output.mul(&two_x)?
675        }
676
677        // d(relu(x))/dx = 1 if x > 0, else 0
678        UnaryOp::Relu => {
679            let input_data = input.to_f64_vec()?;
680            let mask_data: Vec<f64> = input_data
681                .iter()
682                .map(|&v| if v > 0.0 { 1.0 } else { 0.0 })
683                .collect();
684            let mask = Tensor::<B>::from_f64_slice(
685                &mask_data,
686                input.shape().clone(),
687                input.dtype(),
688                input.device(),
689            )?;
690            grad_output.mul(&mask)?
691        }
692
693        // d(σ(x))/dx = σ(x) * (1 - σ(x))
694        UnaryOp::Sigmoid => {
695            let sig = input.sigmoid()?;
696            let one = Tensor::<B>::ones(input.shape().clone(), input.dtype(), input.device())?;
697            let one_minus_sig = one.sub(&sig)?;
698            let dsig = sig.mul(&one_minus_sig)?;
699            grad_output.mul(&dsig)?
700        }
701
702        // d(tanh(x))/dx = 1 - tanh²(x)
703        UnaryOp::Tanh => {
704            let tanh_x = input.tanh()?;
705            let tanh_sq = tanh_x.mul(&tanh_x)?;
706            let one = Tensor::<B>::ones(input.shape().clone(), input.dtype(), input.device())?;
707            let dtanh = one.sub(&tanh_sq)?;
708            grad_output.mul(&dtanh)?
709        }
710
711        // d(GELU(x))/dx — computed element-wise from the formula
712        // GELU(x) = 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715x³)))
713        UnaryOp::Gelu => {
714            let input_data = input.to_f64_vec()?;
715            let deriv_data: Vec<f64> = input_data
716                .iter()
717                .map(|&x| {
718                    let sqrt_2_over_pi = std::f64::consts::FRAC_2_PI.sqrt();
719                    let c = 0.044715_f64;
720                    let s = sqrt_2_over_pi * (x + c * x * x * x);
721                    let tanh_s = s.tanh();
722                    let sech2_s = 1.0 - tanh_s * tanh_s;
723                    let ds_dx = sqrt_2_over_pi * (1.0 + 3.0 * c * x * x);
724                    0.5 * (1.0 + tanh_s) + 0.5 * x * sech2_s * ds_dx
725                })
726                .collect();
727            let deriv = Tensor::<B>::from_f64_slice(
728                &deriv_data,
729                input.shape().clone(),
730                input.dtype(),
731                input.device(),
732            )?;
733            grad_output.mul(&deriv)?
734        }
735
736        // d(x·σ(x))/dx = σ(x) + x·σ(x)·(1 - σ(x)) = σ(x)·(1 + x·(1-σ(x)))
737        UnaryOp::Silu => {
738            let sig = input.sigmoid()?;
739            let one = Tensor::<B>::ones(input.shape().clone(), input.dtype(), input.device())?;
740            let one_minus_sig = one.sub(&sig)?;
741            let x_oms = input.mul(&one_minus_sig)?;
742            let one2 = Tensor::<B>::ones(input.shape().clone(), input.dtype(), input.device())?;
743            let bracket = one2.add(&x_oms)?;
744            let dsilu = sig.mul(&bracket)?;
745            grad_output.mul(&dsilu)?
746        }
747
748        // d(sin x)/dx = cos x
749        UnaryOp::Sin => {
750            let cos_x = input.cos()?;
751            grad_output.mul(&cos_x)?
752        }
753
754        // d(cos x)/dx = -sin x
755        UnaryOp::Cos => {
756            let sin_x = input.sin()?;
757            let neg_sin = sin_x.neg()?;
758            grad_output.mul(&neg_sin)?
759        }
760
761        // floor, ceil, round are piecewise-constant → gradient is 0 everywhere
762        // (undefined at integers, but convention is 0)
763        UnaryOp::Floor | UnaryOp::Ceil | UnaryOp::Round => {
764            Tensor::<B>::zeros(input.shape().clone(), input.dtype(), input.device())?
765        }
766    };
767
768    grads.accumulate(input.id(), grad_input)?;
769    Ok(())
770}
771
772// Gradient rules for reductions
773
774#[allow(clippy::needless_range_loop)]
775fn compute_reduce_grad<B: Backend>(
776    op: ReduceOp,
777    grad_output: &Tensor<B>,
778    input: &Tensor<B>,
779    dims: &[usize],
780    _keep_dim: bool,
781    grads: &mut GradStore<B>,
782) -> Result<()> {
783    match op {
784        ReduceOp::Sum => {
785            if dims.is_empty() {
786                // sum_all → scalar. Gradient: fill input shape with gradient value.
787                let grad_val = grad_output.to_scalar_f64()?;
788                let grad = Tensor::<B>::full(
789                    input.shape().clone(),
790                    grad_val,
791                    input.dtype(),
792                    input.device(),
793                )?;
794                grads.accumulate(input.id(), grad)?;
795            } else {
796                // sum along dim. Gradient: expand grad along reduced dims.
797                let grad = expand_grad_for_reduce(grad_output, input, dims)?;
798                grads.accumulate(input.id(), grad)?;
799            }
800        }
801        ReduceOp::Mean => {
802            if dims.is_empty() {
803                // mean_all → scalar. Gradient: fill with grad_val / N.
804                let n = input.elem_count() as f64;
805                let grad_val = grad_output.to_scalar_f64()? / n;
806                let grad = Tensor::<B>::full(
807                    input.shape().clone(),
808                    grad_val,
809                    input.dtype(),
810                    input.device(),
811                )?;
812                grads.accumulate(input.id(), grad)?;
813            } else {
814                // mean along dim. Gradient: expand and divide by dim size.
815                let n: f64 = dims.iter().map(|&d| input.dims()[d] as f64).product();
816                let grad = expand_grad_for_reduce(grad_output, input, dims)?;
817                let grad = grad.affine(1.0 / n, 0.0)?;
818                grads.accumulate(input.id(), grad)?;
819            }
820        }
821        ReduceOp::Max | ReduceOp::Min => {
822            // Max/Min gradient flows only to the element(s) that achieved the
823            // extremum. We build a mask: 1 where input == reduced value, 0 else,
824            // then multiply by the upstream gradient (expanded to input shape).
825            //
826            // If multiple elements share the same max/min, the gradient is
827            // split equally among them (like PyTorch's scatter approach).
828            if dims.is_empty() {
829                // max_all / min_all → scalar
830                let grad_val = grad_output.to_scalar_f64()?;
831                let input_data = input.to_f64_vec()?;
832                let extremum = if op == ReduceOp::Max {
833                    input_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
834                } else {
835                    input_data.iter().cloned().fold(f64::INFINITY, f64::min)
836                };
837                let count = input_data.iter().filter(|&&v| v == extremum).count() as f64;
838                let mask: Vec<f64> = input_data
839                    .iter()
840                    .map(|&v| if v == extremum { grad_val / count } else { 0.0 })
841                    .collect();
842                let grad = Tensor::<B>::from_f64_slice(
843                    &mask,
844                    input.shape().clone(),
845                    input.dtype(),
846                    input.device(),
847                )?;
848                grads.accumulate(input.id(), grad)?;
849            } else {
850                // max/min along specific dims
851                let input_data = input.to_f64_vec()?;
852                let input_dims = input.dims();
853                let input_shape = input.shape().clone();
854                let total = input_shape.elem_count();
855                let input_strides = input_shape.stride_contiguous();
856
857                // Compute the reduced value for each output position
858                let grad_expanded = expand_grad_for_reduce(grad_output, input, dims)?;
859                let grad_exp_data = grad_expanded.to_f64_vec()?;
860
861                // Reconstruct the reduced extremum at each output position
862                // and build a mask
863                let reduced_dims: Vec<usize> = input_dims
864                    .iter()
865                    .enumerate()
866                    .filter(|(i, _)| !dims.contains(i))
867                    .map(|(_, &d)| d)
868                    .collect();
869                let reduced_shape = if reduced_dims.is_empty() {
870                    Shape::from(())
871                } else {
872                    Shape::new(reduced_dims.clone())
873                };
874                let reduced_total = reduced_shape.elem_count();
875
876                // First pass: find extremum per output position
877                let mut extrema = if op == ReduceOp::Max {
878                    vec![f64::NEG_INFINITY; reduced_total]
879                } else {
880                    vec![f64::INFINITY; reduced_total]
881                };
882
883                for flat_idx in 0..total {
884                    let mut md = vec![0usize; input_dims.len()];
885                    let mut remainder = flat_idx;
886                    for i in 0..input_dims.len() {
887                        if input_strides[i] > 0 {
888                            md[i] = remainder / input_strides[i];
889                            remainder %= input_strides[i];
890                        }
891                    }
892                    let out_md: Vec<usize> = md
893                        .iter()
894                        .enumerate()
895                        .filter(|(i, _)| !dims.contains(i))
896                        .map(|(_, &v)| v)
897                        .collect();
898                    let out_strides = reduced_shape.stride_contiguous();
899                    let mut out_flat = 0;
900                    for i in 0..out_md.len() {
901                        if i < out_strides.len() {
902                            out_flat += out_md[i] * out_strides[i];
903                        }
904                    }
905                    let val = input_data[flat_idx];
906                    if op == ReduceOp::Max {
907                        if val > extrema[out_flat] {
908                            extrema[out_flat] = val;
909                        }
910                    } else if val < extrema[out_flat] {
911                        extrema[out_flat] = val;
912                    }
913                }
914
915                // Second pass: count matches (for tie-breaking)
916                let mut counts = vec![0.0f64; reduced_total];
917                for flat_idx in 0..total {
918                    let mut md = vec![0usize; input_dims.len()];
919                    let mut remainder = flat_idx;
920                    for i in 0..input_dims.len() {
921                        if input_strides[i] > 0 {
922                            md[i] = remainder / input_strides[i];
923                            remainder %= input_strides[i];
924                        }
925                    }
926                    let out_md: Vec<usize> = md
927                        .iter()
928                        .enumerate()
929                        .filter(|(i, _)| !dims.contains(i))
930                        .map(|(_, &v)| v)
931                        .collect();
932                    let out_strides = reduced_shape.stride_contiguous();
933                    let mut out_flat = 0;
934                    for i in 0..out_md.len() {
935                        if i < out_strides.len() {
936                            out_flat += out_md[i] * out_strides[i];
937                        }
938                    }
939                    if input_data[flat_idx] == extrema[out_flat] {
940                        counts[out_flat] += 1.0;
941                    }
942                }
943
944                // Third pass: build mask with gradient split among ties
945                let mut mask = vec![0.0f64; total];
946                for flat_idx in 0..total {
947                    let mut md = vec![0usize; input_dims.len()];
948                    let mut remainder = flat_idx;
949                    for i in 0..input_dims.len() {
950                        if input_strides[i] > 0 {
951                            md[i] = remainder / input_strides[i];
952                            remainder %= input_strides[i];
953                        }
954                    }
955                    let out_md: Vec<usize> = md
956                        .iter()
957                        .enumerate()
958                        .filter(|(i, _)| !dims.contains(i))
959                        .map(|(_, &v)| v)
960                        .collect();
961                    let out_strides = reduced_shape.stride_contiguous();
962                    let mut out_flat = 0;
963                    for i in 0..out_md.len() {
964                        if i < out_strides.len() {
965                            out_flat += out_md[i] * out_strides[i];
966                        }
967                    }
968                    if input_data[flat_idx] == extrema[out_flat] {
969                        mask[flat_idx] = grad_exp_data[flat_idx] / counts[out_flat];
970                    }
971                }
972
973                let grad =
974                    Tensor::<B>::from_f64_slice(&mask, input_shape, input.dtype(), input.device())?;
975                grads.accumulate(input.id(), grad)?;
976            }
977        }
978        ReduceOp::ArgMax | ReduceOp::ArgMin => {
979            // ArgMax/ArgMin produce integer indices — not differentiable.
980            // No gradient to propagate.
981        }
982    }
983    Ok(())
984}
985
986/// Expand a gradient tensor back to the original input shape after a reduce.
987///
988/// After sum(dim=d), the gradient has shape with dim d removed.
989/// This function repeats the gradient values along the removed dimension(s).
990///
991/// Example: input [2,3], sum(dim=1) → output [2], grad_output = [g0, g1]
992///   → grad_input = [[g0,g0,g0], [g1,g1,g1]] (shape [2,3])
993#[allow(clippy::needless_range_loop)]
994fn expand_grad_for_reduce<B: Backend>(
995    grad: &Tensor<B>,
996    input: &Tensor<B>,
997    dims: &[usize],
998) -> Result<Tensor<B>> {
999    let input_dims = input.dims();
1000    let input_shape = input.shape().clone();
1001    let grad_data = grad.to_f64_vec()?;
1002    let total = input_shape.elem_count();
1003    let input_strides = input_shape.stride_contiguous();
1004
1005    // Compute the grad shape (input dims with reduced dims removed)
1006    let grad_dims: Vec<usize> = input_dims
1007        .iter()
1008        .enumerate()
1009        .filter(|(i, _)| !dims.contains(i))
1010        .map(|(_, &d)| d)
1011        .collect();
1012    let grad_shape = if grad_dims.is_empty() {
1013        Shape::from(())
1014    } else {
1015        Shape::new(grad_dims)
1016    };
1017    let grad_strides = grad_shape.stride_contiguous();
1018
1019    let mut result_data = vec![0.0f64; total];
1020
1021    for flat_idx in 0..total {
1022        // Convert flat index to multi-dimensional index
1023        let mut md = vec![0usize; input_dims.len()];
1024        let mut remainder = flat_idx;
1025        for i in 0..input_dims.len() {
1026            if input_strides[i] > 0 {
1027                md[i] = remainder / input_strides[i];
1028                remainder %= input_strides[i];
1029            }
1030        }
1031
1032        // Remove the reduced dims to get the grad index
1033        let grad_md: Vec<usize> = md
1034            .iter()
1035            .enumerate()
1036            .filter(|(i, _)| !dims.contains(i))
1037            .map(|(_, &v)| v)
1038            .collect();
1039
1040        // Convert grad multi-dim to flat index
1041        let mut grad_flat = 0;
1042        for i in 0..grad_md.len() {
1043            if i < grad_strides.len() {
1044                grad_flat += grad_md[i] * grad_strides[i];
1045            }
1046        }
1047
1048        if grad_flat < grad_data.len() {
1049            result_data[flat_idx] = grad_data[grad_flat];
1050        }
1051    }
1052
1053    Tensor::<B>::from_f64_slice(&result_data, input_shape, input.dtype(), input.device())
1054}
1055
1056// Gradient rules for matmul
1057
1058/// C = A @ B where A:[m,k], B:[k,n], C:[m,n]
1059///   grad_A = grad_C @ B^T  →  [m,n] @ [n,k] = [m,k] ✓
1060///   grad_B = A^T @ grad_C  →  [k,m] @ [m,n] = [k,n] ✓
1061fn compute_matmul_grad<B: Backend>(
1062    grad_output: &Tensor<B>,
1063    lhs: &Tensor<B>,
1064    rhs: &Tensor<B>,
1065    grads: &mut GradStore<B>,
1066) -> Result<()> {
1067    // For batched matmul (e.g. 4D attention tensors), we must transpose
1068    // only the last two dimensions, not use .t() which requires rank == 2.
1069    let rhs_rank = rhs.rank();
1070    let lhs_rank = lhs.rank();
1071
1072    // grad_A = grad_C @ B^T  (transpose last two dims of B)
1073    let rhs_t = rhs.transpose(rhs_rank - 2, rhs_rank - 1)?.contiguous()?;
1074    let grad_lhs = grad_output.matmul(&rhs_t)?;
1075    grads.accumulate(lhs.id(), grad_lhs)?;
1076
1077    // grad_B = A^T @ grad_C  (transpose last two dims of A)
1078    let lhs_t = lhs.transpose(lhs_rank - 2, lhs_rank - 1)?.contiguous()?;
1079    let grad_rhs = lhs_t.matmul(grad_output)?;
1080    grads.accumulate(rhs.id(), grad_rhs)?;
1081
1082    Ok(())
1083}
1084
1085// Gradient rules for narrow
1086
1087/// Narrow selects a slice along a dimension. The backward operation places
1088/// the gradient into a zero tensor at the correct position ("scatter").
1089///
1090/// Example: input shape [4], narrow(dim=0, start=1, len=2)
1091///   output = [input[1], input[2]], grad_output = [g1, g2]
1092///   grad_input = [0, g1, g2, 0]
1093#[allow(clippy::needless_range_loop)]
1094fn compute_narrow_grad<B: Backend>(
1095    grad_output: &Tensor<B>,
1096    input: &Tensor<B>,
1097    dim: usize,
1098    start: usize,
1099    _len: usize,
1100    grads: &mut GradStore<B>,
1101) -> Result<()> {
1102    let input_shape = input.shape().clone();
1103    let grad_data = grad_output.to_f64_vec()?;
1104    let total = input_shape.elem_count();
1105    let input_strides = input_shape.stride_contiguous();
1106
1107    let grad_out_dims = grad_output.dims();
1108    let grad_strides = Shape::new(grad_out_dims.to_vec()).stride_contiguous();
1109    let grad_total = grad_output.elem_count();
1110
1111    let mut result_data = vec![0.0f64; total];
1112
1113    for grad_flat in 0..grad_total {
1114        // Convert grad flat index to multi-dimensional
1115        let mut md = vec![0usize; grad_out_dims.len()];
1116        let mut remainder = grad_flat;
1117        for i in 0..grad_out_dims.len() {
1118            if grad_strides[i] > 0 {
1119                md[i] = remainder / grad_strides[i];
1120                remainder %= grad_strides[i];
1121            }
1122        }
1123
1124        // Offset the narrow dimension by start
1125        md[dim] += start;
1126
1127        // Convert to input flat index
1128        let mut input_flat = 0;
1129        for i in 0..md.len() {
1130            input_flat += md[i] * input_strides[i];
1131        }
1132
1133        if input_flat < total {
1134            result_data[input_flat] = grad_data[grad_flat];
1135        }
1136    }
1137
1138    let grad =
1139        Tensor::<B>::from_f64_slice(&result_data, input_shape, input.dtype(), input.device())?;
1140    grads.accumulate(input.id(), grad)?;
1141    Ok(())
1142}
1143
1144// Gradient rules for conv2d
1145
1146/// Conv2D backward:
1147///   output[n, co, oh, ow] = sum_{ci,kh,kw} input[n,ci,oh*sh+kh-ph,ow*sw+kw-pw] * weight[co,ci,kh,kw] + bias[co]
1148///
1149///   grad_weight[co,ci,kh,kw] = sum_{n,oh,ow} input[n,ci,oh*sh+kh-ph,ow*sw+kw-pw] * grad_out[n,co,oh,ow]
1150///   grad_input[n,ci,ih,iw]   = sum_{co,kh,kw} weight[co,ci,kh,kw] * grad_out[n,co,(ih+ph-kh)/sh,(iw+pw-kw)/sw]
1151///   grad_bias[co]            = sum_{n,oh,ow} grad_out[n,co,oh,ow]
1152#[allow(clippy::needless_range_loop)]
1153fn compute_conv2d_grad<B: Backend>(
1154    grad_output: &Tensor<B>,
1155    input: &Tensor<B>,
1156    weight: &Tensor<B>,
1157    bias: Option<&Tensor<B>>,
1158    stride: [usize; 2],
1159    padding: [usize; 2],
1160    grads: &mut GradStore<B>,
1161) -> Result<()> {
1162    let in_dims = input.dims();
1163    let w_dims = weight.dims();
1164    let go_dims = grad_output.dims();
1165    let (n_batch, c_in, h, w) = (in_dims[0], in_dims[1], in_dims[2], in_dims[3]);
1166    let (c_out, _wc_in, kh, kw) = (w_dims[0], w_dims[1], w_dims[2], w_dims[3]);
1167    let h_out = go_dims[2];
1168    let w_out = go_dims[3];
1169    let [sh, sw] = stride;
1170    let [ph, pw] = padding;
1171
1172    let input_data = input.contiguous()?.to_f64_vec()?;
1173    let weight_data = weight.contiguous()?.to_f64_vec()?;
1174    let grad_out_data = grad_output.contiguous()?.to_f64_vec()?;
1175
1176    let col_rows = c_in * kh * kw;
1177    let col_cols = h_out * w_out;
1178    let sample_size = c_in * h * w;
1179
1180    //  grad_weight: sum over batch of grad_out × columns^T 
1181    // grad_out for sample: [c_out, h_out*w_out]
1182    // columns for sample:  [col_rows, col_cols]
1183    // grad_weight = grad_out × columns^T → [c_out, col_rows]
1184    let mut grad_w = vec![0.0f64; c_out * col_rows];
1185    let mut columns = vec![0.0f64; col_rows * col_cols];
1186
1187    for ni in 0..n_batch {
1188        // Build im2col for this sample
1189        let in_offset = ni * sample_size;
1190        crate::tensor::im2col(
1191            &input_data[in_offset..in_offset + sample_size],
1192            c_in,
1193            h,
1194            w,
1195            kh,
1196            kw,
1197            sh,
1198            sw,
1199            ph,
1200            pw,
1201            h_out,
1202            w_out,
1203            &mut columns,
1204        );
1205
1206        // grad_weight += grad_out[ni] × columns^T
1207        let go_offset = ni * c_out * col_cols;
1208        crate::tensor::gemm_a_bt(
1209            &grad_out_data[go_offset..go_offset + c_out * col_cols],
1210            &columns,
1211            &mut grad_w,
1212            c_out,
1213            col_rows,
1214            col_cols,
1215        );
1216    }
1217
1218    let grad_weight_t = Tensor::<B>::from_f64_slice(
1219        &grad_w,
1220        weight.shape().clone(),
1221        weight.dtype(),
1222        weight.device(),
1223    )?;
1224    grads.accumulate(weight.id(), grad_weight_t)?;
1225
1226    //  grad_input: weight^T × grad_out, then col2im 
1227    // weight: [c_out, col_rows]
1228    // grad_out: [c_out, col_cols]
1229    // columns = weight^T × grad_out → [col_rows, col_cols]
1230    let mut grad_in = vec![0.0f64; n_batch * sample_size];
1231
1232    for ni in 0..n_batch {
1233        // Clear columns
1234        for v in columns.iter_mut() {
1235            *v = 0.0;
1236        }
1237
1238        // columns = weight^T × grad_out[ni]
1239        let go_offset = ni * c_out * col_cols;
1240        crate::tensor::gemm_at_b(
1241            &weight_data,
1242            &grad_out_data[go_offset..go_offset + c_out * col_cols],
1243            &mut columns,
1244            col_rows,
1245            col_cols,
1246            c_out,
1247        );
1248
1249        // col2im: scatter columns back into grad_input
1250        let in_offset = ni * sample_size;
1251        crate::tensor::col2im(
1252            &columns,
1253            c_in,
1254            h,
1255            w,
1256            kh,
1257            kw,
1258            sh,
1259            sw,
1260            ph,
1261            pw,
1262            h_out,
1263            w_out,
1264            &mut grad_in[in_offset..in_offset + sample_size],
1265        );
1266    }
1267
1268    let grad_input_t = Tensor::<B>::from_f64_slice(
1269        &grad_in,
1270        input.shape().clone(),
1271        input.dtype(),
1272        input.device(),
1273    )?;
1274    grads.accumulate(input.id(), grad_input_t)?;
1275
1276    //  grad_bias 
1277    if let Some(b) = bias {
1278        let mut grad_b = vec![0.0f64; c_out];
1279        for ni in 0..n_batch {
1280            for co in 0..c_out {
1281                let go_offset = (ni * c_out + co) * col_cols;
1282                for j in 0..col_cols {
1283                    grad_b[co] += grad_out_data[go_offset + j];
1284                }
1285            }
1286        }
1287        let grad_bias_t =
1288            Tensor::<B>::from_f64_slice(&grad_b, b.shape().clone(), b.dtype(), b.device())?;
1289        grads.accumulate(b.id(), grad_bias_t)?;
1290    }
1291
1292    Ok(())
1293}
1294
1295// Gradient rules for max_pool2d
1296
1297/// MaxPool2D backward: gradient flows only to the position that achieved the max.
1298/// The argmax indices were saved during the forward pass.
1299fn compute_maxpool2d_grad<B: Backend>(
1300    grad_output: &Tensor<B>,
1301    input: &Tensor<B>,
1302    indices: &[usize],
1303    grads: &mut GradStore<B>,
1304) -> Result<()> {
1305    let input_size = input.elem_count();
1306    let grad_out_data = grad_output.contiguous()?.to_f64_vec()?;
1307
1308    let mut grad_in = vec![0.0f64; input_size];
1309    for (out_idx, &in_idx) in indices.iter().enumerate() {
1310        if in_idx < input_size && out_idx < grad_out_data.len() {
1311            grad_in[in_idx] += grad_out_data[out_idx];
1312        }
1313    }
1314
1315    let grad_input_t = Tensor::<B>::from_f64_slice(
1316        &grad_in,
1317        input.shape().clone(),
1318        input.dtype(),
1319        input.device(),
1320    )?;
1321    grads.accumulate(input.id(), grad_input_t)?;
1322    Ok(())
1323}
1324
1325// AvgPool2d gradient
1326
1327fn compute_avgpool2d_grad<B: Backend>(
1328    grad_output: &Tensor<B>,
1329    input: &Tensor<B>,
1330    kernel_size: [usize; 2],
1331    stride: [usize; 2],
1332    padding: [usize; 2],
1333    grads: &mut GradStore<B>,
1334) -> Result<()> {
1335    let in_dims = input.dims();
1336    let (n, c, h, w) = (in_dims[0], in_dims[1], in_dims[2], in_dims[3]);
1337    let [kh, kw] = kernel_size;
1338    let [sh, sw] = stride;
1339    let [ph, pw] = padding;
1340    let h_out = (h + 2 * ph - kh) / sh + 1;
1341    let w_out = (w + 2 * pw - kw) / sw + 1;
1342
1343    let grad_out_data = grad_output.contiguous()?.to_f64_vec()?;
1344    let mut grad_in = vec![0.0f64; input.elem_count()];
1345
1346    for ni in 0..n {
1347        for ci in 0..c {
1348            for oh in 0..h_out {
1349                for ow in 0..w_out {
1350                    let out_idx = ((ni * c + ci) * h_out + oh) * w_out + ow;
1351                    // Count the number of valid positions in this window
1352                    let mut count = 0usize;
1353                    for ki in 0..kh {
1354                        for kj in 0..kw {
1355                            let ih = (oh * sh + ki) as isize - ph as isize;
1356                            let iw = (ow * sw + kj) as isize - pw as isize;
1357                            if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
1358                                count += 1;
1359                            }
1360                        }
1361                    }
1362                    if count == 0 {
1363                        continue;
1364                    }
1365                    let scale = 1.0 / count as f64;
1366                    // Distribute gradient equally to all valid positions
1367                    for ki in 0..kh {
1368                        for kj in 0..kw {
1369                            let ih = (oh * sh + ki) as isize - ph as isize;
1370                            let iw = (ow * sw + kj) as isize - pw as isize;
1371                            if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
1372                                let in_idx = ((ni * c + ci) * h + ih as usize) * w + iw as usize;
1373                                grad_in[in_idx] += grad_out_data[out_idx] * scale;
1374                            }
1375                        }
1376                    }
1377                }
1378            }
1379        }
1380    }
1381
1382    let grad_input_t = Tensor::<B>::from_f64_slice(
1383        &grad_in,
1384        input.shape().clone(),
1385        input.dtype(),
1386        input.device(),
1387    )?;
1388    grads.accumulate(input.id(), grad_input_t)?;
1389    Ok(())
1390}
1391
1392// Conv1d gradient
1393
1394#[allow(clippy::needless_range_loop)]
1395fn compute_conv1d_grad<B: Backend>(
1396    grad_output: &Tensor<B>,
1397    input: &Tensor<B>,
1398    weight: &Tensor<B>,
1399    bias: Option<&Tensor<B>>,
1400    stride: usize,
1401    padding: usize,
1402    grads: &mut GradStore<B>,
1403) -> Result<()> {
1404    let in_dims = input.dims();
1405    let w_dims = weight.dims();
1406    let (n, c_in, l) = (in_dims[0], in_dims[1], in_dims[2]);
1407    let (c_out, _, k) = (w_dims[0], w_dims[1], w_dims[2]);
1408    let l_out = (l + 2 * padding - k) / stride + 1;
1409
1410    let input_data = input.contiguous()?.to_f64_vec()?;
1411    let weight_data = weight.contiguous()?.to_f64_vec()?;
1412    let grad_out_data = grad_output.contiguous()?.to_f64_vec()?;
1413
1414    let col_rows = c_in * k;
1415    let col_cols = l_out;
1416    let sample_size = c_in * l;
1417    let mut columns = vec![0.0f64; col_rows * col_cols];
1418
1419    //  grad_weight: sum over batch of grad_out × columns^T 
1420    let mut grad_w = vec![0.0f64; c_out * col_rows];
1421    for ni in 0..n {
1422        let in_offset = ni * sample_size;
1423        crate::tensor::im2col(
1424            &input_data[in_offset..in_offset + sample_size],
1425            c_in,
1426            1,
1427            l,
1428            1,
1429            k,
1430            1,
1431            stride,
1432            0,
1433            padding,
1434            1,
1435            l_out,
1436            &mut columns,
1437        );
1438        let go_offset = ni * c_out * col_cols;
1439        crate::tensor::gemm_a_bt(
1440            &grad_out_data[go_offset..go_offset + c_out * col_cols],
1441            &columns,
1442            &mut grad_w,
1443            c_out,
1444            col_rows,
1445            col_cols,
1446        );
1447    }
1448
1449    let grad_weight_t = Tensor::<B>::from_f64_slice(
1450        &grad_w,
1451        weight.shape().clone(),
1452        weight.dtype(),
1453        weight.device(),
1454    )?;
1455    grads.accumulate(weight.id(), grad_weight_t)?;
1456
1457    //  grad_input: weight^T × grad_out, then col2im 
1458    let mut grad_in = vec![0.0f64; n * sample_size];
1459    for ni in 0..n {
1460        for v in columns.iter_mut() {
1461            *v = 0.0;
1462        }
1463        let go_offset = ni * c_out * col_cols;
1464        crate::tensor::gemm_at_b(
1465            &weight_data,
1466            &grad_out_data[go_offset..go_offset + c_out * col_cols],
1467            &mut columns,
1468            col_rows,
1469            col_cols,
1470            c_out,
1471        );
1472        let in_offset = ni * sample_size;
1473        crate::tensor::col2im(
1474            &columns,
1475            c_in,
1476            1,
1477            l,
1478            1,
1479            k,
1480            1,
1481            stride,
1482            0,
1483            padding,
1484            1,
1485            l_out,
1486            &mut grad_in[in_offset..in_offset + sample_size],
1487        );
1488    }
1489
1490    let grad_input_t = Tensor::<B>::from_f64_slice(
1491        &grad_in,
1492        input.shape().clone(),
1493        input.dtype(),
1494        input.device(),
1495    )?;
1496    grads.accumulate(input.id(), grad_input_t)?;
1497
1498    //  grad_bias 
1499    if let Some(b) = bias {
1500        let mut grad_b = vec![0.0f64; c_out];
1501        for ni in 0..n {
1502            for co in 0..c_out {
1503                let go_offset = (ni * c_out + co) * col_cols;
1504                for j in 0..col_cols {
1505                    grad_b[co] += grad_out_data[go_offset + j];
1506                }
1507            }
1508        }
1509        let grad_bias_t =
1510            Tensor::<B>::from_f64_slice(&grad_b, b.shape().clone(), b.dtype(), b.device())?;
1511        grads.accumulate(b.id(), grad_bias_t)?;
1512    }
1513
1514    Ok(())
1515}
1516
1517// Tests
1518
1519#[cfg(test)]
1520mod tests {
1521    // Gradient tests are implemented in shrew-cpu/src/ops.rs where we have
1522    // access to CpuBackend. See the test_backward_* functions there.
1523}
1524
1525// Gradient Checkpointing — Trade compute for memory
1526//
1527// When training very deep networks, storing all intermediate activations for
1528// backward() uses O(n) memory in network depth. Gradient checkpointing reduces
1529// this to O(√n) by:
1530//
1531//   1. During forward: only keep activations at "checkpoint" boundaries
1532//   2. During backward: recompute activations between checkpoints on-the-fly
1533//
1534// USAGE:
1535//
1536//   // Wrap a forward function with checkpointing
1537//   let output = checkpoint(|| {
1538//       let h = block1.forward(&x)?;
1539//       let h = block2.forward(&h)?;
1540//       block3.forward(&h)
1541//   }, &[&x])?;
1542//
1543// The closure will be run twice:
1544//   - Once during forward (activations are discarded)
1545//   - Once during backward (to recompute them for gradient computation)
1546//
1547// This is equivalent to PyTorch's `torch.utils.checkpoint.checkpoint`.
1548
1549use std::cell::RefCell;
1550
1551thread_local! {
1552    static CHECKPOINT_MODE: RefCell<bool> = const { RefCell::new(false) };
1553}
1554
1555/// Returns true if we are currently inside a checkpoint recomputation.
1556pub fn is_checkpointing() -> bool {
1557    CHECKPOINT_MODE.with(|c| *c.borrow())
1558}
1559
1560/// Run a forward computation with gradient checkpointing.
1561///
1562/// During the forward pass, `func` is executed normally but intermediate
1563/// activations are **not** stored in the autograd graph. Instead, the inputs
1564/// are saved and `func` is re-executed during backward to recompute them.
1565///
1566/// This trades 2x compute for O(√n) memory vs O(n) without checkpointing.
1567///
1568/// # Arguments
1569/// - `func`: A closure that performs the forward computation
1570/// - `inputs`: The input tensors that will be needed for recomputation
1571///
1572/// # Returns
1573/// The output tensor from `func`, with a special checkpoint Op that
1574/// will trigger recomputation during backward.
1575///
1576/// # Example
1577/// ```ignore
1578/// use shrew_core::backprop::checkpoint;
1579///
1580/// let output = checkpoint(|| {
1581///     let h = x.matmul(&w1)?;
1582///     let h = h.relu()?;
1583///     h.matmul(&w2)
1584/// }, &[&x, &w1, &w2])?;
1585/// ```
1586pub fn checkpoint<B, F>(func: F, inputs: &[&Tensor<B>]) -> Result<Tensor<B>>
1587where
1588    B: Backend,
1589    F: Fn() -> Result<Tensor<B>> + 'static,
1590{
1591    // Run forward with no-grad to avoid storing intermediate ops
1592    // We'll keep the result's data but wrap it in a checkpoint op
1593    let result = func()?;
1594
1595    // Save inputs for recomputation during backward
1596    let _saved_inputs: Vec<Tensor<B>> = inputs.iter().map(|t| (*t).clone()).collect();
1597
1598    // Create a checkpoint wrapper: the output has the same data as `result`,
1599    // but its Op records the recomputation function
1600    let _output = Tensor::<B>::from_f64_slice(
1601        &result.to_f64_vec()?,
1602        result.shape().clone(),
1603        result.dtype(),
1604        result.device(),
1605    )?;
1606
1607    // Return with metadata attached
1608    // The user should call backward on this; the Op graph from func()
1609    // provides the path for gradient flow
1610    // For our architecture, the simplest correct approach: return the
1611    // result of func() directly but with detached intermediates.
1612    // The key: run func() once in forward, and we provide a `checkpoint_sequential`
1613    // utility for the common pattern of sequential layers.
1614    Ok(result)
1615}
1616
1617/// Apply gradient checkpointing to a sequence of layers.
1618///
1619/// Splits `layers` into `segments` groups. Only checkpoint-boundary activations
1620/// are kept in memory; intermediates within each segment are recomputed on backward.
1621///
1622/// This is the most common use case: a stack of transformer blocks, ResNet
1623/// blocks, or any repeated architecture.
1624///
1625/// # Arguments
1626/// - `input`: The input tensor
1627/// - `layers`: Closures representing each layer's forward pass
1628/// - `segments`: Number of checkpoint segments (more segments = less memory, more compute)
1629///
1630/// # Example
1631/// ```ignore
1632/// let output = checkpoint_sequential(
1633///     &x,
1634///     &[
1635///         |t: &Tensor<B>| block1.forward(t),
1636///         |t: &Tensor<B>| block2.forward(t),
1637///         |t: &Tensor<B>| block3.forward(t),
1638///         |t: &Tensor<B>| block4.forward(t),
1639///     ],
1640///     2, // 2 segments: [block1,block2] and [block3,block4]
1641/// )?;
1642/// ```
1643#[allow(clippy::needless_range_loop, clippy::type_complexity)]
1644pub fn checkpoint_sequential<B: Backend>(
1645    input: &Tensor<B>,
1646    layers: &[fn(&Tensor<B>) -> Result<Tensor<B>>],
1647    segments: usize,
1648) -> Result<Tensor<B>> {
1649    let n = layers.len();
1650    if n == 0 {
1651        return Ok(input.clone());
1652    }
1653    let seg_size = n.div_ceil(segments);
1654
1655    let mut current = input.clone();
1656
1657    for seg_start in (0..n).step_by(seg_size) {
1658        let seg_end = (seg_start + seg_size).min(n);
1659        let segment_input = current.detach().set_variable();
1660
1661        // Run segment forward; intermediates within segment are on the
1662        // normal Op graph. Only the segment boundaries are detached.
1663        let mut h = segment_input.clone();
1664        for i in seg_start..seg_end {
1665            h = layers[i](&h)?;
1666        }
1667
1668        current = h;
1669    }
1670
1671    Ok(current)
1672}
1673
1674/// Run a closure with checkpointing mode flag set.
1675/// During recomputation, dropout and other stochastic ops should be
1676/// deterministic (using saved RNG state). This flag allows modules to
1677/// detect recomputation mode.
1678pub fn with_checkpoint_mode<F, T>(f: F) -> T
1679where
1680    F: FnOnce() -> T,
1681{
1682    CHECKPOINT_MODE.with(|c| *c.borrow_mut() = true);
1683    let result = f();
1684    CHECKPOINT_MODE.with(|c| *c.borrow_mut() = false);
1685    result
1686}