svod_tensor/broadcast.rs
1//! Broadcasting operations for tensors.
2//!
3//! Implements NumPy-style broadcasting rules:
4//! - Shapes are aligned from the right (trailing dimensions)
5//! - Missing dimensions are treated as 1
6//! - For each dimension, sizes must either match or one must be 1
7//!
8//! This module provides the infrastructure for automatic broadcasting
9//! in binary operations, matching Tinygrad's architecture.
10
11use snafu::ResultExt;
12
13use super::*;
14use svod_ir::shape::{align_shapes_left, broadcast_shape};
15
16impl Tensor {
17 /// Broadcast two tensors to a common shape for binary operations.
18 ///
19 /// This method implements automatic broadcasting similar to NumPy/PyTorch.
20 /// It aligns the shapes, computes the broadcast result shape, and broadcasts
21 /// each tensor to that shape.
22 ///
23 /// # Broadcasting Rules
24 ///
25 /// - Shapes are aligned from the right (trailing dimensions)
26 /// - Missing dimensions are padded with 1 on the left
27 /// - For each dimension, sizes must either match or one must be 1
28 /// - The result dimension is the maximum of the two
29 ///
30 /// # Examples
31 ///
32 /// ```ignore
33 /// // Scalar + Vector: [] + [3] -> [3]
34 /// let scalar = Tensor::from_slice([5.0f32]);
35 /// let vector = Tensor::from_slice([1.0f32, 2.0, 3.0]);
36 /// let (a, b) = scalar.broadcast_for_binop(&vector)?;
37 ///
38 /// // Matrix + Row: [2, 3] + [1, 3] -> [2, 3]
39 /// let matrix = Tensor::from_slice([1.0f32; 6]).try_reshape(&[2, 3])?;
40 /// let row = Tensor::from_slice([1.0f32; 3]).try_reshape(&[1, 3])?;
41 /// let (a, b) = matrix.broadcast_for_binop(&row)?;
42 /// ```
43 ///
44 /// # Errors
45 ///
46 /// Returns error if shapes are incompatible for broadcasting.
47 pub(crate) fn broadcast_for_binop(&self, other: &Tensor) -> Result<(Tensor, Tensor)> {
48 let self_shape = self.shape()?;
49 let other_shape = other.shape()?;
50
51 // Early return if shapes already match
52 if self_shape == other_shape {
53 return Ok((self.clone(), other.clone()));
54 }
55
56 // Handle scalar cases (empty shape means scalar in svod)
57 // Actually, in svod scalars have shape [1], but let's handle both
58 if self_shape.is_empty() && other_shape.is_empty() {
59 return Ok((self.clone(), other.clone()));
60 }
61
62 // Align shapes (pad with 1s on left)
63 let aligned = align_shapes_left(&[self_shape.clone(), other_shape.clone()]);
64
65 // Compute broadcast result shape
66 let result_shape = broadcast_shape(&aligned[0], &aligned[1]).context(UOpSnafu)?;
67
68 // Broadcast each tensor to result shape
69 let self_broadcast = self.broadcast_to(&result_shape)?;
70 let other_broadcast = other.broadcast_to(&result_shape)?;
71
72 Ok((self_broadcast, other_broadcast))
73 }
74
75 /// Broadcast tensor to a target shape.
76 ///
77 /// This is the low-level broadcast operation that reshapes (adds explicit 1 dimensions)
78 /// and then expands (replicates data along size-1 dimensions).
79 ///
80 /// # Algorithm
81 ///
82 /// 1. If shape already matches, return self
83 /// 2. Pad shape with 1s on the left to match rank
84 /// 3. Reshape to add explicit 1 dimensions
85 /// 4. Expand size-1 dimensions to target size
86 ///
87 /// # Examples
88 ///
89 /// ```ignore
90 /// // [3] -> [2, 3]
91 /// let t = Tensor::from_slice([1.0f32, 2.0, 3.0]);
92 /// let target = vec![SInt::from(2), SInt::from(3)];
93 /// let broadcasted = t.broadcast_to(&target)?;
94 /// ```
95 ///
96 /// # Errors
97 ///
98 /// Returns error if:
99 /// - Shape has more dimensions than target
100 /// - Dimension sizes are incompatible (not 1 and not equal to target)
101 pub fn broadcast_to(&self, target_shape: &svod_ir::shape::Shape) -> Result<Tensor> {
102 let self_shape = self.shape()?;
103
104 // Early return if already correct shape
105 if &self_shape == target_shape {
106 return Ok(self.clone());
107 }
108
109 // Cannot broadcast to fewer dimensions
110 if self_shape.len() > target_shape.len() {
111 return Err(Error::BroadcastFewerDimensions { from_dims: self_shape.len(), to_dims: target_shape.len() });
112 }
113
114 // Pad shape with 1s on left if needed
115 let aligned_shape = if self_shape.len() < target_shape.len() {
116 let padding = target_shape.len() - self_shape.len();
117 let mut new_shape = svod_ir::shape::Shape::new();
118 new_shape.extend(std::iter::repeat_n(svod_ir::SInt::from(1), padding));
119 new_shape.extend(self_shape.iter().cloned());
120 new_shape
121 } else {
122 self_shape.clone()
123 };
124
125 // Validate broadcast compatibility
126 for (i, (aligned_dim, target_dim)) in aligned_shape.iter().zip(target_shape.iter()).enumerate() {
127 if let (Some(aligned_size), Some(target_size)) = (aligned_dim.as_const(), target_dim.as_const())
128 && aligned_size != 1
129 && aligned_size != target_size
130 {
131 return Err(Error::BroadcastIncompatible { dim: i, from_size: aligned_size, to_size: target_size });
132 }
133 // For symbolic dimensions, conservatively assume they're compatible
134 }
135
136 // Reshape to add explicit 1 dimensions (if needed)
137 let reshaped = if aligned_shape != self_shape {
138 // Call IR layer directly to support symbolic dimensions
139 self.uop().try_reshape(&aligned_shape).map(Self::new).context(UOpSnafu)?
140 } else {
141 self.clone()
142 };
143
144 // Check if expansion is actually needed
145 if &aligned_shape == target_shape {
146 return Ok(reshaped);
147 }
148
149 // Expand to target shape - call IR layer directly to support symbolic dimensions
150 reshaped.uop().try_expand(target_shape).map(Self::new).context(UOpSnafu)
151 }
152}