Skip to main content

rustorch_core/
broadcast.rs

1use crate::autograd::BackwardOp;
2// use crate::storage::Storage;
3use crate::Tensor;
4use std::sync::Arc;
5
6#[derive(Debug)]
7pub struct ExpandBackward {
8    pub input: Tensor,
9    pub input_shape: Vec<usize>,
10}
11
12impl BackwardOp for ExpandBackward {
13    fn backward(&self, grad: &Tensor) {
14        if self.input.requires_grad() {
15            // Gradient reduction: sum over broadcasted dimensions
16            let grad_reduced = crate::ops::sum_to(grad, &self.input_shape);
17            self.input.accumulate_grad(&grad_reduced);
18            self.input.backward_step();
19        }
20    }
21}
22
23pub fn broadcast_shapes(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
24    let len1 = shape1.len();
25    let len2 = shape2.len();
26    let max_len = std::cmp::max(len1, len2);
27
28    let mut result_shape = vec![0; max_len];
29
30    for i in 0..max_len {
31        let dim1 = if i < len1 { shape1[len1 - 1 - i] } else { 1 };
32        let dim2 = if i < len2 { shape2[len2 - 1 - i] } else { 1 };
33
34        if dim1 == dim2 {
35            result_shape[max_len - 1 - i] = dim1;
36        } else if dim1 == 1 {
37            result_shape[max_len - 1 - i] = dim2;
38        } else if dim2 == 1 {
39            result_shape[max_len - 1 - i] = dim1;
40        } else {
41            return None;
42        }
43    }
44
45    Some(result_shape)
46}
47
48impl Tensor {
49    // Lazy expansion: returns a view with modified strides (0 stride for broadcasted dims)
50    pub fn expand(&self, new_shape: &[usize]) -> Tensor {
51        if self.shape() == new_shape {
52            return self.clone();
53        }
54
55        let current_shape = self.shape();
56        let current_strides = self.strides();
57
58        let ndim_new = new_shape.len();
59        let ndim_old = current_shape.len();
60
61        if ndim_new < ndim_old {
62            panic!("expand: new shape must have >= dims than current shape");
63        }
64
65        let mut new_strides = vec![0; ndim_new];
66        let offset = ndim_new - ndim_old;
67
68        for i in 0..ndim_new {
69            if i < offset {
70                // New dimension added at front (broadcasting)
71                // If new_shape[i] > 1, stride is 0.
72                // If new_shape[i] == 1, stride is arbitrary (say 0).
73                new_strides[i] = 0;
74            } else {
75                let old_idx = i - offset;
76                let old_dim = current_shape[old_idx];
77                let new_dim = new_shape[i];
78
79                if old_dim == 1 && new_dim > 1 {
80                    // Broadcast existing dim: stride 0
81                    new_strides[i] = 0;
82                } else if old_dim == new_dim {
83                    // Inherit stride
84                    new_strides[i] = current_strides[old_idx];
85                } else {
86                    panic!(
87                        "expand: invalid shape {:?} -> {:?}",
88                        current_shape, new_shape
89                    );
90                }
91            }
92        }
93
94        // Return view
95        // We construct a new Tensor sharing the same storage
96        // This requires accessing private fields or using a constructor that takes strides.
97        // Tensor::new_with_storage usually assumes contiguous.
98        // We need `Tensor::new_with_storage_and_strides` or similar.
99        // Or modify `new_with_storage`?
100        // Let's check `tensor.rs`.
101
102        // For now, I will assume I can create it.
103        // I need to add `new_view` or similar to Tensor.
104        // But wait, `Tensor::new_with_storage` computes strides from shape assuming contiguous.
105        // I need to add a method to Tensor to create from storage + strides.
106
107        // HACK: I cannot easily modify `Tensor` private fields from here if they are private to crate.
108        // `broadcast.rs` is in `src/`, same crate. So I can access `pub(crate)` fields.
109        // But `Tensor` struct definition is in `tensor.rs`.
110        // `Tensor` wraps `Arc<TensorImpl>`. `TensorImpl` fields are `pub(crate)`.
111
112        use crate::tensor::TensorImpl;
113        use std::sync::Mutex;
114
115        let inner = TensorImpl {
116            storage: self.storage().clone(),
117            shape: new_shape.to_vec(),
118            strides: new_strides,
119            grad: Mutex::new(None),
120            requires_grad: self.requires_grad(),
121            op: if self.requires_grad() {
122                Some(Arc::new(ExpandBackward {
123                    input: self.clone(),
124                    input_shape: self.shape().to_vec(),
125                }))
126            } else {
127                None
128            },
129            is_leaf: false, // Views are not leaf usually? Or if they share storage...
130                            // If it has history, it's not leaf.
131        };
132
133        Tensor {
134            inner: Arc::new(inner),
135        }
136    }
137}