Skip to main content

torsh_quantization/
specialized.rs

1//! Specialized quantization algorithms for advanced use cases
2//!
3//! This module provides advanced quantization techniques beyond standard INT8 quantization,
4//! including low-bit quantization, extreme compression methods, and adaptive precision.
5//!
6//! # Features
7//!
8//! - **INT4 Quantization**: 4-bit quantization for extreme compression
9//! - **Binary Quantization**: 1-bit quantization using {-1, +1} values
10//! - **Ternary Quantization**: 2-bit quantization using {-1, 0, +1} values
11//! - **Group-wise Quantization**: Channel grouping for improved accuracy
12//! - **Mixed Precision**: Layer-specific precision assignment
13//! - **Adaptive Thresholding**: Smart threshold selection for extreme quantization
14
15use crate::config::{MixedPrecisionConfig, QuantConfig};
16
17#[cfg(feature = "std")]
18use std::collections::HashMap;
19
20#[cfg(not(feature = "std"))]
21extern crate alloc;
22
23#[cfg(not(feature = "std"))]
24use alloc::{collections::BTreeMap as HashMap, string::String, vec::Vec};
25
26use torsh_core::{
27    dtype::DType,
28    error::{Result as TorshResult, TorshError},
29};
30use torsh_tensor::Tensor;
31
32/// INT4 quantization (4-bit per tensor)
33pub fn quantize_int4_per_tensor(
34    tensor: &Tensor,
35    _config: &QuantConfig,
36) -> TorshResult<(Tensor, f32, i32)> {
37    let data = tensor.data()?;
38    let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b)).min(0.0);
39    let max_val = data
40        .iter()
41        .fold(f32::NEG_INFINITY, |a, &b| a.max(b))
42        .max(0.0);
43
44    // INT4 range: -8 to 7
45    let scale = (max_val - min_val) / 15.0; // 15 = 7 - (-8)
46    let scale = if scale == 0.0 { 1.0 } else { scale };
47
48    let zero_point = (-8.0 - min_val / scale).round().clamp(-8.0, 7.0) as i32;
49
50    let quantized_data: Vec<f32> = data
51        .iter()
52        .map(|&x| {
53            let quantized = (x / scale).round() + zero_point as f32;
54            quantized.clamp(-8.0, 7.0) // Store as f32 for compatibility
55        })
56        .collect();
57
58    let quantized_tensor = Tensor::from_data(
59        quantized_data,
60        tensor.shape().dims().to_vec(),
61        tensor.device(),
62    )?;
63
64    Ok((quantized_tensor, scale, zero_point))
65}
66
67/// INT4 per-channel quantization
68pub fn quantize_int4_per_channel(
69    tensor: &Tensor,
70    axis: usize,
71    _config: &QuantConfig,
72) -> TorshResult<(Tensor, f32, i32)> {
73    let binding = tensor.shape();
74    let shape = binding.dims();
75
76    if axis >= shape.len() {
77        return Err(TorshError::InvalidArgument(
78            "Axis out of bounds".to_string(),
79        ));
80    }
81
82    let num_channels = shape[axis];
83    let data = tensor.data()?;
84
85    // Calculate strides for efficient channel access
86    let mut strides = vec![1; shape.len()];
87    for i in (0..shape.len().saturating_sub(1)).rev() {
88        strides[i] = strides[i + 1] * shape[i + 1];
89    }
90
91    let mut scales = Vec::with_capacity(num_channels);
92    let mut zero_points = Vec::with_capacity(num_channels);
93    let mut quantized_data = vec![0.0f32; data.len()];
94
95    // Process each channel
96    for ch in 0..num_channels {
97        let mut channel_min = f32::INFINITY;
98        let mut channel_max = f32::NEG_INFINITY;
99
100        // Calculate channel statistics
101        for (i, &val) in data.iter().enumerate() {
102            let mut ch_idx = 0;
103            let mut remaining = i;
104
105            // Calculate channel index for this element
106            for (dim, &stride) in strides.iter().enumerate() {
107                let coord = remaining / stride;
108                remaining %= stride;
109                if dim == axis {
110                    ch_idx = coord;
111                }
112            }
113
114            if ch_idx == ch {
115                channel_min = channel_min.min(val);
116                channel_max = channel_max.max(val);
117            }
118        }
119
120        // Ensure min <= max
121        channel_min = channel_min.min(0.0);
122        channel_max = channel_max.max(0.0);
123
124        // Calculate INT4 quantization parameters for this channel
125        let scale = (channel_max - channel_min) / 15.0; // INT4 range: -8 to 7
126        let scale = if scale == 0.0 { 1.0 } else { scale };
127        let zero_point = (-8.0 - channel_min / scale).round().clamp(-8.0, 7.0) as i32;
128
129        scales.push(scale);
130        zero_points.push(zero_point);
131
132        // Quantize channel data
133        for (i, &val) in data.iter().enumerate() {
134            let mut ch_idx = 0;
135            let mut remaining = i;
136
137            for (dim, &stride) in strides.iter().enumerate() {
138                let coord = remaining / stride;
139                remaining %= stride;
140                if dim == axis {
141                    ch_idx = coord;
142                }
143            }
144
145            if ch_idx == ch {
146                let quantized = (val / scale).round() + zero_point as f32;
147                quantized_data[i] = quantized.clamp(-8.0, 7.0);
148            }
149        }
150    }
151
152    let quantized_tensor = Tensor::from_data(quantized_data, shape.to_vec(), tensor.device())?;
153
154    // Return average parameters for compatibility
155    let avg_scale = scales.iter().sum::<f32>() / scales.len() as f32;
156    let avg_zero_point =
157        (zero_points.iter().sum::<i32>() as f32 / zero_points.len() as f32).round() as i32;
158
159    Ok((quantized_tensor, avg_scale, avg_zero_point))
160}
161
162/// Binary quantization (-1, +1)
163pub fn quantize_binary(tensor: &Tensor) -> TorshResult<(Tensor, f32, i32)> {
164    let data = tensor.data()?;
165
166    if data.is_empty() {
167        return Err(TorshError::InvalidArgument(
168            "Cannot quantize empty tensor".to_string(),
169        ));
170    }
171
172    // Calculate scale as the mean of absolute values
173    let scale = data.iter().map(|&x| x.abs()).sum::<f32>() / data.len() as f32;
174    let scale = if scale == 0.0 { 1.0 } else { scale };
175
176    let quantized_data: Vec<f32> = data
177        .iter()
178        .map(|&x| if x >= 0.0 { 1.0 } else { -1.0 })
179        .collect();
180
181    let quantized_tensor = Tensor::from_data(
182        quantized_data,
183        tensor.shape().dims().to_vec(),
184        tensor.device(),
185    )?;
186
187    Ok((quantized_tensor, scale, 0)) // Binary is symmetric, so zero_point = 0
188}
189
190/// Ternary quantization (-1, 0, +1)
191pub fn quantize_ternary(tensor: &Tensor) -> TorshResult<(Tensor, f32, i32)> {
192    let data = tensor.data()?;
193
194    if data.is_empty() {
195        return Err(TorshError::InvalidArgument(
196            "Cannot quantize empty tensor".to_string(),
197        ));
198    }
199
200    // Calculate threshold as fraction of max absolute value
201    let max_abs = data.iter().map(|&x| x.abs()).fold(0.0f32, f32::max);
202    let threshold = max_abs * 0.7; // Threshold parameter
203
204    // Calculate scale from non-zero values
205    let non_zero_sum: f32 = data
206        .iter()
207        .filter(|&&x| x.abs() > threshold)
208        .map(|&x| x.abs())
209        .sum();
210    let non_zero_count = data.iter().filter(|&&x| x.abs() > threshold).count();
211
212    let scale = if non_zero_count > 0 {
213        non_zero_sum / non_zero_count as f32
214    } else {
215        1.0
216    };
217
218    let quantized_data: Vec<f32> = data
219        .iter()
220        .map(|&x| {
221            if x.abs() <= threshold {
222                0.0
223            } else if x > 0.0 {
224                1.0
225            } else {
226                -1.0
227            }
228        })
229        .collect();
230
231    let quantized_tensor = Tensor::from_data(
232        quantized_data,
233        tensor.shape().dims().to_vec(),
234        tensor.device(),
235    )?;
236
237    Ok((quantized_tensor, scale, 0)) // Ternary is symmetric, so zero_point = 0
238}
239
240/// Group-wise quantization (divide channels into groups and quantize per-group)
241pub fn quantize_group_wise(
242    tensor: &Tensor,
243    axis: usize,
244    group_size: usize,
245    config: &QuantConfig,
246) -> TorshResult<(Tensor, f32, i32)> {
247    let binding = tensor.shape();
248    let shape = binding.dims();
249
250    if axis >= shape.len() {
251        return Err(TorshError::InvalidArgument(
252            "Axis out of bounds".to_string(),
253        ));
254    }
255
256    if group_size == 0 {
257        return Err(TorshError::InvalidArgument(
258            "Group size must be greater than 0".to_string(),
259        ));
260    }
261
262    let num_channels = shape[axis];
263    let num_groups = num_channels.div_ceil(group_size); // Ceiling division
264
265    let data = tensor.data()?;
266    let mut quantized_data = vec![0.0f32; data.len()];
267
268    // Calculate strides for indexing (optimized version)
269    let mut strides = vec![1; shape.len()];
270    for i in (0..shape.len().saturating_sub(1)).rev() {
271        strides[i] = strides[i + 1] * shape[i + 1];
272    }
273
274    let mut group_scales = Vec::new();
275    let mut group_zero_points = Vec::new();
276
277    // Process each group
278    for group_idx in 0..num_groups {
279        let start_ch = group_idx * group_size;
280        let end_ch = (start_ch + group_size).min(num_channels);
281
282        // Collect data for this group
283        let mut group_data = Vec::new();
284        for ch in start_ch..end_ch {
285            // Extract data for this channel
286            for (i, _) in data.iter().enumerate() {
287                let idx = i;
288                let mut ch_idx = 0;
289                let mut remaining = idx;
290
291                // Calculate channel index for this element
292                for (dim, &stride) in strides.iter().enumerate() {
293                    let coord = remaining / stride;
294                    remaining %= stride;
295                    if dim == axis {
296                        ch_idx = coord;
297                    }
298                }
299
300                if ch_idx == ch {
301                    group_data.push(data[i]);
302                }
303            }
304        }
305
306        if group_data.is_empty() {
307            continue;
308        }
309
310        // Calculate quantization parameters for this group
311        let min_val = group_data
312            .iter()
313            .fold(f32::INFINITY, |a, &b| a.min(b))
314            .min(0.0);
315        let max_val = group_data
316            .iter()
317            .fold(f32::NEG_INFINITY, |a, &b| a.max(b))
318            .max(0.0);
319
320        let (qmin, qmax) = config.get_qint_range();
321        let scale = (max_val - min_val) / (qmax - qmin) as f32;
322        let scale = if scale == 0.0 { 1.0 } else { scale };
323
324        let zero_point = (qmin as f32 - min_val / scale)
325            .round()
326            .max(qmin as f32)
327            .min(qmax as f32) as i32;
328
329        group_scales.push(scale);
330        group_zero_points.push(zero_point);
331
332        // Quantize this group's data
333        for ch in start_ch..end_ch {
334            for i in 0..data.len() {
335                let idx = i;
336                let mut ch_idx = 0;
337                let mut remaining = idx;
338
339                // Calculate channel index for this element
340                for (dim, &stride) in strides.iter().enumerate() {
341                    let coord = remaining / stride;
342                    remaining %= stride;
343                    if dim == axis {
344                        ch_idx = coord;
345                    }
346                }
347
348                if ch_idx == ch {
349                    let quantized = (data[i] / scale).round() + zero_point as f32;
350                    quantized_data[i] = quantized.max(qmin as f32).min(qmax as f32);
351                }
352            }
353        }
354    }
355
356    let quantized_tensor = Tensor::from_data(
357        quantized_data,
358        tensor.shape().dims().to_vec(),
359        tensor.device(),
360    )?;
361
362    // Return average scale and zero_point for compatibility
363    let avg_scale = if group_scales.is_empty() {
364        1.0
365    } else {
366        group_scales.iter().sum::<f32>() / group_scales.len() as f32
367    };
368    let avg_zero_point = if group_zero_points.is_empty() {
369        0
370    } else {
371        (group_zero_points.iter().sum::<i32>() as f32 / group_zero_points.len() as f32).round()
372            as i32
373    };
374
375    Ok((quantized_tensor, avg_scale, avg_zero_point))
376}
377
378/// Mixed precision quantization for different layers
379pub fn quantize_mixed_precision(
380    tensors: &HashMap<String, Tensor>,
381    config: &MixedPrecisionConfig,
382) -> TorshResult<HashMap<String, (Tensor, f32, i32)>> {
383    let mut results = HashMap::new();
384
385    for (layer_name, tensor) in tensors {
386        // Determine precision for this layer
387        let precision = determine_layer_precision(layer_name, config);
388
389        // Create quantization config for this precision
390        let layer_config = create_precision_config(precision);
391
392        // Quantize using appropriate scheme
393        let result = crate::algorithms::quantize_with_config(tensor, &layer_config)?;
394        results.insert(layer_name.clone(), result);
395    }
396
397    Ok(results)
398}
399
400/// Determine precision for a layer based on mixed precision config
401pub fn determine_layer_precision(layer_name: &str, config: &MixedPrecisionConfig) -> DType {
402    // Check exact matches first
403    for (pattern, precision) in &config.layer_precision {
404        if layer_name.contains(pattern) {
405            return *precision;
406        }
407    }
408
409    // Return default precision
410    config.default_precision
411}
412
413/// Create quantization config for specific precision
414pub fn create_precision_config(precision: DType) -> QuantConfig {
415    match precision {
416        DType::I8 => QuantConfig::int8(),
417        DType::U8 => QuantConfig::uint8(),
418        DType::F16 => {
419            // For FP16, we don't quantize but use reduced precision
420            QuantConfig {
421                dtype: DType::F16,
422                enable_fake_quant: false,
423                ..Default::default()
424            }
425        }
426        DType::F32 => {
427            // Keep full precision
428            QuantConfig {
429                dtype: DType::F32,
430                enable_fake_quant: false,
431                ..Default::default()
432            }
433        }
434        _ => QuantConfig::int8(), // Default fallback
435    }
436}
437
438/// Advanced binary quantization with learned threshold
439pub fn quantize_binary_learned_threshold(
440    tensor: &Tensor,
441    threshold: Option<f32>,
442) -> TorshResult<(Tensor, f32, i32, f32)> {
443    let data = tensor.data()?;
444
445    if data.is_empty() {
446        return Err(TorshError::InvalidArgument(
447            "Cannot quantize empty tensor".to_string(),
448        ));
449    }
450
451    // Use provided threshold or learn it
452    let threshold = threshold.unwrap_or_else(|| {
453        // Simple threshold learning: mean of absolute values
454        let abs_sum: f32 = data.iter().map(|&x| x.abs()).sum();
455        abs_sum / data.len() as f32
456    });
457
458    // Calculate scale from values above threshold
459    let above_threshold: Vec<f32> = data
460        .iter()
461        .filter(|&&x| x.abs() > threshold)
462        .cloned()
463        .collect();
464
465    let scale = if above_threshold.is_empty() {
466        1.0
467    } else {
468        above_threshold.iter().map(|&x| x.abs()).sum::<f32>() / above_threshold.len() as f32
469    };
470
471    let quantized_data: Vec<f32> = data
472        .iter()
473        .map(|&x| {
474            if x.abs() <= threshold {
475                0.0
476            } else if x >= 0.0 {
477                1.0
478            } else {
479                -1.0
480            }
481        })
482        .collect();
483
484    let quantized_tensor = Tensor::from_data(
485        quantized_data,
486        tensor.shape().dims().to_vec(),
487        tensor.device(),
488    )?;
489
490    Ok((quantized_tensor, scale, 0, threshold))
491}
492
493/// Adaptive ternary quantization with optimal threshold selection
494pub fn quantize_ternary_adaptive(tensor: &Tensor) -> TorshResult<(Tensor, f32, i32, f32)> {
495    let data = tensor.data()?;
496
497    if data.is_empty() {
498        return Err(TorshError::InvalidArgument(
499            "Cannot quantize empty tensor".to_string(),
500        ));
501    }
502
503    // Find optimal threshold using search
504    let max_abs = data.iter().map(|&x| x.abs()).fold(0.0f32, f32::max);
505    let mut best_threshold = 0.0;
506    let mut best_error = f32::INFINITY;
507
508    // Search for optimal threshold
509    for i in 1..=10 {
510        let threshold = max_abs * (i as f32 * 0.1);
511        let error = calculate_ternary_error(&data, threshold);
512        if error < best_error {
513            best_error = error;
514            best_threshold = threshold;
515        }
516    }
517
518    // Apply quantization with best threshold
519    let non_zero_sum: f32 = data
520        .iter()
521        .filter(|&&x| x.abs() > best_threshold)
522        .map(|&x| x.abs())
523        .sum();
524    let non_zero_count = data.iter().filter(|&&x| x.abs() > best_threshold).count();
525
526    let scale = if non_zero_count > 0 {
527        non_zero_sum / non_zero_count as f32
528    } else {
529        1.0
530    };
531
532    let quantized_data: Vec<f32> = data
533        .iter()
534        .map(|&x| {
535            if x.abs() <= best_threshold {
536                0.0
537            } else if x > 0.0 {
538                1.0
539            } else {
540                -1.0
541            }
542        })
543        .collect();
544
545    let quantized_tensor = Tensor::from_data(
546        quantized_data,
547        tensor.shape().dims().to_vec(),
548        tensor.device(),
549    )?;
550
551    Ok((quantized_tensor, scale, 0, best_threshold))
552}
553
554/// Calculate quantization error for ternary quantization with given threshold
555fn calculate_ternary_error(data: &[f32], threshold: f32) -> f32 {
556    data.iter()
557        .map(|&x| {
558            let quantized = if x.abs() <= threshold {
559                0.0
560            } else if x > 0.0 {
561                1.0
562            } else {
563                -1.0
564            };
565            (x - quantized).powi(2)
566        })
567        .sum::<f32>()
568        / data.len() as f32
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574    use torsh_core::device::DeviceType;
575    use torsh_tensor::creation::tensor_1d;
576
577    #[test]
578    fn test_quantize_int4_per_tensor() {
579        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
580        let tensor = tensor_1d(&data).unwrap();
581        let config = QuantConfig::int4();
582
583        let result = quantize_int4_per_tensor(&tensor, &config);
584        assert!(result.is_ok());
585
586        let (quantized, scale, zero_point) = result.unwrap();
587        assert!(scale > 0.0);
588        assert!(zero_point >= -8 && zero_point <= 7);
589
590        let quantized_data = quantized.data().unwrap();
591        assert_eq!(quantized_data.len(), data.len());
592
593        // Check that values are in INT4 range
594        for &val in &quantized_data {
595            assert!(val >= -8.0 && val <= 7.0);
596        }
597    }
598
599    #[test]
600    fn test_quantize_binary() {
601        let data = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
602        let tensor = tensor_1d(&data).unwrap();
603
604        let result = quantize_binary(&tensor);
605        assert!(result.is_ok());
606
607        let (quantized, scale, zero_point) = result.unwrap();
608        assert!(scale > 0.0);
609        assert_eq!(zero_point, 0); // Binary is symmetric
610
611        let quantized_data = quantized.data().unwrap();
612        assert_eq!(quantized_data.len(), data.len());
613
614        // Check that all values are either -1.0 or 1.0
615        for &val in &quantized_data {
616            assert!(val == -1.0 || val == 1.0);
617        }
618    }
619
620    #[test]
621    fn test_quantize_ternary() {
622        let data = vec![-3.0, -1.0, 0.1, 1.0, 3.0];
623        let tensor = tensor_1d(&data).unwrap();
624
625        let result = quantize_ternary(&tensor);
626        assert!(result.is_ok());
627
628        let (quantized, scale, zero_point) = result.unwrap();
629        assert!(scale > 0.0);
630        assert_eq!(zero_point, 0); // Ternary is symmetric
631
632        let quantized_data = quantized.data().unwrap();
633        assert_eq!(quantized_data.len(), data.len());
634
635        // Check that all values are -1.0, 0.0, or 1.0
636        for &val in &quantized_data {
637            assert!(val == -1.0 || val == 0.0 || val == 1.0);
638        }
639    }
640
641    #[test]
642    fn test_quantize_group_wise() {
643        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
644        let tensor = Tensor::from_data(data, vec![2, 3], DeviceType::Cpu).unwrap();
645        let config = QuantConfig::group_wise(1, 2);
646
647        let result = quantize_group_wise(&tensor, 1, 2, &config);
648        assert!(result.is_ok());
649
650        let (quantized, scale, _zero_point) = result.unwrap();
651        assert!(scale > 0.0);
652        assert_eq!(quantized.shape().dims(), tensor.shape().dims());
653    }
654
655    #[test]
656    fn test_mixed_precision() {
657        let mut tensors = HashMap::new();
658        tensors.insert(
659            "embedding".to_string(),
660            tensor_1d(&[1.0, 2.0, 3.0]).unwrap(),
661        );
662        tensors.insert(
663            "attention".to_string(),
664            tensor_1d(&[4.0, 5.0, 6.0]).unwrap(),
665        );
666
667        let config = MixedPrecisionConfig::default();
668
669        let result = quantize_mixed_precision(&tensors, &config);
670        assert!(result.is_ok());
671
672        let results = result.unwrap();
673        assert_eq!(results.len(), 2);
674        assert!(results.contains_key("embedding"));
675        assert!(results.contains_key("attention"));
676    }
677
678    #[test]
679    fn test_determine_layer_precision() {
680        let config = MixedPrecisionConfig::default();
681
682        let embedding_precision = determine_layer_precision("layer.embedding.weight", &config);
683        assert_eq!(embedding_precision, DType::I8);
684
685        let attention_precision = determine_layer_precision("layer.attention.query", &config);
686        assert_eq!(attention_precision, DType::F16);
687
688        let unknown_precision = determine_layer_precision("layer.unknown.weight", &config);
689        assert_eq!(unknown_precision, DType::I8); // Default
690    }
691
692    #[test]
693    fn test_binary_learned_threshold() {
694        let data = vec![-2.0, -0.1, 0.1, 0.5, 2.0];
695        let tensor = tensor_1d(&data).unwrap();
696
697        let result = quantize_binary_learned_threshold(&tensor, Some(0.3));
698        assert!(result.is_ok());
699
700        let (quantized, scale, zero_point, threshold) = result.unwrap();
701        assert!(scale > 0.0);
702        assert_eq!(zero_point, 0);
703        assert_eq!(threshold, 0.3);
704
705        let quantized_data = quantized.data().unwrap();
706
707        // Values below threshold should become 0, others ±1
708        for (i, &original) in data.iter().enumerate() {
709            let expected = if original.abs() <= 0.3 {
710                0.0
711            } else if original >= 0.0 {
712                1.0
713            } else {
714                -1.0
715            };
716            assert_eq!(quantized_data[i], expected);
717        }
718    }
719
720    #[test]
721    fn test_ternary_adaptive() {
722        let data = vec![-3.0, -0.5, 0.0, 0.5, 3.0];
723        let tensor = tensor_1d(&data).unwrap();
724
725        let result = quantize_ternary_adaptive(&tensor);
726        assert!(result.is_ok());
727
728        let (quantized, scale, zero_point, threshold) = result.unwrap();
729        assert!(scale > 0.0);
730        assert_eq!(zero_point, 0);
731        assert!(threshold > 0.0);
732
733        let quantized_data = quantized.data().unwrap();
734        assert_eq!(quantized_data.len(), data.len());
735
736        // All values should be -1, 0, or 1
737        for &val in &quantized_data {
738            assert!(val == -1.0 || val == 0.0 || val == 1.0);
739        }
740    }
741
742    #[test]
743    fn test_error_cases() {
744        // Test empty tensor
745        let empty_data: Vec<f32> = vec![];
746        let empty_tensor = tensor_1d(&empty_data).unwrap();
747
748        assert!(quantize_binary(&empty_tensor).is_err());
749        assert!(quantize_ternary(&empty_tensor).is_err());
750
751        // Test invalid axis
752        let data = vec![1.0, 2.0, 3.0];
753        let tensor = tensor_1d(&data).unwrap();
754        let config = QuantConfig::group_wise(0, 2);
755
756        let result = quantize_group_wise(&tensor, 5, 2, &config);
757        assert!(result.is_err());
758
759        // Test zero group size
760        let result = quantize_group_wise(&tensor, 0, 0, &config);
761        assert!(result.is_err());
762    }
763}