rustorch_core/
broadcast.rs1use crate::autograd::BackwardOp;
2use crate::Tensor;
4use std::sync::Arc;
5
6#[derive(Debug)]
7pub struct ExpandBackward {
8 pub input: Tensor,
9 pub input_shape: Vec<usize>,
10}
11
12impl BackwardOp for ExpandBackward {
13 fn backward(&self, grad: &Tensor) {
14 if self.input.requires_grad() {
15 let grad_reduced = crate::ops::sum_to(grad, &self.input_shape);
17 self.input.accumulate_grad(&grad_reduced);
18 self.input.backward_step();
19 }
20 }
21}
22
23pub fn broadcast_shapes(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
24 let len1 = shape1.len();
25 let len2 = shape2.len();
26 let max_len = std::cmp::max(len1, len2);
27
28 let mut result_shape = vec![0; max_len];
29
30 for i in 0..max_len {
31 let dim1 = if i < len1 { shape1[len1 - 1 - i] } else { 1 };
32 let dim2 = if i < len2 { shape2[len2 - 1 - i] } else { 1 };
33
34 if dim1 == dim2 {
35 result_shape[max_len - 1 - i] = dim1;
36 } else if dim1 == 1 {
37 result_shape[max_len - 1 - i] = dim2;
38 } else if dim2 == 1 {
39 result_shape[max_len - 1 - i] = dim1;
40 } else {
41 return None;
42 }
43 }
44
45 Some(result_shape)
46}
47
48impl Tensor {
49 pub fn expand(&self, new_shape: &[usize]) -> Tensor {
51 if self.shape() == new_shape {
52 return self.clone();
53 }
54
55 let current_shape = self.shape();
56 let current_strides = self.strides();
57
58 let ndim_new = new_shape.len();
59 let ndim_old = current_shape.len();
60
61 if ndim_new < ndim_old {
62 panic!("expand: new shape must have >= dims than current shape");
63 }
64
65 let mut new_strides = vec![0; ndim_new];
66 let offset = ndim_new - ndim_old;
67
68 for i in 0..ndim_new {
69 if i < offset {
70 new_strides[i] = 0;
74 } else {
75 let old_idx = i - offset;
76 let old_dim = current_shape[old_idx];
77 let new_dim = new_shape[i];
78
79 if old_dim == 1 && new_dim > 1 {
80 new_strides[i] = 0;
82 } else if old_dim == new_dim {
83 new_strides[i] = current_strides[old_idx];
85 } else {
86 panic!(
87 "expand: invalid shape {:?} -> {:?}",
88 current_shape, new_shape
89 );
90 }
91 }
92 }
93
94 use crate::tensor::TensorImpl;
113 use std::sync::Mutex;
114
115 let inner = TensorImpl {
116 storage: self.storage().clone(),
117 shape: new_shape.to_vec(),
118 strides: new_strides,
119 grad: Mutex::new(None),
120 requires_grad: self.requires_grad(),
121 op: if self.requires_grad() {
122 Some(Arc::new(ExpandBackward {
123 input: self.clone(),
124 input_shape: self.shape().to_vec(),
125 }))
126 } else {
127 None
128 },
129 is_leaf: false, };
132
133 Tensor {
134 inner: Arc::new(inner),
135 }
136 }
137}