1use scirs2_core::parallel_ops::*;
15use torsh_core::error::{Result as TorshResult, TorshError};
16
17pub fn quantize_per_tensor_affine_simd(
19 input: &[f32],
20 scale: f32,
21 zero_point: i32,
22 output: &mut [f32],
23) -> TorshResult<()> {
24 if input.len() != output.len() {
25 return Err(TorshError::InvalidArgument(
26 "Input and output length mismatch".to_string(),
27 ));
28 }
29
30 if scale <= 0.0 {
31 return Err(TorshError::InvalidArgument(
32 "Scale must be positive".to_string(),
33 ));
34 }
35
36 let inv_scale = 1.0 / scale;
37 let zero_point_f32 = zero_point as f32;
38
39 input
41 .par_iter()
42 .zip(output.par_iter_mut())
43 .for_each(|(&x, out)| {
44 let quantized = (x * inv_scale).round() + zero_point_f32;
45 *out = quantized.clamp(-128.0, 127.0);
46 });
47
48 Ok(())
49}
50
51pub fn dequantize_per_tensor_affine_simd(
53 input: &[f32],
54 scale: f32,
55 zero_point: i32,
56 output: &mut [f32],
57) -> TorshResult<()> {
58 if input.len() != output.len() {
59 return Err(TorshError::InvalidArgument(
60 "Input and output length mismatch".to_string(),
61 ));
62 }
63
64 if scale <= 0.0 {
65 return Err(TorshError::InvalidArgument(
66 "Scale must be positive".to_string(),
67 ));
68 }
69
70 let zero_point_f32 = zero_point as f32;
71
72 input
74 .par_iter()
75 .zip(output.par_iter_mut())
76 .for_each(|(&x, out)| {
77 *out = (x - zero_point_f32) * scale;
78 });
79
80 Ok(())
81}
82
83pub fn find_min_max_simd(data: &[f32]) -> TorshResult<(f32, f32)> {
85 if data.is_empty() {
86 return Err(TorshError::InvalidArgument(
87 "Cannot find min/max of empty array".to_string(),
88 ));
89 }
90
91 const CHUNK_SIZE: usize = 1024; let (min_val, max_val) = if data.len() > CHUNK_SIZE {
94 data.par_chunks(CHUNK_SIZE)
96 .map(|chunk| {
97 let mut local_min = f32::INFINITY;
98 let mut local_max = f32::NEG_INFINITY;
99 for &val in chunk {
100 local_min = local_min.min(val);
101 local_max = local_max.max(val);
102 }
103 (local_min, local_max)
104 })
105 .reduce(
106 || (f32::INFINITY, f32::NEG_INFINITY),
107 |(min1, max1), (min2, max2)| (min1.min(min2), max1.max(max2)),
108 )
109 } else {
110 let mut min_val = f32::INFINITY;
112 let mut max_val = f32::NEG_INFINITY;
113 for &val in data {
114 min_val = min_val.min(val);
115 max_val = max_val.max(val);
116 }
117 (min_val, max_val)
118 };
119
120 Ok((min_val, max_val))
121}
122
123pub fn quantize_per_channel_simd(
125 input: &[f32],
126 scales: &[f32],
127 zero_points: &[i32],
128 channel_size: usize,
129 output: &mut [f32],
130) -> TorshResult<()> {
131 if input.len() != output.len() {
132 return Err(TorshError::InvalidArgument(
133 "Input and output length mismatch".to_string(),
134 ));
135 }
136
137 let num_channels = scales.len();
138 if num_channels != zero_points.len() {
139 return Err(TorshError::InvalidArgument(
140 "Scales and zero_points length mismatch".to_string(),
141 ));
142 }
143
144 if input.len() != num_channels * channel_size {
145 return Err(TorshError::InvalidArgument(
146 "Input size does not match channel configuration".to_string(),
147 ));
148 }
149
150 for (ch, (&scale, &zero_point)) in scales.iter().zip(zero_points.iter()).enumerate() {
152 if scale <= 0.0 {
153 return Err(TorshError::InvalidArgument(format!(
154 "Scale for channel {} must be positive",
155 ch
156 )));
157 }
158
159 let channel_start = ch * channel_size;
160 let channel_end = channel_start + channel_size;
161
162 let input_slice = &input[channel_start..channel_end];
163 let output_slice = &mut output[channel_start..channel_end];
164
165 quantize_per_tensor_affine_simd(input_slice, scale, zero_point, output_slice)?;
166 }
167
168 Ok(())
169}
170
171pub fn quantize_batch_consistent_simd(
173 tensors: &[&[f32]],
174 scale: f32,
175 zero_point: i32,
176 outputs: &mut [&mut [f32]],
177) -> TorshResult<()> {
178 if tensors.len() != outputs.len() {
179 return Err(TorshError::InvalidArgument(
180 "Number of input tensors must match output tensors".to_string(),
181 ));
182 }
183
184 tensors
186 .par_iter()
187 .zip(outputs.par_iter_mut())
188 .try_for_each(|(input, output)| {
189 quantize_per_tensor_affine_simd(input, scale, zero_point, output)
190 })?;
191
192 Ok(())
193}
194
195pub fn quantize_to_int8_simd(
197 input: &[f32],
198 scale: f32,
199 zero_point: i32,
200 output: &mut [i8],
201) -> TorshResult<()> {
202 if input.len() != output.len() {
203 return Err(TorshError::InvalidArgument(
204 "Input and output length mismatch".to_string(),
205 ));
206 }
207
208 if scale <= 0.0 {
209 return Err(TorshError::InvalidArgument(
210 "Scale must be positive".to_string(),
211 ));
212 }
213
214 let inv_scale = 1.0 / scale;
215 let zero_point_f32 = zero_point as f32;
216
217 input
219 .par_iter()
220 .zip(output.par_iter_mut())
221 .for_each(|(&x, out)| {
222 let quantized = (x * inv_scale).round() + zero_point_f32;
223 *out = quantized.clamp(-128.0, 127.0) as i8;
224 });
225
226 Ok(())
227}
228
229pub fn calculate_tensor_stats_simd(data: &[f32]) -> TorshResult<TensorStats> {
231 if data.is_empty() {
232 return Err(TorshError::InvalidArgument(
233 "Cannot calculate stats of empty tensor".to_string(),
234 ));
235 }
236
237 let (min_val, max_val) = find_min_max_simd(data)?;
238
239 let sum: f64 = data.par_iter().map(|&x| x as f64).sum();
241 let mean = sum / data.len() as f64;
242
243 let variance_sum: f64 = data
245 .par_iter()
246 .map(|&x| {
247 let diff = x as f64 - mean;
248 diff * diff
249 })
250 .sum();
251 let variance = variance_sum / data.len() as f64;
252 let std_dev = variance.sqrt();
253
254 Ok(TensorStats {
255 min: min_val,
256 max: max_val,
257 mean: mean as f32,
258 std_dev: std_dev as f32,
259 variance: variance as f32,
260 })
261}
262
263#[derive(Debug, Clone)]
265pub struct TensorStats {
266 pub min: f32,
267 pub max: f32,
268 pub mean: f32,
269 pub std_dev: f32,
270 pub variance: f32,
271}
272
273pub fn is_simd_available() -> bool {
275 cfg!(any(
277 target_feature = "avx512f",
278 target_feature = "avx2",
279 target_feature = "avx",
280 target_feature = "sse2",
281 target_feature = "neon" ))
283}
284
285pub fn get_simd_width() -> usize {
287 if cfg!(target_feature = "avx512f") {
290 16 } else if cfg!(target_feature = "avx2") {
292 8 } else if cfg!(any(target_feature = "sse2", target_feature = "neon")) {
294 4 } else {
296 1 }
298}
299
300#[cfg(target_arch = "aarch64")]
302pub fn quantize_neon_optimized(
303 input: &[f32],
304 scale: f32,
305 zero_point: i32,
306 output: &mut [f32],
307) -> TorshResult<()> {
308 if input.len() != output.len() {
309 return Err(TorshError::InvalidArgument(
310 "Input and output length mismatch".to_string(),
311 ));
312 }
313
314 if scale <= 0.0 {
315 return Err(TorshError::InvalidArgument(
316 "Scale must be positive".to_string(),
317 ));
318 }
319
320 let inv_scale = 1.0 / scale;
321 let zero_point_f32 = zero_point as f32;
322
323 const NEON_WIDTH: usize = 4;
325 let chunks = input.len() / NEON_WIDTH;
326
327 for i in 0..chunks {
329 let start = i * NEON_WIDTH;
330 let end = start + NEON_WIDTH;
331
332 for (&inp, out) in input[start..end].iter().zip(output[start..end].iter_mut()) {
334 let quantized = (inp * inv_scale).round() + zero_point_f32;
335 *out = quantized.clamp(-128.0, 127.0);
336 }
337 }
338
339 let remainder_start = chunks * NEON_WIDTH;
341 for (&inp, out) in input[remainder_start..]
342 .iter()
343 .zip(output[remainder_start..].iter_mut())
344 {
345 let quantized = (inp * inv_scale).round() + zero_point_f32;
346 *out = quantized.clamp(-128.0, 127.0);
347 }
348
349 Ok(())
350}
351
352#[cfg(target_arch = "aarch64")]
354pub fn find_min_max_neon(data: &[f32]) -> TorshResult<(f32, f32)> {
355 if data.is_empty() {
356 return Err(TorshError::InvalidArgument(
357 "Cannot find min/max of empty array".to_string(),
358 ));
359 }
360
361 const NEON_WIDTH: usize = 4;
362 let chunks = data.len() / NEON_WIDTH;
363
364 let mut min_val = f32::INFINITY;
365 let mut max_val = f32::NEG_INFINITY;
366
367 for i in 0..chunks {
369 let start = i * NEON_WIDTH;
370 let end = start + NEON_WIDTH;
371
372 for &val in &data[start..end] {
374 min_val = min_val.min(val);
375 max_val = max_val.max(val);
376 }
377 }
378
379 let remainder_start = chunks * NEON_WIDTH;
381 for &val in &data[remainder_start..] {
382 min_val = min_val.min(val);
383 max_val = max_val.max(val);
384 }
385
386 Ok((min_val, max_val))
387}
388
389pub fn quantize_mobile_optimized(
391 input: &[f32],
392 scale: f32,
393 zero_point: i32,
394 output: &mut [i8],
395 use_reduced_precision: bool,
396) -> TorshResult<()> {
397 if input.len() != output.len() {
398 return Err(TorshError::InvalidArgument(
399 "Input and output length mismatch".to_string(),
400 ));
401 }
402
403 if scale <= 0.0 {
404 return Err(TorshError::InvalidArgument(
405 "Scale must be positive".to_string(),
406 ));
407 }
408
409 let inv_scale = if use_reduced_precision {
410 1.0 / scale
412 } else {
413 1.0 / scale
414 };
415
416 let zero_point_f32 = zero_point as f32;
417
418 const MOBILE_CHUNK_SIZE: usize = 256;
420
421 if input.len() > MOBILE_CHUNK_SIZE {
422 input
424 .chunks(MOBILE_CHUNK_SIZE)
425 .zip(output.chunks_mut(MOBILE_CHUNK_SIZE))
426 .for_each(|(input_chunk, output_chunk)| {
427 for (&x, out) in input_chunk.iter().zip(output_chunk.iter_mut()) {
428 let quantized = if use_reduced_precision {
429 (x * inv_scale + 0.5).floor() + zero_point_f32
431 } else {
432 (x * inv_scale).round() + zero_point_f32
433 };
434 *out = quantized.clamp(-128.0, 127.0) as i8;
435 }
436 });
437 } else {
438 for (&x, out) in input.iter().zip(output.iter_mut()) {
440 let quantized = (x * inv_scale).round() + zero_point_f32;
441 *out = quantized.clamp(-128.0, 127.0) as i8;
442 }
443 }
444
445 Ok(())
446}
447
448pub fn get_mobile_optimization_hints() -> MobileOptimizationHints {
450 MobileOptimizationHints {
451 prefer_int8: true,
452 use_reduced_precision: cfg!(target_os = "android") || cfg!(target_os = "ios"),
453 optimal_chunk_size: if cfg!(target_arch = "aarch64") {
454 256
455 } else {
456 512
457 },
458 enable_fast_math: true,
459 prefer_sequential: false, }
461}
462
463#[derive(Debug, Clone)]
465pub struct MobileOptimizationHints {
466 pub prefer_int8: bool,
467 pub use_reduced_precision: bool,
468 pub optimal_chunk_size: usize,
469 pub enable_fast_math: bool,
470 pub prefer_sequential: bool,
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476 use approx::assert_relative_eq;
477
478 #[test]
479 fn test_quantize_per_tensor_affine_simd() {
480 let input = vec![1.0, 2.0, 3.0, 4.0];
481 let mut output = vec![0.0; 4];
482
483 quantize_per_tensor_affine_simd(&input, 0.1, 0, &mut output).unwrap();
484
485 assert_relative_eq!(output[0], 10.0, epsilon = 1e-6);
486 assert_relative_eq!(output[1], 20.0, epsilon = 1e-6);
487 assert_relative_eq!(output[2], 30.0, epsilon = 1e-6);
488 assert_relative_eq!(output[3], 40.0, epsilon = 1e-6);
489 }
490
491 #[test]
492 fn test_dequantize_per_tensor_affine_simd() {
493 let input = vec![10.0, 20.0, 30.0, 40.0];
494 let mut output = vec![0.0; 4];
495
496 dequantize_per_tensor_affine_simd(&input, 0.1, 0, &mut output).unwrap();
497
498 assert_relative_eq!(output[0], 1.0, epsilon = 1e-6);
499 assert_relative_eq!(output[1], 2.0, epsilon = 1e-6);
500 assert_relative_eq!(output[2], 3.0, epsilon = 1e-6);
501 assert_relative_eq!(output[3], 4.0, epsilon = 1e-6);
502 }
503
504 #[test]
505 fn test_find_min_max_simd() {
506 let data = vec![-1.5, 0.0, 2.3, -0.8, 4.7, 1.2];
507 let (min_val, max_val) = find_min_max_simd(&data).unwrap();
508
509 assert_relative_eq!(min_val, -1.5, epsilon = 1e-6);
510 assert_relative_eq!(max_val, 4.7, epsilon = 1e-6);
511 }
512
513 #[test]
514 fn test_calculate_tensor_stats_simd() {
515 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
516 let stats = calculate_tensor_stats_simd(&data).unwrap();
517
518 assert_relative_eq!(stats.min, 1.0, epsilon = 1e-6);
519 assert_relative_eq!(stats.max, 5.0, epsilon = 1e-6);
520 assert_relative_eq!(stats.mean, 3.0, epsilon = 1e-6);
521 assert_relative_eq!(stats.std_dev, (2.0f64).sqrt() as f32, epsilon = 1e-4);
522 }
523
524 #[test]
525 fn test_quantize_to_int8_simd() {
526 let input = vec![1.0, 2.0, 3.0, 4.0];
527 let mut output = vec![0i8; 4];
528
529 quantize_to_int8_simd(&input, 0.1, 0, &mut output).unwrap();
530
531 assert_eq!(output[0], 10i8);
532 assert_eq!(output[1], 20i8);
533 assert_eq!(output[2], 30i8);
534 assert_eq!(output[3], 40i8);
535 }
536
537 #[test]
538 fn test_error_cases() {
539 let input = vec![1.0, 2.0];
540 let mut output = vec![0.0; 3]; let result = quantize_per_tensor_affine_simd(&input, 0.1, 0, &mut output);
543 assert!(result.is_err());
544
545 let mut output_correct = vec![0.0; 2];
546 let result = quantize_per_tensor_affine_simd(&input, -0.1, 0, &mut output_correct);
547 assert!(result.is_err());
548
549 let empty_data: Vec<f32> = vec![];
550 let result = find_min_max_simd(&empty_data);
551 assert!(result.is_err());
552 }
553
554 #[test]
555 fn test_simd_availability() {
556 let available = is_simd_available();
557 let width = get_simd_width();
558
559 assert!(width >= 1); if available {
565 assert!(width > 1); }
567 }
568
569 #[test]
570 fn test_mobile_optimized_quantization() {
571 let input = vec![1.0, 2.0, 3.0, 4.0, -1.0, -2.0];
572 let mut output = vec![0i8; 6];
573
574 quantize_mobile_optimized(&input, 0.1, 0, &mut output, false).unwrap();
575
576 assert_eq!(output[0], 10i8);
577 assert_eq!(output[1], 20i8);
578 assert_eq!(output[2], 30i8);
579 assert_eq!(output[3], 40i8);
580 assert_eq!(output[4], -10i8);
581 assert_eq!(output[5], -20i8);
582 }
583
584 #[test]
585 fn test_mobile_optimized_quantization_reduced_precision() {
586 let input = vec![1.0, 2.0, 3.0, 4.0];
587 let mut output = vec![0i8; 4];
588
589 quantize_mobile_optimized(&input, 0.1, 0, &mut output, true).unwrap();
591
592 assert!((output[0] as f32 - 10.0).abs() <= 1.0);
594 assert!((output[1] as f32 - 20.0).abs() <= 1.0);
595 }
596
597 #[cfg(target_arch = "aarch64")]
598 #[test]
599 fn test_neon_quantization() {
600 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
601 let mut output = vec![0.0; 8];
602
603 quantize_neon_optimized(&input, 0.1, 0, &mut output).unwrap();
604
605 assert_relative_eq!(output[0], 10.0, epsilon = 1e-6);
606 assert_relative_eq!(output[1], 20.0, epsilon = 1e-6);
607 assert_relative_eq!(output[7], 80.0, epsilon = 1e-6);
608 }
609
610 #[cfg(target_arch = "aarch64")]
611 #[test]
612 fn test_neon_min_max() {
613 let data = vec![-1.5, 0.0, 2.3, -0.8, 4.7, 1.2, 9.5, -2.1];
614 let (min_val, max_val) = find_min_max_neon(&data).unwrap();
615
616 assert_relative_eq!(min_val, -2.1, epsilon = 1e-6);
617 assert_relative_eq!(max_val, 9.5, epsilon = 1e-6);
618 }
619
620 #[test]
621 fn test_mobile_optimization_hints() {
622 let hints = get_mobile_optimization_hints();
623
624 assert!(hints.prefer_int8); assert!(hints.optimal_chunk_size > 0); assert_eq!(hints.prefer_sequential, false); }
628}