Skip to main content

tenflowers_core/
quantization.rs

1//! Quantization operations for TenfloweRS
2//!
3//! This module provides quantization and dequantization operations for INT8/INT4
4//! to enable efficient inference and reduced memory usage.
5
6use crate::{DType, Result, Tensor, TensorError};
7use scirs2_core::num_traits::cast::cast;
8use scirs2_core::numeric::Float;
9
10#[cfg(feature = "gpu")]
11use crate::device::context::get_gpu_context;
12#[cfg(feature = "gpu")]
13use crate::gpu::buffer::GpuBuffer;
14#[cfg(feature = "gpu")]
15use bytemuck;
16#[cfg(feature = "gpu")]
17use scirs2_core::ndarray::Array1;
18#[cfg(feature = "gpu")]
19use wgpu::util::DeviceExt;
20
21/// Quantization parameters for symmetric and asymmetric quantization
22#[derive(Debug, Clone)]
23pub struct QuantizationParams {
24    /// Scale factor for quantization (real_value = scale * quantized_value + zero_point)
25    pub scale: f32,
26    /// Zero point for asymmetric quantization (usually 0 for symmetric)
27    pub zero_point: i32,
28    /// Quantized data type (Int8 or Int4)
29    pub dtype: DType,
30    /// Minimum quantized value (e.g., -128 for INT8)
31    pub qmin: i32,
32    /// Maximum quantized value (e.g., 127 for INT8)
33    pub qmax: i32,
34}
35
36impl QuantizationParams {
37    /// Create symmetric INT8 quantization parameters
38    pub fn symmetric_int8(scale: f32) -> Self {
39        Self {
40            scale,
41            zero_point: 0,
42            dtype: DType::Int8,
43            qmin: -128,
44            qmax: 127,
45        }
46    }
47
48    /// Create asymmetric INT8 quantization parameters  
49    pub fn asymmetric_int8(scale: f32, zero_point: i32) -> Self {
50        Self {
51            scale,
52            zero_point,
53            dtype: DType::Int8,
54            qmin: -128,
55            qmax: 127,
56        }
57    }
58
59    /// Create symmetric INT4 quantization parameters
60    pub fn symmetric_int4(scale: f32) -> Self {
61        Self {
62            scale,
63            zero_point: 0,
64            dtype: DType::Int4,
65            qmin: -8,
66            qmax: 7,
67        }
68    }
69
70    /// Create asymmetric INT4 quantization parameters
71    pub fn asymmetric_int4(scale: f32, zero_point: i32) -> Self {
72        Self {
73            scale,
74            zero_point,
75            dtype: DType::Int4,
76            qmin: -8,
77            qmax: 7,
78        }
79    }
80
81    /// Calculate quantization parameters from tensor statistics
82    pub fn from_tensor_stats(min_val: f32, max_val: f32, dtype: DType) -> Result<Self> {
83        let (qmin, qmax) = match dtype {
84            DType::Int8 => (-128, 127),
85            DType::Int4 => (-8, 7),
86            _ => {
87                return Err(TensorError::invalid_argument(format!(
88                    "Unsupported quantization dtype: {dtype:?}"
89                )))
90            }
91        };
92
93        // For symmetric quantization
94        let abs_max = min_val.abs().max(max_val.abs());
95        let scale = abs_max / qmax as f32;
96
97        Ok(Self {
98            scale,
99            zero_point: 0,
100            dtype,
101            qmin,
102            qmax,
103        })
104    }
105
106    /// Calculate asymmetric quantization parameters from tensor statistics
107    pub fn asymmetric_from_tensor_stats(min_val: f32, max_val: f32, dtype: DType) -> Result<Self> {
108        let (qmin, qmax) = match dtype {
109            DType::Int8 => (-128, 127),
110            DType::Int4 => (-8, 7),
111            _ => {
112                return Err(TensorError::invalid_argument(format!(
113                    "Unsupported quantization dtype: {dtype:?}"
114                )))
115            }
116        };
117
118        let scale = (max_val - min_val) / (qmax - qmin) as f32;
119        let zero_point = qmin - (min_val / scale).round() as i32;
120
121        Ok(Self {
122            scale,
123            zero_point: zero_point.clamp(qmin, qmax),
124            dtype,
125            qmin,
126            qmax,
127        })
128    }
129}
130
131/// Quantize a float32 tensor to INT8 or INT4
132pub fn quantize<T>(tensor: &Tensor<T>, params: &QuantizationParams) -> Result<Tensor<i8>>
133where
134    T: Float + Send + Sync + Clone + Default + 'static + bytemuck::Pod + bytemuck::Zeroable,
135{
136    match &tensor.storage {
137        crate::tensor::TensorStorage::Cpu(data) => {
138            let quantized_data = data.mapv(|val| {
139                let val_f32 = cast::<T, f32>(val).unwrap_or(0.0);
140                let quantized =
141                    ((val_f32 / params.scale) + params.zero_point as f32).round() as i32;
142                quantized.clamp(params.qmin, params.qmax) as i8
143            });
144
145            Ok(Tensor::from_array(quantized_data))
146        }
147        #[cfg(feature = "gpu")]
148        crate::tensor::TensorStorage::Gpu(gpu_buffer) => gpu_quantize(gpu_buffer, params),
149    }
150}
151
152/// Dequantize an INT8 or INT4 tensor back to float32
153pub fn dequantize(tensor: &Tensor<i8>, params: &QuantizationParams) -> Result<Tensor<f32>> {
154    match &tensor.storage {
155        crate::tensor::TensorStorage::Cpu(data) => {
156            let dequantized_data =
157                data.mapv(|val| (val as i32 - params.zero_point) as f32 * params.scale);
158
159            Ok(Tensor::from_array(dequantized_data))
160        }
161        #[cfg(feature = "gpu")]
162        crate::tensor::TensorStorage::Gpu(gpu_buffer) => gpu_dequantize(gpu_buffer, params),
163    }
164}
165
166/// Dynamic quantization - calculates quantization parameters on the fly
167pub fn dynamic_quantize<T>(
168    tensor: &Tensor<T>,
169    dtype: DType,
170) -> Result<(Tensor<i8>, QuantizationParams)>
171where
172    T: Float
173        + PartialOrd
174        + Send
175        + Sync
176        + Clone
177        + Default
178        + 'static
179        + bytemuck::Pod
180        + bytemuck::Zeroable,
181{
182    match &tensor.storage {
183        crate::tensor::TensorStorage::Cpu(data) => {
184            // Calculate min and max values
185            let mut min_val = T::infinity();
186            let mut max_val = T::neg_infinity();
187
188            for &val in data.iter() {
189                if val < min_val {
190                    min_val = val;
191                }
192                if val > max_val {
193                    max_val = val;
194                }
195            }
196
197            let min_f32 = cast::<T, f32>(min_val).unwrap_or(0.0);
198            let max_f32 = cast::<T, f32>(max_val).unwrap_or(0.0);
199
200            let params = QuantizationParams::from_tensor_stats(min_f32, max_f32, dtype)?;
201            let quantized = quantize(tensor, &params)?;
202
203            Ok((quantized, params))
204        }
205        #[cfg(feature = "gpu")]
206        crate::tensor::TensorStorage::Gpu(gpu_buffer) => gpu_dynamic_quantize(gpu_buffer, dtype),
207    }
208}
209
210/// Quantize-aware training support - applies fake quantization during training
211pub fn fake_quantize<T>(tensor: &Tensor<T>, params: &QuantizationParams) -> Result<Tensor<T>>
212where
213    T: Float + Send + Sync + Clone + Default + 'static + bytemuck::Pod + bytemuck::Zeroable,
214{
215    match &tensor.storage {
216        crate::tensor::TensorStorage::Cpu(data) => {
217            let fake_quantized_data = data.mapv(|val| {
218                let val_f32 = cast::<T, f32>(val).unwrap_or(0.0);
219                let quantized =
220                    ((val_f32 / params.scale) + params.zero_point as f32).round() as i32;
221                let clamped = quantized.clamp(params.qmin, params.qmax);
222                let dequantized = (clamped - params.zero_point) as f32 * params.scale;
223                cast::<f32, T>(dequantized).unwrap_or_default()
224            });
225
226            Ok(Tensor::from_array(fake_quantized_data))
227        }
228        #[cfg(feature = "gpu")]
229        crate::tensor::TensorStorage::Gpu(gpu_buffer) => {
230            gpu_fake_quantize(gpu_buffer, params, tensor.shape())
231        }
232    }
233}
234
235/// Per-channel quantization for weights (common in neural networks)
236pub fn per_channel_quantize<T>(
237    tensor: &Tensor<T>,
238    channel_axis: usize,
239) -> Result<(Tensor<i8>, Vec<QuantizationParams>)>
240where
241    T: Float
242        + PartialOrd
243        + Send
244        + Sync
245        + Clone
246        + Default
247        + 'static
248        + bytemuck::Pod
249        + bytemuck::Zeroable,
250{
251    match &tensor.storage {
252        crate::tensor::TensorStorage::Cpu(data) => {
253            let shape = data.shape();
254            if channel_axis >= shape.len() {
255                return Err(TensorError::invalid_argument(format!(
256                    "Channel axis {} out of bounds for tensor with {} dimensions",
257                    channel_axis,
258                    shape.len()
259                )));
260            }
261
262            let num_channels = shape[channel_axis];
263            let mut params_vec = Vec::with_capacity(num_channels);
264
265            // For simplicity, we'll implement a basic version
266            // A full implementation would properly handle per-channel statistics
267            let overall_params = QuantizationParams::from_tensor_stats(0.0, 1.0, DType::Int8)?;
268            for _ in 0..num_channels {
269                params_vec.push(overall_params.clone());
270            }
271
272            let quantized = quantize(tensor, &overall_params)?;
273            Ok((quantized, params_vec))
274        }
275        #[cfg(feature = "gpu")]
276        crate::tensor::TensorStorage::Gpu(gpu_buffer) => {
277            gpu_per_channel_quantize(gpu_buffer, channel_axis)
278        }
279    }
280}
281
282#[cfg(feature = "gpu")]
283fn gpu_quantize<T>(gpu_buffer: &GpuBuffer<T>, params: &QuantizationParams) -> Result<Tensor<i8>>
284where
285    T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
286{
287    let device_id = match gpu_buffer.device_enum() {
288        crate::Device::Gpu(id) => id,
289        _ => {
290            return Err(TensorError::invalid_argument(
291                "Expected GPU device".to_string(),
292            ))
293        }
294    };
295
296    let gpu_ctx = get_gpu_context(device_id)?;
297    let buffer_size = gpu_buffer.len() * std::mem::size_of::<T>();
298
299    // Create output buffer for quantized values (i8 stored as i32)
300    let output_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
301        label: Some("quantize_output"),
302        size: (gpu_buffer.len() * std::mem::size_of::<i32>()) as u64,
303        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
304        mapped_at_creation: false,
305    });
306
307    // Create params buffer
308    let params_data = [
309        params.scale,
310        params.zero_point as f32,
311        params.qmin as f32,
312        params.qmax as f32,
313    ];
314    let params_buffer = gpu_ctx
315        .device
316        .create_buffer_init(&wgpu::util::BufferInitDescriptor {
317            label: Some("quantize_params"),
318            contents: bytemuck::cast_slice(&params_data),
319            usage: wgpu::BufferUsages::STORAGE,
320        });
321
322    // Create compute shader
323    let shader = gpu_ctx
324        .device
325        .create_shader_module(wgpu::ShaderModuleDescriptor {
326            label: Some("quantization_ops"),
327            source: wgpu::ShaderSource::Wgsl(
328                include_str!("gpu/shaders/quantization_ops.wgsl").into(),
329            ),
330        });
331
332    let compute_pipeline =
333        gpu_ctx
334            .device
335            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
336                label: Some("quantize_pipeline"),
337                layout: None,
338                module: &shader,
339                entry_point: Some(match params.dtype {
340                    DType::Int8 => "quantize_int8",
341                    DType::Int4 => "quantize_int4",
342                    _ => "quantize",
343                }),
344                cache: None,
345                compilation_options: Default::default(),
346            });
347
348    // Create bind group
349    let bind_group = gpu_ctx
350        .device
351        .create_bind_group(&wgpu::BindGroupDescriptor {
352            label: Some("quantize_bind_group"),
353            layout: &compute_pipeline.get_bind_group_layout(0),
354            entries: &[
355                wgpu::BindGroupEntry {
356                    binding: 0,
357                    resource: gpu_buffer.buffer().as_entire_binding(),
358                },
359                wgpu::BindGroupEntry {
360                    binding: 1,
361                    resource: output_buffer.as_entire_binding(),
362                },
363                wgpu::BindGroupEntry {
364                    binding: 2,
365                    resource: params_buffer.as_entire_binding(),
366                },
367            ],
368        });
369
370    // Execute compute shader
371    let mut encoder = gpu_ctx
372        .device
373        .create_command_encoder(&wgpu::CommandEncoderDescriptor {
374            label: Some("quantize_encoder"),
375        });
376
377    {
378        let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
379            label: Some("quantize_pass"),
380            timestamp_writes: None,
381        });
382        compute_pass.set_pipeline(&compute_pipeline);
383        compute_pass.set_bind_group(0, &bind_group, &[]);
384
385        let workgroup_size = 64;
386        let num_workgroups = (gpu_buffer.len() + workgroup_size - 1) / workgroup_size;
387        compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
388    }
389
390    gpu_ctx.queue.submit(std::iter::once(encoder.finish()));
391
392    // Read back results
393    let staging_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
394        label: Some("quantize_staging"),
395        size: (gpu_buffer.len() * std::mem::size_of::<i32>()) as u64,
396        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
397        mapped_at_creation: false,
398    });
399
400    let mut encoder = gpu_ctx
401        .device
402        .create_command_encoder(&wgpu::CommandEncoderDescriptor {
403            label: Some("readback_encoder"),
404        });
405    encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, staging_buffer.size());
406    gpu_ctx.queue.submit(std::iter::once(encoder.finish()));
407
408    // Map and read data
409    let buffer_slice = staging_buffer.slice(..);
410    let (sender, receiver) = std::sync::mpsc::channel();
411    buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
412        sender.send(result).expect("channel send should succeed");
413    });
414
415    gpu_ctx
416        .device
417        .poll(wgpu::PollType::wait_indefinitely())
418        .ok();
419    receiver
420        .recv()
421        .map_err(|e| TensorError::ComputeError {
422            operation: "gpu_buffer_read".to_string(),
423            details: format!("Channel receive failed: {}", e),
424            retry_possible: true,
425            context: None,
426        })?
427        .map_err(|e| TensorError::invalid_argument(format!("Buffer mapping failed: {:?}", e)))?;
428
429    let data = buffer_slice.get_mapped_range();
430    let i32_data: &[i32] = bytemuck::cast_slice(&data);
431    let i8_data: Vec<i8> = i32_data.iter().map(|&x| x as i8).collect();
432
433    drop(data);
434    staging_buffer.unmap();
435
436    let array = Array1::from_vec(i8_data).into_dyn();
437    Ok(Tensor::from_array(array))
438}
439
440#[cfg(feature = "gpu")]
441fn gpu_dequantize(gpu_buffer: &GpuBuffer<i8>, params: &QuantizationParams) -> Result<Tensor<f32>> {
442    let device_id = match gpu_buffer.device_enum() {
443        crate::Device::Gpu(id) => id,
444        _ => {
445            return Err(TensorError::invalid_argument(
446                "Expected GPU device".to_string(),
447            ))
448        }
449    };
450
451    let gpu_ctx = get_gpu_context(device_id)?;
452
453    // Create output buffer for dequantized values
454    let output_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
455        label: Some("dequantize_output"),
456        size: (gpu_buffer.len() * std::mem::size_of::<f32>()) as u64,
457        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
458        mapped_at_creation: false,
459    });
460
461    // Create params buffer
462    let params_data = [params.scale, params.zero_point as f32];
463    let params_buffer = gpu_ctx
464        .device
465        .create_buffer_init(&wgpu::util::BufferInitDescriptor {
466            label: Some("dequantize_params"),
467            contents: bytemuck::cast_slice(&params_data),
468            usage: wgpu::BufferUsages::STORAGE,
469        });
470
471    // Create compute shader
472    let shader = gpu_ctx
473        .device
474        .create_shader_module(wgpu::ShaderModuleDescriptor {
475            label: Some("quantization_ops"),
476            source: wgpu::ShaderSource::Wgsl(
477                include_str!("gpu/shaders/quantization_ops.wgsl").into(),
478            ),
479        });
480
481    let compute_pipeline =
482        gpu_ctx
483            .device
484            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
485                label: Some("dequantize_pipeline"),
486                layout: None,
487                module: &shader,
488                entry_point: Some("dequantize"),
489                cache: None,
490                compilation_options: Default::default(),
491            });
492
493    // Create bind group
494    let bind_group = gpu_ctx
495        .device
496        .create_bind_group(&wgpu::BindGroupDescriptor {
497            label: Some("dequantize_bind_group"),
498            layout: &compute_pipeline.get_bind_group_layout(0),
499            entries: &[
500                wgpu::BindGroupEntry {
501                    binding: 0,
502                    resource: gpu_buffer.buffer().as_entire_binding(),
503                },
504                wgpu::BindGroupEntry {
505                    binding: 1,
506                    resource: output_buffer.as_entire_binding(),
507                },
508                wgpu::BindGroupEntry {
509                    binding: 2,
510                    resource: params_buffer.as_entire_binding(),
511                },
512            ],
513        });
514
515    // Execute compute shader
516    let mut encoder = gpu_ctx
517        .device
518        .create_command_encoder(&wgpu::CommandEncoderDescriptor {
519            label: Some("dequantize_encoder"),
520        });
521
522    {
523        let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
524            label: Some("dequantize_pass"),
525            timestamp_writes: None,
526        });
527        compute_pass.set_pipeline(&compute_pipeline);
528        compute_pass.set_bind_group(0, &bind_group, &[]);
529
530        let workgroup_size = 64;
531        let num_workgroups = (gpu_buffer.len() + workgroup_size - 1) / workgroup_size;
532        compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
533    }
534
535    gpu_ctx.queue.submit(std::iter::once(encoder.finish()));
536
537    // Read back results
538    let staging_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
539        label: Some("dequantize_staging"),
540        size: (gpu_buffer.len() * std::mem::size_of::<f32>()) as u64,
541        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
542        mapped_at_creation: false,
543    });
544
545    let mut encoder = gpu_ctx
546        .device
547        .create_command_encoder(&wgpu::CommandEncoderDescriptor {
548            label: Some("readback_encoder"),
549        });
550    encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, staging_buffer.size());
551    gpu_ctx.queue.submit(std::iter::once(encoder.finish()));
552
553    // Map and read data
554    let buffer_slice = staging_buffer.slice(..);
555    let (sender, receiver) = std::sync::mpsc::channel();
556    buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
557        sender.send(result).expect("channel send should succeed");
558    });
559
560    gpu_ctx
561        .device
562        .poll(wgpu::PollType::wait_indefinitely())
563        .ok();
564    receiver
565        .recv()
566        .map_err(|e| TensorError::ComputeError {
567            operation: "gpu_buffer_read".to_string(),
568            details: format!("Channel receive failed: {}", e),
569            retry_possible: true,
570            context: None,
571        })?
572        .map_err(|e| TensorError::invalid_argument(format!("Buffer mapping failed: {:?}", e)))?;
573
574    let data = buffer_slice.get_mapped_range();
575    let f32_data: &[f32] = bytemuck::cast_slice(&data);
576    let result_vec: Vec<f32> = f32_data.to_vec();
577
578    drop(data);
579    staging_buffer.unmap();
580
581    let array = Array1::from_vec(result_vec).into_dyn();
582    Ok(Tensor::from_array(array))
583}
584
585#[cfg(feature = "gpu")]
586fn gpu_dynamic_quantize<T>(
587    gpu_buffer: &GpuBuffer<T>,
588    dtype: DType,
589) -> Result<(Tensor<i8>, QuantizationParams)>
590where
591    T: Default
592        + bytemuck::Pod
593        + bytemuck::Zeroable
594        + Clone
595        + Send
596        + Sync
597        + 'static
598        + scirs2_core::num_traits::Float,
599{
600    // For simplicity, fall back to CPU implementation for now
601    // A full implementation would use the dynamic_quantize shader
602    let cpu_array = gpu_buffer.to_cpu_array()?;
603    let cpu_tensor = Tensor::from_array(cpu_array);
604    dynamic_quantize(&cpu_tensor, dtype)
605}
606
607#[cfg(feature = "gpu")]
608fn gpu_fake_quantize<T>(
609    gpu_buffer: &GpuBuffer<T>,
610    params: &QuantizationParams,
611    shape: &crate::Shape,
612) -> Result<Tensor<T>>
613where
614    T: Default + bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
615{
616    let device_id = match gpu_buffer.device_enum() {
617        crate::Device::Gpu(id) => id,
618        _ => {
619            return Err(TensorError::invalid_argument(
620                "Expected GPU device".to_string(),
621            ))
622        }
623    };
624
625    let gpu_ctx = get_gpu_context(device_id)?;
626
627    // Create output buffer
628    let output_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
629        label: Some("fake_quantize_output"),
630        size: (gpu_buffer.len() * std::mem::size_of::<T>()) as u64,
631        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
632        mapped_at_creation: false,
633    });
634
635    // Create params buffer
636    let params_data = [
637        params.scale,
638        params.zero_point as f32,
639        params.qmin as f32,
640        params.qmax as f32,
641    ];
642    let params_buffer = gpu_ctx
643        .device
644        .create_buffer_init(&wgpu::util::BufferInitDescriptor {
645            label: Some("fake_quantize_params"),
646            contents: bytemuck::cast_slice(&params_data),
647            usage: wgpu::BufferUsages::STORAGE,
648        });
649
650    // Create compute shader
651    let shader = gpu_ctx
652        .device
653        .create_shader_module(wgpu::ShaderModuleDescriptor {
654            label: Some("quantization_ops"),
655            source: wgpu::ShaderSource::Wgsl(
656                include_str!("gpu/shaders/quantization_ops.wgsl").into(),
657            ),
658        });
659
660    let compute_pipeline =
661        gpu_ctx
662            .device
663            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
664                label: Some("fake_quantize_pipeline"),
665                layout: None,
666                module: &shader,
667                entry_point: Some("fake_quantize"),
668                cache: None,
669                compilation_options: Default::default(),
670            });
671
672    // Create bind group
673    let bind_group = gpu_ctx
674        .device
675        .create_bind_group(&wgpu::BindGroupDescriptor {
676            label: Some("fake_quantize_bind_group"),
677            layout: &compute_pipeline.get_bind_group_layout(0),
678            entries: &[
679                wgpu::BindGroupEntry {
680                    binding: 0,
681                    resource: gpu_buffer.buffer().as_entire_binding(),
682                },
683                wgpu::BindGroupEntry {
684                    binding: 1,
685                    resource: output_buffer.as_entire_binding(),
686                },
687                wgpu::BindGroupEntry {
688                    binding: 2,
689                    resource: params_buffer.as_entire_binding(),
690                },
691            ],
692        });
693
694    // Execute compute shader
695    let mut encoder = gpu_ctx
696        .device
697        .create_command_encoder(&wgpu::CommandEncoderDescriptor {
698            label: Some("fake_quantize_encoder"),
699        });
700
701    {
702        let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
703            label: Some("fake_quantize_pass"),
704            timestamp_writes: None,
705        });
706        compute_pass.set_pipeline(&compute_pipeline);
707        compute_pass.set_bind_group(0, &bind_group, &[]);
708
709        let workgroup_size = 64;
710        let num_workgroups = (gpu_buffer.len() + workgroup_size - 1) / workgroup_size;
711        compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
712    }
713
714    gpu_ctx.queue.submit(std::iter::once(encoder.finish()));
715
716    // Read back results
717    let staging_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
718        label: Some("fake_quantize_staging"),
719        size: (gpu_buffer.len() * std::mem::size_of::<T>()) as u64,
720        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
721        mapped_at_creation: false,
722    });
723
724    let mut encoder = gpu_ctx
725        .device
726        .create_command_encoder(&wgpu::CommandEncoderDescriptor {
727            label: Some("readback_encoder"),
728        });
729    encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, staging_buffer.size());
730    gpu_ctx.queue.submit(std::iter::once(encoder.finish()));
731
732    // Map and read data
733    let buffer_slice = staging_buffer.slice(..);
734    let (sender, receiver) = std::sync::mpsc::channel();
735    buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
736        sender.send(result).expect("channel send should succeed");
737    });
738
739    gpu_ctx
740        .device
741        .poll(wgpu::PollType::wait_indefinitely())
742        .ok();
743    receiver
744        .recv()
745        .map_err(|e| TensorError::ComputeError {
746            operation: "gpu_buffer_read".to_string(),
747            details: format!("Channel receive failed: {}", e),
748            retry_possible: true,
749            context: None,
750        })?
751        .map_err(|e| TensorError::invalid_argument(format!("Buffer mapping failed: {:?}", e)))?;
752
753    let data = buffer_slice.get_mapped_range();
754    let result_data: &[T] = bytemuck::cast_slice(&data);
755    let result_vec: Vec<T> = result_data.to_vec();
756
757    drop(data);
758    staging_buffer.unmap();
759
760    let result_array = scirs2_core::ndarray::Array::from_shape_vec(shape.dims(), result_vec)
761        .map_err(|e| TensorError::invalid_argument(format!("Shape mismatch: {:?}", e)))?
762        .into_dyn();
763    let result_buffer = GpuBuffer::from_cpu_array(&result_array, device_id)?;
764    Ok(Tensor::from_gpu_buffer(result_buffer, shape.clone()))
765}
766
767#[cfg(feature = "gpu")]
768fn gpu_per_channel_quantize<T>(
769    gpu_buffer: &GpuBuffer<T>,
770    channel_axis: usize,
771) -> Result<(Tensor<i8>, Vec<QuantizationParams>)>
772where
773    T: Default
774        + bytemuck::Pod
775        + bytemuck::Zeroable
776        + Clone
777        + Send
778        + Sync
779        + 'static
780        + scirs2_core::num_traits::Float,
781{
782    // For simplicity, fall back to CPU implementation for now
783    // A full implementation would use the per_channel_quantize shader
784    let cpu_array = gpu_buffer.to_cpu_array()?;
785    let cpu_tensor = Tensor::from_array(cpu_array);
786    per_channel_quantize(&cpu_tensor, channel_axis)
787}
788
789#[cfg(test)]
790mod tests {
791    use super::*;
792    use scirs2_core::ndarray::Array;
793
794    #[test]
795    fn test_symmetric_quantization() {
796        let data = Array::from_vec(vec![1.0f32, 2.0, 3.0, -1.0, -2.0]).into_dyn();
797        let tensor = Tensor::from_array(data);
798
799        let params = QuantizationParams::symmetric_int8(0.1);
800        let quantized = quantize(&tensor, &params).expect("test: quantize should succeed");
801        let dequantized = dequantize(&quantized, &params).expect("test: dequantize should succeed");
802
803        // Test that quantization and dequantization are approximately inverse operations
804        // (within quantization error)
805        assert!(dequantized.shape() == tensor.shape());
806    }
807
808    #[test]
809    fn test_dynamic_quantization() {
810        let data = Array::from_vec(vec![1.0f32, 2.0, 3.0, -1.0, -2.0]).into_dyn();
811        let tensor = Tensor::from_array(data);
812
813        let (quantized, params) =
814            dynamic_quantize(&tensor, DType::Int8).expect("test: dynamic_quantize should succeed");
815
816        assert_eq!(quantized.dtype(), DType::Int8);
817        assert!(params.scale > 0.0);
818    }
819
820    #[test]
821    fn test_fake_quantization() {
822        let data = Array::from_vec(vec![1.0f32, 2.0, 3.0, -1.0, -2.0]).into_dyn();
823        let tensor = Tensor::from_array(data);
824
825        let params = QuantizationParams::symmetric_int8(0.1);
826        let fake_quantized =
827            fake_quantize(&tensor, &params).expect("test: fake_quantize should succeed");
828
829        // Fake quantization should maintain the same data type and shape
830        assert_eq!(fake_quantized.dtype(), tensor.dtype());
831        assert_eq!(fake_quantized.shape(), tensor.shape());
832    }
833
834    #[test]
835    #[cfg(feature = "gpu")]
836    #[ignore = "GPU buffer usage conflicts - needs WGPU buffer management fixes"]
837    fn test_gpu_quantization() {
838        let data = Array::from_vec(vec![1.0f32, 2.0, 3.0, -1.0, -2.0]).into_dyn();
839        let cpu_tensor = Tensor::from_array(data);
840
841        // Convert to GPU tensor
842        let gpu_tensor = cpu_tensor
843            .to_device(crate::Device::Gpu(0))
844            .expect("test: operation should succeed");
845
846        let params = QuantizationParams::symmetric_int8(0.1);
847        let quantized = quantize(&gpu_tensor, &params).expect("test: quantize should succeed");
848        let dequantized = dequantize(&quantized, &params).expect("test: dequantize should succeed");
849
850        // Test that GPU quantization works
851        assert_eq!(quantized.dtype(), DType::Int8);
852        assert_eq!(dequantized.dtype(), DType::Float32);
853        assert_eq!(dequantized.shape(), gpu_tensor.shape());
854    }
855
856    #[test]
857    #[cfg(feature = "gpu")]
858    #[ignore = "GPU buffer usage conflicts - needs WGPU buffer management fixes"]
859    fn test_gpu_fake_quantization() {
860        let data = Array::from_vec(vec![1.0f32, 2.0, 3.0, -1.0, -2.0]).into_dyn();
861        let cpu_tensor = Tensor::from_array(data);
862
863        // Convert to GPU tensor
864        let gpu_tensor = cpu_tensor
865            .to_device(crate::Device::Gpu(0))
866            .expect("test: operation should succeed");
867
868        let params = QuantizationParams::symmetric_int8(0.1);
869        let fake_quantized =
870            fake_quantize(&gpu_tensor, &params).expect("test: fake_quantize should succeed");
871
872        // Fake quantization should maintain the same data type and shape
873        assert_eq!(fake_quantized.dtype(), gpu_tensor.dtype());
874        assert_eq!(fake_quantized.shape(), gpu_tensor.shape());
875    }
876}