Skip to main content

torsh_tensor/
bfloat16_ops.rs

1//! BFloat16 tensor operations and optimizations
2//!
3//! This module provides specialized operations for BFloat16 (bf16) tensors,
4//! including proper rounding modes and optimized implementations.
5
6use crate::{Tensor, TensorElement};
7use half::bf16;
8use torsh_core::{
9    dtype::{BF16RoundingMode, BFloat16Ops},
10    error::Result,
11};
12
13/// Extension trait for BFloat16 tensor operations
14pub trait BFloat16TensorOps<T: TensorElement> {
15    /// Convert tensor to bf16 with specified rounding mode
16    fn to_bf16_with_rounding(&self, mode: BF16RoundingMode) -> Result<Tensor<bf16>>;
17
18    /// Convert from bf16 tensor to higher precision
19    fn to_f32(&self) -> Result<Tensor<f32>>;
20
21    /// Perform operation in higher precision then round back to bf16
22    fn bf16_high_precision_op<F>(&self, op: F) -> Result<Tensor<bf16>>
23    where
24        F: Fn(&Tensor<f32>) -> Result<Tensor<f32>>;
25}
26
27impl BFloat16TensorOps<f32> for Tensor<f32> {
28    fn to_bf16_with_rounding(&self, mode: BF16RoundingMode) -> Result<Tensor<bf16>> {
29        let data = self.data()?;
30        let converted_data: Vec<bf16> = data
31            .iter()
32            .map(|&x| bf16::from_f32_with_rounding(x, mode))
33            .collect();
34
35        Tensor::from_data(converted_data, self.shape().dims().to_vec(), self.device())
36    }
37
38    fn to_f32(&self) -> Result<Tensor<f32>> {
39        // This doesn't make sense for f32 -> bf16, but included for completeness
40        self.to_bf16_with_rounding(BF16RoundingMode::NearestTiesToEven)?
41            .to_f32()
42    }
43
44    fn bf16_high_precision_op<F>(&self, op: F) -> Result<Tensor<bf16>>
45    where
46        F: Fn(&Tensor<f32>) -> Result<Tensor<f32>>,
47    {
48        // f32 is already high precision, apply op and convert to bf16
49        let result = op(self)?;
50        result.to_bf16_with_rounding(BF16RoundingMode::NearestTiesToEven)
51    }
52}
53
54impl BFloat16TensorOps<bf16> for Tensor<bf16> {
55    fn to_bf16_with_rounding(&self, _mode: BF16RoundingMode) -> Result<Tensor<bf16>> {
56        // Already bf16, return clone
57        Ok(self.clone())
58    }
59
60    fn to_f32(&self) -> Result<Tensor<f32>> {
61        let data = self.data()?;
62        let converted_data: Vec<f32> = data.iter().map(|&x| x.to_f32()).collect();
63
64        Tensor::from_data(converted_data, self.shape().dims().to_vec(), self.device())
65    }
66
67    fn bf16_high_precision_op<F>(&self, op: F) -> Result<Tensor<bf16>>
68    where
69        F: Fn(&Tensor<f32>) -> Result<Tensor<f32>>,
70    {
71        // Convert to f32, apply op, convert back to bf16
72        let f32_tensor = self.to_f32()?;
73        let result = op(&f32_tensor)?;
74        result.to_bf16_with_rounding(BF16RoundingMode::NearestTiesToEven)
75    }
76}
77
78/// Specialized bf16 arithmetic operations with proper rounding
79impl Tensor<bf16> {
80    /// Add two bf16 tensors with specified rounding mode
81    pub fn add_with_rounding(
82        &self,
83        other: &Tensor<bf16>,
84        mode: BF16RoundingMode,
85    ) -> Result<Tensor<bf16>> {
86        let self_data = self.data()?;
87        let other_data = other.data()?;
88
89        if self_data.len() != other_data.len() {
90            return Err(torsh_core::error::TorshError::InvalidArgument(
91                "Tensor shapes must match for addition".to_string(),
92            ));
93        }
94
95        let result_data: Vec<bf16> = self_data
96            .iter()
97            .zip(other_data.iter())
98            .map(|(&a, &b)| {
99                let sum_f32 = a.to_f32() + b.to_f32();
100                bf16::from_f32_with_rounding(sum_f32, mode)
101            })
102            .collect();
103
104        Tensor::from_data(result_data, self.shape().dims().to_vec(), self.device())
105    }
106
107    /// Multiply two bf16 tensors with specified rounding mode
108    pub fn mul_with_rounding(
109        &self,
110        other: &Tensor<bf16>,
111        mode: BF16RoundingMode,
112    ) -> Result<Tensor<bf16>> {
113        let self_data = self.data()?;
114        let other_data = other.data()?;
115
116        if self_data.len() != other_data.len() {
117            return Err(torsh_core::error::TorshError::InvalidArgument(
118                "Tensor shapes must match for multiplication".to_string(),
119            ));
120        }
121
122        let result_data: Vec<bf16> = self_data
123            .iter()
124            .zip(other_data.iter())
125            .map(|(&a, &b)| a.mul_with_rounding(b, mode))
126            .collect();
127
128        Tensor::from_data(result_data, self.shape().dims().to_vec(), self.device())
129    }
130
131    /// Fused multiply-add with proper bf16 rounding
132    pub fn fma_with_rounding(
133        &self,
134        other: &Tensor<bf16>,
135        addend: &Tensor<bf16>,
136        mode: BF16RoundingMode,
137    ) -> Result<Tensor<bf16>> {
138        let self_data = self.data()?;
139        let other_data = other.data()?;
140        let addend_data = addend.data()?;
141
142        if self_data.len() != other_data.len() || self_data.len() != addend_data.len() {
143            return Err(torsh_core::error::TorshError::InvalidArgument(
144                "All tensor shapes must match for FMA".to_string(),
145            ));
146        }
147
148        let result_data: Vec<bf16> = self_data
149            .iter()
150            .zip(other_data.iter())
151            .zip(addend_data.iter())
152            .map(|((&a, &b), &c)| a.fma_with_rounding(b, c, mode))
153            .collect();
154
155        Tensor::from_data(result_data, self.shape().dims().to_vec(), self.device())
156    }
157}
158
159/// Optimized bf16 creation functions
160pub mod creation {
161    use super::*;
162    use crate::creation;
163
164    /// Create bf16 tensor from f32 data with specified rounding
165    pub fn tensor_1d_bf16_from_f32(data: &[f32], mode: BF16RoundingMode) -> Result<Tensor<bf16>> {
166        let bf16_data: Vec<bf16> = data
167            .iter()
168            .map(|&x| bf16::from_f32_with_rounding(x, mode))
169            .collect();
170        creation::tensor_1d(&bf16_data)
171    }
172
173    /// Create 2D bf16 tensor from f32 data with specified rounding
174    pub fn tensor_2d_bf16_from_f32(
175        data: &[&[f32]],
176        mode: BF16RoundingMode,
177    ) -> Result<Tensor<bf16>> {
178        let rows = data.len();
179        let cols = if rows > 0 { data[0].len() } else { 0 };
180
181        let mut bf16_data = Vec::with_capacity(rows * cols);
182        for row in data {
183            for &val in row.iter() {
184                bf16_data.push(bf16::from_f32_with_rounding(val, mode));
185            }
186        }
187
188        Tensor::from_data(
189            bf16_data,
190            vec![rows, cols],
191            torsh_core::device::DeviceType::Cpu,
192        )
193    }
194
195    /// Create bf16 zeros tensor
196    pub fn zeros_bf16(shape: &[usize]) -> Result<Tensor<bf16>> {
197        creation::zeros::<bf16>(shape)
198    }
199
200    /// Create bf16 ones tensor
201    pub fn ones_bf16(shape: &[usize]) -> Result<Tensor<bf16>> {
202        creation::ones::<bf16>(shape)
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use crate::creation;
210    use approx::assert_relative_eq;
211
212    #[test]
213    fn test_bf16_tensor_creation() {
214        let data = vec![
215            bf16::from_f32(1.0),
216            bf16::from_f32(2.0),
217            bf16::from_f32(3.0),
218        ];
219        let tensor = creation::tensor_1d(&data).expect("bf16 tensor creation failed");
220
221        assert_eq!(tensor.shape().dims(), &[3]);
222        assert_eq!(tensor.data().expect("data retrieval failed"), data);
223    }
224
225    #[test]
226    fn test_bf16_zeros_ones() {
227        let zeros = creation::zeros::<bf16>(&[2, 3]).expect("zeros creation failed");
228        assert_eq!(zeros.shape().dims(), &[2, 3]);
229
230        let zeros_data = zeros.data().expect("data retrieval failed");
231        assert!(zeros_data.iter().all(|&x| x == bf16::from_f32(0.0)));
232
233        let ones = creation::ones::<bf16>(&[2, 3]).expect("ones creation failed");
234        let ones_data = ones.data().expect("data retrieval failed");
235        assert!(ones_data.iter().all(|&x| x == bf16::from_f32(1.0)));
236    }
237
238    #[test]
239    fn test_bf16_rounding_modes() {
240        let f32_data = vec![1.5f32, 2.5f32, 3.7f32];
241
242        // Test different rounding modes
243        let nearest_even = super::creation::tensor_1d_bf16_from_f32(
244            &f32_data,
245            BF16RoundingMode::NearestTiesToEven,
246        )
247        .expect("nearest_even creation failed");
248        let nearest_away =
249            super::creation::tensor_1d_bf16_from_f32(&f32_data, BF16RoundingMode::NearestTiesAway)
250                .expect("nearest_away creation failed");
251        let toward_zero =
252            super::creation::tensor_1d_bf16_from_f32(&f32_data, BF16RoundingMode::TowardZero)
253                .expect("toward_zero creation failed");
254
255        let nearest_even_data = nearest_even.data().expect("data retrieval failed");
256        let nearest_away_data = nearest_away.data().expect("data retrieval failed");
257        let toward_zero_data = toward_zero.data().expect("data retrieval failed");
258
259        // Verify different rounding behaviors for tie cases
260        assert_eq!(
261            nearest_even_data[0],
262            bf16::from_f32_with_rounding(1.5, BF16RoundingMode::NearestTiesToEven)
263        );
264        assert_eq!(
265            nearest_away_data[0],
266            bf16::from_f32_with_rounding(1.5, BF16RoundingMode::NearestTiesAway)
267        );
268        assert_eq!(
269            toward_zero_data[0],
270            bf16::from_f32_with_rounding(1.5, BF16RoundingMode::TowardZero)
271        );
272    }
273
274    #[test]
275    fn test_bf16_arithmetic_with_rounding() {
276        let a = creation::tensor_1d(&[bf16::from_f32(1.5), bf16::from_f32(2.5)])
277            .expect("tensor creation failed");
278        let b = creation::tensor_1d(&[bf16::from_f32(0.5), bf16::from_f32(1.5)])
279            .expect("tensor creation failed");
280
281        let result = a
282            .add_with_rounding(&b, BF16RoundingMode::NearestTiesToEven)
283            .expect("add_with_rounding failed");
284        let result_data = result.data().expect("data retrieval failed");
285
286        assert_relative_eq!(result_data[0].to_f32(), 2.0, epsilon = 1e-6);
287        assert_relative_eq!(result_data[1].to_f32(), 4.0, epsilon = 1e-6);
288    }
289
290    #[test]
291    fn test_bf16_conversion() {
292        let f32_tensor =
293            creation::tensor_1d(&[1.0f32, 2.0f32, 3.0f32]).expect("tensor creation failed");
294
295        // Convert to bf16
296        let bf16_tensor = f32_tensor
297            .to_bf16_with_rounding(BF16RoundingMode::NearestTiesToEven)
298            .expect("to_bf16 conversion failed");
299
300        // Convert back to f32
301        let f32_converted = bf16_tensor.to_f32().expect("to_f32 conversion failed");
302        let f32_converted_data = f32_converted.data().expect("data retrieval failed");
303
304        // Should be approximately equal (some precision loss expected)
305        assert_relative_eq!(f32_converted_data[0], 1.0, epsilon = 1e-2);
306        assert_relative_eq!(f32_converted_data[1], 2.0, epsilon = 1e-2);
307        assert_relative_eq!(f32_converted_data[2], 3.0, epsilon = 1e-2);
308    }
309
310    #[test]
311    fn test_bf16_high_precision_op() {
312        let bf16_tensor = creation::tensor_1d(&[bf16::from_f32(1.0), bf16::from_f32(2.0)])
313            .expect("tensor creation failed");
314
315        // Apply a complex operation in high precision
316        let result = bf16_tensor
317            .bf16_high_precision_op(|t| {
318                let doubled = t.mul_op(t)?; // Square in f32 precision
319                doubled.add_scalar(1.0) // Add 1 in f32 precision
320            })
321            .expect("bf16_high_precision_op failed");
322
323        let result_data = result.data().expect("data retrieval failed");
324        assert_relative_eq!(result_data[0].to_f32(), 2.0, epsilon = 1e-2); // 1^2 + 1 = 2
325        assert_relative_eq!(result_data[1].to_f32(), 5.0, epsilon = 1e-2); // 2^2 + 1 = 5
326    }
327
328    #[test]
329    fn test_bf16_fma() {
330        let a = creation::tensor_1d(&[bf16::from_f32(2.0), bf16::from_f32(3.0)])
331            .expect("tensor creation failed");
332        let b = creation::tensor_1d(&[bf16::from_f32(4.0), bf16::from_f32(5.0)])
333            .expect("tensor creation failed");
334        let c = creation::tensor_1d(&[bf16::from_f32(1.0), bf16::from_f32(2.0)])
335            .expect("tensor creation failed");
336
337        let result = a
338            .fma_with_rounding(&b, &c, BF16RoundingMode::NearestTiesToEven)
339            .expect("fma_with_rounding failed");
340        let result_data = result.data().expect("data retrieval failed");
341
342        // FMA: a * b + c
343        assert_relative_eq!(result_data[0].to_f32(), 9.0, epsilon = 1e-2); // 2 * 4 + 1 = 9
344        assert_relative_eq!(result_data[1].to_f32(), 17.0, epsilon = 1e-2); // 3 * 5 + 2 = 17
345    }
346
347    #[test]
348    fn test_bf16_precision_limits() {
349        // Test bf16 precision limits
350        let large_value = 65504.0f32; // Near bf16 max
351        let small_value = 1e-6f32; // Very small value
352
353        let large_tensor = super::creation::tensor_1d_bf16_from_f32(
354            &[large_value],
355            BF16RoundingMode::NearestTiesToEven,
356        )
357        .expect("large tensor creation failed");
358        let small_tensor = super::creation::tensor_1d_bf16_from_f32(
359            &[small_value],
360            BF16RoundingMode::NearestTiesToEven,
361        )
362        .expect("small tensor creation failed");
363
364        let large_data = large_tensor.data().expect("data retrieval failed");
365        let small_data = small_tensor.data().expect("data retrieval failed");
366
367        // Large values should be preserved with some precision loss
368        assert!((large_data[0].to_f32() - large_value).abs() < 1000.0);
369
370        // Very small values might be rounded to zero or have significant precision loss
371        assert!(small_data[0].to_f32() >= 0.0);
372    }
373}