1use crate::{DType, Result, Tensor, TensorError};
7use scirs2_core::num_traits::cast::cast;
8use scirs2_core::numeric::Float;
9
10#[cfg(feature = "gpu")]
11use crate::device::context::get_gpu_context;
12#[cfg(feature = "gpu")]
13use crate::gpu::buffer::GpuBuffer;
14#[cfg(feature = "gpu")]
15use bytemuck;
16#[cfg(feature = "gpu")]
17use scirs2_core::ndarray::Array1;
18#[cfg(feature = "gpu")]
19use wgpu::util::DeviceExt;
20
21#[derive(Debug, Clone)]
23pub struct QuantizationParams {
24 pub scale: f32,
26 pub zero_point: i32,
28 pub dtype: DType,
30 pub qmin: i32,
32 pub qmax: i32,
34}
35
36impl QuantizationParams {
37 pub fn symmetric_int8(scale: f32) -> Self {
39 Self {
40 scale,
41 zero_point: 0,
42 dtype: DType::Int8,
43 qmin: -128,
44 qmax: 127,
45 }
46 }
47
48 pub fn asymmetric_int8(scale: f32, zero_point: i32) -> Self {
50 Self {
51 scale,
52 zero_point,
53 dtype: DType::Int8,
54 qmin: -128,
55 qmax: 127,
56 }
57 }
58
59 pub fn symmetric_int4(scale: f32) -> Self {
61 Self {
62 scale,
63 zero_point: 0,
64 dtype: DType::Int4,
65 qmin: -8,
66 qmax: 7,
67 }
68 }
69
70 pub fn asymmetric_int4(scale: f32, zero_point: i32) -> Self {
72 Self {
73 scale,
74 zero_point,
75 dtype: DType::Int4,
76 qmin: -8,
77 qmax: 7,
78 }
79 }
80
81 pub fn from_tensor_stats(min_val: f32, max_val: f32, dtype: DType) -> Result<Self> {
83 let (qmin, qmax) = match dtype {
84 DType::Int8 => (-128, 127),
85 DType::Int4 => (-8, 7),
86 _ => {
87 return Err(TensorError::invalid_argument(format!(
88 "Unsupported quantization dtype: {dtype:?}"
89 )))
90 }
91 };
92
93 let abs_max = min_val.abs().max(max_val.abs());
95 let scale = abs_max / qmax as f32;
96
97 Ok(Self {
98 scale,
99 zero_point: 0,
100 dtype,
101 qmin,
102 qmax,
103 })
104 }
105
106 pub fn asymmetric_from_tensor_stats(min_val: f32, max_val: f32, dtype: DType) -> Result<Self> {
108 let (qmin, qmax) = match dtype {
109 DType::Int8 => (-128, 127),
110 DType::Int4 => (-8, 7),
111 _ => {
112 return Err(TensorError::invalid_argument(format!(
113 "Unsupported quantization dtype: {dtype:?}"
114 )))
115 }
116 };
117
118 let scale = (max_val - min_val) / (qmax - qmin) as f32;
119 let zero_point = qmin - (min_val / scale).round() as i32;
120
121 Ok(Self {
122 scale,
123 zero_point: zero_point.clamp(qmin, qmax),
124 dtype,
125 qmin,
126 qmax,
127 })
128 }
129}
130
131pub fn quantize<T>(tensor: &Tensor<T>, params: &QuantizationParams) -> Result<Tensor<i8>>
133where
134 T: Float + Send + Sync + Clone + Default + 'static + bytemuck::Pod + bytemuck::Zeroable,
135{
136 match &tensor.storage {
137 crate::tensor::TensorStorage::Cpu(data) => {
138 let quantized_data = data.mapv(|val| {
139 let val_f32 = cast::<T, f32>(val).unwrap_or(0.0);
140 let quantized =
141 ((val_f32 / params.scale) + params.zero_point as f32).round() as i32;
142 quantized.clamp(params.qmin, params.qmax) as i8
143 });
144
145 Ok(Tensor::from_array(quantized_data))
146 }
147 #[cfg(feature = "gpu")]
148 crate::tensor::TensorStorage::Gpu(gpu_buffer) => gpu_quantize(gpu_buffer, params),
149 }
150}
151
152pub fn dequantize(tensor: &Tensor<i8>, params: &QuantizationParams) -> Result<Tensor<f32>> {
154 match &tensor.storage {
155 crate::tensor::TensorStorage::Cpu(data) => {
156 let dequantized_data =
157 data.mapv(|val| (val as i32 - params.zero_point) as f32 * params.scale);
158
159 Ok(Tensor::from_array(dequantized_data))
160 }
161 #[cfg(feature = "gpu")]
162 crate::tensor::TensorStorage::Gpu(gpu_buffer) => gpu_dequantize(gpu_buffer, params),
163 }
164}
165
166pub fn dynamic_quantize<T>(
168 tensor: &Tensor<T>,
169 dtype: DType,
170) -> Result<(Tensor<i8>, QuantizationParams)>
171where
172 T: Float
173 + PartialOrd
174 + Send
175 + Sync
176 + Clone
177 + Default
178 + 'static
179 + bytemuck::Pod
180 + bytemuck::Zeroable,
181{
182 match &tensor.storage {
183 crate::tensor::TensorStorage::Cpu(data) => {
184 let mut min_val = T::infinity();
186 let mut max_val = T::neg_infinity();
187
188 for &val in data.iter() {
189 if val < min_val {
190 min_val = val;
191 }
192 if val > max_val {
193 max_val = val;
194 }
195 }
196
197 let min_f32 = cast::<T, f32>(min_val).unwrap_or(0.0);
198 let max_f32 = cast::<T, f32>(max_val).unwrap_or(0.0);
199
200 let params = QuantizationParams::from_tensor_stats(min_f32, max_f32, dtype)?;
201 let quantized = quantize(tensor, ¶ms)?;
202
203 Ok((quantized, params))
204 }
205 #[cfg(feature = "gpu")]
206 crate::tensor::TensorStorage::Gpu(gpu_buffer) => gpu_dynamic_quantize(gpu_buffer, dtype),
207 }
208}
209
210pub fn fake_quantize<T>(tensor: &Tensor<T>, params: &QuantizationParams) -> Result<Tensor<T>>
212where
213 T: Float + Send + Sync + Clone + Default + 'static + bytemuck::Pod + bytemuck::Zeroable,
214{
215 match &tensor.storage {
216 crate::tensor::TensorStorage::Cpu(data) => {
217 let fake_quantized_data = data.mapv(|val| {
218 let val_f32 = cast::<T, f32>(val).unwrap_or(0.0);
219 let quantized =
220 ((val_f32 / params.scale) + params.zero_point as f32).round() as i32;
221 let clamped = quantized.clamp(params.qmin, params.qmax);
222 let dequantized = (clamped - params.zero_point) as f32 * params.scale;
223 cast::<f32, T>(dequantized).unwrap_or_default()
224 });
225
226 Ok(Tensor::from_array(fake_quantized_data))
227 }
228 #[cfg(feature = "gpu")]
229 crate::tensor::TensorStorage::Gpu(gpu_buffer) => {
230 gpu_fake_quantize(gpu_buffer, params, tensor.shape())
231 }
232 }
233}
234
235pub fn per_channel_quantize<T>(
237 tensor: &Tensor<T>,
238 channel_axis: usize,
239) -> Result<(Tensor<i8>, Vec<QuantizationParams>)>
240where
241 T: Float
242 + PartialOrd
243 + Send
244 + Sync
245 + Clone
246 + Default
247 + 'static
248 + bytemuck::Pod
249 + bytemuck::Zeroable,
250{
251 match &tensor.storage {
252 crate::tensor::TensorStorage::Cpu(data) => {
253 let shape = data.shape();
254 if channel_axis >= shape.len() {
255 return Err(TensorError::invalid_argument(format!(
256 "Channel axis {} out of bounds for tensor with {} dimensions",
257 channel_axis,
258 shape.len()
259 )));
260 }
261
262 let num_channels = shape[channel_axis];
263 let mut params_vec = Vec::with_capacity(num_channels);
264
265 let overall_params = QuantizationParams::from_tensor_stats(0.0, 1.0, DType::Int8)?;
268 for _ in 0..num_channels {
269 params_vec.push(overall_params.clone());
270 }
271
272 let quantized = quantize(tensor, &overall_params)?;
273 Ok((quantized, params_vec))
274 }
275 #[cfg(feature = "gpu")]
276 crate::tensor::TensorStorage::Gpu(gpu_buffer) => {
277 gpu_per_channel_quantize(gpu_buffer, channel_axis)
278 }
279 }
280}
281
282#[cfg(feature = "gpu")]
283fn gpu_quantize<T>(gpu_buffer: &GpuBuffer<T>, params: &QuantizationParams) -> Result<Tensor<i8>>
284where
285 T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
286{
287 let device_id = match gpu_buffer.device_enum() {
288 crate::Device::Gpu(id) => id,
289 _ => {
290 return Err(TensorError::invalid_argument(
291 "Expected GPU device".to_string(),
292 ))
293 }
294 };
295
296 let gpu_ctx = get_gpu_context(device_id)?;
297 let buffer_size = gpu_buffer.len() * std::mem::size_of::<T>();
298
299 let output_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
301 label: Some("quantize_output"),
302 size: (gpu_buffer.len() * std::mem::size_of::<i32>()) as u64,
303 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
304 mapped_at_creation: false,
305 });
306
307 let params_data = [
309 params.scale,
310 params.zero_point as f32,
311 params.qmin as f32,
312 params.qmax as f32,
313 ];
314 let params_buffer = gpu_ctx
315 .device
316 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
317 label: Some("quantize_params"),
318 contents: bytemuck::cast_slice(¶ms_data),
319 usage: wgpu::BufferUsages::STORAGE,
320 });
321
322 let shader = gpu_ctx
324 .device
325 .create_shader_module(wgpu::ShaderModuleDescriptor {
326 label: Some("quantization_ops"),
327 source: wgpu::ShaderSource::Wgsl(
328 include_str!("gpu/shaders/quantization_ops.wgsl").into(),
329 ),
330 });
331
332 let compute_pipeline =
333 gpu_ctx
334 .device
335 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
336 label: Some("quantize_pipeline"),
337 layout: None,
338 module: &shader,
339 entry_point: Some(match params.dtype {
340 DType::Int8 => "quantize_int8",
341 DType::Int4 => "quantize_int4",
342 _ => "quantize",
343 }),
344 cache: None,
345 compilation_options: Default::default(),
346 });
347
348 let bind_group = gpu_ctx
350 .device
351 .create_bind_group(&wgpu::BindGroupDescriptor {
352 label: Some("quantize_bind_group"),
353 layout: &compute_pipeline.get_bind_group_layout(0),
354 entries: &[
355 wgpu::BindGroupEntry {
356 binding: 0,
357 resource: gpu_buffer.buffer().as_entire_binding(),
358 },
359 wgpu::BindGroupEntry {
360 binding: 1,
361 resource: output_buffer.as_entire_binding(),
362 },
363 wgpu::BindGroupEntry {
364 binding: 2,
365 resource: params_buffer.as_entire_binding(),
366 },
367 ],
368 });
369
370 let mut encoder = gpu_ctx
372 .device
373 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
374 label: Some("quantize_encoder"),
375 });
376
377 {
378 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
379 label: Some("quantize_pass"),
380 timestamp_writes: None,
381 });
382 compute_pass.set_pipeline(&compute_pipeline);
383 compute_pass.set_bind_group(0, &bind_group, &[]);
384
385 let workgroup_size = 64;
386 let num_workgroups = (gpu_buffer.len() + workgroup_size - 1) / workgroup_size;
387 compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
388 }
389
390 gpu_ctx.queue.submit(std::iter::once(encoder.finish()));
391
392 let staging_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
394 label: Some("quantize_staging"),
395 size: (gpu_buffer.len() * std::mem::size_of::<i32>()) as u64,
396 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
397 mapped_at_creation: false,
398 });
399
400 let mut encoder = gpu_ctx
401 .device
402 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
403 label: Some("readback_encoder"),
404 });
405 encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, staging_buffer.size());
406 gpu_ctx.queue.submit(std::iter::once(encoder.finish()));
407
408 let buffer_slice = staging_buffer.slice(..);
410 let (sender, receiver) = std::sync::mpsc::channel();
411 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
412 sender.send(result).expect("channel send should succeed");
413 });
414
415 gpu_ctx
416 .device
417 .poll(wgpu::PollType::wait_indefinitely())
418 .ok();
419 receiver
420 .recv()
421 .map_err(|e| TensorError::ComputeError {
422 operation: "gpu_buffer_read".to_string(),
423 details: format!("Channel receive failed: {}", e),
424 retry_possible: true,
425 context: None,
426 })?
427 .map_err(|e| TensorError::invalid_argument(format!("Buffer mapping failed: {:?}", e)))?;
428
429 let data = buffer_slice.get_mapped_range();
430 let i32_data: &[i32] = bytemuck::cast_slice(&data);
431 let i8_data: Vec<i8> = i32_data.iter().map(|&x| x as i8).collect();
432
433 drop(data);
434 staging_buffer.unmap();
435
436 let array = Array1::from_vec(i8_data).into_dyn();
437 Ok(Tensor::from_array(array))
438}
439
440#[cfg(feature = "gpu")]
441fn gpu_dequantize(gpu_buffer: &GpuBuffer<i8>, params: &QuantizationParams) -> Result<Tensor<f32>> {
442 let device_id = match gpu_buffer.device_enum() {
443 crate::Device::Gpu(id) => id,
444 _ => {
445 return Err(TensorError::invalid_argument(
446 "Expected GPU device".to_string(),
447 ))
448 }
449 };
450
451 let gpu_ctx = get_gpu_context(device_id)?;
452
453 let output_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
455 label: Some("dequantize_output"),
456 size: (gpu_buffer.len() * std::mem::size_of::<f32>()) as u64,
457 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
458 mapped_at_creation: false,
459 });
460
461 let params_data = [params.scale, params.zero_point as f32];
463 let params_buffer = gpu_ctx
464 .device
465 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
466 label: Some("dequantize_params"),
467 contents: bytemuck::cast_slice(¶ms_data),
468 usage: wgpu::BufferUsages::STORAGE,
469 });
470
471 let shader = gpu_ctx
473 .device
474 .create_shader_module(wgpu::ShaderModuleDescriptor {
475 label: Some("quantization_ops"),
476 source: wgpu::ShaderSource::Wgsl(
477 include_str!("gpu/shaders/quantization_ops.wgsl").into(),
478 ),
479 });
480
481 let compute_pipeline =
482 gpu_ctx
483 .device
484 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
485 label: Some("dequantize_pipeline"),
486 layout: None,
487 module: &shader,
488 entry_point: Some("dequantize"),
489 cache: None,
490 compilation_options: Default::default(),
491 });
492
493 let bind_group = gpu_ctx
495 .device
496 .create_bind_group(&wgpu::BindGroupDescriptor {
497 label: Some("dequantize_bind_group"),
498 layout: &compute_pipeline.get_bind_group_layout(0),
499 entries: &[
500 wgpu::BindGroupEntry {
501 binding: 0,
502 resource: gpu_buffer.buffer().as_entire_binding(),
503 },
504 wgpu::BindGroupEntry {
505 binding: 1,
506 resource: output_buffer.as_entire_binding(),
507 },
508 wgpu::BindGroupEntry {
509 binding: 2,
510 resource: params_buffer.as_entire_binding(),
511 },
512 ],
513 });
514
515 let mut encoder = gpu_ctx
517 .device
518 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
519 label: Some("dequantize_encoder"),
520 });
521
522 {
523 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
524 label: Some("dequantize_pass"),
525 timestamp_writes: None,
526 });
527 compute_pass.set_pipeline(&compute_pipeline);
528 compute_pass.set_bind_group(0, &bind_group, &[]);
529
530 let workgroup_size = 64;
531 let num_workgroups = (gpu_buffer.len() + workgroup_size - 1) / workgroup_size;
532 compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
533 }
534
535 gpu_ctx.queue.submit(std::iter::once(encoder.finish()));
536
537 let staging_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
539 label: Some("dequantize_staging"),
540 size: (gpu_buffer.len() * std::mem::size_of::<f32>()) as u64,
541 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
542 mapped_at_creation: false,
543 });
544
545 let mut encoder = gpu_ctx
546 .device
547 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
548 label: Some("readback_encoder"),
549 });
550 encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, staging_buffer.size());
551 gpu_ctx.queue.submit(std::iter::once(encoder.finish()));
552
553 let buffer_slice = staging_buffer.slice(..);
555 let (sender, receiver) = std::sync::mpsc::channel();
556 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
557 sender.send(result).expect("channel send should succeed");
558 });
559
560 gpu_ctx
561 .device
562 .poll(wgpu::PollType::wait_indefinitely())
563 .ok();
564 receiver
565 .recv()
566 .map_err(|e| TensorError::ComputeError {
567 operation: "gpu_buffer_read".to_string(),
568 details: format!("Channel receive failed: {}", e),
569 retry_possible: true,
570 context: None,
571 })?
572 .map_err(|e| TensorError::invalid_argument(format!("Buffer mapping failed: {:?}", e)))?;
573
574 let data = buffer_slice.get_mapped_range();
575 let f32_data: &[f32] = bytemuck::cast_slice(&data);
576 let result_vec: Vec<f32> = f32_data.to_vec();
577
578 drop(data);
579 staging_buffer.unmap();
580
581 let array = Array1::from_vec(result_vec).into_dyn();
582 Ok(Tensor::from_array(array))
583}
584
585#[cfg(feature = "gpu")]
586fn gpu_dynamic_quantize<T>(
587 gpu_buffer: &GpuBuffer<T>,
588 dtype: DType,
589) -> Result<(Tensor<i8>, QuantizationParams)>
590where
591 T: Default
592 + bytemuck::Pod
593 + bytemuck::Zeroable
594 + Clone
595 + Send
596 + Sync
597 + 'static
598 + scirs2_core::num_traits::Float,
599{
600 let cpu_array = gpu_buffer.to_cpu_array()?;
603 let cpu_tensor = Tensor::from_array(cpu_array);
604 dynamic_quantize(&cpu_tensor, dtype)
605}
606
607#[cfg(feature = "gpu")]
608fn gpu_fake_quantize<T>(
609 gpu_buffer: &GpuBuffer<T>,
610 params: &QuantizationParams,
611 shape: &crate::Shape,
612) -> Result<Tensor<T>>
613where
614 T: Default + bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
615{
616 let device_id = match gpu_buffer.device_enum() {
617 crate::Device::Gpu(id) => id,
618 _ => {
619 return Err(TensorError::invalid_argument(
620 "Expected GPU device".to_string(),
621 ))
622 }
623 };
624
625 let gpu_ctx = get_gpu_context(device_id)?;
626
627 let output_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
629 label: Some("fake_quantize_output"),
630 size: (gpu_buffer.len() * std::mem::size_of::<T>()) as u64,
631 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
632 mapped_at_creation: false,
633 });
634
635 let params_data = [
637 params.scale,
638 params.zero_point as f32,
639 params.qmin as f32,
640 params.qmax as f32,
641 ];
642 let params_buffer = gpu_ctx
643 .device
644 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
645 label: Some("fake_quantize_params"),
646 contents: bytemuck::cast_slice(¶ms_data),
647 usage: wgpu::BufferUsages::STORAGE,
648 });
649
650 let shader = gpu_ctx
652 .device
653 .create_shader_module(wgpu::ShaderModuleDescriptor {
654 label: Some("quantization_ops"),
655 source: wgpu::ShaderSource::Wgsl(
656 include_str!("gpu/shaders/quantization_ops.wgsl").into(),
657 ),
658 });
659
660 let compute_pipeline =
661 gpu_ctx
662 .device
663 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
664 label: Some("fake_quantize_pipeline"),
665 layout: None,
666 module: &shader,
667 entry_point: Some("fake_quantize"),
668 cache: None,
669 compilation_options: Default::default(),
670 });
671
672 let bind_group = gpu_ctx
674 .device
675 .create_bind_group(&wgpu::BindGroupDescriptor {
676 label: Some("fake_quantize_bind_group"),
677 layout: &compute_pipeline.get_bind_group_layout(0),
678 entries: &[
679 wgpu::BindGroupEntry {
680 binding: 0,
681 resource: gpu_buffer.buffer().as_entire_binding(),
682 },
683 wgpu::BindGroupEntry {
684 binding: 1,
685 resource: output_buffer.as_entire_binding(),
686 },
687 wgpu::BindGroupEntry {
688 binding: 2,
689 resource: params_buffer.as_entire_binding(),
690 },
691 ],
692 });
693
694 let mut encoder = gpu_ctx
696 .device
697 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
698 label: Some("fake_quantize_encoder"),
699 });
700
701 {
702 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
703 label: Some("fake_quantize_pass"),
704 timestamp_writes: None,
705 });
706 compute_pass.set_pipeline(&compute_pipeline);
707 compute_pass.set_bind_group(0, &bind_group, &[]);
708
709 let workgroup_size = 64;
710 let num_workgroups = (gpu_buffer.len() + workgroup_size - 1) / workgroup_size;
711 compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
712 }
713
714 gpu_ctx.queue.submit(std::iter::once(encoder.finish()));
715
716 let staging_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
718 label: Some("fake_quantize_staging"),
719 size: (gpu_buffer.len() * std::mem::size_of::<T>()) as u64,
720 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
721 mapped_at_creation: false,
722 });
723
724 let mut encoder = gpu_ctx
725 .device
726 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
727 label: Some("readback_encoder"),
728 });
729 encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, staging_buffer.size());
730 gpu_ctx.queue.submit(std::iter::once(encoder.finish()));
731
732 let buffer_slice = staging_buffer.slice(..);
734 let (sender, receiver) = std::sync::mpsc::channel();
735 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
736 sender.send(result).expect("channel send should succeed");
737 });
738
739 gpu_ctx
740 .device
741 .poll(wgpu::PollType::wait_indefinitely())
742 .ok();
743 receiver
744 .recv()
745 .map_err(|e| TensorError::ComputeError {
746 operation: "gpu_buffer_read".to_string(),
747 details: format!("Channel receive failed: {}", e),
748 retry_possible: true,
749 context: None,
750 })?
751 .map_err(|e| TensorError::invalid_argument(format!("Buffer mapping failed: {:?}", e)))?;
752
753 let data = buffer_slice.get_mapped_range();
754 let result_data: &[T] = bytemuck::cast_slice(&data);
755 let result_vec: Vec<T> = result_data.to_vec();
756
757 drop(data);
758 staging_buffer.unmap();
759
760 let result_array = scirs2_core::ndarray::Array::from_shape_vec(shape.dims(), result_vec)
761 .map_err(|e| TensorError::invalid_argument(format!("Shape mismatch: {:?}", e)))?
762 .into_dyn();
763 let result_buffer = GpuBuffer::from_cpu_array(&result_array, device_id)?;
764 Ok(Tensor::from_gpu_buffer(result_buffer, shape.clone()))
765}
766
767#[cfg(feature = "gpu")]
768fn gpu_per_channel_quantize<T>(
769 gpu_buffer: &GpuBuffer<T>,
770 channel_axis: usize,
771) -> Result<(Tensor<i8>, Vec<QuantizationParams>)>
772where
773 T: Default
774 + bytemuck::Pod
775 + bytemuck::Zeroable
776 + Clone
777 + Send
778 + Sync
779 + 'static
780 + scirs2_core::num_traits::Float,
781{
782 let cpu_array = gpu_buffer.to_cpu_array()?;
785 let cpu_tensor = Tensor::from_array(cpu_array);
786 per_channel_quantize(&cpu_tensor, channel_axis)
787}
788
789#[cfg(test)]
790mod tests {
791 use super::*;
792 use scirs2_core::ndarray::Array;
793
794 #[test]
795 fn test_symmetric_quantization() {
796 let data = Array::from_vec(vec![1.0f32, 2.0, 3.0, -1.0, -2.0]).into_dyn();
797 let tensor = Tensor::from_array(data);
798
799 let params = QuantizationParams::symmetric_int8(0.1);
800 let quantized = quantize(&tensor, ¶ms).expect("test: quantize should succeed");
801 let dequantized = dequantize(&quantized, ¶ms).expect("test: dequantize should succeed");
802
803 assert!(dequantized.shape() == tensor.shape());
806 }
807
808 #[test]
809 fn test_dynamic_quantization() {
810 let data = Array::from_vec(vec![1.0f32, 2.0, 3.0, -1.0, -2.0]).into_dyn();
811 let tensor = Tensor::from_array(data);
812
813 let (quantized, params) =
814 dynamic_quantize(&tensor, DType::Int8).expect("test: dynamic_quantize should succeed");
815
816 assert_eq!(quantized.dtype(), DType::Int8);
817 assert!(params.scale > 0.0);
818 }
819
820 #[test]
821 fn test_fake_quantization() {
822 let data = Array::from_vec(vec![1.0f32, 2.0, 3.0, -1.0, -2.0]).into_dyn();
823 let tensor = Tensor::from_array(data);
824
825 let params = QuantizationParams::symmetric_int8(0.1);
826 let fake_quantized =
827 fake_quantize(&tensor, ¶ms).expect("test: fake_quantize should succeed");
828
829 assert_eq!(fake_quantized.dtype(), tensor.dtype());
831 assert_eq!(fake_quantized.shape(), tensor.shape());
832 }
833
834 #[test]
835 #[cfg(feature = "gpu")]
836 #[ignore = "GPU buffer usage conflicts - needs WGPU buffer management fixes"]
837 fn test_gpu_quantization() {
838 let data = Array::from_vec(vec![1.0f32, 2.0, 3.0, -1.0, -2.0]).into_dyn();
839 let cpu_tensor = Tensor::from_array(data);
840
841 let gpu_tensor = cpu_tensor
843 .to_device(crate::Device::Gpu(0))
844 .expect("test: operation should succeed");
845
846 let params = QuantizationParams::symmetric_int8(0.1);
847 let quantized = quantize(&gpu_tensor, ¶ms).expect("test: quantize should succeed");
848 let dequantized = dequantize(&quantized, ¶ms).expect("test: dequantize should succeed");
849
850 assert_eq!(quantized.dtype(), DType::Int8);
852 assert_eq!(dequantized.dtype(), DType::Float32);
853 assert_eq!(dequantized.shape(), gpu_tensor.shape());
854 }
855
856 #[test]
857 #[cfg(feature = "gpu")]
858 #[ignore = "GPU buffer usage conflicts - needs WGPU buffer management fixes"]
859 fn test_gpu_fake_quantization() {
860 let data = Array::from_vec(vec![1.0f32, 2.0, 3.0, -1.0, -2.0]).into_dyn();
861 let cpu_tensor = Tensor::from_array(data);
862
863 let gpu_tensor = cpu_tensor
865 .to_device(crate::Device::Gpu(0))
866 .expect("test: operation should succeed");
867
868 let params = QuantizationParams::symmetric_int8(0.1);
869 let fake_quantized =
870 fake_quantize(&gpu_tensor, ¶ms).expect("test: fake_quantize should succeed");
871
872 assert_eq!(fake_quantized.dtype(), gpu_tensor.dtype());
874 assert_eq!(fake_quantized.shape(), gpu_tensor.shape());
875 }
876}