1use std::collections::{HashMap, HashSet};
72
73use crate::backend::{Backend, BinaryOp, ReduceOp, UnaryOp};
74use crate::error::Result;
75use crate::op::{Op, TensorId};
76use crate::shape::Shape;
77use crate::tensor::Tensor;
78
79pub struct GradStore<B: Backend> {
84 grads: HashMap<TensorId, Tensor<B>>,
85}
86
87impl<B: Backend> Clone for GradStore<B> {
88 fn clone(&self) -> Self {
89 GradStore {
90 grads: self.grads.clone(),
91 }
92 }
93}
94
95impl<B: Backend> Default for GradStore<B> {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101impl<B: Backend> GradStore<B> {
102 pub fn new() -> Self {
104 GradStore {
105 grads: HashMap::new(),
106 }
107 }
108
109 pub fn get(&self, tensor: &Tensor<B>) -> Option<&Tensor<B>> {
111 self.grads.get(&tensor.id())
112 }
113
114 fn get_by_id(&self, id: &TensorId) -> Option<&Tensor<B>> {
115 self.grads.get(id)
116 }
117
118 pub fn accumulate(&mut self, id: TensorId, grad: Tensor<B>) -> Result<()> {
122 if let Some(existing) = self.grads.get(&id) {
123 let new_grad = existing.add(&grad)?;
124 self.grads.insert(id, new_grad);
125 } else {
126 self.grads.insert(id, grad);
127 }
128 Ok(())
129 }
130}
131
132fn build_topo<B: Backend>(root: &Tensor<B>) -> Vec<Tensor<B>> {
138 let mut visited = HashSet::new();
139 let mut order = Vec::new();
140
141 fn visit<B: Backend>(
142 t: &Tensor<B>,
143 visited: &mut HashSet<TensorId>,
144 order: &mut Vec<Tensor<B>>,
145 ) {
146 if visited.contains(&t.id()) {
147 return;
148 }
149 visited.insert(t.id());
150 for input in t.op().inputs() {
152 visit(input, visited, order);
153 }
154 order.push(t.clone());
156 }
157
158 visit(root, &mut visited, &mut order);
159 order
160}
161
162#[allow(clippy::needless_range_loop)]
167pub fn backward<B: Backend>(root: &Tensor<B>) -> Result<GradStore<B>> {
168 if root.elem_count() != 1 {
170 return Err(crate::Error::msg(
171 "backward() requires a scalar tensor (single element). \
172 Use .sum_all() or .mean_all() to reduce to a scalar first.",
173 ));
174 }
175
176 let topo = build_topo(root);
178
179 let mut grads = GradStore::new();
181 let ones = Tensor::<B>::ones(root.shape().clone(), root.dtype(), root.device())?;
182 grads.grads.insert(root.id(), ones);
183
184 for tensor in topo.iter().rev() {
186 let grad_output = match grads.get_by_id(&tensor.id()) {
187 Some(g) => g.clone(),
188 None => continue, };
190
191 match tensor.op() {
192 Op::None => {
193 }
195
196 Op::Contiguous { input } => {
197 grads.accumulate(input.id(), grad_output)?;
199 }
200
201 Op::Binary { lhs, rhs, op } => {
202 compute_binary_grad(*op, &grad_output, lhs, rhs, &mut grads)?;
203 }
204
205 Op::Unary { input, op } => {
206 compute_unary_grad(*op, &grad_output, input, &mut grads)?;
207 }
208
209 Op::Reduce {
210 input,
211 op,
212 dims,
213 keep_dim,
214 } => {
215 compute_reduce_grad(*op, &grad_output, input, dims, *keep_dim, &mut grads)?;
216 }
217
218 Op::Matmul { lhs, rhs } => {
219 compute_matmul_grad(&grad_output, lhs, rhs, &mut grads)?;
220 }
221
222 Op::Reshape { input, src_shape } => {
223 let grad = grad_output.reshape(src_shape.clone())?;
225 grads.accumulate(input.id(), grad)?;
226 }
227
228 Op::Transpose { input, dim0, dim1 } => {
229 let grad = grad_output.transpose(*dim0, *dim1)?;
231 grads.accumulate(input.id(), grad)?;
232 }
233
234 Op::Narrow {
235 input,
236 dim,
237 start,
238 len,
239 } => {
240 compute_narrow_grad(&grad_output, input, *dim, *start, *len, &mut grads)?;
242 }
243
244 Op::Affine { input, mul, .. } => {
245 let grad = grad_output.affine(*mul, 0.0)?;
247 grads.accumulate(input.id(), grad)?;
248 }
249
250 Op::Conv2d {
251 input,
252 weight,
253 bias,
254 stride,
255 padding,
256 } => {
257 compute_conv2d_grad(
258 &grad_output,
259 input,
260 weight,
261 bias.as_ref(),
262 *stride,
263 *padding,
264 &mut grads,
265 )?;
266 }
267
268 Op::MaxPool2d { input, indices, .. } => {
269 compute_maxpool2d_grad(&grad_output, input, indices, &mut grads)?;
270 }
271
272 Op::Cat { inputs, dim, sizes } => {
273 let mut offset = 0usize;
275 for (inp, &sz) in inputs.iter().zip(sizes.iter()) {
276 let grad_slice = grad_output.narrow(*dim, offset, sz)?;
277 grads.accumulate(inp.id(), grad_slice)?;
278 offset += sz;
279 }
280 }
281
282 Op::Powf { input, exponent } => {
283 let n = *exponent;
285 let x_pow_nm1 = input.powf(n - 1.0)?;
286 let n_tensor =
287 Tensor::<B>::full(input.shape().clone(), n, input.dtype(), input.device())?;
288 let grad = grad_output.mul(&n_tensor)?.mul(&x_pow_nm1)?;
289 grads.accumulate(input.id(), grad)?;
290 }
291
292 Op::Clamp { input, min, max } => {
293 let input_data = input.to_f64_vec()?;
295 let grad_data = grad_output.to_f64_vec()?;
296 let mask: Vec<f64> = input_data
297 .iter()
298 .zip(grad_data.iter())
299 .map(|(&x, &g)| if x > *min && x < *max { g } else { 0.0 })
300 .collect();
301 let grad = Tensor::<B>::from_f64_slice(
302 &mask,
303 input.shape().clone(),
304 input.dtype(),
305 input.device(),
306 )?;
307 grads.accumulate(input.id(), grad)?;
308 }
309
310 Op::WhereCond {
311 mask,
312 on_true,
313 on_false,
314 } => {
315 let mask_data = mask.to_f64_vec()?;
317 let grad_data = grad_output.to_f64_vec()?;
318 let n = mask_data.len();
319
320 let grad_true_data: Vec<f64> = (0..n)
321 .map(|i| {
322 if mask_data[i] != 0.0 {
323 grad_data[i]
324 } else {
325 0.0
326 }
327 })
328 .collect();
329 let grad_false_data: Vec<f64> = (0..n)
330 .map(|i| {
331 if mask_data[i] == 0.0 {
332 grad_data[i]
333 } else {
334 0.0
335 }
336 })
337 .collect();
338
339 let grad_true = Tensor::<B>::from_f64_slice(
340 &grad_true_data,
341 on_true.shape().clone(),
342 on_true.dtype(),
343 on_true.device(),
344 )?;
345 let grad_false = Tensor::<B>::from_f64_slice(
346 &grad_false_data,
347 on_false.shape().clone(),
348 on_false.dtype(),
349 on_false.device(),
350 )?;
351 grads.accumulate(on_true.id(), grad_true)?;
352 grads.accumulate(on_false.id(), grad_false)?;
353 }
355
356 Op::Gather { input, index, dim } => {
357 let dim = *dim;
361 let input_dims = input.dims();
362 let rank = input_dims.len();
363
364 let grad_data = grad_output.to_f64_vec()?;
366 let index_data = index.to_f64_vec()?;
367
368 let mut grad_input_data = vec![0.0f64; input.elem_count()];
370 let input_strides = input.shape().stride_contiguous();
371
372 let index_strides = index.shape().stride_contiguous();
374
375 let n = index_data.len();
376 for flat_idx in 0..n {
377 let mut coords = vec![0usize; rank];
379 let mut remainder = flat_idx;
380 for d in 0..rank {
381 coords[d] = remainder / index_strides[d];
382 remainder %= index_strides[d];
383 }
384
385 let idx_val = index_data[flat_idx] as usize;
387 coords[dim] = idx_val;
388
389 let mut input_flat = 0;
391 for d in 0..rank {
392 input_flat += coords[d] * input_strides[d];
393 }
394
395 grad_input_data[input_flat] += grad_data[flat_idx];
397 }
398
399 let grad_input = Tensor::<B>::from_f64_slice(
400 &grad_input_data,
401 input.shape().clone(),
402 input.dtype(),
403 input.device(),
404 )?;
405 grads.accumulate(input.id(), grad_input)?;
406 }
408
409 Op::Pad { input, padding } => {
410 let mut grad = grad_output.clone();
412 let input_dims = input.dims();
413 for d in 0..input_dims.len() {
414 let [before, _after] = padding[d];
415 if before > 0 || _after > 0 {
416 grad = grad.narrow(d, before, input_dims[d])?;
417 }
418 }
419 grads.accumulate(input.id(), grad)?;
420 }
421
422 Op::AvgPool2d {
423 input,
424 kernel_size,
425 stride,
426 padding,
427 } => {
428 compute_avgpool2d_grad(
429 &grad_output,
430 input,
431 *kernel_size,
432 *stride,
433 *padding,
434 &mut grads,
435 )?;
436 }
437
438 Op::Conv1d {
439 input,
440 weight,
441 bias,
442 stride,
443 padding,
444 } => {
445 compute_conv1d_grad(
446 &grad_output,
447 input,
448 weight,
449 bias.as_ref(),
450 *stride,
451 *padding,
452 &mut grads,
453 )?;
454 }
455
456 Op::IndexSelect {
457 input,
458 indices,
459 dim,
460 } => {
461 let dim = *dim;
466 let input_dims = input.dims();
467 let rank = input_dims.len();
468
469 let grad_data = grad_output.to_f64_vec()?;
470 let index_data = indices.to_f64_vec()?; let _num_indices = index_data.len();
472
473 let mut grad_input_data = vec![0.0f64; input.elem_count()];
474 let input_strides = input.shape().stride_contiguous();
475 let _output_dims = grad_output.dims();
476 let output_strides = grad_output.shape().stride_contiguous();
477
478 let total = grad_data.len();
479 for flat_idx in 0..total {
480 let mut coords = vec![0usize; rank];
482 let mut remainder = flat_idx;
483 for d in 0..rank {
484 coords[d] = remainder / output_strides[d];
485 remainder %= output_strides[d];
486 }
487
488 let out_dim_coord = coords[dim];
490 let src_idx = index_data[out_dim_coord] as usize;
491 coords[dim] = src_idx;
492
493 let mut input_flat = 0;
495 for d in 0..rank {
496 input_flat += coords[d] * input_strides[d];
497 }
498
499 grad_input_data[input_flat] += grad_data[flat_idx];
501 }
502
503 let grad_input = Tensor::<B>::from_f64_slice(
504 &grad_input_data,
505 input.shape().clone(),
506 input.dtype(),
507 input.device(),
508 )?;
509 grads.accumulate(input.id(), grad_input)?;
510 }
512
513 Op::ToDtype { input, src_dtype } => {
514 let grad_in = grad_output.to_dtype(*src_dtype)?;
516 grads.accumulate(input.id(), grad_in)?;
517 }
518 }
519 }
520
521 Ok(grads)
522}
523
524fn compute_binary_grad<B: Backend>(
527 op: BinaryOp,
528 grad_output: &Tensor<B>,
529 lhs: &Tensor<B>,
530 rhs: &Tensor<B>,
531 grads: &mut GradStore<B>,
532) -> Result<()> {
533 match op {
534 BinaryOp::Add => {
535 let grad_lhs = reduce_broadcast_grad(grad_output, lhs.shape())?;
537 let grad_rhs = reduce_broadcast_grad(grad_output, rhs.shape())?;
538 grads.accumulate(lhs.id(), grad_lhs)?;
539 grads.accumulate(rhs.id(), grad_rhs)?;
540 }
541 BinaryOp::Sub => {
542 let grad_lhs = reduce_broadcast_grad(grad_output, lhs.shape())?;
544 let neg = grad_output.neg()?;
545 let grad_rhs = reduce_broadcast_grad(&neg, rhs.shape())?;
546 grads.accumulate(lhs.id(), grad_lhs)?;
547 grads.accumulate(rhs.id(), grad_rhs)?;
548 }
549 BinaryOp::Mul => {
550 let raw_lhs = grad_output.mul(rhs)?;
552 let raw_rhs = grad_output.mul(lhs)?;
553 grads.accumulate(lhs.id(), reduce_broadcast_grad(&raw_lhs, lhs.shape())?)?;
554 grads.accumulate(rhs.id(), reduce_broadcast_grad(&raw_rhs, rhs.shape())?)?;
555 }
556 BinaryOp::Div => {
557 let raw_lhs = grad_output.div(rhs)?;
560 grads.accumulate(lhs.id(), reduce_broadcast_grad(&raw_lhs, lhs.shape())?)?;
561 let neg_grad = grad_output.neg()?;
562 let b_sq = rhs.mul(rhs)?;
563 let raw_rhs = neg_grad.mul(lhs)?.div(&b_sq)?;
564 grads.accumulate(rhs.id(), reduce_broadcast_grad(&raw_rhs, rhs.shape())?)?;
565 }
566 }
567 Ok(())
568}
569
570fn reduce_broadcast_grad<B: Backend>(
579 grad: &Tensor<B>,
580 target_shape: &crate::Shape,
581) -> Result<Tensor<B>> {
582 let grad_shape = grad.dims();
583 let target_dims = target_shape.dims();
584
585 if grad_shape == target_dims {
587 return Ok(grad.clone());
588 }
589
590 let grad_rank = grad_shape.len();
592 let target_rank = target_dims.len();
593 let mut padded_target = vec![1usize; grad_rank];
594 let offset = grad_rank - target_rank;
595 padded_target[offset..offset + target_rank].copy_from_slice(target_dims);
596
597 let mut result = grad.clone();
599 let mut dims_to_sum: Vec<usize> = Vec::new();
601 for d in 0..grad_rank {
602 if padded_target[d] == 1 && grad_shape[d] > 1 {
603 dims_to_sum.push(d);
604 }
605 }
606
607 for &d in dims_to_sum.iter().rev() {
610 result = result.sum(d, true)?;
611 }
612
613 result = result.reshape(target_shape.clone())?;
615
616 Ok(result)
617}
618
619fn compute_unary_grad<B: Backend>(
622 op: UnaryOp,
623 grad_output: &Tensor<B>,
624 input: &Tensor<B>,
625 grads: &mut GradStore<B>,
626) -> Result<()> {
627 let grad_input = match op {
628 UnaryOp::Neg => grad_output.neg()?,
630
631 UnaryOp::Abs => {
633 let input_data = input.to_f64_vec()?;
634 let sign_data: Vec<f64> = input_data
635 .iter()
636 .map(|&v| {
637 if v > 0.0 {
638 1.0
639 } else if v < 0.0 {
640 -1.0
641 } else {
642 0.0
643 }
644 })
645 .collect();
646 let sign = Tensor::<B>::from_f64_slice(
647 &sign_data,
648 input.shape().clone(),
649 input.dtype(),
650 input.device(),
651 )?;
652 grad_output.mul(&sign)?
653 }
654
655 UnaryOp::Exp => {
657 let exp_x = input.exp()?;
658 grad_output.mul(&exp_x)?
659 }
660
661 UnaryOp::Log => grad_output.div(input)?,
663
664 UnaryOp::Sqrt => {
666 let sqrt_x = input.sqrt()?;
667 let two_sqrt = sqrt_x.affine(2.0, 0.0)?;
668 grad_output.div(&two_sqrt)?
669 }
670
671 UnaryOp::Square => {
673 let two_x = input.affine(2.0, 0.0)?;
674 grad_output.mul(&two_x)?
675 }
676
677 UnaryOp::Relu => {
679 let input_data = input.to_f64_vec()?;
680 let mask_data: Vec<f64> = input_data
681 .iter()
682 .map(|&v| if v > 0.0 { 1.0 } else { 0.0 })
683 .collect();
684 let mask = Tensor::<B>::from_f64_slice(
685 &mask_data,
686 input.shape().clone(),
687 input.dtype(),
688 input.device(),
689 )?;
690 grad_output.mul(&mask)?
691 }
692
693 UnaryOp::Sigmoid => {
695 let sig = input.sigmoid()?;
696 let one = Tensor::<B>::ones(input.shape().clone(), input.dtype(), input.device())?;
697 let one_minus_sig = one.sub(&sig)?;
698 let dsig = sig.mul(&one_minus_sig)?;
699 grad_output.mul(&dsig)?
700 }
701
702 UnaryOp::Tanh => {
704 let tanh_x = input.tanh()?;
705 let tanh_sq = tanh_x.mul(&tanh_x)?;
706 let one = Tensor::<B>::ones(input.shape().clone(), input.dtype(), input.device())?;
707 let dtanh = one.sub(&tanh_sq)?;
708 grad_output.mul(&dtanh)?
709 }
710
711 UnaryOp::Gelu => {
714 let input_data = input.to_f64_vec()?;
715 let deriv_data: Vec<f64> = input_data
716 .iter()
717 .map(|&x| {
718 let sqrt_2_over_pi = std::f64::consts::FRAC_2_PI.sqrt();
719 let c = 0.044715_f64;
720 let s = sqrt_2_over_pi * (x + c * x * x * x);
721 let tanh_s = s.tanh();
722 let sech2_s = 1.0 - tanh_s * tanh_s;
723 let ds_dx = sqrt_2_over_pi * (1.0 + 3.0 * c * x * x);
724 0.5 * (1.0 + tanh_s) + 0.5 * x * sech2_s * ds_dx
725 })
726 .collect();
727 let deriv = Tensor::<B>::from_f64_slice(
728 &deriv_data,
729 input.shape().clone(),
730 input.dtype(),
731 input.device(),
732 )?;
733 grad_output.mul(&deriv)?
734 }
735
736 UnaryOp::Silu => {
738 let sig = input.sigmoid()?;
739 let one = Tensor::<B>::ones(input.shape().clone(), input.dtype(), input.device())?;
740 let one_minus_sig = one.sub(&sig)?;
741 let x_oms = input.mul(&one_minus_sig)?;
742 let one2 = Tensor::<B>::ones(input.shape().clone(), input.dtype(), input.device())?;
743 let bracket = one2.add(&x_oms)?;
744 let dsilu = sig.mul(&bracket)?;
745 grad_output.mul(&dsilu)?
746 }
747
748 UnaryOp::Sin => {
750 let cos_x = input.cos()?;
751 grad_output.mul(&cos_x)?
752 }
753
754 UnaryOp::Cos => {
756 let sin_x = input.sin()?;
757 let neg_sin = sin_x.neg()?;
758 grad_output.mul(&neg_sin)?
759 }
760
761 UnaryOp::Floor | UnaryOp::Ceil | UnaryOp::Round => {
764 Tensor::<B>::zeros(input.shape().clone(), input.dtype(), input.device())?
765 }
766 };
767
768 grads.accumulate(input.id(), grad_input)?;
769 Ok(())
770}
771
772#[allow(clippy::needless_range_loop)]
775fn compute_reduce_grad<B: Backend>(
776 op: ReduceOp,
777 grad_output: &Tensor<B>,
778 input: &Tensor<B>,
779 dims: &[usize],
780 _keep_dim: bool,
781 grads: &mut GradStore<B>,
782) -> Result<()> {
783 match op {
784 ReduceOp::Sum => {
785 if dims.is_empty() {
786 let grad_val = grad_output.to_scalar_f64()?;
788 let grad = Tensor::<B>::full(
789 input.shape().clone(),
790 grad_val,
791 input.dtype(),
792 input.device(),
793 )?;
794 grads.accumulate(input.id(), grad)?;
795 } else {
796 let grad = expand_grad_for_reduce(grad_output, input, dims)?;
798 grads.accumulate(input.id(), grad)?;
799 }
800 }
801 ReduceOp::Mean => {
802 if dims.is_empty() {
803 let n = input.elem_count() as f64;
805 let grad_val = grad_output.to_scalar_f64()? / n;
806 let grad = Tensor::<B>::full(
807 input.shape().clone(),
808 grad_val,
809 input.dtype(),
810 input.device(),
811 )?;
812 grads.accumulate(input.id(), grad)?;
813 } else {
814 let n: f64 = dims.iter().map(|&d| input.dims()[d] as f64).product();
816 let grad = expand_grad_for_reduce(grad_output, input, dims)?;
817 let grad = grad.affine(1.0 / n, 0.0)?;
818 grads.accumulate(input.id(), grad)?;
819 }
820 }
821 ReduceOp::Max | ReduceOp::Min => {
822 if dims.is_empty() {
829 let grad_val = grad_output.to_scalar_f64()?;
831 let input_data = input.to_f64_vec()?;
832 let extremum = if op == ReduceOp::Max {
833 input_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
834 } else {
835 input_data.iter().cloned().fold(f64::INFINITY, f64::min)
836 };
837 let count = input_data.iter().filter(|&&v| v == extremum).count() as f64;
838 let mask: Vec<f64> = input_data
839 .iter()
840 .map(|&v| if v == extremum { grad_val / count } else { 0.0 })
841 .collect();
842 let grad = Tensor::<B>::from_f64_slice(
843 &mask,
844 input.shape().clone(),
845 input.dtype(),
846 input.device(),
847 )?;
848 grads.accumulate(input.id(), grad)?;
849 } else {
850 let input_data = input.to_f64_vec()?;
852 let input_dims = input.dims();
853 let input_shape = input.shape().clone();
854 let total = input_shape.elem_count();
855 let input_strides = input_shape.stride_contiguous();
856
857 let grad_expanded = expand_grad_for_reduce(grad_output, input, dims)?;
859 let grad_exp_data = grad_expanded.to_f64_vec()?;
860
861 let reduced_dims: Vec<usize> = input_dims
864 .iter()
865 .enumerate()
866 .filter(|(i, _)| !dims.contains(i))
867 .map(|(_, &d)| d)
868 .collect();
869 let reduced_shape = if reduced_dims.is_empty() {
870 Shape::from(())
871 } else {
872 Shape::new(reduced_dims.clone())
873 };
874 let reduced_total = reduced_shape.elem_count();
875
876 let mut extrema = if op == ReduceOp::Max {
878 vec![f64::NEG_INFINITY; reduced_total]
879 } else {
880 vec![f64::INFINITY; reduced_total]
881 };
882
883 for flat_idx in 0..total {
884 let mut md = vec![0usize; input_dims.len()];
885 let mut remainder = flat_idx;
886 for i in 0..input_dims.len() {
887 if input_strides[i] > 0 {
888 md[i] = remainder / input_strides[i];
889 remainder %= input_strides[i];
890 }
891 }
892 let out_md: Vec<usize> = md
893 .iter()
894 .enumerate()
895 .filter(|(i, _)| !dims.contains(i))
896 .map(|(_, &v)| v)
897 .collect();
898 let out_strides = reduced_shape.stride_contiguous();
899 let mut out_flat = 0;
900 for i in 0..out_md.len() {
901 if i < out_strides.len() {
902 out_flat += out_md[i] * out_strides[i];
903 }
904 }
905 let val = input_data[flat_idx];
906 if op == ReduceOp::Max {
907 if val > extrema[out_flat] {
908 extrema[out_flat] = val;
909 }
910 } else if val < extrema[out_flat] {
911 extrema[out_flat] = val;
912 }
913 }
914
915 let mut counts = vec![0.0f64; reduced_total];
917 for flat_idx in 0..total {
918 let mut md = vec![0usize; input_dims.len()];
919 let mut remainder = flat_idx;
920 for i in 0..input_dims.len() {
921 if input_strides[i] > 0 {
922 md[i] = remainder / input_strides[i];
923 remainder %= input_strides[i];
924 }
925 }
926 let out_md: Vec<usize> = md
927 .iter()
928 .enumerate()
929 .filter(|(i, _)| !dims.contains(i))
930 .map(|(_, &v)| v)
931 .collect();
932 let out_strides = reduced_shape.stride_contiguous();
933 let mut out_flat = 0;
934 for i in 0..out_md.len() {
935 if i < out_strides.len() {
936 out_flat += out_md[i] * out_strides[i];
937 }
938 }
939 if input_data[flat_idx] == extrema[out_flat] {
940 counts[out_flat] += 1.0;
941 }
942 }
943
944 let mut mask = vec![0.0f64; total];
946 for flat_idx in 0..total {
947 let mut md = vec![0usize; input_dims.len()];
948 let mut remainder = flat_idx;
949 for i in 0..input_dims.len() {
950 if input_strides[i] > 0 {
951 md[i] = remainder / input_strides[i];
952 remainder %= input_strides[i];
953 }
954 }
955 let out_md: Vec<usize> = md
956 .iter()
957 .enumerate()
958 .filter(|(i, _)| !dims.contains(i))
959 .map(|(_, &v)| v)
960 .collect();
961 let out_strides = reduced_shape.stride_contiguous();
962 let mut out_flat = 0;
963 for i in 0..out_md.len() {
964 if i < out_strides.len() {
965 out_flat += out_md[i] * out_strides[i];
966 }
967 }
968 if input_data[flat_idx] == extrema[out_flat] {
969 mask[flat_idx] = grad_exp_data[flat_idx] / counts[out_flat];
970 }
971 }
972
973 let grad =
974 Tensor::<B>::from_f64_slice(&mask, input_shape, input.dtype(), input.device())?;
975 grads.accumulate(input.id(), grad)?;
976 }
977 }
978 ReduceOp::ArgMax | ReduceOp::ArgMin => {
979 }
982 }
983 Ok(())
984}
985
986#[allow(clippy::needless_range_loop)]
994fn expand_grad_for_reduce<B: Backend>(
995 grad: &Tensor<B>,
996 input: &Tensor<B>,
997 dims: &[usize],
998) -> Result<Tensor<B>> {
999 let input_dims = input.dims();
1000 let input_shape = input.shape().clone();
1001 let grad_data = grad.to_f64_vec()?;
1002 let total = input_shape.elem_count();
1003 let input_strides = input_shape.stride_contiguous();
1004
1005 let grad_dims: Vec<usize> = input_dims
1007 .iter()
1008 .enumerate()
1009 .filter(|(i, _)| !dims.contains(i))
1010 .map(|(_, &d)| d)
1011 .collect();
1012 let grad_shape = if grad_dims.is_empty() {
1013 Shape::from(())
1014 } else {
1015 Shape::new(grad_dims)
1016 };
1017 let grad_strides = grad_shape.stride_contiguous();
1018
1019 let mut result_data = vec![0.0f64; total];
1020
1021 for flat_idx in 0..total {
1022 let mut md = vec![0usize; input_dims.len()];
1024 let mut remainder = flat_idx;
1025 for i in 0..input_dims.len() {
1026 if input_strides[i] > 0 {
1027 md[i] = remainder / input_strides[i];
1028 remainder %= input_strides[i];
1029 }
1030 }
1031
1032 let grad_md: Vec<usize> = md
1034 .iter()
1035 .enumerate()
1036 .filter(|(i, _)| !dims.contains(i))
1037 .map(|(_, &v)| v)
1038 .collect();
1039
1040 let mut grad_flat = 0;
1042 for i in 0..grad_md.len() {
1043 if i < grad_strides.len() {
1044 grad_flat += grad_md[i] * grad_strides[i];
1045 }
1046 }
1047
1048 if grad_flat < grad_data.len() {
1049 result_data[flat_idx] = grad_data[grad_flat];
1050 }
1051 }
1052
1053 Tensor::<B>::from_f64_slice(&result_data, input_shape, input.dtype(), input.device())
1054}
1055
1056fn compute_matmul_grad<B: Backend>(
1062 grad_output: &Tensor<B>,
1063 lhs: &Tensor<B>,
1064 rhs: &Tensor<B>,
1065 grads: &mut GradStore<B>,
1066) -> Result<()> {
1067 let rhs_rank = rhs.rank();
1070 let lhs_rank = lhs.rank();
1071
1072 let rhs_t = rhs.transpose(rhs_rank - 2, rhs_rank - 1)?.contiguous()?;
1074 let grad_lhs = grad_output.matmul(&rhs_t)?;
1075 grads.accumulate(lhs.id(), grad_lhs)?;
1076
1077 let lhs_t = lhs.transpose(lhs_rank - 2, lhs_rank - 1)?.contiguous()?;
1079 let grad_rhs = lhs_t.matmul(grad_output)?;
1080 grads.accumulate(rhs.id(), grad_rhs)?;
1081
1082 Ok(())
1083}
1084
1085#[allow(clippy::needless_range_loop)]
1094fn compute_narrow_grad<B: Backend>(
1095 grad_output: &Tensor<B>,
1096 input: &Tensor<B>,
1097 dim: usize,
1098 start: usize,
1099 _len: usize,
1100 grads: &mut GradStore<B>,
1101) -> Result<()> {
1102 let input_shape = input.shape().clone();
1103 let grad_data = grad_output.to_f64_vec()?;
1104 let total = input_shape.elem_count();
1105 let input_strides = input_shape.stride_contiguous();
1106
1107 let grad_out_dims = grad_output.dims();
1108 let grad_strides = Shape::new(grad_out_dims.to_vec()).stride_contiguous();
1109 let grad_total = grad_output.elem_count();
1110
1111 let mut result_data = vec![0.0f64; total];
1112
1113 for grad_flat in 0..grad_total {
1114 let mut md = vec![0usize; grad_out_dims.len()];
1116 let mut remainder = grad_flat;
1117 for i in 0..grad_out_dims.len() {
1118 if grad_strides[i] > 0 {
1119 md[i] = remainder / grad_strides[i];
1120 remainder %= grad_strides[i];
1121 }
1122 }
1123
1124 md[dim] += start;
1126
1127 let mut input_flat = 0;
1129 for i in 0..md.len() {
1130 input_flat += md[i] * input_strides[i];
1131 }
1132
1133 if input_flat < total {
1134 result_data[input_flat] = grad_data[grad_flat];
1135 }
1136 }
1137
1138 let grad =
1139 Tensor::<B>::from_f64_slice(&result_data, input_shape, input.dtype(), input.device())?;
1140 grads.accumulate(input.id(), grad)?;
1141 Ok(())
1142}
1143
1144#[allow(clippy::needless_range_loop)]
1153fn compute_conv2d_grad<B: Backend>(
1154 grad_output: &Tensor<B>,
1155 input: &Tensor<B>,
1156 weight: &Tensor<B>,
1157 bias: Option<&Tensor<B>>,
1158 stride: [usize; 2],
1159 padding: [usize; 2],
1160 grads: &mut GradStore<B>,
1161) -> Result<()> {
1162 let in_dims = input.dims();
1163 let w_dims = weight.dims();
1164 let go_dims = grad_output.dims();
1165 let (n_batch, c_in, h, w) = (in_dims[0], in_dims[1], in_dims[2], in_dims[3]);
1166 let (c_out, _wc_in, kh, kw) = (w_dims[0], w_dims[1], w_dims[2], w_dims[3]);
1167 let h_out = go_dims[2];
1168 let w_out = go_dims[3];
1169 let [sh, sw] = stride;
1170 let [ph, pw] = padding;
1171
1172 let input_data = input.contiguous()?.to_f64_vec()?;
1173 let weight_data = weight.contiguous()?.to_f64_vec()?;
1174 let grad_out_data = grad_output.contiguous()?.to_f64_vec()?;
1175
1176 let col_rows = c_in * kh * kw;
1177 let col_cols = h_out * w_out;
1178 let sample_size = c_in * h * w;
1179
1180 let mut grad_w = vec![0.0f64; c_out * col_rows];
1185 let mut columns = vec![0.0f64; col_rows * col_cols];
1186
1187 for ni in 0..n_batch {
1188 let in_offset = ni * sample_size;
1190 crate::tensor::im2col(
1191 &input_data[in_offset..in_offset + sample_size],
1192 c_in,
1193 h,
1194 w,
1195 kh,
1196 kw,
1197 sh,
1198 sw,
1199 ph,
1200 pw,
1201 h_out,
1202 w_out,
1203 &mut columns,
1204 );
1205
1206 let go_offset = ni * c_out * col_cols;
1208 crate::tensor::gemm_a_bt(
1209 &grad_out_data[go_offset..go_offset + c_out * col_cols],
1210 &columns,
1211 &mut grad_w,
1212 c_out,
1213 col_rows,
1214 col_cols,
1215 );
1216 }
1217
1218 let grad_weight_t = Tensor::<B>::from_f64_slice(
1219 &grad_w,
1220 weight.shape().clone(),
1221 weight.dtype(),
1222 weight.device(),
1223 )?;
1224 grads.accumulate(weight.id(), grad_weight_t)?;
1225
1226 let mut grad_in = vec![0.0f64; n_batch * sample_size];
1231
1232 for ni in 0..n_batch {
1233 for v in columns.iter_mut() {
1235 *v = 0.0;
1236 }
1237
1238 let go_offset = ni * c_out * col_cols;
1240 crate::tensor::gemm_at_b(
1241 &weight_data,
1242 &grad_out_data[go_offset..go_offset + c_out * col_cols],
1243 &mut columns,
1244 col_rows,
1245 col_cols,
1246 c_out,
1247 );
1248
1249 let in_offset = ni * sample_size;
1251 crate::tensor::col2im(
1252 &columns,
1253 c_in,
1254 h,
1255 w,
1256 kh,
1257 kw,
1258 sh,
1259 sw,
1260 ph,
1261 pw,
1262 h_out,
1263 w_out,
1264 &mut grad_in[in_offset..in_offset + sample_size],
1265 );
1266 }
1267
1268 let grad_input_t = Tensor::<B>::from_f64_slice(
1269 &grad_in,
1270 input.shape().clone(),
1271 input.dtype(),
1272 input.device(),
1273 )?;
1274 grads.accumulate(input.id(), grad_input_t)?;
1275
1276 if let Some(b) = bias {
1278 let mut grad_b = vec![0.0f64; c_out];
1279 for ni in 0..n_batch {
1280 for co in 0..c_out {
1281 let go_offset = (ni * c_out + co) * col_cols;
1282 for j in 0..col_cols {
1283 grad_b[co] += grad_out_data[go_offset + j];
1284 }
1285 }
1286 }
1287 let grad_bias_t =
1288 Tensor::<B>::from_f64_slice(&grad_b, b.shape().clone(), b.dtype(), b.device())?;
1289 grads.accumulate(b.id(), grad_bias_t)?;
1290 }
1291
1292 Ok(())
1293}
1294
1295fn compute_maxpool2d_grad<B: Backend>(
1300 grad_output: &Tensor<B>,
1301 input: &Tensor<B>,
1302 indices: &[usize],
1303 grads: &mut GradStore<B>,
1304) -> Result<()> {
1305 let input_size = input.elem_count();
1306 let grad_out_data = grad_output.contiguous()?.to_f64_vec()?;
1307
1308 let mut grad_in = vec![0.0f64; input_size];
1309 for (out_idx, &in_idx) in indices.iter().enumerate() {
1310 if in_idx < input_size && out_idx < grad_out_data.len() {
1311 grad_in[in_idx] += grad_out_data[out_idx];
1312 }
1313 }
1314
1315 let grad_input_t = Tensor::<B>::from_f64_slice(
1316 &grad_in,
1317 input.shape().clone(),
1318 input.dtype(),
1319 input.device(),
1320 )?;
1321 grads.accumulate(input.id(), grad_input_t)?;
1322 Ok(())
1323}
1324
1325fn compute_avgpool2d_grad<B: Backend>(
1328 grad_output: &Tensor<B>,
1329 input: &Tensor<B>,
1330 kernel_size: [usize; 2],
1331 stride: [usize; 2],
1332 padding: [usize; 2],
1333 grads: &mut GradStore<B>,
1334) -> Result<()> {
1335 let in_dims = input.dims();
1336 let (n, c, h, w) = (in_dims[0], in_dims[1], in_dims[2], in_dims[3]);
1337 let [kh, kw] = kernel_size;
1338 let [sh, sw] = stride;
1339 let [ph, pw] = padding;
1340 let h_out = (h + 2 * ph - kh) / sh + 1;
1341 let w_out = (w + 2 * pw - kw) / sw + 1;
1342
1343 let grad_out_data = grad_output.contiguous()?.to_f64_vec()?;
1344 let mut grad_in = vec![0.0f64; input.elem_count()];
1345
1346 for ni in 0..n {
1347 for ci in 0..c {
1348 for oh in 0..h_out {
1349 for ow in 0..w_out {
1350 let out_idx = ((ni * c + ci) * h_out + oh) * w_out + ow;
1351 let mut count = 0usize;
1353 for ki in 0..kh {
1354 for kj in 0..kw {
1355 let ih = (oh * sh + ki) as isize - ph as isize;
1356 let iw = (ow * sw + kj) as isize - pw as isize;
1357 if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
1358 count += 1;
1359 }
1360 }
1361 }
1362 if count == 0 {
1363 continue;
1364 }
1365 let scale = 1.0 / count as f64;
1366 for ki in 0..kh {
1368 for kj in 0..kw {
1369 let ih = (oh * sh + ki) as isize - ph as isize;
1370 let iw = (ow * sw + kj) as isize - pw as isize;
1371 if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
1372 let in_idx = ((ni * c + ci) * h + ih as usize) * w + iw as usize;
1373 grad_in[in_idx] += grad_out_data[out_idx] * scale;
1374 }
1375 }
1376 }
1377 }
1378 }
1379 }
1380 }
1381
1382 let grad_input_t = Tensor::<B>::from_f64_slice(
1383 &grad_in,
1384 input.shape().clone(),
1385 input.dtype(),
1386 input.device(),
1387 )?;
1388 grads.accumulate(input.id(), grad_input_t)?;
1389 Ok(())
1390}
1391
1392#[allow(clippy::needless_range_loop)]
1395fn compute_conv1d_grad<B: Backend>(
1396 grad_output: &Tensor<B>,
1397 input: &Tensor<B>,
1398 weight: &Tensor<B>,
1399 bias: Option<&Tensor<B>>,
1400 stride: usize,
1401 padding: usize,
1402 grads: &mut GradStore<B>,
1403) -> Result<()> {
1404 let in_dims = input.dims();
1405 let w_dims = weight.dims();
1406 let (n, c_in, l) = (in_dims[0], in_dims[1], in_dims[2]);
1407 let (c_out, _, k) = (w_dims[0], w_dims[1], w_dims[2]);
1408 let l_out = (l + 2 * padding - k) / stride + 1;
1409
1410 let input_data = input.contiguous()?.to_f64_vec()?;
1411 let weight_data = weight.contiguous()?.to_f64_vec()?;
1412 let grad_out_data = grad_output.contiguous()?.to_f64_vec()?;
1413
1414 let col_rows = c_in * k;
1415 let col_cols = l_out;
1416 let sample_size = c_in * l;
1417 let mut columns = vec![0.0f64; col_rows * col_cols];
1418
1419 let mut grad_w = vec![0.0f64; c_out * col_rows];
1421 for ni in 0..n {
1422 let in_offset = ni * sample_size;
1423 crate::tensor::im2col(
1424 &input_data[in_offset..in_offset + sample_size],
1425 c_in,
1426 1,
1427 l,
1428 1,
1429 k,
1430 1,
1431 stride,
1432 0,
1433 padding,
1434 1,
1435 l_out,
1436 &mut columns,
1437 );
1438 let go_offset = ni * c_out * col_cols;
1439 crate::tensor::gemm_a_bt(
1440 &grad_out_data[go_offset..go_offset + c_out * col_cols],
1441 &columns,
1442 &mut grad_w,
1443 c_out,
1444 col_rows,
1445 col_cols,
1446 );
1447 }
1448
1449 let grad_weight_t = Tensor::<B>::from_f64_slice(
1450 &grad_w,
1451 weight.shape().clone(),
1452 weight.dtype(),
1453 weight.device(),
1454 )?;
1455 grads.accumulate(weight.id(), grad_weight_t)?;
1456
1457 let mut grad_in = vec![0.0f64; n * sample_size];
1459 for ni in 0..n {
1460 for v in columns.iter_mut() {
1461 *v = 0.0;
1462 }
1463 let go_offset = ni * c_out * col_cols;
1464 crate::tensor::gemm_at_b(
1465 &weight_data,
1466 &grad_out_data[go_offset..go_offset + c_out * col_cols],
1467 &mut columns,
1468 col_rows,
1469 col_cols,
1470 c_out,
1471 );
1472 let in_offset = ni * sample_size;
1473 crate::tensor::col2im(
1474 &columns,
1475 c_in,
1476 1,
1477 l,
1478 1,
1479 k,
1480 1,
1481 stride,
1482 0,
1483 padding,
1484 1,
1485 l_out,
1486 &mut grad_in[in_offset..in_offset + sample_size],
1487 );
1488 }
1489
1490 let grad_input_t = Tensor::<B>::from_f64_slice(
1491 &grad_in,
1492 input.shape().clone(),
1493 input.dtype(),
1494 input.device(),
1495 )?;
1496 grads.accumulate(input.id(), grad_input_t)?;
1497
1498 if let Some(b) = bias {
1500 let mut grad_b = vec![0.0f64; c_out];
1501 for ni in 0..n {
1502 for co in 0..c_out {
1503 let go_offset = (ni * c_out + co) * col_cols;
1504 for j in 0..col_cols {
1505 grad_b[co] += grad_out_data[go_offset + j];
1506 }
1507 }
1508 }
1509 let grad_bias_t =
1510 Tensor::<B>::from_f64_slice(&grad_b, b.shape().clone(), b.dtype(), b.device())?;
1511 grads.accumulate(b.id(), grad_bias_t)?;
1512 }
1513
1514 Ok(())
1515}
1516
1517#[cfg(test)]
1520mod tests {
1521 }
1524
1525use std::cell::RefCell;
1550
1551thread_local! {
1552 static CHECKPOINT_MODE: RefCell<bool> = const { RefCell::new(false) };
1553}
1554
1555pub fn is_checkpointing() -> bool {
1557 CHECKPOINT_MODE.with(|c| *c.borrow())
1558}
1559
1560pub fn checkpoint<B, F>(func: F, inputs: &[&Tensor<B>]) -> Result<Tensor<B>>
1587where
1588 B: Backend,
1589 F: Fn() -> Result<Tensor<B>> + 'static,
1590{
1591 let result = func()?;
1594
1595 let _saved_inputs: Vec<Tensor<B>> = inputs.iter().map(|t| (*t).clone()).collect();
1597
1598 let _output = Tensor::<B>::from_f64_slice(
1601 &result.to_f64_vec()?,
1602 result.shape().clone(),
1603 result.dtype(),
1604 result.device(),
1605 )?;
1606
1607 Ok(result)
1615}
1616
1617#[allow(clippy::needless_range_loop, clippy::type_complexity)]
1644pub fn checkpoint_sequential<B: Backend>(
1645 input: &Tensor<B>,
1646 layers: &[fn(&Tensor<B>) -> Result<Tensor<B>>],
1647 segments: usize,
1648) -> Result<Tensor<B>> {
1649 let n = layers.len();
1650 if n == 0 {
1651 return Ok(input.clone());
1652 }
1653 let seg_size = n.div_ceil(segments);
1654
1655 let mut current = input.clone();
1656
1657 for seg_start in (0..n).step_by(seg_size) {
1658 let seg_end = (seg_start + seg_size).min(n);
1659 let segment_input = current.detach().set_variable();
1660
1661 let mut h = segment_input.clone();
1664 for i in seg_start..seg_end {
1665 h = layers[i](&h)?;
1666 }
1667
1668 current = h;
1669 }
1670
1671 Ok(current)
1672}
1673
1674pub fn with_checkpoint_mode<F, T>(f: F) -> T
1679where
1680 F: FnOnce() -> T,
1681{
1682 CHECKPOINT_MODE.with(|c| *c.borrow_mut() = true);
1683 let result = f();
1684 CHECKPOINT_MODE.with(|c| *c.borrow_mut() = false);
1685 result
1686}