1use crate::error::{LinalgError, LinalgResult};
9use crate::matrixfree::{LinearOperator, MatrixFreeOp};
10use crate::quantization::calibration::determine_data_type;
11use crate::quantization::{QuantizationMethod, QuantizationParams};
12use scirs2_core::ndarray::ScalarOperand;
13use scirs2_core::ndarray::{Array1, ArrayView1, ArrayView2};
14use scirs2_core::numeric::{AsPrimitive, Float, FromPrimitive, NumAssign, One, Zero};
15use std::fmt::Debug;
16use std::iter::Sum;
17use std::sync::Arc;
18
19pub type MatVecFn<F> = Arc<dyn Fn(&ArrayView1<F>) -> LinalgResult<Array1<F>> + Send + Sync>;
21
22pub struct QuantizedMatrixFreeOp<F>
24where
25 F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug + 'static,
26{
27 shape: (usize, usize),
29
30 params: QuantizationParams,
32
33 op_fn: MatVecFn<F>,
35
36 symmetric: bool,
38
39 positive_definite: bool,
41}
42
43impl<F> QuantizedMatrixFreeOp<F>
44where
45 F: Float
46 + NumAssign
47 + Zero
48 + Sum
49 + One
50 + ScalarOperand
51 + Send
52 + Sync
53 + Debug
54 + FromPrimitive
55 + AsPrimitive<f32>
56 + 'static,
57 f32: AsPrimitive<F>,
58{
59 pub fn new<O>(
76 rows: usize,
77 cols: usize,
78 bits: u8,
79 method: QuantizationMethod,
80 op_fn: O,
81 ) -> LinalgResult<Self>
82 where
83 O: Fn(&ArrayView1<F>) -> LinalgResult<Array1<F>> + Send + Sync + 'static,
84 {
85 let min_val: f32 = 0.0;
87 let max_val: f32 = 1.0;
88 let (scale, zero_point) = if method == QuantizationMethod::Symmetric {
89 let abs_max = max_val.abs().max(min_val.abs());
90 let scale = abs_max / ((1 << (bits - 1)) - 1) as f32;
91 (scale, 0)
92 } else {
93 let scale = (max_val - min_val) / ((1 << bits) - 1) as f32;
94 let zero_point = (-min_val / scale).round() as i32;
95 (scale, zero_point)
96 };
97
98 let params = QuantizationParams {
100 bits,
101 scale,
102 zero_point,
103 min_val,
104 max_val,
105 method,
106 data_type: determine_data_type(bits),
107 channel_scales: None,
108 channel_zero_points: None,
109 };
110
111 Ok(QuantizedMatrixFreeOp {
112 shape: (rows, cols),
113 params,
114 op_fn: Arc::new(op_fn),
115 symmetric: false,
116 positive_definite: false,
117 })
118 }
119
120 pub fn frommatrix(
135 matrix: &ArrayView2<F>,
136 bits: u8,
137 method: QuantizationMethod,
138 ) -> LinalgResult<Self> {
139 let matrix_f32: Array1<f32> = matrix.iter().map(|&x| x.as_()).collect();
141
142 let (min_val, max_val) = if method == QuantizationMethod::Symmetric {
144 let max_abs = matrix_f32.iter().fold(0.0f32, |acc, &x| acc.max(x.abs()));
145 (-max_abs, max_abs)
146 } else {
147 let min_val = matrix_f32.iter().fold(f32::INFINITY, |acc, &x| acc.min(x));
148 let max_val = matrix_f32
149 .iter()
150 .fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
151 (min_val, max_val)
152 };
153
154 let (scale, zero_point) = if method == QuantizationMethod::Symmetric {
156 let abs_max = max_val.abs().max(min_val.abs());
157 let scale = abs_max / ((1 << (bits - 1)) - 1) as f32;
158 (scale, 0)
159 } else {
160 let scale = (max_val - min_val) / ((1 << bits) - 1) as f32;
161 let zero_point = (-min_val / scale).round() as i32;
162 (scale, zero_point)
163 };
164
165 let params = QuantizationParams {
167 bits,
168 scale,
169 zero_point,
170 min_val,
171 max_val,
172 method,
173 data_type: determine_data_type(bits),
174 channel_scales: None,
175 channel_zero_points: None,
176 };
177
178 let shape = matrix.dim();
180
181 let quantized_data: Vec<i8> = matrix_f32
183 .iter()
184 .map(|&val| {
185 if method == QuantizationMethod::Symmetric {
186 (val / scale)
187 .round()
188 .clamp(-(1 << (bits - 1)) as f32, ((1 << (bits - 1)) - 1) as f32)
189 as i8
190 } else {
191 ((val / scale) + zero_point as f32)
192 .round()
193 .clamp(0.0, ((1 << bits) - 1) as f32) as i8
194 }
195 })
196 .collect();
197
198 let op_fn = move |x: &ArrayView1<F>| -> LinalgResult<Array1<F>> {
200 if x.len() != shape.1 {
201 return Err(LinalgError::ShapeError(format!(
202 "Input vector has wrong length: expected {}, got {}",
203 shape.1,
204 x.len()
205 )));
206 }
207
208 let x_f32: Vec<f32> = x.iter().map(|&val| val.as_()).collect();
210
211 let mut result = Array1::zeros(shape.0);
213
214 for i in 0..shape.0 {
216 let mut sum = 0.0f32;
217 for j in 0..shape.1 {
218 let q_val = quantized_data[i * shape.1 + j] as f32;
219 let dequantized = if method == QuantizationMethod::Symmetric {
220 q_val * scale
221 } else {
222 (q_val - zero_point as f32) * scale
223 };
224 sum += dequantized * x_f32[j];
225 }
226 result[i] = F::from_f32(sum).unwrap_or(F::zero());
227 }
228
229 Ok(result)
230 };
231
232 let symmetric = method == QuantizationMethod::Symmetric
234 && shape.0 == shape.1
235 && ismatrix_symmetric(matrix);
236
237 Ok(QuantizedMatrixFreeOp {
238 shape,
239 params,
240 op_fn: Arc::new(op_fn),
241 symmetric,
242 positive_definite: false, })
244 }
245
246 pub fn symmetric(mut self) -> Self {
252 if self.shape.0 != self.shape.1 {
253 panic!("Only square operators can be symmetric");
254 }
255 self.symmetric = true;
256 self
257 }
258
259 pub fn positive_definite(mut self) -> Self {
265 if !self.symmetric {
266 panic!("Only symmetric operators can be positive definite");
267 }
268 self.positive_definite = true;
269 self
270 }
271
272 pub fn params(&self) -> &QuantizationParams {
274 &self.params
275 }
276
277 pub fn block_diagonal(
292 blocks: Vec<ArrayView2<F>>,
293 bits: u8,
294 method: QuantizationMethod,
295 ) -> LinalgResult<Self> {
296 if blocks.is_empty() {
297 return Err(LinalgError::ValueError("Empty blocks vector".to_string()));
298 }
299
300 let total_rows = blocks.iter().map(|b| b.dim().0).sum();
302 let total_cols = blocks.iter().map(|b| b.dim().1).sum();
303
304 let mut block_data = Vec::new();
306
307 for block in &blocks {
308 let block_f32: Vec<f32> = block.iter().map(|&x| x.as_()).collect();
310
311 let (min_val, max_val) = if method == QuantizationMethod::Symmetric {
313 let max_abs = block_f32.iter().fold(0.0f32, |acc, &x| acc.max(x.abs()));
314 (-max_abs, max_abs)
315 } else {
316 let min_val = block_f32.iter().fold(f32::INFINITY, |acc, &x| acc.min(x));
317 let max_val = block_f32
318 .iter()
319 .fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
320 (min_val, max_val)
321 };
322
323 let (scale, zero_point) = if method == QuantizationMethod::Symmetric {
325 let abs_max = max_val.abs().max(min_val.abs());
326 let scale = abs_max / ((1 << (bits - 1)) - 1) as f32;
327 (scale, 0)
328 } else {
329 let scale = (max_val - min_val) / ((1 << bits) - 1) as f32;
330 let zero_point = (-min_val / scale).round() as i32;
331 (scale, zero_point)
332 };
333
334 let quantized: Vec<i8> = block_f32
336 .iter()
337 .map(|&val| {
338 if method == QuantizationMethod::Symmetric {
339 (val / scale)
340 .round()
341 .clamp(-(1 << (bits - 1)) as f32, ((1 << (bits - 1)) - 1) as f32)
342 as i8
343 } else {
344 ((val / scale) + zero_point as f32)
345 .round()
346 .clamp(0.0, ((1 << bits) - 1) as f32) as i8
347 }
348 })
349 .collect();
350
351 block_data.push((block.dim(), quantized, scale, zero_point));
353 }
354
355 let block_data_clone = block_data.clone();
357 let blocks_method = method;
358
359 let op_fn = move |x: &ArrayView1<F>| -> LinalgResult<Array1<F>> {
360 if x.len() != total_cols {
361 return Err(LinalgError::ShapeError(format!(
362 "Input vector has wrong length: expected {}, got {}",
363 total_cols,
364 x.len()
365 )));
366 }
367
368 let x_f32: Vec<f32> = x.iter().map(|&val| val.as_()).collect();
370
371 let mut result = Array1::zeros(total_rows);
373
374 let mut row_offset = 0;
376 let mut col_offset = 0;
377
378 for (shape, quantized, scale, zero_point) in block_data_clone.iter() {
379 let block_rows = shape.0;
380 let block_cols = shape.1;
381
382 for i in 0..block_rows {
384 let mut sum = 0.0f32;
385 for j in 0..block_cols {
386 let x_idx = col_offset + j;
387 if x_idx < x_f32.len() {
388 let q_val = quantized[i * block_cols + j] as f32;
389 let dequantized = if blocks_method == QuantizationMethod::Symmetric {
390 q_val * (*scale)
391 } else {
392 (q_val - (*zero_point) as f32) * (*scale)
393 };
394 sum += dequantized * x_f32[x_idx];
395 }
396 }
397
398 let result_idx = row_offset + i;
399 if result_idx < result.len() {
400 result[result_idx] = F::from_f32(sum).unwrap_or(F::zero());
401 }
402 }
403
404 row_offset += block_rows;
405 col_offset += block_cols;
406 }
407
408 Ok(result)
409 };
410
411 let global_min_val = block_data
413 .iter()
414 .map(|(_, _, scale, zero_point)| {
415 if method == QuantizationMethod::Symmetric {
416 -(*scale) * ((1 << (bits - 1)) - 1) as f32
417 } else {
418 -(*zero_point) as f32 * (*scale)
419 }
420 })
421 .fold(f32::INFINITY, |a, b| a.min(b));
422
423 let global_max_val = block_data
424 .iter()
425 .map(|(_, _, scale_, _)| {
426 if method == QuantizationMethod::Symmetric {
427 (*scale_) * ((1 << (bits - 1)) - 1) as f32
428 } else {
429 (*scale_) * ((1 << bits) - 1) as f32
430 }
431 })
432 .fold(f32::NEG_INFINITY, |a, b| a.max(b));
433
434 let (scale, zero_point) = if method == QuantizationMethod::Symmetric {
436 let abs_max = global_max_val.abs().max(global_min_val.abs());
437 let scale = abs_max / ((1 << (bits - 1)) - 1) as f32;
438 (scale, 0)
439 } else {
440 let scale = (global_max_val - global_min_val) / ((1 << bits) - 1) as f32;
441 let zero_point = (-global_min_val / scale).round() as i32;
442 (scale, zero_point)
443 };
444
445 let params = QuantizationParams {
446 bits,
447 scale,
448 zero_point,
449 min_val: global_min_val,
450 max_val: global_max_val,
451 method,
452 data_type: determine_data_type(bits),
453 channel_scales: None,
454 channel_zero_points: None,
455 };
456
457 let all_square = blocks.iter().all(|b| b.dim().0 == b.dim().1);
459 let symmetric = method == QuantizationMethod::Symmetric && all_square;
460
461 Ok(QuantizedMatrixFreeOp {
462 shape: (total_rows, total_cols),
463 params,
464 op_fn: Arc::new(op_fn),
465 symmetric,
466 positive_definite: false,
467 })
468 }
469
470 pub fn sparse(
488 rows: usize,
489 cols: usize,
490 indices: Vec<(usize, usize)>,
491 values: &ArrayView1<F>,
492 bits: u8,
493 method: QuantizationMethod,
494 ) -> LinalgResult<Self> {
495 if indices.len() != values.len() {
496 return Err(LinalgError::ShapeError(
497 "Indices and values must have the same length".to_string(),
498 ));
499 }
500
501 for &(i, j) in &indices {
503 if i >= rows || j >= cols {
504 return Err(LinalgError::ShapeError(format!(
505 "Index ({i}, {j}) out of bounds for matrix of shape ({rows}, {cols})"
506 )));
507 }
508 }
509
510 let values_f32: Vec<f32> = values.iter().map(|&val| val.as_()).collect();
512
513 let (min_val, max_val) = if method == QuantizationMethod::Symmetric {
515 let max_abs = values_f32.iter().fold(0.0f32, |acc, &x| acc.max(x.abs()));
516 (-max_abs, max_abs)
517 } else {
518 let min_val = values_f32.iter().fold(f32::INFINITY, |acc, &x| acc.min(x));
519 let max_val = values_f32
520 .iter()
521 .fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
522 (min_val, max_val)
523 };
524
525 let (scale, zero_point) = if method == QuantizationMethod::Symmetric {
527 let abs_max = max_val.abs().max(min_val.abs());
528 let scale = abs_max / ((1 << (bits - 1)) - 1) as f32;
529 (scale, 0)
530 } else {
531 let scale = (max_val - min_val) / ((1 << bits) - 1) as f32;
532 let zero_point = (-min_val / scale).round() as i32;
533 (scale, zero_point)
534 };
535
536 let quantized_data: Vec<i8> = values_f32
538 .iter()
539 .map(|&val| {
540 if method == QuantizationMethod::Symmetric {
541 (val / scale)
542 .round()
543 .clamp(-(1 << (bits - 1)) as f32, ((1 << (bits - 1)) - 1) as f32)
544 as i8
545 } else {
546 ((val / scale) + zero_point as f32)
547 .round()
548 .clamp(0.0, ((1 << bits) - 1) as f32) as i8
549 }
550 })
551 .collect();
552
553 let indices_owned = indices.clone();
555 let sparse_method = method;
556
557 let op_fn = move |x: &ArrayView1<F>| -> LinalgResult<Array1<F>> {
559 if x.len() != cols {
560 return Err(LinalgError::ShapeError(format!(
561 "Input vector has wrong length: expected {}, got {}",
562 cols,
563 x.len()
564 )));
565 }
566
567 let x_f32: Vec<f32> = x.iter().map(|&val| val.as_()).collect();
569
570 let mut result = Array1::zeros(rows);
572
573 for (idx, &(i, j)) in indices_owned.iter().enumerate() {
575 if idx < quantized_data.len() {
576 let q_val = quantized_data[idx] as f32;
577 let dequantized = if sparse_method == QuantizationMethod::Symmetric {
578 q_val * scale
579 } else {
580 (q_val - zero_point as f32) * scale
581 };
582
583 result[i] += F::from_f32(dequantized * x_f32[j]).unwrap_or(F::zero());
584 }
585 }
586
587 Ok(result)
588 };
589
590 let params = QuantizationParams {
592 bits,
593 scale,
594 zero_point,
595 min_val,
596 max_val,
597 method,
598 data_type: determine_data_type(bits),
599 channel_scales: None,
600 channel_zero_points: None,
601 };
602
603 let symmetric = rows == cols
605 && method == QuantizationMethod::Symmetric
606 && indices
607 .iter()
608 .all(|&(i, j)| i == j || indices.contains(&(j, i)));
609
610 Ok(QuantizedMatrixFreeOp {
611 shape: (rows, cols),
612 params,
613 op_fn: Arc::new(op_fn),
614 symmetric,
615 positive_definite: false,
616 })
617 }
618
619 pub fn banded(
636 n: usize,
637 bands: Vec<(isize, ArrayView1<F>)>,
638 bits: u8,
639 method: QuantizationMethod,
640 ) -> LinalgResult<Self> {
641 for &(offset, ref band) in &bands {
643 let expected_len = n - offset.unsigned_abs();
644 if band.len() != expected_len {
645 return Err(LinalgError::ShapeError(format!(
646 "Band with offset {} should have length {}, got {}",
647 offset,
648 expected_len,
649 band.len()
650 )));
651 }
652 }
653
654 let mut band_data = Vec::new();
656
657 for (offset, band) in &bands {
658 let band_f32: Vec<f32> = band.iter().map(|&x| x.as_()).collect();
660
661 let (min_val, max_val) = if method == QuantizationMethod::Symmetric {
663 let max_abs = band_f32.iter().fold(0.0f32, |acc, &x| acc.max(x.abs()));
664 (-max_abs, max_abs)
665 } else {
666 let min_val = band_f32.iter().fold(f32::INFINITY, |acc, &x| acc.min(x));
667 let max_val = band_f32
668 .iter()
669 .fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
670 (min_val, max_val)
671 };
672
673 let (scale, zero_point) = if method == QuantizationMethod::Symmetric {
675 let abs_max = max_val.abs().max(min_val.abs());
676 let scale = abs_max / ((1 << (bits - 1)) - 1) as f32;
677 (scale, 0)
678 } else {
679 let scale = (max_val - min_val) / ((1 << bits) - 1) as f32;
680 let zero_point = (-min_val / scale).round() as i32;
681 (scale, zero_point)
682 };
683
684 let quantized: Vec<i8> = band_f32
686 .iter()
687 .map(|&val| {
688 if method == QuantizationMethod::Symmetric {
689 (val / scale)
690 .round()
691 .clamp(-(1 << (bits - 1)) as f32, ((1 << (bits - 1)) - 1) as f32)
692 as i8
693 } else {
694 ((val / scale) + zero_point as f32)
695 .round()
696 .clamp(0.0, ((1 << bits) - 1) as f32) as i8
697 }
698 })
699 .collect();
700
701 band_data.push((*offset, quantized, scale, zero_point));
703 }
704
705 let band_data_clone = band_data.clone();
707 let banded_method = method;
708
709 let op_fn = move |x: &ArrayView1<F>| -> LinalgResult<Array1<F>> {
710 if x.len() != n {
711 return Err(LinalgError::ShapeError(format!(
712 "Expected vector of length {}, got {}",
713 n,
714 x.len()
715 )));
716 }
717
718 let x_f32: Vec<f32> = x.iter().map(|&val| val.as_()).collect();
720
721 let mut result = Array1::zeros(n);
723
724 for (offset, quantized, scale, zero_point) in &band_data_clone {
726 let band_len = quantized.len();
727
728 if *offset >= 0 {
729 let offset_usize = *offset as usize;
731 for i in 0..band_len {
732 if i < n && (i + offset_usize) < n {
733 let q_val = quantized[i] as f32;
734 let dequantized = if banded_method == QuantizationMethod::Symmetric {
735 q_val * (*scale)
736 } else {
737 (q_val - (*zero_point) as f32) * (*scale)
738 };
739
740 result[i] += F::from_f32(dequantized * x_f32[i + offset_usize])
741 .unwrap_or(F::zero());
742 }
743 }
744 } else {
745 let offset_usize = (-*offset) as usize;
747 for i in 0..band_len {
748 if (i + offset_usize) < n && i < n {
749 let q_val = quantized[i] as f32;
750 let dequantized = if banded_method == QuantizationMethod::Symmetric {
751 q_val * (*scale)
752 } else {
753 (q_val - (*zero_point) as f32) * (*scale)
754 };
755
756 result[i + offset_usize] +=
757 F::from_f32(dequantized * x_f32[i]).unwrap_or(F::zero());
758 }
759 }
760 }
761 }
762
763 Ok(result)
764 };
765
766 let global_min_val = band_data
768 .iter()
769 .map(|(_, _, scale, zero_point)| {
770 if method == QuantizationMethod::Symmetric {
771 -(*scale) * ((1 << (bits - 1)) - 1) as f32
772 } else {
773 -(*zero_point) as f32 * (*scale)
774 }
775 })
776 .fold(f32::INFINITY, |a, b| a.min(b));
777
778 let global_max_val = band_data
779 .iter()
780 .map(|(_, _, scale_, _)| {
781 if method == QuantizationMethod::Symmetric {
782 (*scale_) * ((1 << (bits - 1)) - 1) as f32
783 } else {
784 (*scale_) * ((1 << bits) - 1) as f32
785 }
786 })
787 .fold(f32::NEG_INFINITY, |a, b| a.max(b));
788
789 let params = QuantizationParams {
791 bits,
792 scale: 1.0, zero_point: 0,
794 min_val: global_min_val,
795 max_val: global_max_val,
796 method,
797 data_type: determine_data_type(bits),
798 channel_scales: None,
799 channel_zero_points: None,
800 };
801
802 let symmetric = method == QuantizationMethod::Symmetric
804 && band_data.iter().all(|(offset, _, _, _)| {
805 *offset == 0 || band_data.iter().any(|(o, _, _, _)| *o == -*offset)
808 });
809
810 Ok(QuantizedMatrixFreeOp {
811 shape: (n, n),
812 params,
813 op_fn: Arc::new(op_fn),
814 symmetric,
815 positive_definite: false,
816 })
817 }
818}
819
820impl<F> MatrixFreeOp<F> for QuantizedMatrixFreeOp<F>
821where
822 F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug + 'static,
823{
824 fn apply(&self, x: &ArrayView1<F>) -> LinalgResult<Array1<F>> {
825 if x.len() != self.shape.1 {
826 return Err(LinalgError::ShapeError(format!(
827 "Input vector has wrong length: expected {}, got {}",
828 self.shape.1,
829 x.len()
830 )));
831 }
832 (self.op_fn)(x)
833 }
834
835 fn nrows(&self) -> usize {
836 self.shape.0
837 }
838
839 fn ncols(&self) -> usize {
840 self.shape.1
841 }
842
843 fn is_symmetric(&self) -> bool {
844 self.symmetric
845 }
846
847 fn is_positive_definite(&self) -> bool {
848 self.positive_definite
849 }
850}
851
852#[allow(dead_code)]
865pub fn quantized_to_linear_operator<F>(op: &QuantizedMatrixFreeOp<F>) -> LinearOperator<F>
866where
867 F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug + 'static,
868{
869 let rows = op.nrows();
870 let cols = op.ncols();
871 let is_symmetric = op.is_symmetric();
872 let is_positive_definite = op.is_positive_definite();
873
874 let op_clone = op.clone();
877
878 let linear_op = if rows == cols {
879 LinearOperator::new(rows, move |x: &ArrayView1<F>| match op_clone.apply(x) {
880 Ok(result) => result,
881 Err(_) => Array1::zeros(rows),
882 })
883 } else {
884 LinearOperator::new_rectangular(rows, cols, move |x: &ArrayView1<F>| {
885 match op_clone.apply(x) {
886 Ok(result) => result,
887 Err(_) => Array1::zeros(rows),
888 }
889 })
890 };
891
892 if is_symmetric {
894 let linear_op = linear_op.symmetric();
895 if is_positive_definite {
896 linear_op.positive_definite()
897 } else {
898 linear_op
899 }
900 } else {
901 linear_op
902 }
903}
904
905impl<F> Clone for QuantizedMatrixFreeOp<F>
907where
908 F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug + 'static,
909{
910 fn clone(&self) -> Self {
911 QuantizedMatrixFreeOp {
912 shape: self.shape,
913 params: self.params.clone(),
914 op_fn: Arc::clone(&self.op_fn),
915 symmetric: self.symmetric,
916 positive_definite: self.positive_definite,
917 }
918 }
919}
920
921#[allow(dead_code)]
923fn ismatrix_symmetric<F>(matrix: &ArrayView2<F>) -> bool
924where
925 F: Float + PartialEq,
926{
927 let (rows, cols) = matrix.dim();
928 if rows != cols {
929 return false;
930 }
931
932 for i in 0..rows {
933 for j in i + 1..cols {
934 if matrix[[i, j]] != matrix[[j, i]] {
935 return false;
936 }
937 }
938 }
939
940 true
941}
942
943#[cfg(test)]
944mod tests {
945 use super::*;
946 use approx::assert_relative_eq;
947 use scirs2_core::ndarray::array;
948
949 #[test]
950 fn test_quantizedmatrix_free_op_frommatrix() {
951 let matrix = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
953
954 let op =
956 QuantizedMatrixFreeOp::frommatrix(&matrix.view(), 8, QuantizationMethod::Symmetric)
957 .unwrap();
958
959 let x = array![1.0f32, 2.0, 3.0];
961 let y = op.apply(&x.view()).unwrap();
962
963 let expected = matrix.dot(&x);
965
966 assert_eq!(y.len(), expected.len());
968 for i in 0..y.len() {
969 assert_relative_eq!(y[i], expected[i], epsilon = 1.0);
970 }
971 }
972
973 #[test]
974 fn test_quantizedmatrix_free_op_block_diagonal() {
975 let block1 = array![[1.0f32, 2.0], [3.0, 4.0]];
977
978 let block2 = array![[5.0f32]];
979
980 let op = QuantizedMatrixFreeOp::block_diagonal(
982 vec![block1.view(), block2.view()],
983 8,
984 QuantizationMethod::Symmetric,
985 )
986 .unwrap();
987
988 let x = array![1.0f32, 2.0, 3.0];
990 let y = op.apply(&x.view()).unwrap();
991
992 let expected = array![5.0f32, 11.0, 15.0];
997
998 assert_eq!(y.len(), expected.len());
999 for i in 0..y.len() {
1000 assert_relative_eq!(y[i], expected[i], epsilon = 1.0);
1001 }
1002 }
1003
1004 #[test]
1005 fn test_quantizedmatrix_free_op_sparse() {
1006 let indices = vec![(0, 0), (0, 2), (1, 1), (2, 0), (2, 2)];
1011 let values = array![1.0f32, 2.0, 3.0, 4.0, 5.0];
1012
1013 let op = QuantizedMatrixFreeOp::sparse(
1014 3,
1015 3,
1016 indices,
1017 &values.view(),
1018 8,
1019 QuantizationMethod::Symmetric,
1020 )
1021 .unwrap();
1022
1023 let x = array![1.0f32, 2.0, 3.0];
1025 let y = op.apply(&x.view()).unwrap();
1026
1027 let expected = array![7.0f32, 6.0, 19.0];
1032
1033 assert_eq!(y.len(), expected.len());
1034 for i in 0..y.len() {
1035 assert_relative_eq!(y[i], expected[i], epsilon = 1.0);
1036 }
1037 }
1038
1039 #[test]
1040 fn test_quantizedmatrix_free_op_banded() {
1041 let main_diag = array![2.0f32, 3.0, 4.0];
1047 let super_diag = array![1.0f32, 1.0];
1048 let sub_diag = array![1.0f32, 1.0];
1049
1050 let bands = vec![
1051 (0, main_diag.view()),
1052 (1, super_diag.view()),
1053 (-1, sub_diag.view()),
1054 ];
1055
1056 let op = QuantizedMatrixFreeOp::banded(3, bands, 8, QuantizationMethod::Symmetric).unwrap();
1057
1058 let x = array![1.0f32, 2.0, 3.0];
1060 let y = op.apply(&x.view()).unwrap();
1061
1062 let expected = array![4.0f32, 10.0, 14.0];
1067
1068 assert_eq!(y.len(), expected.len());
1069 for i in 0..y.len() {
1070 assert_relative_eq!(y[i], expected[i], epsilon = 1.0);
1071 }
1072 }
1073
1074 #[test]
1075 fn test_quantized_to_linear_operator() {
1076 let matrix = array![[1.0f32, 2.0], [2.0, 3.0]];
1078
1079 let quantized_op =
1081 QuantizedMatrixFreeOp::frommatrix(&matrix.view(), 8, QuantizationMethod::Symmetric)
1082 .unwrap()
1083 .symmetric()
1084 .positive_definite();
1085
1086 let linear_op = quantized_to_linear_operator(&quantized_op);
1088
1089 assert_eq!(linear_op.nrows(), quantized_op.nrows());
1091 assert_eq!(linear_op.ncols(), quantized_op.ncols());
1092 assert_eq!(linear_op.is_symmetric(), quantized_op.is_symmetric());
1093 assert_eq!(
1094 linear_op.is_positive_definite(),
1095 quantized_op.is_positive_definite()
1096 );
1097
1098 let x = array![1.0f32, 2.0];
1100 let y_quantized = quantized_op.apply(&x.view()).unwrap();
1101 let y_linear = linear_op.apply(&x.view()).unwrap();
1102
1103 assert_eq!(y_quantized.len(), y_linear.len());
1104 for i in 0..y_quantized.len() {
1105 assert_relative_eq!(y_quantized[i], y_linear[i], epsilon = 1e-6);
1106 }
1107 }
1108}