Skip to main content

rustorch_core/ops/
view.rs

1use crate::autograd::BackwardOp;
2use crate::storage::Storage;
3use crate::tensor::TensorImpl;
4use crate::Tensor;
5use std::sync::{Arc, Mutex};
6
7#[derive(Debug)]
8pub struct ReshapeBackward {
9    pub input: Tensor,
10    pub input_shape: Vec<usize>,
11}
12
13impl BackwardOp for ReshapeBackward {
14    fn backward(&self, grad: &Tensor) {
15        if self.input.requires_grad() {
16            // Gradient should be reshaped back to input shape
17            let grad_reshaped = grad.reshape(&self.input_shape);
18            self.input.accumulate_grad(&grad_reshaped);
19            self.input.backward_step();
20        }
21    }
22}
23
24// --- Permute ---
25
26#[derive(Debug)]
27pub struct PermuteBackward {
28    pub input: Tensor,
29    pub dims: Vec<usize>, // Original permutation
30}
31
32impl BackwardOp for PermuteBackward {
33    fn backward(&self, grad: &Tensor) {
34        if self.input.requires_grad() {
35            let ndim = self.dims.len();
36            let mut inverse_dims = vec![0; ndim];
37            for (i, &d) in self.dims.iter().enumerate() {
38                inverse_dims[d] = i;
39            }
40
41            let grad_permuted = grad.permute(&inverse_dims);
42            self.input.accumulate_grad(&grad_permuted);
43            self.input.backward_step();
44        }
45    }
46}
47
48pub fn permute(input: &Tensor, dims: &[usize]) -> Tensor {
49    let ndim = input.shape().len();
50    if dims.len() != ndim {
51        panic!(
52            "Permute dims length {} does not match tensor ndim {}",
53            dims.len(),
54            ndim
55        );
56    }
57
58    // Check if dims are valid permutation
59    let mut seen = vec![false; ndim];
60    for &d in dims {
61        if d >= ndim || seen[d] {
62            panic!("Invalid permutation {:?}", dims);
63        }
64        seen[d] = true;
65    }
66
67    let old_shape = input.shape();
68    let old_strides = input.strides();
69
70    let mut new_shape = vec![0; ndim];
71    let mut new_strides = vec![0; ndim];
72
73    for (i, &d) in dims.iter().enumerate() {
74        new_shape[i] = old_shape[d];
75        new_strides[i] = old_strides[d];
76    }
77
78    // Create new tensor sharing storage
79    // Need access to internal fields. TensorImpl fields are pub(crate).
80    // View operations share storage.
81
82    let inner = &input.inner;
83
84    let mut tensor = Tensor {
85        inner: Arc::new(TensorImpl {
86            storage: inner.storage.clone(),
87            shape: new_shape,
88            strides: new_strides,
89            grad: Mutex::new(None),
90            requires_grad: inner.requires_grad,
91            op: None,
92            is_leaf: false,
93        }),
94    };
95
96    if input.requires_grad() {
97        tensor.set_op(Arc::new(PermuteBackward {
98            input: input.clone(),
99            dims: dims.to_vec(),
100        }));
101    }
102
103    tensor
104}
105
106pub fn transpose(input: &Tensor, dim0: usize, dim1: usize) -> Tensor {
107    let ndim = input.shape().len();
108    let mut dims: Vec<usize> = (0..ndim).collect();
109    dims.swap(dim0, dim1);
110    permute(input, &dims)
111}
112
113pub fn contiguous(input: &Tensor) -> Tensor {
114    if input.is_contiguous() {
115        return input.clone();
116    }
117
118    #[cfg(feature = "wgpu_backend")]
119    {
120        if input.storage().device().is_wgpu() {
121            if let Some(input_buf) = input.storage().wgpu_buffer() {
122                let output_buf = crate::backend::wgpu::contiguous_wgpu(
123                    input_buf,
124                    input.shape(),
125                    input.strides(),
126                );
127
128                let size: usize = input.shape().iter().product();
129                let storage = Storage::new_wgpu(output_buf, size, 0);
130                let mut tensor = Tensor::new_with_storage(storage, input.shape());
131                if input.requires_grad() {
132                    tensor.set_requires_grad_mut(true);
133                    tensor.set_op(Arc::new(ContiguousBackward {
134                        input: input.clone(),
135                    }));
136                }
137                return tensor;
138            }
139        }
140    }
141
142    let shape = input.shape();
143    let size: usize = shape.iter().product();
144    let mut data = vec![0.0; size];
145
146    let input_guard = input.data();
147    let input_storage = &*input_guard;
148    let strides = input.strides();
149    let storage_len = input_storage.len();
150
151    for (i, val) in data.iter_mut().enumerate().take(size) {
152        let mut physical_offset = 0;
153        let mut temp_i = i;
154        for dim_idx in (0..shape.len()).rev() {
155            let dim_size = shape[dim_idx];
156            let coord = temp_i % dim_size;
157            temp_i /= dim_size;
158            physical_offset += coord * strides[dim_idx];
159        }
160        if storage_len == 1 {
161            *val = input_storage[0];
162        } else if physical_offset < storage_len {
163            *val = input_storage[physical_offset];
164        } else {
165            *val = 0.0;
166        }
167    }
168
169    let storage = Storage::new(data);
170    let mut tensor = Tensor::new_with_storage(storage, shape);
171    if input.requires_grad() {
172        tensor.set_requires_grad_mut(true);
173        tensor.set_op(Arc::new(ContiguousBackward {
174            input: input.clone(),
175        }));
176    }
177    tensor
178}
179
180#[derive(Debug)]
181pub struct ContiguousBackward {
182    pub input: Tensor,
183}
184
185impl BackwardOp for ContiguousBackward {
186    fn backward(&self, grad: &Tensor) {
187        if self.input.requires_grad() {
188            let grad_contig = if grad.is_contiguous() {
189                grad.clone()
190            } else {
191                grad.contiguous()
192            };
193            let grad_view = grad_contig.reshape(self.input.shape());
194
195            let mut grad_input = if self.input.is_contiguous() {
196                grad_view
197            } else {
198                let mut data = vec![0.0; self.input.shape().iter().product()];
199                let strides = self.input.strides();
200                let shape = self.input.shape();
201
202                let grad_guard = grad_view.data();
203                let grad_data = &*grad_guard;
204
205                for (i, &g) in grad_data.iter().enumerate() {
206                    let mut physical_offset = 0;
207                    let mut temp_i = i;
208                    for dim_idx in (0..shape.len()).rev() {
209                        let dim_size = shape[dim_idx];
210                        let coord = temp_i % dim_size;
211                        temp_i /= dim_size;
212                        physical_offset += coord * strides[dim_idx];
213                    }
214                    data[physical_offset] = g;
215                }
216
217                Tensor::new_with_storage(Storage::new(data), shape)
218            };
219
220            grad_input.set_requires_grad_mut(true);
221            self.input.accumulate_grad(&grad_input);
222            self.input.backward_step();
223        }
224    }
225}
226
227pub fn sum_to(input: &Tensor, shape: &[usize]) -> Tensor {
228    if input.shape() == shape {
229        return input.clone();
230    }
231
232    #[cfg(feature = "wgpu_backend")]
233    {
234        if input.storage().device().is_wgpu() {
235            let input_shape = input.shape();
236            let input_ndim = input_shape.len();
237            let output_ndim = shape.len();
238
239            if output_ndim == 0 || (output_ndim == 1 && shape[0] == 1) {
240                let input_contig = if input.is_contiguous() {
241                    input.clone()
242                } else {
243                    input.contiguous()
244                };
245                if let Some(input_buf) = input_contig.storage().wgpu_buffer() {
246                    let total_size: usize = input_contig.shape().iter().product();
247                    let output_buf =
248                        crate::backend::wgpu::reduce_sum_all_wgpu(input_buf, total_size);
249                    let storage = Storage::new_wgpu(output_buf, 1, 0);
250                    return Tensor::new_with_storage(storage, shape);
251                }
252            }
253
254            if input_ndim == 2 && output_ndim == 1 && input_shape[1] == shape[0] {
255                let input_contig = if input.is_contiguous() {
256                    input.clone()
257                } else {
258                    input.contiguous()
259                };
260                if let Some(input_buf) = input_contig.storage().wgpu_buffer() {
261                    let output_buf =
262                        crate::backend::wgpu::reduce_sum_dim0_wgpu(input_buf, input_contig.shape());
263                    let size: usize = shape.iter().product();
264                    let storage = Storage::new_wgpu(output_buf, size, 0);
265                    return Tensor::new_with_storage(storage, shape);
266                }
267            }
268
269            if input_ndim == 2 && output_ndim == 1 && input_shape[0] == shape[0] {
270                let input_contig = if input.is_contiguous() {
271                    input.clone()
272                } else {
273                    input.contiguous()
274                };
275                if let Some(input_buf) = input_contig.storage().wgpu_buffer() {
276                    let output_buf = crate::backend::wgpu::reduce_sum_dim_wgpu(
277                        input_buf,
278                        input_contig.shape(),
279                        1,
280                    );
281                    let size: usize = shape.iter().product();
282                    let storage = Storage::new_wgpu(output_buf, size, 0);
283                    return Tensor::new_with_storage(storage, shape);
284                }
285            }
286        }
287    }
288
289    let input_contig = if input.is_contiguous() {
290        input.clone()
291    } else {
292        input.contiguous()
293    };
294    let input_shape = input_contig.shape();
295    let output_shape = shape;
296
297    if input_shape.len() == 2 && output_shape.len() == 1 {
298        if input_shape[1] == output_shape[0] {
299            let m = input_shape[0];
300            let n = input_shape[1];
301            let data = input_contig.data();
302            let mut result = vec![0.0; n];
303
304            for (j, out) in result.iter_mut().enumerate().take(n) {
305                let mut col = vec![0.0f32; m];
306                for i in 0..m {
307                    col[i] = data[i * n + j];
308                }
309                *out = crate::ops::sum_auto(&col);
310            }
311
312            return Tensor::new_with_storage(Storage::new(result), output_shape);
313        }
314    }
315
316    if input_shape.len() == 1 && output_shape.len() == 1 {
317        if input_shape[0] == output_shape[0] {
318            return input_contig.clone();
319        }
320    }
321
322    if output_shape.iter().product::<usize>() == 1 {
323        let data = input_contig.data();
324        let sum = crate::ops::sum_auto(&data);
325        return Tensor::new_with_storage(Storage::new(vec![sum]), output_shape);
326    }
327
328    panic!(
329        "General sum_to not implemented for shape {:?} -> {:?}",
330        input_shape, output_shape
331    );
332}