1use crate::{Tensor, TensorElement};
7use half::bf16;
8use torsh_core::{
9 dtype::{BF16RoundingMode, BFloat16Ops},
10 error::Result,
11};
12
13pub trait BFloat16TensorOps<T: TensorElement> {
15 fn to_bf16_with_rounding(&self, mode: BF16RoundingMode) -> Result<Tensor<bf16>>;
17
18 fn to_f32(&self) -> Result<Tensor<f32>>;
20
21 fn bf16_high_precision_op<F>(&self, op: F) -> Result<Tensor<bf16>>
23 where
24 F: Fn(&Tensor<f32>) -> Result<Tensor<f32>>;
25}
26
27impl BFloat16TensorOps<f32> for Tensor<f32> {
28 fn to_bf16_with_rounding(&self, mode: BF16RoundingMode) -> Result<Tensor<bf16>> {
29 let data = self.data()?;
30 let converted_data: Vec<bf16> = data
31 .iter()
32 .map(|&x| bf16::from_f32_with_rounding(x, mode))
33 .collect();
34
35 Tensor::from_data(converted_data, self.shape().dims().to_vec(), self.device())
36 }
37
38 fn to_f32(&self) -> Result<Tensor<f32>> {
39 self.to_bf16_with_rounding(BF16RoundingMode::NearestTiesToEven)?
41 .to_f32()
42 }
43
44 fn bf16_high_precision_op<F>(&self, op: F) -> Result<Tensor<bf16>>
45 where
46 F: Fn(&Tensor<f32>) -> Result<Tensor<f32>>,
47 {
48 let result = op(self)?;
50 result.to_bf16_with_rounding(BF16RoundingMode::NearestTiesToEven)
51 }
52}
53
54impl BFloat16TensorOps<bf16> for Tensor<bf16> {
55 fn to_bf16_with_rounding(&self, _mode: BF16RoundingMode) -> Result<Tensor<bf16>> {
56 Ok(self.clone())
58 }
59
60 fn to_f32(&self) -> Result<Tensor<f32>> {
61 let data = self.data()?;
62 let converted_data: Vec<f32> = data.iter().map(|&x| x.to_f32()).collect();
63
64 Tensor::from_data(converted_data, self.shape().dims().to_vec(), self.device())
65 }
66
67 fn bf16_high_precision_op<F>(&self, op: F) -> Result<Tensor<bf16>>
68 where
69 F: Fn(&Tensor<f32>) -> Result<Tensor<f32>>,
70 {
71 let f32_tensor = self.to_f32()?;
73 let result = op(&f32_tensor)?;
74 result.to_bf16_with_rounding(BF16RoundingMode::NearestTiesToEven)
75 }
76}
77
78impl Tensor<bf16> {
80 pub fn add_with_rounding(
82 &self,
83 other: &Tensor<bf16>,
84 mode: BF16RoundingMode,
85 ) -> Result<Tensor<bf16>> {
86 let self_data = self.data()?;
87 let other_data = other.data()?;
88
89 if self_data.len() != other_data.len() {
90 return Err(torsh_core::error::TorshError::InvalidArgument(
91 "Tensor shapes must match for addition".to_string(),
92 ));
93 }
94
95 let result_data: Vec<bf16> = self_data
96 .iter()
97 .zip(other_data.iter())
98 .map(|(&a, &b)| {
99 let sum_f32 = a.to_f32() + b.to_f32();
100 bf16::from_f32_with_rounding(sum_f32, mode)
101 })
102 .collect();
103
104 Tensor::from_data(result_data, self.shape().dims().to_vec(), self.device())
105 }
106
107 pub fn mul_with_rounding(
109 &self,
110 other: &Tensor<bf16>,
111 mode: BF16RoundingMode,
112 ) -> Result<Tensor<bf16>> {
113 let self_data = self.data()?;
114 let other_data = other.data()?;
115
116 if self_data.len() != other_data.len() {
117 return Err(torsh_core::error::TorshError::InvalidArgument(
118 "Tensor shapes must match for multiplication".to_string(),
119 ));
120 }
121
122 let result_data: Vec<bf16> = self_data
123 .iter()
124 .zip(other_data.iter())
125 .map(|(&a, &b)| a.mul_with_rounding(b, mode))
126 .collect();
127
128 Tensor::from_data(result_data, self.shape().dims().to_vec(), self.device())
129 }
130
131 pub fn fma_with_rounding(
133 &self,
134 other: &Tensor<bf16>,
135 addend: &Tensor<bf16>,
136 mode: BF16RoundingMode,
137 ) -> Result<Tensor<bf16>> {
138 let self_data = self.data()?;
139 let other_data = other.data()?;
140 let addend_data = addend.data()?;
141
142 if self_data.len() != other_data.len() || self_data.len() != addend_data.len() {
143 return Err(torsh_core::error::TorshError::InvalidArgument(
144 "All tensor shapes must match for FMA".to_string(),
145 ));
146 }
147
148 let result_data: Vec<bf16> = self_data
149 .iter()
150 .zip(other_data.iter())
151 .zip(addend_data.iter())
152 .map(|((&a, &b), &c)| a.fma_with_rounding(b, c, mode))
153 .collect();
154
155 Tensor::from_data(result_data, self.shape().dims().to_vec(), self.device())
156 }
157}
158
159pub mod creation {
161 use super::*;
162 use crate::creation;
163
164 pub fn tensor_1d_bf16_from_f32(data: &[f32], mode: BF16RoundingMode) -> Result<Tensor<bf16>> {
166 let bf16_data: Vec<bf16> = data
167 .iter()
168 .map(|&x| bf16::from_f32_with_rounding(x, mode))
169 .collect();
170 creation::tensor_1d(&bf16_data)
171 }
172
173 pub fn tensor_2d_bf16_from_f32(
175 data: &[&[f32]],
176 mode: BF16RoundingMode,
177 ) -> Result<Tensor<bf16>> {
178 let rows = data.len();
179 let cols = if rows > 0 { data[0].len() } else { 0 };
180
181 let mut bf16_data = Vec::with_capacity(rows * cols);
182 for row in data {
183 for &val in row.iter() {
184 bf16_data.push(bf16::from_f32_with_rounding(val, mode));
185 }
186 }
187
188 Tensor::from_data(
189 bf16_data,
190 vec![rows, cols],
191 torsh_core::device::DeviceType::Cpu,
192 )
193 }
194
195 pub fn zeros_bf16(shape: &[usize]) -> Result<Tensor<bf16>> {
197 creation::zeros::<bf16>(shape)
198 }
199
200 pub fn ones_bf16(shape: &[usize]) -> Result<Tensor<bf16>> {
202 creation::ones::<bf16>(shape)
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use crate::creation;
210 use approx::assert_relative_eq;
211
212 #[test]
213 fn test_bf16_tensor_creation() {
214 let data = vec![
215 bf16::from_f32(1.0),
216 bf16::from_f32(2.0),
217 bf16::from_f32(3.0),
218 ];
219 let tensor = creation::tensor_1d(&data).expect("bf16 tensor creation failed");
220
221 assert_eq!(tensor.shape().dims(), &[3]);
222 assert_eq!(tensor.data().expect("data retrieval failed"), data);
223 }
224
225 #[test]
226 fn test_bf16_zeros_ones() {
227 let zeros = creation::zeros::<bf16>(&[2, 3]).expect("zeros creation failed");
228 assert_eq!(zeros.shape().dims(), &[2, 3]);
229
230 let zeros_data = zeros.data().expect("data retrieval failed");
231 assert!(zeros_data.iter().all(|&x| x == bf16::from_f32(0.0)));
232
233 let ones = creation::ones::<bf16>(&[2, 3]).expect("ones creation failed");
234 let ones_data = ones.data().expect("data retrieval failed");
235 assert!(ones_data.iter().all(|&x| x == bf16::from_f32(1.0)));
236 }
237
238 #[test]
239 fn test_bf16_rounding_modes() {
240 let f32_data = vec![1.5f32, 2.5f32, 3.7f32];
241
242 let nearest_even = super::creation::tensor_1d_bf16_from_f32(
244 &f32_data,
245 BF16RoundingMode::NearestTiesToEven,
246 )
247 .expect("nearest_even creation failed");
248 let nearest_away =
249 super::creation::tensor_1d_bf16_from_f32(&f32_data, BF16RoundingMode::NearestTiesAway)
250 .expect("nearest_away creation failed");
251 let toward_zero =
252 super::creation::tensor_1d_bf16_from_f32(&f32_data, BF16RoundingMode::TowardZero)
253 .expect("toward_zero creation failed");
254
255 let nearest_even_data = nearest_even.data().expect("data retrieval failed");
256 let nearest_away_data = nearest_away.data().expect("data retrieval failed");
257 let toward_zero_data = toward_zero.data().expect("data retrieval failed");
258
259 assert_eq!(
261 nearest_even_data[0],
262 bf16::from_f32_with_rounding(1.5, BF16RoundingMode::NearestTiesToEven)
263 );
264 assert_eq!(
265 nearest_away_data[0],
266 bf16::from_f32_with_rounding(1.5, BF16RoundingMode::NearestTiesAway)
267 );
268 assert_eq!(
269 toward_zero_data[0],
270 bf16::from_f32_with_rounding(1.5, BF16RoundingMode::TowardZero)
271 );
272 }
273
274 #[test]
275 fn test_bf16_arithmetic_with_rounding() {
276 let a = creation::tensor_1d(&[bf16::from_f32(1.5), bf16::from_f32(2.5)])
277 .expect("tensor creation failed");
278 let b = creation::tensor_1d(&[bf16::from_f32(0.5), bf16::from_f32(1.5)])
279 .expect("tensor creation failed");
280
281 let result = a
282 .add_with_rounding(&b, BF16RoundingMode::NearestTiesToEven)
283 .expect("add_with_rounding failed");
284 let result_data = result.data().expect("data retrieval failed");
285
286 assert_relative_eq!(result_data[0].to_f32(), 2.0, epsilon = 1e-6);
287 assert_relative_eq!(result_data[1].to_f32(), 4.0, epsilon = 1e-6);
288 }
289
290 #[test]
291 fn test_bf16_conversion() {
292 let f32_tensor =
293 creation::tensor_1d(&[1.0f32, 2.0f32, 3.0f32]).expect("tensor creation failed");
294
295 let bf16_tensor = f32_tensor
297 .to_bf16_with_rounding(BF16RoundingMode::NearestTiesToEven)
298 .expect("to_bf16 conversion failed");
299
300 let f32_converted = bf16_tensor.to_f32().expect("to_f32 conversion failed");
302 let f32_converted_data = f32_converted.data().expect("data retrieval failed");
303
304 assert_relative_eq!(f32_converted_data[0], 1.0, epsilon = 1e-2);
306 assert_relative_eq!(f32_converted_data[1], 2.0, epsilon = 1e-2);
307 assert_relative_eq!(f32_converted_data[2], 3.0, epsilon = 1e-2);
308 }
309
310 #[test]
311 fn test_bf16_high_precision_op() {
312 let bf16_tensor = creation::tensor_1d(&[bf16::from_f32(1.0), bf16::from_f32(2.0)])
313 .expect("tensor creation failed");
314
315 let result = bf16_tensor
317 .bf16_high_precision_op(|t| {
318 let doubled = t.mul_op(t)?; doubled.add_scalar(1.0) })
321 .expect("bf16_high_precision_op failed");
322
323 let result_data = result.data().expect("data retrieval failed");
324 assert_relative_eq!(result_data[0].to_f32(), 2.0, epsilon = 1e-2); assert_relative_eq!(result_data[1].to_f32(), 5.0, epsilon = 1e-2); }
327
328 #[test]
329 fn test_bf16_fma() {
330 let a = creation::tensor_1d(&[bf16::from_f32(2.0), bf16::from_f32(3.0)])
331 .expect("tensor creation failed");
332 let b = creation::tensor_1d(&[bf16::from_f32(4.0), bf16::from_f32(5.0)])
333 .expect("tensor creation failed");
334 let c = creation::tensor_1d(&[bf16::from_f32(1.0), bf16::from_f32(2.0)])
335 .expect("tensor creation failed");
336
337 let result = a
338 .fma_with_rounding(&b, &c, BF16RoundingMode::NearestTiesToEven)
339 .expect("fma_with_rounding failed");
340 let result_data = result.data().expect("data retrieval failed");
341
342 assert_relative_eq!(result_data[0].to_f32(), 9.0, epsilon = 1e-2); assert_relative_eq!(result_data[1].to_f32(), 17.0, epsilon = 1e-2); }
346
347 #[test]
348 fn test_bf16_precision_limits() {
349 let large_value = 65504.0f32; let small_value = 1e-6f32; let large_tensor = super::creation::tensor_1d_bf16_from_f32(
354 &[large_value],
355 BF16RoundingMode::NearestTiesToEven,
356 )
357 .expect("large tensor creation failed");
358 let small_tensor = super::creation::tensor_1d_bf16_from_f32(
359 &[small_value],
360 BF16RoundingMode::NearestTiesToEven,
361 )
362 .expect("small tensor creation failed");
363
364 let large_data = large_tensor.data().expect("data retrieval failed");
365 let small_data = small_tensor.data().expect("data retrieval failed");
366
367 assert!((large_data[0].to_f32() - large_value).abs() < 1000.0);
369
370 assert!(small_data[0].to_f32() >= 0.0);
372 }
373}