rustorch_core/ops/
view.rs1use crate::autograd::BackwardOp;
2use crate::storage::Storage;
3use crate::tensor::TensorImpl;
4use crate::Tensor;
5use std::sync::{Arc, Mutex};
6
7#[derive(Debug)]
8pub struct ReshapeBackward {
9 pub input: Tensor,
10 pub input_shape: Vec<usize>,
11}
12
13impl BackwardOp for ReshapeBackward {
14 fn backward(&self, grad: &Tensor) {
15 if self.input.requires_grad() {
16 let grad_reshaped = grad.reshape(&self.input_shape);
18 self.input.accumulate_grad(&grad_reshaped);
19 self.input.backward_step();
20 }
21 }
22}
23
24#[derive(Debug)]
27pub struct PermuteBackward {
28 pub input: Tensor,
29 pub dims: Vec<usize>, }
31
32impl BackwardOp for PermuteBackward {
33 fn backward(&self, grad: &Tensor) {
34 if self.input.requires_grad() {
35 let ndim = self.dims.len();
36 let mut inverse_dims = vec![0; ndim];
37 for (i, &d) in self.dims.iter().enumerate() {
38 inverse_dims[d] = i;
39 }
40
41 let grad_permuted = grad.permute(&inverse_dims);
42 self.input.accumulate_grad(&grad_permuted);
43 self.input.backward_step();
44 }
45 }
46}
47
48pub fn permute(input: &Tensor, dims: &[usize]) -> Tensor {
49 let ndim = input.shape().len();
50 if dims.len() != ndim {
51 panic!(
52 "Permute dims length {} does not match tensor ndim {}",
53 dims.len(),
54 ndim
55 );
56 }
57
58 let mut seen = vec![false; ndim];
60 for &d in dims {
61 if d >= ndim || seen[d] {
62 panic!("Invalid permutation {:?}", dims);
63 }
64 seen[d] = true;
65 }
66
67 let old_shape = input.shape();
68 let old_strides = input.strides();
69
70 let mut new_shape = vec![0; ndim];
71 let mut new_strides = vec![0; ndim];
72
73 for (i, &d) in dims.iter().enumerate() {
74 new_shape[i] = old_shape[d];
75 new_strides[i] = old_strides[d];
76 }
77
78 let inner = &input.inner;
83
84 let mut tensor = Tensor {
85 inner: Arc::new(TensorImpl {
86 storage: inner.storage.clone(),
87 shape: new_shape,
88 strides: new_strides,
89 grad: Mutex::new(None),
90 requires_grad: inner.requires_grad,
91 op: None,
92 is_leaf: false,
93 }),
94 };
95
96 if input.requires_grad() {
97 tensor.set_op(Arc::new(PermuteBackward {
98 input: input.clone(),
99 dims: dims.to_vec(),
100 }));
101 }
102
103 tensor
104}
105
106pub fn transpose(input: &Tensor, dim0: usize, dim1: usize) -> Tensor {
107 let ndim = input.shape().len();
108 let mut dims: Vec<usize> = (0..ndim).collect();
109 dims.swap(dim0, dim1);
110 permute(input, &dims)
111}
112
113pub fn contiguous(input: &Tensor) -> Tensor {
114 if input.is_contiguous() {
115 return input.clone();
116 }
117
118 #[cfg(feature = "wgpu_backend")]
119 {
120 if input.storage().device().is_wgpu() {
121 if let Some(input_buf) = input.storage().wgpu_buffer() {
122 let output_buf = crate::backend::wgpu::contiguous_wgpu(
123 input_buf,
124 input.shape(),
125 input.strides(),
126 );
127
128 let size: usize = input.shape().iter().product();
129 let storage = Storage::new_wgpu(output_buf, size, 0);
130 let mut tensor = Tensor::new_with_storage(storage, input.shape());
131 if input.requires_grad() {
132 tensor.set_requires_grad_mut(true);
133 tensor.set_op(Arc::new(ContiguousBackward {
134 input: input.clone(),
135 }));
136 }
137 return tensor;
138 }
139 }
140 }
141
142 let shape = input.shape();
143 let size: usize = shape.iter().product();
144 let mut data = vec![0.0; size];
145
146 let input_guard = input.data();
147 let input_storage = &*input_guard;
148 let strides = input.strides();
149 let storage_len = input_storage.len();
150
151 for (i, val) in data.iter_mut().enumerate().take(size) {
152 let mut physical_offset = 0;
153 let mut temp_i = i;
154 for dim_idx in (0..shape.len()).rev() {
155 let dim_size = shape[dim_idx];
156 let coord = temp_i % dim_size;
157 temp_i /= dim_size;
158 physical_offset += coord * strides[dim_idx];
159 }
160 if storage_len == 1 {
161 *val = input_storage[0];
162 } else if physical_offset < storage_len {
163 *val = input_storage[physical_offset];
164 } else {
165 *val = 0.0;
166 }
167 }
168
169 let storage = Storage::new(data);
170 let mut tensor = Tensor::new_with_storage(storage, shape);
171 if input.requires_grad() {
172 tensor.set_requires_grad_mut(true);
173 tensor.set_op(Arc::new(ContiguousBackward {
174 input: input.clone(),
175 }));
176 }
177 tensor
178}
179
180#[derive(Debug)]
181pub struct ContiguousBackward {
182 pub input: Tensor,
183}
184
185impl BackwardOp for ContiguousBackward {
186 fn backward(&self, grad: &Tensor) {
187 if self.input.requires_grad() {
188 let grad_contig = if grad.is_contiguous() {
189 grad.clone()
190 } else {
191 grad.contiguous()
192 };
193 let grad_view = grad_contig.reshape(self.input.shape());
194
195 let mut grad_input = if self.input.is_contiguous() {
196 grad_view
197 } else {
198 let mut data = vec![0.0; self.input.shape().iter().product()];
199 let strides = self.input.strides();
200 let shape = self.input.shape();
201
202 let grad_guard = grad_view.data();
203 let grad_data = &*grad_guard;
204
205 for (i, &g) in grad_data.iter().enumerate() {
206 let mut physical_offset = 0;
207 let mut temp_i = i;
208 for dim_idx in (0..shape.len()).rev() {
209 let dim_size = shape[dim_idx];
210 let coord = temp_i % dim_size;
211 temp_i /= dim_size;
212 physical_offset += coord * strides[dim_idx];
213 }
214 data[physical_offset] = g;
215 }
216
217 Tensor::new_with_storage(Storage::new(data), shape)
218 };
219
220 grad_input.set_requires_grad_mut(true);
221 self.input.accumulate_grad(&grad_input);
222 self.input.backward_step();
223 }
224 }
225}
226
227pub fn sum_to(input: &Tensor, shape: &[usize]) -> Tensor {
228 if input.shape() == shape {
229 return input.clone();
230 }
231
232 #[cfg(feature = "wgpu_backend")]
233 {
234 if input.storage().device().is_wgpu() {
235 let input_shape = input.shape();
236 let input_ndim = input_shape.len();
237 let output_ndim = shape.len();
238
239 if output_ndim == 0 || (output_ndim == 1 && shape[0] == 1) {
240 let input_contig = if input.is_contiguous() {
241 input.clone()
242 } else {
243 input.contiguous()
244 };
245 if let Some(input_buf) = input_contig.storage().wgpu_buffer() {
246 let total_size: usize = input_contig.shape().iter().product();
247 let output_buf =
248 crate::backend::wgpu::reduce_sum_all_wgpu(input_buf, total_size);
249 let storage = Storage::new_wgpu(output_buf, 1, 0);
250 return Tensor::new_with_storage(storage, shape);
251 }
252 }
253
254 if input_ndim == 2 && output_ndim == 1 && input_shape[1] == shape[0] {
255 let input_contig = if input.is_contiguous() {
256 input.clone()
257 } else {
258 input.contiguous()
259 };
260 if let Some(input_buf) = input_contig.storage().wgpu_buffer() {
261 let output_buf =
262 crate::backend::wgpu::reduce_sum_dim0_wgpu(input_buf, input_contig.shape());
263 let size: usize = shape.iter().product();
264 let storage = Storage::new_wgpu(output_buf, size, 0);
265 return Tensor::new_with_storage(storage, shape);
266 }
267 }
268
269 if input_ndim == 2 && output_ndim == 1 && input_shape[0] == shape[0] {
270 let input_contig = if input.is_contiguous() {
271 input.clone()
272 } else {
273 input.contiguous()
274 };
275 if let Some(input_buf) = input_contig.storage().wgpu_buffer() {
276 let output_buf = crate::backend::wgpu::reduce_sum_dim_wgpu(
277 input_buf,
278 input_contig.shape(),
279 1,
280 );
281 let size: usize = shape.iter().product();
282 let storage = Storage::new_wgpu(output_buf, size, 0);
283 return Tensor::new_with_storage(storage, shape);
284 }
285 }
286 }
287 }
288
289 let input_contig = if input.is_contiguous() {
290 input.clone()
291 } else {
292 input.contiguous()
293 };
294 let input_shape = input_contig.shape();
295 let output_shape = shape;
296
297 if input_shape.len() == 2 && output_shape.len() == 1 {
298 if input_shape[1] == output_shape[0] {
299 let m = input_shape[0];
300 let n = input_shape[1];
301 let data = input_contig.data();
302 let mut result = vec![0.0; n];
303
304 for (j, out) in result.iter_mut().enumerate().take(n) {
305 let mut col = vec![0.0f32; m];
306 for i in 0..m {
307 col[i] = data[i * n + j];
308 }
309 *out = crate::ops::sum_auto(&col);
310 }
311
312 return Tensor::new_with_storage(Storage::new(result), output_shape);
313 }
314 }
315
316 if input_shape.len() == 1 && output_shape.len() == 1 {
317 if input_shape[0] == output_shape[0] {
318 return input_contig.clone();
319 }
320 }
321
322 if output_shape.iter().product::<usize>() == 1 {
323 let data = input_contig.data();
324 let sum = crate::ops::sum_auto(&data);
325 return Tensor::new_with_storage(Storage::new(vec![sum]), output_shape);
326 }
327
328 panic!(
329 "General sum_to not implemented for shape {:?} -> {:?}",
330 input_shape, output_shape
331 );
332}