scirs2_linalg/quantization/
quantized_matrixfree.rs

1//! Matrix-free operations for quantized tensors
2//!
3//! This module provides matrix-free operations for quantized tensors, enabling
4//! efficient memory usage and computation for large models. It combines the benefits
5//! of quantization (reduced memory footprint) with matrix-free operations (no need to
6//! materialize large matrices).
7
8use crate::error::{LinalgError, LinalgResult};
9use crate::matrixfree::{LinearOperator, MatrixFreeOp};
10use crate::quantization::calibration::determine_data_type;
11use crate::quantization::{QuantizationMethod, QuantizationParams};
12use scirs2_core::ndarray::ScalarOperand;
13use scirs2_core::ndarray::{Array1, ArrayView1, ArrayView2};
14use scirs2_core::numeric::{AsPrimitive, Float, FromPrimitive, NumAssign, One, Zero};
15use std::fmt::Debug;
16use std::iter::Sum;
17use std::sync::Arc;
18
19/// Type alias for the matrix-vector product function
20pub type MatVecFn<F> = Arc<dyn Fn(&ArrayView1<F>) -> LinalgResult<Array1<F>> + Send + Sync>;
21
22/// A matrix-free operator that represents a quantized matrix
23pub struct QuantizedMatrixFreeOp<F>
24where
25    F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug + 'static,
26{
27    /// The shape of the matrix (rows, columns)
28    shape: (usize, usize),
29
30    /// The quantization parameters
31    params: QuantizationParams,
32
33    /// Matrix application function that doesn't require storing the full matrix
34    op_fn: MatVecFn<F>,
35
36    /// Flag indicating whether the operator is symmetric
37    symmetric: bool,
38
39    /// Flag indicating whether the operator is positive definite
40    positive_definite: bool,
41}
42
43impl<F> QuantizedMatrixFreeOp<F>
44where
45    F: Float
46        + NumAssign
47        + Zero
48        + Sum
49        + One
50        + ScalarOperand
51        + Send
52        + Sync
53        + Debug
54        + FromPrimitive
55        + AsPrimitive<f32>
56        + 'static,
57    f32: AsPrimitive<F>,
58{
59    /// Create a new quantized matrix-free operator from a function
60    ///
61    /// This allows direct specification of how the operator acts on vectors
62    /// without materializing the quantized matrix.
63    ///
64    /// # Arguments
65    ///
66    /// * `rows` - Number of rows in the matrix
67    /// * `cols` - Number of columns in the matrix
68    /// * `bits` - Bit width for quantization
69    /// * `method` - Quantization method
70    /// * `op_fn` - Function that implements the matrix-vector product in the quantized domain
71    ///
72    /// # Returns
73    ///
74    /// A new `QuantizedMatrixFreeOp` instance
75    pub fn new<O>(
76        rows: usize,
77        cols: usize,
78        bits: u8,
79        method: QuantizationMethod,
80        op_fn: O,
81    ) -> LinalgResult<Self>
82    where
83        O: Fn(&ArrayView1<F>) -> LinalgResult<Array1<F>> + Send + Sync + 'static,
84    {
85        // Create default quantization parameters - these will be refined when data is observed
86        let min_val: f32 = 0.0;
87        let max_val: f32 = 1.0;
88        let (scale, zero_point) = if method == QuantizationMethod::Symmetric {
89            let abs_max = max_val.abs().max(min_val.abs());
90            let scale = abs_max / ((1 << (bits - 1)) - 1) as f32;
91            (scale, 0)
92        } else {
93            let scale = (max_val - min_val) / ((1 << bits) - 1) as f32;
94            let zero_point = (-min_val / scale).round() as i32;
95            (scale, zero_point)
96        };
97
98        // Create the quantization parameters
99        let params = QuantizationParams {
100            bits,
101            scale,
102            zero_point,
103            min_val,
104            max_val,
105            method,
106            data_type: determine_data_type(bits),
107            channel_scales: None,
108            channel_zero_points: None,
109        };
110
111        Ok(QuantizedMatrixFreeOp {
112            shape: (rows, cols),
113            params,
114            op_fn: Arc::new(op_fn),
115            symmetric: false,
116            positive_definite: false,
117        })
118    }
119
120    /// Create a quantized matrix-free operator from an explicit matrix
121    ///
122    /// This quantizes the matrix and creates a matrix-free operator that
123    /// applies the quantized matrix without materializing it for each operation.
124    ///
125    /// # Arguments
126    ///
127    /// * `matrix` - Matrix to quantize
128    /// * `bits` - Bit width for quantization
129    /// * `method` - Quantization method
130    ///
131    /// # Returns
132    ///
133    /// A new `QuantizedMatrixFreeOp` instance
134    pub fn frommatrix(
135        matrix: &ArrayView2<F>,
136        bits: u8,
137        method: QuantizationMethod,
138    ) -> LinalgResult<Self> {
139        // Convert matrix to f32 for quantization
140        let matrix_f32: Array1<f32> = matrix.iter().map(|&x| x.as_()).collect();
141
142        // Get min/max values for quantization parameters
143        let (min_val, max_val) = if method == QuantizationMethod::Symmetric {
144            let max_abs = matrix_f32.iter().fold(0.0f32, |acc, &x| acc.max(x.abs()));
145            (-max_abs, max_abs)
146        } else {
147            let min_val = matrix_f32.iter().fold(f32::INFINITY, |acc, &x| acc.min(x));
148            let max_val = matrix_f32
149                .iter()
150                .fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
151            (min_val, max_val)
152        };
153
154        // Calculate quantization parameters
155        let (scale, zero_point) = if method == QuantizationMethod::Symmetric {
156            let abs_max = max_val.abs().max(min_val.abs());
157            let scale = abs_max / ((1 << (bits - 1)) - 1) as f32;
158            (scale, 0)
159        } else {
160            let scale = (max_val - min_val) / ((1 << bits) - 1) as f32;
161            let zero_point = (-min_val / scale).round() as i32;
162            (scale, zero_point)
163        };
164
165        // Create the quantization parameters
166        let params = QuantizationParams {
167            bits,
168            scale,
169            zero_point,
170            min_val,
171            max_val,
172            method,
173            data_type: determine_data_type(bits),
174            channel_scales: None,
175            channel_zero_points: None,
176        };
177
178        // Copy the dimensions before we move the matrix
179        let shape = matrix.dim();
180
181        // Quantize matrix ahead of time
182        let quantized_data: Vec<i8> = matrix_f32
183            .iter()
184            .map(|&val| {
185                if method == QuantizationMethod::Symmetric {
186                    (val / scale)
187                        .round()
188                        .clamp(-(1 << (bits - 1)) as f32, ((1 << (bits - 1)) - 1) as f32)
189                        as i8
190                } else {
191                    ((val / scale) + zero_point as f32)
192                        .round()
193                        .clamp(0.0, ((1 << bits) - 1) as f32) as i8
194                }
195            })
196            .collect();
197
198        // Create the matrix-vector product function
199        let op_fn = move |x: &ArrayView1<F>| -> LinalgResult<Array1<F>> {
200            if x.len() != shape.1 {
201                return Err(LinalgError::ShapeError(format!(
202                    "Input vector has wrong length: expected {}, got {}",
203                    shape.1,
204                    x.len()
205                )));
206            }
207
208            // Convert input to f32
209            let x_f32: Vec<f32> = x.iter().map(|&val| val.as_()).collect();
210
211            // Apply the quantized matrix-vector product
212            let mut result = Array1::zeros(shape.0);
213
214            // Implement matrix-vector product manually
215            for i in 0..shape.0 {
216                let mut sum = 0.0f32;
217                for j in 0..shape.1 {
218                    let q_val = quantized_data[i * shape.1 + j] as f32;
219                    let dequantized = if method == QuantizationMethod::Symmetric {
220                        q_val * scale
221                    } else {
222                        (q_val - zero_point as f32) * scale
223                    };
224                    sum += dequantized * x_f32[j];
225                }
226                result[i] = F::from_f32(sum).unwrap_or(F::zero());
227            }
228
229            Ok(result)
230        };
231
232        // Determine if the operator is symmetric
233        let symmetric = method == QuantizationMethod::Symmetric
234            && shape.0 == shape.1
235            && ismatrix_symmetric(matrix);
236
237        Ok(QuantizedMatrixFreeOp {
238            shape,
239            params,
240            op_fn: Arc::new(op_fn),
241            symmetric,
242            positive_definite: false, // We can't reliably detect this from the matrix
243        })
244    }
245
246    /// Mark the operator as symmetric
247    ///
248    /// # Returns
249    ///
250    /// Self with the symmetric flag set to true
251    pub fn symmetric(mut self) -> Self {
252        if self.shape.0 != self.shape.1 {
253            panic!("Only square operators can be symmetric");
254        }
255        self.symmetric = true;
256        self
257    }
258
259    /// Mark the operator as positive definite
260    ///
261    /// # Returns
262    ///
263    /// Self with the positive_definite flag set to true
264    pub fn positive_definite(mut self) -> Self {
265        if !self.symmetric {
266            panic!("Only symmetric operators can be positive definite");
267        }
268        self.positive_definite = true;
269        self
270    }
271
272    /// Get the quantization parameters
273    pub fn params(&self) -> &QuantizationParams {
274        &self.params
275    }
276
277    /// Create a memory-efficient operator for block-diagonal matrices
278    ///
279    /// This is particularly useful for large models with block structure,
280    /// as it avoids materializing the full matrix.
281    ///
282    /// # Arguments
283    ///
284    /// * `blocks` - A vector of smaller matrices to place on the diagonal
285    /// * `bits` - Bit width for quantization
286    /// * `method` - Quantization method
287    ///
288    /// # Returns
289    ///
290    /// A new `QuantizedMatrixFreeOp` instance
291    pub fn block_diagonal(
292        blocks: Vec<ArrayView2<F>>,
293        bits: u8,
294        method: QuantizationMethod,
295    ) -> LinalgResult<Self> {
296        if blocks.is_empty() {
297            return Err(LinalgError::ValueError("Empty blocks vector".to_string()));
298        }
299
300        // Calculate total dimensions
301        let total_rows = blocks.iter().map(|b| b.dim().0).sum();
302        let total_cols = blocks.iter().map(|b| b.dim().1).sum();
303
304        // Quantize each block separately
305        let mut block_data = Vec::new();
306
307        for block in &blocks {
308            // Convert to f32
309            let block_f32: Vec<f32> = block.iter().map(|&x| x.as_()).collect();
310
311            // Get min/max for this block
312            let (min_val, max_val) = if method == QuantizationMethod::Symmetric {
313                let max_abs = block_f32.iter().fold(0.0f32, |acc, &x| acc.max(x.abs()));
314                (-max_abs, max_abs)
315            } else {
316                let min_val = block_f32.iter().fold(f32::INFINITY, |acc, &x| acc.min(x));
317                let max_val = block_f32
318                    .iter()
319                    .fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
320                (min_val, max_val)
321            };
322
323            // Calculate quantization parameters for this block
324            let (scale, zero_point) = if method == QuantizationMethod::Symmetric {
325                let abs_max = max_val.abs().max(min_val.abs());
326                let scale = abs_max / ((1 << (bits - 1)) - 1) as f32;
327                (scale, 0)
328            } else {
329                let scale = (max_val - min_val) / ((1 << bits) - 1) as f32;
330                let zero_point = (-min_val / scale).round() as i32;
331                (scale, zero_point)
332            };
333
334            // Quantize the block
335            let quantized: Vec<i8> = block_f32
336                .iter()
337                .map(|&val| {
338                    if method == QuantizationMethod::Symmetric {
339                        (val / scale)
340                            .round()
341                            .clamp(-(1 << (bits - 1)) as f32, ((1 << (bits - 1)) - 1) as f32)
342                            as i8
343                    } else {
344                        ((val / scale) + zero_point as f32)
345                            .round()
346                            .clamp(0.0, ((1 << bits) - 1) as f32) as i8
347                    }
348                })
349                .collect();
350
351            // Store the dimensions, quantized data, and quantization parameters
352            block_data.push((block.dim(), quantized, scale, zero_point));
353        }
354
355        // Create a function that applies the block-diagonal matrix
356        let block_data_clone = block_data.clone();
357        let blocks_method = method;
358
359        let op_fn = move |x: &ArrayView1<F>| -> LinalgResult<Array1<F>> {
360            if x.len() != total_cols {
361                return Err(LinalgError::ShapeError(format!(
362                    "Input vector has wrong length: expected {}, got {}",
363                    total_cols,
364                    x.len()
365                )));
366            }
367
368            // Convert input to f32
369            let x_f32: Vec<f32> = x.iter().map(|&val| val.as_()).collect();
370
371            // Prepare result vector
372            let mut result = Array1::zeros(total_rows);
373
374            // Process each block
375            let mut row_offset = 0;
376            let mut col_offset = 0;
377
378            for (shape, quantized, scale, zero_point) in block_data_clone.iter() {
379                let block_rows = shape.0;
380                let block_cols = shape.1;
381
382                // Apply this block to the corresponding section of the input vector
383                for i in 0..block_rows {
384                    let mut sum = 0.0f32;
385                    for j in 0..block_cols {
386                        let x_idx = col_offset + j;
387                        if x_idx < x_f32.len() {
388                            let q_val = quantized[i * block_cols + j] as f32;
389                            let dequantized = if blocks_method == QuantizationMethod::Symmetric {
390                                q_val * (*scale)
391                            } else {
392                                (q_val - (*zero_point) as f32) * (*scale)
393                            };
394                            sum += dequantized * x_f32[x_idx];
395                        }
396                    }
397
398                    let result_idx = row_offset + i;
399                    if result_idx < result.len() {
400                        result[result_idx] = F::from_f32(sum).unwrap_or(F::zero());
401                    }
402                }
403
404                row_offset += block_rows;
405                col_offset += block_cols;
406            }
407
408            Ok(result)
409        };
410
411        // Calculate global min/max values for our parameters
412        let global_min_val = block_data
413            .iter()
414            .map(|(_, _, scale, zero_point)| {
415                if method == QuantizationMethod::Symmetric {
416                    -(*scale) * ((1 << (bits - 1)) - 1) as f32
417                } else {
418                    -(*zero_point) as f32 * (*scale)
419                }
420            })
421            .fold(f32::INFINITY, |a, b| a.min(b));
422
423        let global_max_val = block_data
424            .iter()
425            .map(|(_, _, scale_, _)| {
426                if method == QuantizationMethod::Symmetric {
427                    (*scale_) * ((1 << (bits - 1)) - 1) as f32
428                } else {
429                    (*scale_) * ((1 << bits) - 1) as f32
430                }
431            })
432            .fold(f32::NEG_INFINITY, |a, b| a.max(b));
433
434        // Create the quantization parameters
435        let (scale, zero_point) = if method == QuantizationMethod::Symmetric {
436            let abs_max = global_max_val.abs().max(global_min_val.abs());
437            let scale = abs_max / ((1 << (bits - 1)) - 1) as f32;
438            (scale, 0)
439        } else {
440            let scale = (global_max_val - global_min_val) / ((1 << bits) - 1) as f32;
441            let zero_point = (-global_min_val / scale).round() as i32;
442            (scale, zero_point)
443        };
444
445        let params = QuantizationParams {
446            bits,
447            scale,
448            zero_point,
449            min_val: global_min_val,
450            max_val: global_max_val,
451            method,
452            data_type: determine_data_type(bits),
453            channel_scales: None,
454            channel_zero_points: None,
455        };
456
457        // Check if all blocks are square and the operator could be symmetric
458        let all_square = blocks.iter().all(|b| b.dim().0 == b.dim().1);
459        let symmetric = method == QuantizationMethod::Symmetric && all_square;
460
461        Ok(QuantizedMatrixFreeOp {
462            shape: (total_rows, total_cols),
463            params,
464            op_fn: Arc::new(op_fn),
465            symmetric,
466            positive_definite: false,
467        })
468    }
469
470    /// Create a memory-efficient operator for structured sparse matrices
471    ///
472    /// This is particularly useful for large sparse models, as it stores
473    /// only the non-zero elements and their indices.
474    ///
475    /// # Arguments
476    ///
477    /// * `rows` - Number of rows in the matrix
478    /// * `cols` - Number of columns in the matrix
479    /// * `indices` - Pairs of (row, column) indices for non-zero elements
480    /// * `values` - Values at the corresponding indices
481    /// * `bits` - Bit width for quantization
482    /// * `method` - Quantization method
483    ///
484    /// # Returns
485    ///
486    /// A new `QuantizedMatrixFreeOp` instance
487    pub fn sparse(
488        rows: usize,
489        cols: usize,
490        indices: Vec<(usize, usize)>,
491        values: &ArrayView1<F>,
492        bits: u8,
493        method: QuantizationMethod,
494    ) -> LinalgResult<Self> {
495        if indices.len() != values.len() {
496            return Err(LinalgError::ShapeError(
497                "Indices and values must have the same length".to_string(),
498            ));
499        }
500
501        // Validate indices
502        for &(i, j) in &indices {
503            if i >= rows || j >= cols {
504                return Err(LinalgError::ShapeError(format!(
505                    "Index ({i}, {j}) out of bounds for matrix of shape ({rows}, {cols})"
506                )));
507            }
508        }
509
510        // Convert values to f32 for quantization
511        let values_f32: Vec<f32> = values.iter().map(|&val| val.as_()).collect();
512
513        // Get min/max values for quantization parameters
514        let (min_val, max_val) = if method == QuantizationMethod::Symmetric {
515            let max_abs = values_f32.iter().fold(0.0f32, |acc, &x| acc.max(x.abs()));
516            (-max_abs, max_abs)
517        } else {
518            let min_val = values_f32.iter().fold(f32::INFINITY, |acc, &x| acc.min(x));
519            let max_val = values_f32
520                .iter()
521                .fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
522            (min_val, max_val)
523        };
524
525        // Calculate quantization parameters
526        let (scale, zero_point) = if method == QuantizationMethod::Symmetric {
527            let abs_max = max_val.abs().max(min_val.abs());
528            let scale = abs_max / ((1 << (bits - 1)) - 1) as f32;
529            (scale, 0)
530        } else {
531            let scale = (max_val - min_val) / ((1 << bits) - 1) as f32;
532            let zero_point = (-min_val / scale).round() as i32;
533            (scale, zero_point)
534        };
535
536        // Quantize the values
537        let quantized_data: Vec<i8> = values_f32
538            .iter()
539            .map(|&val| {
540                if method == QuantizationMethod::Symmetric {
541                    (val / scale)
542                        .round()
543                        .clamp(-(1 << (bits - 1)) as f32, ((1 << (bits - 1)) - 1) as f32)
544                        as i8
545                } else {
546                    ((val / scale) + zero_point as f32)
547                        .round()
548                        .clamp(0.0, ((1 << bits) - 1) as f32) as i8
549                }
550            })
551            .collect();
552
553        // Create a copy of indices for the closure
554        let indices_owned = indices.clone();
555        let sparse_method = method;
556
557        // Create a function that applies the sparse matrix
558        let op_fn = move |x: &ArrayView1<F>| -> LinalgResult<Array1<F>> {
559            if x.len() != cols {
560                return Err(LinalgError::ShapeError(format!(
561                    "Input vector has wrong length: expected {}, got {}",
562                    cols,
563                    x.len()
564                )));
565            }
566
567            // Convert input to f32
568            let x_f32: Vec<f32> = x.iter().map(|&val| val.as_()).collect();
569
570            // Prepare result vector
571            let mut result = Array1::zeros(rows);
572
573            // Apply the sparse matrix
574            for (idx, &(i, j)) in indices_owned.iter().enumerate() {
575                if idx < quantized_data.len() {
576                    let q_val = quantized_data[idx] as f32;
577                    let dequantized = if sparse_method == QuantizationMethod::Symmetric {
578                        q_val * scale
579                    } else {
580                        (q_val - zero_point as f32) * scale
581                    };
582
583                    result[i] += F::from_f32(dequantized * x_f32[j]).unwrap_or(F::zero());
584                }
585            }
586
587            Ok(result)
588        };
589
590        // Create quantization parameters
591        let params = QuantizationParams {
592            bits,
593            scale,
594            zero_point,
595            min_val,
596            max_val,
597            method,
598            data_type: determine_data_type(bits),
599            channel_scales: None,
600            channel_zero_points: None,
601        };
602
603        // Check if the matrix could be symmetric
604        let symmetric = rows == cols
605            && method == QuantizationMethod::Symmetric
606            && indices
607                .iter()
608                .all(|&(i, j)| i == j || indices.contains(&(j, i)));
609
610        Ok(QuantizedMatrixFreeOp {
611            shape: (rows, cols),
612            params,
613            op_fn: Arc::new(op_fn),
614            symmetric,
615            positive_definite: false,
616        })
617    }
618
619    /// Create a memory-efficient operator for a banded matrix
620    ///
621    /// This is particularly useful for banded matrices like tridiagonal
622    /// or pentadiagonal matrices, as it stores only the bands.
623    ///
624    /// # Arguments
625    ///
626    /// * `n` - Size of the square matrix
627    /// * `bands` - Vector of (offset, band_values) pairs, where offset is the diagonal offset
628    ///   (0 for main diagonal, 1 for first super-diagonal, -1 for first sub-diagonal)
629    /// * `bits` - Bit width for quantization
630    /// * `method` - Quantization method
631    ///
632    /// # Returns
633    ///
634    /// A new `QuantizedMatrixFreeOp` instance
635    pub fn banded(
636        n: usize,
637        bands: Vec<(isize, ArrayView1<F>)>,
638        bits: u8,
639        method: QuantizationMethod,
640    ) -> LinalgResult<Self> {
641        // Validate bands
642        for &(offset, ref band) in &bands {
643            let expected_len = n - offset.unsigned_abs();
644            if band.len() != expected_len {
645                return Err(LinalgError::ShapeError(format!(
646                    "Band with offset {} should have length {}, got {}",
647                    offset,
648                    expected_len,
649                    band.len()
650                )));
651            }
652        }
653
654        // Quantize each band
655        let mut band_data = Vec::new();
656
657        for (offset, band) in &bands {
658            // Convert to f32
659            let band_f32: Vec<f32> = band.iter().map(|&x| x.as_()).collect();
660
661            // Get min/max for this band
662            let (min_val, max_val) = if method == QuantizationMethod::Symmetric {
663                let max_abs = band_f32.iter().fold(0.0f32, |acc, &x| acc.max(x.abs()));
664                (-max_abs, max_abs)
665            } else {
666                let min_val = band_f32.iter().fold(f32::INFINITY, |acc, &x| acc.min(x));
667                let max_val = band_f32
668                    .iter()
669                    .fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
670                (min_val, max_val)
671            };
672
673            // Calculate quantization parameters for this band
674            let (scale, zero_point) = if method == QuantizationMethod::Symmetric {
675                let abs_max = max_val.abs().max(min_val.abs());
676                let scale = abs_max / ((1 << (bits - 1)) - 1) as f32;
677                (scale, 0)
678            } else {
679                let scale = (max_val - min_val) / ((1 << bits) - 1) as f32;
680                let zero_point = (-min_val / scale).round() as i32;
681                (scale, zero_point)
682            };
683
684            // Quantize the band
685            let quantized: Vec<i8> = band_f32
686                .iter()
687                .map(|&val| {
688                    if method == QuantizationMethod::Symmetric {
689                        (val / scale)
690                            .round()
691                            .clamp(-(1 << (bits - 1)) as f32, ((1 << (bits - 1)) - 1) as f32)
692                            as i8
693                    } else {
694                        ((val / scale) + zero_point as f32)
695                            .round()
696                            .clamp(0.0, ((1 << bits) - 1) as f32) as i8
697                    }
698                })
699                .collect();
700
701            // Store the offset, quantized data, and quantization parameters
702            band_data.push((*offset, quantized, scale, zero_point));
703        }
704
705        // Create a function that applies the banded matrix
706        let band_data_clone = band_data.clone();
707        let banded_method = method;
708
709        let op_fn = move |x: &ArrayView1<F>| -> LinalgResult<Array1<F>> {
710            if x.len() != n {
711                return Err(LinalgError::ShapeError(format!(
712                    "Expected vector of length {}, got {}",
713                    n,
714                    x.len()
715                )));
716            }
717
718            // Convert input to f32
719            let x_f32: Vec<f32> = x.iter().map(|&val| val.as_()).collect();
720
721            // Prepare result vector
722            let mut result = Array1::zeros(n);
723
724            // Apply each band
725            for (offset, quantized, scale, zero_point) in &band_data_clone {
726                let band_len = quantized.len();
727
728                if *offset >= 0 {
729                    // Super-diagonal or main diagonal
730                    let offset_usize = *offset as usize;
731                    for i in 0..band_len {
732                        if i < n && (i + offset_usize) < n {
733                            let q_val = quantized[i] as f32;
734                            let dequantized = if banded_method == QuantizationMethod::Symmetric {
735                                q_val * (*scale)
736                            } else {
737                                (q_val - (*zero_point) as f32) * (*scale)
738                            };
739
740                            result[i] += F::from_f32(dequantized * x_f32[i + offset_usize])
741                                .unwrap_or(F::zero());
742                        }
743                    }
744                } else {
745                    // Sub-diagonal
746                    let offset_usize = (-*offset) as usize;
747                    for i in 0..band_len {
748                        if (i + offset_usize) < n && i < n {
749                            let q_val = quantized[i] as f32;
750                            let dequantized = if banded_method == QuantizationMethod::Symmetric {
751                                q_val * (*scale)
752                            } else {
753                                (q_val - (*zero_point) as f32) * (*scale)
754                            };
755
756                            result[i + offset_usize] +=
757                                F::from_f32(dequantized * x_f32[i]).unwrap_or(F::zero());
758                        }
759                    }
760                }
761            }
762
763            Ok(result)
764        };
765
766        // Calculate global min/max values for our parameters
767        let global_min_val = band_data
768            .iter()
769            .map(|(_, _, scale, zero_point)| {
770                if method == QuantizationMethod::Symmetric {
771                    -(*scale) * ((1 << (bits - 1)) - 1) as f32
772                } else {
773                    -(*zero_point) as f32 * (*scale)
774                }
775            })
776            .fold(f32::INFINITY, |a, b| a.min(b));
777
778        let global_max_val = band_data
779            .iter()
780            .map(|(_, _, scale_, _)| {
781                if method == QuantizationMethod::Symmetric {
782                    (*scale_) * ((1 << (bits - 1)) - 1) as f32
783                } else {
784                    (*scale_) * ((1 << bits) - 1) as f32
785                }
786            })
787            .fold(f32::NEG_INFINITY, |a, b| a.max(b));
788
789        // Create the quantization parameters
790        let params = QuantizationParams {
791            bits,
792            scale: 1.0, // These are placeholder values since we store per-band parameters
793            zero_point: 0,
794            min_val: global_min_val,
795            max_val: global_max_val,
796            method,
797            data_type: determine_data_type(bits),
798            channel_scales: None,
799            channel_zero_points: None,
800        };
801
802        // Check if the matrix could be symmetric
803        let symmetric = method == QuantizationMethod::Symmetric
804            && band_data.iter().all(|(offset, _, _, _)| {
805                // For a symmetric banded matrix, if there's a band at offset k,
806                // there must also be a band at offset -k
807                *offset == 0 || band_data.iter().any(|(o, _, _, _)| *o == -*offset)
808            });
809
810        Ok(QuantizedMatrixFreeOp {
811            shape: (n, n),
812            params,
813            op_fn: Arc::new(op_fn),
814            symmetric,
815            positive_definite: false,
816        })
817    }
818}
819
820impl<F> MatrixFreeOp<F> for QuantizedMatrixFreeOp<F>
821where
822    F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug + 'static,
823{
824    fn apply(&self, x: &ArrayView1<F>) -> LinalgResult<Array1<F>> {
825        if x.len() != self.shape.1 {
826            return Err(LinalgError::ShapeError(format!(
827                "Input vector has wrong length: expected {}, got {}",
828                self.shape.1,
829                x.len()
830            )));
831        }
832        (self.op_fn)(x)
833    }
834
835    fn nrows(&self) -> usize {
836        self.shape.0
837    }
838
839    fn ncols(&self) -> usize {
840        self.shape.1
841    }
842
843    fn is_symmetric(&self) -> bool {
844        self.symmetric
845    }
846
847    fn is_positive_definite(&self) -> bool {
848        self.positive_definite
849    }
850}
851
852/// Convert a QuantizedMatrixFreeOp to a generic LinearOperator
853///
854/// This is useful when you want to use the quantized operator with
855/// algorithms that expect a LinearOperator.
856///
857/// # Arguments
858///
859/// * `op` - The quantized matrix-free operator
860///
861/// # Returns
862///
863/// A LinearOperator that wraps the quantized operator
864#[allow(dead_code)]
865pub fn quantized_to_linear_operator<F>(op: &QuantizedMatrixFreeOp<F>) -> LinearOperator<F>
866where
867    F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug + 'static,
868{
869    let rows = op.nrows();
870    let cols = op.ncols();
871    let is_symmetric = op.is_symmetric();
872    let is_positive_definite = op.is_positive_definite();
873
874    // We need to clone op.op_fn, but can't directly due to the trait bound
875    // So we create a new closure that delegates to the original
876    let op_clone = op.clone();
877
878    let linear_op = if rows == cols {
879        LinearOperator::new(rows, move |x: &ArrayView1<F>| match op_clone.apply(x) {
880            Ok(result) => result,
881            Err(_) => Array1::zeros(rows),
882        })
883    } else {
884        LinearOperator::new_rectangular(rows, cols, move |x: &ArrayView1<F>| {
885            match op_clone.apply(x) {
886                Ok(result) => result,
887                Err(_) => Array1::zeros(rows),
888            }
889        })
890    };
891
892    // Add flags if applicable
893    if is_symmetric {
894        let linear_op = linear_op.symmetric();
895        if is_positive_definite {
896            linear_op.positive_definite()
897        } else {
898            linear_op
899        }
900    } else {
901        linear_op
902    }
903}
904
905// Add clone support to QuantizedMatrixFreeOp
906impl<F> Clone for QuantizedMatrixFreeOp<F>
907where
908    F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug + 'static,
909{
910    fn clone(&self) -> Self {
911        QuantizedMatrixFreeOp {
912            shape: self.shape,
913            params: self.params.clone(),
914            op_fn: Arc::clone(&self.op_fn),
915            symmetric: self.symmetric,
916            positive_definite: self.positive_definite,
917        }
918    }
919}
920
921/// Check if a matrix is symmetric
922#[allow(dead_code)]
923fn ismatrix_symmetric<F>(matrix: &ArrayView2<F>) -> bool
924where
925    F: Float + PartialEq,
926{
927    let (rows, cols) = matrix.dim();
928    if rows != cols {
929        return false;
930    }
931
932    for i in 0..rows {
933        for j in i + 1..cols {
934            if matrix[[i, j]] != matrix[[j, i]] {
935                return false;
936            }
937        }
938    }
939
940    true
941}
942
943#[cfg(test)]
944mod tests {
945    use super::*;
946    use approx::assert_relative_eq;
947    use scirs2_core::ndarray::array;
948
949    #[test]
950    fn test_quantizedmatrix_free_op_frommatrix() {
951        // Create a test matrix
952        let matrix = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
953
954        // Create a quantized matrix-free operator
955        let op =
956            QuantizedMatrixFreeOp::frommatrix(&matrix.view(), 8, QuantizationMethod::Symmetric)
957                .unwrap();
958
959        // Apply to a vector
960        let x = array![1.0f32, 2.0, 3.0];
961        let y = op.apply(&x.view()).unwrap();
962
963        // Compute expected result with regular matrix multiplication
964        let expected = matrix.dot(&x);
965
966        // Check that the results are close
967        assert_eq!(y.len(), expected.len());
968        for i in 0..y.len() {
969            assert_relative_eq!(y[i], expected[i], epsilon = 1.0);
970        }
971    }
972
973    #[test]
974    fn test_quantizedmatrix_free_op_block_diagonal() {
975        // Create test matrices for the blocks
976        let block1 = array![[1.0f32, 2.0], [3.0, 4.0]];
977
978        let block2 = array![[5.0f32]];
979
980        // Create a block-diagonal operator
981        let op = QuantizedMatrixFreeOp::block_diagonal(
982            vec![block1.view(), block2.view()],
983            8,
984            QuantizationMethod::Symmetric,
985        )
986        .unwrap();
987
988        // Apply to a vector
989        let x = array![1.0f32, 2.0, 3.0];
990        let y = op.apply(&x.view()).unwrap();
991
992        // Expected result would be:
993        // [ block1[0,0]*x[0] + block1[0,1]*x[1], block1[1,0]*x[0] + block1[1,1]*x[1], block2[0,0]*x[2] ]
994        // = [ 1.0*1.0 + 2.0*2.0, 3.0*1.0 + 4.0*2.0, 5.0*3.0 ]
995        // = [ 5.0, 11.0, 15.0 ]
996        let expected = array![5.0f32, 11.0, 15.0];
997
998        assert_eq!(y.len(), expected.len());
999        for i in 0..y.len() {
1000            assert_relative_eq!(y[i], expected[i], epsilon = 1.0);
1001        }
1002    }
1003
1004    #[test]
1005    fn test_quantizedmatrix_free_op_sparse() {
1006        // Create a sparse matrix:
1007        // [ 1.0 0.0 2.0 ]
1008        // [ 0.0 3.0 0.0 ]
1009        // [ 4.0 0.0 5.0 ]
1010        let indices = vec![(0, 0), (0, 2), (1, 1), (2, 0), (2, 2)];
1011        let values = array![1.0f32, 2.0, 3.0, 4.0, 5.0];
1012
1013        let op = QuantizedMatrixFreeOp::sparse(
1014            3,
1015            3,
1016            indices,
1017            &values.view(),
1018            8,
1019            QuantizationMethod::Symmetric,
1020        )
1021        .unwrap();
1022
1023        // Apply to a vector
1024        let x = array![1.0f32, 2.0, 3.0];
1025        let y = op.apply(&x.view()).unwrap();
1026
1027        // Expected result:
1028        // [ 1.0*1.0 + 0.0*2.0 + 2.0*3.0, 0.0*1.0 + 3.0*2.0 + 0.0*3.0, 4.0*1.0 + 0.0*2.0 + 5.0*3.0 ]
1029        // = [ 1.0 + 6.0, 6.0, 4.0 + 15.0 ]
1030        // = [ 7.0, 6.0, 19.0 ]
1031        let expected = array![7.0f32, 6.0, 19.0];
1032
1033        assert_eq!(y.len(), expected.len());
1034        for i in 0..y.len() {
1035            assert_relative_eq!(y[i], expected[i], epsilon = 1.0);
1036        }
1037    }
1038
1039    #[test]
1040    fn test_quantizedmatrix_free_op_banded() {
1041        // Create a tridiagonal matrix:
1042        // [ 2.0 1.0 0.0 ]
1043        // [ 1.0 3.0 1.0 ]
1044        // [ 0.0 1.0 4.0 ]
1045
1046        let main_diag = array![2.0f32, 3.0, 4.0];
1047        let super_diag = array![1.0f32, 1.0];
1048        let sub_diag = array![1.0f32, 1.0];
1049
1050        let bands = vec![
1051            (0, main_diag.view()),
1052            (1, super_diag.view()),
1053            (-1, sub_diag.view()),
1054        ];
1055
1056        let op = QuantizedMatrixFreeOp::banded(3, bands, 8, QuantizationMethod::Symmetric).unwrap();
1057
1058        // Apply to a vector
1059        let x = array![1.0f32, 2.0, 3.0];
1060        let y = op.apply(&x.view()).unwrap();
1061
1062        // Expected result:
1063        // [ 2.0*1.0 + 1.0*2.0, 1.0*1.0 + 3.0*2.0 + 1.0*3.0, 1.0*2.0 + 4.0*3.0 ]
1064        // = [ 2.0 + 2.0, 1.0 + 6.0 + 3.0, 2.0 + 12.0 ]
1065        // = [ 4.0, 10.0, 14.0 ]
1066        let expected = array![4.0f32, 10.0, 14.0];
1067
1068        assert_eq!(y.len(), expected.len());
1069        for i in 0..y.len() {
1070            assert_relative_eq!(y[i], expected[i], epsilon = 1.0);
1071        }
1072    }
1073
1074    #[test]
1075    fn test_quantized_to_linear_operator() {
1076        // Create a test matrix
1077        let matrix = array![[1.0f32, 2.0], [2.0, 3.0]];
1078
1079        // Create a symmetric quantized matrix-free operator
1080        let quantized_op =
1081            QuantizedMatrixFreeOp::frommatrix(&matrix.view(), 8, QuantizationMethod::Symmetric)
1082                .unwrap()
1083                .symmetric()
1084                .positive_definite();
1085
1086        // Convert to a LinearOperator
1087        let linear_op = quantized_to_linear_operator(&quantized_op);
1088
1089        // Check that properties are preserved
1090        assert_eq!(linear_op.nrows(), quantized_op.nrows());
1091        assert_eq!(linear_op.ncols(), quantized_op.ncols());
1092        assert_eq!(linear_op.is_symmetric(), quantized_op.is_symmetric());
1093        assert_eq!(
1094            linear_op.is_positive_definite(),
1095            quantized_op.is_positive_definite()
1096        );
1097
1098        // Apply to a vector and check results
1099        let x = array![1.0f32, 2.0];
1100        let y_quantized = quantized_op.apply(&x.view()).unwrap();
1101        let y_linear = linear_op.apply(&x.view()).unwrap();
1102
1103        assert_eq!(y_quantized.len(), y_linear.len());
1104        for i in 0..y_quantized.len() {
1105            assert_relative_eq!(y_quantized[i], y_linear[i], epsilon = 1e-6);
1106        }
1107    }
1108}