Skip to main content

svod_tensor/
broadcast.rs

1//! Broadcasting operations for tensors.
2//!
3//! Implements NumPy-style broadcasting rules:
4//! - Shapes are aligned from the right (trailing dimensions)
5//! - Missing dimensions are treated as 1
6//! - For each dimension, sizes must either match or one must be 1
7//!
8//! This module provides the infrastructure for automatic broadcasting
9//! in binary operations, matching Tinygrad's architecture.
10
11use snafu::ResultExt;
12
13use super::*;
14use svod_ir::shape::{align_shapes_left, broadcast_shape};
15
16impl Tensor {
17    /// Broadcast two tensors to a common shape for binary operations.
18    ///
19    /// This method implements automatic broadcasting similar to NumPy/PyTorch.
20    /// It aligns the shapes, computes the broadcast result shape, and broadcasts
21    /// each tensor to that shape.
22    ///
23    /// # Broadcasting Rules
24    ///
25    /// - Shapes are aligned from the right (trailing dimensions)
26    /// - Missing dimensions are padded with 1 on the left
27    /// - For each dimension, sizes must either match or one must be 1
28    /// - The result dimension is the maximum of the two
29    ///
30    /// # Examples
31    ///
32    /// ```ignore
33    /// // Scalar + Vector: [] + [3] -> [3]
34    /// let scalar = Tensor::from_slice([5.0f32]);
35    /// let vector = Tensor::from_slice([1.0f32, 2.0, 3.0]);
36    /// let (a, b) = scalar.broadcast_for_binop(&vector)?;
37    ///
38    /// // Matrix + Row: [2, 3] + [1, 3] -> [2, 3]
39    /// let matrix = Tensor::from_slice([1.0f32; 6]).try_reshape(&[2, 3])?;
40    /// let row = Tensor::from_slice([1.0f32; 3]).try_reshape(&[1, 3])?;
41    /// let (a, b) = matrix.broadcast_for_binop(&row)?;
42    /// ```
43    ///
44    /// # Errors
45    ///
46    /// Returns error if shapes are incompatible for broadcasting.
47    pub(crate) fn broadcast_for_binop(&self, other: &Tensor) -> Result<(Tensor, Tensor)> {
48        let self_shape = self.shape()?;
49        let other_shape = other.shape()?;
50
51        // Early return if shapes already match
52        if self_shape == other_shape {
53            return Ok((self.clone(), other.clone()));
54        }
55
56        // Handle scalar cases (empty shape means scalar in svod)
57        // Actually, in svod scalars have shape [1], but let's handle both
58        if self_shape.is_empty() && other_shape.is_empty() {
59            return Ok((self.clone(), other.clone()));
60        }
61
62        // Align shapes (pad with 1s on left)
63        let aligned = align_shapes_left(&[self_shape.clone(), other_shape.clone()]);
64
65        // Compute broadcast result shape
66        let result_shape = broadcast_shape(&aligned[0], &aligned[1]).context(UOpSnafu)?;
67
68        // Broadcast each tensor to result shape
69        let self_broadcast = self.broadcast_to(&result_shape)?;
70        let other_broadcast = other.broadcast_to(&result_shape)?;
71
72        Ok((self_broadcast, other_broadcast))
73    }
74
75    /// Broadcast tensor to a target shape.
76    ///
77    /// This is the low-level broadcast operation that reshapes (adds explicit 1 dimensions)
78    /// and then expands (replicates data along size-1 dimensions).
79    ///
80    /// # Algorithm
81    ///
82    /// 1. If shape already matches, return self
83    /// 2. Pad shape with 1s on the left to match rank
84    /// 3. Reshape to add explicit 1 dimensions
85    /// 4. Expand size-1 dimensions to target size
86    ///
87    /// # Examples
88    ///
89    /// ```ignore
90    /// // [3] -> [2, 3]
91    /// let t = Tensor::from_slice([1.0f32, 2.0, 3.0]);
92    /// let target = vec![SInt::from(2), SInt::from(3)];
93    /// let broadcasted = t.broadcast_to(&target)?;
94    /// ```
95    ///
96    /// # Errors
97    ///
98    /// Returns error if:
99    /// - Shape has more dimensions than target
100    /// - Dimension sizes are incompatible (not 1 and not equal to target)
101    pub fn broadcast_to(&self, target_shape: &svod_ir::shape::Shape) -> Result<Tensor> {
102        let self_shape = self.shape()?;
103
104        // Early return if already correct shape
105        if &self_shape == target_shape {
106            return Ok(self.clone());
107        }
108
109        // Cannot broadcast to fewer dimensions
110        if self_shape.len() > target_shape.len() {
111            return Err(Error::BroadcastFewerDimensions { from_dims: self_shape.len(), to_dims: target_shape.len() });
112        }
113
114        // Pad shape with 1s on left if needed
115        let aligned_shape = if self_shape.len() < target_shape.len() {
116            let padding = target_shape.len() - self_shape.len();
117            let mut new_shape = svod_ir::shape::Shape::new();
118            new_shape.extend(std::iter::repeat_n(svod_ir::SInt::from(1), padding));
119            new_shape.extend(self_shape.iter().cloned());
120            new_shape
121        } else {
122            self_shape.clone()
123        };
124
125        // Validate broadcast compatibility
126        for (i, (aligned_dim, target_dim)) in aligned_shape.iter().zip(target_shape.iter()).enumerate() {
127            if let (Some(aligned_size), Some(target_size)) = (aligned_dim.as_const(), target_dim.as_const())
128                && aligned_size != 1
129                && aligned_size != target_size
130            {
131                return Err(Error::BroadcastIncompatible { dim: i, from_size: aligned_size, to_size: target_size });
132            }
133            // For symbolic dimensions, conservatively assume they're compatible
134        }
135
136        // Reshape to add explicit 1 dimensions (if needed)
137        let reshaped = if aligned_shape != self_shape {
138            // Call IR layer directly to support symbolic dimensions
139            self.uop().try_reshape(&aligned_shape).map(Self::new).context(UOpSnafu)?
140        } else {
141            self.clone()
142        };
143
144        // Check if expansion is actually needed
145        if &aligned_shape == target_shape {
146            return Ok(reshaped);
147        }
148
149        // Expand to target shape - call IR layer directly to support symbolic dimensions
150        reshaped.uop().try_expand(target_shape).map(Self::new).context(UOpSnafu)
151    }
152}