Skip to main content

scirs2_numpy/
dlpack_cuda.rs

1//! CUDA / GPU tensor passthrough via DLPack.
2//!
3//! When a DLPack capsule contains a tensor resident on CUDA (or another
4//! non-CPU device), naively trying to consume it as a CPU `ndarray` would
5//! either panic or silently trigger an unacceptable host-device copy.
6//!
7//! This module provides:
8//!
9//! - [`CudaTensorInfo`] — metadata extracted from a CUDA DLPack tensor
10//!   without triggering any data copy.
11//! - [`cuda_tensor_info_from_dltensor`] — pure-Rust function operating
12//!   directly on a [`DLTensor`]; no Python runtime needed.
13//! - [`dlpack_auto_dispatch_f32`] / [`dlpack_auto_dispatch_f64`] — device-
14//!   aware dispatch that returns an `ndarray` view for CPU tensors, or
15//!   `CudaTensorInfo` for GPU tensors, with **no host-device copy**.
16//!
17//! # Design
18//!
19//! The DLPack standard defines device type codes:
20//!
21//! | Code | Device |
22//! |------|--------|
23//! | 1    | CPU    |
24//! | 2    | CUDA   |
25//! | 3    | CUDA pinned host |
26//! | 4    | OpenCL |
27//! | 7    | Vulkan |
28//! | 8    | Metal  |
29//! | 10   | ROCm   |
30//!
31//! For CPU tensors (type 1) the existing zero-copy `array_from_dlpack_f32/f64`
32//! functions are used directly.  For CUDA tensors (type 2) we extract shape,
33//! dtype, device_id, and byte_offset without touching the data pointer.
34//!
35//! # CUDA runtime linkage
36//!
37//! Full GPU-to-GPU processing (e.g. copying the tensor buffer to a
38//! `cudarc`-managed allocation) requires the `cuda_special` cargo feature
39//! and CUDA runtime linkage, which is deliberately kept **out of default
40//! features** to preserve the Pure Rust build.  With default features only
41//! the metadata extraction path is available.
42
43use std::ffi::CStr;
44
45use pyo3::prelude::*;
46use pyo3::types::PyDict;
47
48use crate::dlpack::{
49    array_from_dlpack_f32, array_from_dlpack_f64, DLDataType, DLDeviceType, DLManagedTensor,
50    DLTensor, DlpackError,
51};
52
53/// DLPack capsule name used when the capsule has not yet been consumed.
54const DLTENSOR_NAME: &CStr = c"dltensor";
55
56/// DLPack capsule name used after the capsule has been consumed by a framework.
57///
58/// Attempting to call [`cuda_tensor_info`] on an already-consumed capsule raises
59/// [`pyo3::exceptions::PyValueError`].
60const USED_DLTENSOR_NAME: &CStr = c"used_dltensor";
61
62/// Metadata extracted from a CUDA-resident DLPack tensor.
63///
64/// Accessing this struct **does not** copy tensor data from GPU to CPU.
65/// It records only the shape, dtype, device index, and byte offset that are
66/// safe to read from the capsule header.
67#[derive(Debug, Clone)]
68pub struct CudaTensorInfo {
69    /// Zero-based index of the CUDA device (e.g., 0 for the first GPU).
70    pub device_id: i32,
71    /// Tensor dimensions in row-major (C) order.
72    pub shape: Vec<usize>,
73    /// DLPack element data-type descriptor.
74    pub dtype: DLDataType,
75    /// Byte offset from the data pointer to the first element.
76    pub byte_offset: u64,
77    /// Raw device-type code (2 = CUDA, 10 = ROCm, etc.).
78    pub device_type_code: i32,
79}
80
81impl CudaTensorInfo {
82    /// Return the total number of elements (product of shape dimensions).
83    pub fn numel(&self) -> usize {
84        self.shape.iter().product()
85    }
86
87    /// Return the element bit-width.
88    pub fn dtype_bits(&self) -> u8 {
89        self.dtype.bits
90    }
91
92    /// Return a human-readable device string (e.g. `"cuda:0"`).
93    pub fn device_str(&self) -> String {
94        let name = match self.device_type_code {
95            2 => "cuda",
96            3 => "cuda_host",
97            4 => "opencl",
98            7 => "vulkan",
99            8 => "metal",
100            10 => "rocm",
101            _ => "unknown",
102        };
103        format!("{}:{}", name, self.device_id)
104    }
105}
106
107/// The result of auto-dispatching a DLPack tensor based on its device.
108///
109/// No host-device copy is performed for [`DLPackDispatchResult::Gpu`].
110pub enum DLPackDispatchResult<'a, T> {
111    /// The tensor is on CPU and has been zero-copy viewed as an `ndarray`.
112    Cpu(ndarray::ArrayViewD<'a, T>),
113    /// The tensor is on a GPU (or other accelerator) — metadata is returned
114    /// without touching the data buffer.
115    Gpu(CudaTensorInfo),
116    /// The tensor is on an unrecognised device.
117    OtherDevice {
118        /// DLPack device-type code.
119        device_type: i32,
120        /// Device index.
121        device_id: i32,
122    },
123}
124
125/// Extract [`CudaTensorInfo`] from a raw [`DLTensor`] pointer.
126///
127/// This is the **pure-Rust** implementation that does not require a Python
128/// runtime; the PyO3 wrappers delegate to this function after extracting the
129/// `DLTensor` from the capsule.
130///
131/// # Errors
132///
133/// Returns [`DlpackError::NonCpuDevice`] when `tensor.device.device_type == 1`
134/// (CPU), because `CudaTensorInfo` is only meaningful for non-CPU devices.
135///
136/// Returns [`DlpackError::NullPointer`] when `tensor.data` is null.
137///
138/// # Safety
139///
140/// `tensor.shape` must be valid for `tensor.ndim` elements.
141pub fn cuda_tensor_info_from_dltensor(tensor: &DLTensor) -> Result<CudaTensorInfo, DlpackError> {
142    if tensor.data.is_null() {
143        return Err(DlpackError::NullPointer);
144    }
145    // Reject CPU tensors — this function is for non-CPU devices.
146    if tensor.device.device_type == DLDeviceType::Cpu as i32 {
147        return Err(DlpackError::NonCpuDevice);
148    }
149    let ndim = tensor.ndim.max(0) as usize;
150    let shape = if ndim == 0 || tensor.shape.is_null() {
151        Vec::new()
152    } else {
153        // SAFETY: caller guarantees shape is valid for ndim elements.
154        unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, ndim) }
155            .iter()
156            .map(|&d| d as usize)
157            .collect()
158    };
159    Ok(CudaTensorInfo {
160        device_id: tensor.device.device_id,
161        shape,
162        dtype: tensor.dtype,
163        byte_offset: tensor.byte_offset,
164        device_type_code: tensor.device.device_type,
165    })
166}
167
168/// Dispatch an `f32` DLPack tensor to CPU or GPU path without a CPU roundtrip.
169///
170/// - **CPU (device_type=1)**: zero-copy `ndarray::ArrayViewD<f32>`.
171/// - **GPU/accelerator (device_type≠1)**: metadata only, no data copy.
172///
173/// # Safety
174///
175/// `tensor` must be a valid, aligned, non-null pointer to a [`DLTensor`]
176/// whose `shape` field is valid for `ndim` elements.  The tensor and its data
177/// must remain live and unmodified for the lifetime `'a` of the returned view.
178pub unsafe fn dlpack_auto_dispatch_f32<'a>(
179    tensor: *const DLTensor,
180) -> Result<DLPackDispatchResult<'a, f32>, DlpackError> {
181    // SAFETY: caller guarantees tensor is valid.
182    let t = unsafe { &*tensor };
183    match t.device.device_type {
184        dt if dt == DLDeviceType::Cpu as i32 => {
185            // SAFETY: forwarded from caller invariants.
186            let view = unsafe { array_from_dlpack_f32(tensor)? };
187            Ok(DLPackDispatchResult::Cpu(view))
188        }
189        _ => {
190            let info = cuda_tensor_info_from_dltensor(t)?;
191            Ok(DLPackDispatchResult::Gpu(info))
192        }
193    }
194}
195
196/// Dispatch an `f64` DLPack tensor to CPU or GPU path without a CPU roundtrip.
197///
198/// Same semantics as [`dlpack_auto_dispatch_f32`] but for 64-bit floats.
199///
200/// # Safety
201///
202/// Same invariants as [`dlpack_auto_dispatch_f32`].
203pub unsafe fn dlpack_auto_dispatch_f64<'a>(
204    tensor: *const DLTensor,
205) -> Result<DLPackDispatchResult<'a, f64>, DlpackError> {
206    // SAFETY: caller guarantees tensor is valid.
207    let t = unsafe { &*tensor };
208    match t.device.device_type {
209        dt if dt == DLDeviceType::Cpu as i32 => {
210            // SAFETY: forwarded from caller invariants.
211            let view = unsafe { array_from_dlpack_f64(tensor)? };
212            Ok(DLPackDispatchResult::Cpu(view))
213        }
214        _ => {
215            let info = cuda_tensor_info_from_dltensor(t)?;
216            Ok(DLPackDispatchResult::Gpu(info))
217        }
218    }
219}
220
221// ─── Python-facing capsule API ───────────────────────────────────────────────
222
223/// Extract [`CudaTensorInfo`] from a Python DLPack capsule object.
224///
225/// Accepts any Python object that wraps a `"dltensor"` PyCapsule — typically
226/// the return value of `tensor.__dlpack__()` on a PyTorch, JAX, or CuPy GPU
227/// tensor.
228///
229/// # Errors
230///
231/// Returns [`PyValueError`] when:
232/// - the object is not a PyCapsule named `"dltensor"`,
233/// - the capsule has already been consumed (name `"used_dltensor"`),
234/// - the underlying tensor is CPU-resident (use the regular DLPack CPU path),
235/// - the data pointer inside the managed tensor is null.
236///
237/// # Example (Python side)
238///
239/// ```python
240/// import torch
241/// t = torch.zeros(4, 4, device="cuda")
242/// capsule = t.__dlpack__()
243/// info = scirs2_numpy.get_cuda_tensor_info(capsule)
244/// # info == {"device_id": 0, "shape": [4, 4], "device_type": 2, "device_str": "cuda:0"}
245/// ```
246pub fn cuda_tensor_info(capsule: &Bound<'_, PyAny>) -> PyResult<CudaTensorInfo> {
247    // ── Step 1: obtain the raw PyObject* ────────────────────────────────────
248    let raw_obj: *mut pyo3::ffi::PyObject = capsule.as_ptr();
249
250    // ── Step 2: detect "used_dltensor" and give a clear error ───────────────
251    // PyCapsule_IsValid returns 1 when the capsule exists and its name matches.
252    // We check `used_dltensor` first so the error message is actionable.
253    let is_used =
254        unsafe { pyo3::ffi::PyCapsule_IsValid(raw_obj, USED_DLTENSOR_NAME.as_ptr()) == 1 };
255    if is_used {
256        return Err(pyo3::exceptions::PyValueError::new_err(
257            "DLPack capsule has already been consumed ('used_dltensor'). \
258             Call __dlpack__() again on the original tensor.",
259        ));
260    }
261
262    // ── Step 3: retrieve the dltensor pointer ───────────────────────────────
263    let raw_ptr = unsafe { pyo3::ffi::PyCapsule_GetPointer(raw_obj, DLTENSOR_NAME.as_ptr()) };
264
265    if raw_ptr.is_null() {
266        // PyCapsule_GetPointer sets a Python exception when it fails.
267        // Propagate it as a PyErr.
268        return Err(PyErr::fetch(capsule.py()));
269    }
270
271    // ── Step 4: dereference the managed tensor to get DLTensor ─────────────
272    // SAFETY: `raw_ptr` was returned by PyCapsule_GetPointer with the correct
273    // capsule name; it points to the DLManagedTensor stored when the capsule
274    // was created.  The capsule (and therefore this allocation) remains live
275    // as long as `capsule` is live.
276    let managed = unsafe { &*(raw_ptr as *const DLManagedTensor) };
277    let dl_tensor = &managed.dl_tensor;
278
279    // ── Step 5: delegate to pure-Rust extractor ─────────────────────────────
280    cuda_tensor_info_from_dltensor(dl_tensor).map_err(|e| match e {
281        DlpackError::NonCpuDevice => pyo3::exceptions::PyValueError::new_err(
282            "cuda_tensor_info requires a non-CPU DLPack tensor. \
283             Use the standard DLPack CPU path for CPU tensors.",
284        ),
285        DlpackError::NullPointer => {
286            pyo3::exceptions::PyValueError::new_err("DLPack tensor has a null data pointer.")
287        }
288        other => pyo3::exceptions::PyValueError::new_err(format!("DLPack error: {other}")),
289    })
290}
291
292/// Python-facing function: extract GPU tensor metadata from a DLPack capsule.
293///
294/// Accepts the capsule returned by `tensor.__dlpack__()` and returns a dict:
295///
296/// ```text
297/// {
298///   "device_id":   int,   # zero-based GPU index
299///   "shape":       list,  # tensor dimensions
300///   "device_type": int,   # raw DLPack device code (2=CUDA, 10=ROCm, …)
301///   "device_str":  str,   # human-readable, e.g. "cuda:0"
302/// }
303/// ```
304///
305/// Raises `ValueError` for CPU tensors, consumed capsules, or null data pointers.
306#[pyfunction]
307pub fn get_cuda_tensor_info(py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult<Py<PyDict>> {
308    let info = cuda_tensor_info(obj)?;
309
310    let dict = PyDict::new(py);
311    dict.set_item("device_id", info.device_id)?;
312    dict.set_item("shape", info.shape.clone())?;
313    dict.set_item("device_type", info.device_type_code)?;
314    dict.set_item("device_str", info.device_str())?;
315    Ok(dict.into())
316}
317
318/// Register the `get_cuda_tensor_info` function into a PyO3 module.
319///
320/// Call this from your `#[pymodule]` init function to expose the function to Python.
321pub fn register_dlpack_cuda_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
322    m.add_function(wrap_pyfunction!(get_cuda_tensor_info, m)?)?;
323    Ok(())
324}
325
326// ─── Testing helpers ─────────────────────────────────────────────────────────
327
328/// Build a mock non-CPU [`DLTensor`] for testing (never accesses data pointer).
329///
330/// The `data` pointer is set to a sentinel non-null value; **do not** read from it.
331#[cfg(test)]
332fn make_non_cpu_dltensor(device_type: i32, device_id: i32, shape: &[i64]) -> DLTensor {
333    use crate::dlpack::{DLDataTypeCode, DLDevice};
334    use std::ffi::c_void;
335    // Sentinel non-null data pointer (never dereferenced in metadata-only path).
336    static SENTINEL: u8 = 0;
337    DLTensor {
338        data: &SENTINEL as *const u8 as *mut c_void,
339        device: DLDevice {
340            device_type,
341            device_id,
342        },
343        ndim: shape.len() as i32,
344        dtype: DLDataType {
345            code: DLDataTypeCode::Float as u8,
346            bits: 32,
347            lanes: 1,
348        },
349        shape: shape.as_ptr() as *mut i64,
350        strides: std::ptr::null_mut(),
351        byte_offset: 0,
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use crate::dlpack::{dlpack_from_slice, DLDeviceType};
359    use std::ffi::c_void;
360
361    // ─── cuda_tensor_info_from_dltensor ──────────────────────────────────────
362
363    #[test]
364    fn test_cuda_tensor_info_rejects_cpu_device() {
365        let data = [1.0_f64, 2.0, 3.0];
366        let shape = [3_i64];
367        let tensor = dlpack_from_slice(&data, &shape);
368        // CPU tensor (device_type=1) must be rejected.
369        let result = cuda_tensor_info_from_dltensor(&tensor);
370        assert!(
371            matches!(result, Err(DlpackError::NonCpuDevice)),
372            "CPU tensor should be rejected by cuda_tensor_info_from_dltensor"
373        );
374    }
375
376    #[test]
377    fn test_cuda_tensor_info_rejects_null_data() {
378        let shape = [4_i64, 4];
379        let mut tensor = make_non_cpu_dltensor(2, 0, &shape);
380        tensor.data = std::ptr::null_mut();
381        let result = cuda_tensor_info_from_dltensor(&tensor);
382        assert!(
383            matches!(result, Err(DlpackError::NullPointer)),
384            "null data pointer should be rejected"
385        );
386    }
387
388    #[test]
389    fn test_cuda_tensor_info_extracts_shape() {
390        let shape = [3_i64, 4, 5];
391        let tensor = make_non_cpu_dltensor(2, 0, &shape);
392        let info = cuda_tensor_info_from_dltensor(&tensor)
393            .expect("CUDA tensor should produce CudaTensorInfo");
394        assert_eq!(info.shape, vec![3, 4, 5], "shape mismatch");
395        assert_eq!(info.numel(), 60, "numel mismatch");
396    }
397
398    #[test]
399    fn test_cuda_tensor_info_extracts_device_id() {
400        let shape = [8_i64];
401        let tensor = make_non_cpu_dltensor(2, 3, &shape); // device_id = 3 (4th GPU)
402        let info = cuda_tensor_info_from_dltensor(&tensor).expect("should produce CudaTensorInfo");
403        assert_eq!(info.device_id, 3, "device_id mismatch");
404        assert_eq!(
405            info.device_type_code, 2,
406            "device_type_code should be CUDA (2)"
407        );
408    }
409
410    #[test]
411    fn test_cuda_tensor_info_device_str() {
412        let shape = [1_i64];
413        let tensor = make_non_cpu_dltensor(2, 0, &shape);
414        let info = cuda_tensor_info_from_dltensor(&tensor).expect("should succeed");
415        assert_eq!(info.device_str(), "cuda:0");
416    }
417
418    #[test]
419    fn test_rocm_tensor_info_device_str() {
420        let shape = [1_i64];
421        let tensor = make_non_cpu_dltensor(10, 1, &shape); // ROCm device 1
422        let info = cuda_tensor_info_from_dltensor(&tensor).expect("should succeed");
423        assert_eq!(info.device_str(), "rocm:1");
424    }
425
426    #[test]
427    fn test_cuda_tensor_info_zero_dim_tensor() {
428        // ndim=0, shape ptr null.
429        use crate::dlpack::{DLDataType, DLDataTypeCode};
430        static SENTINEL: u8 = 0;
431        let tensor = DLTensor {
432            data: &SENTINEL as *const u8 as *mut c_void,
433            device: crate::dlpack::DLDevice {
434                device_type: 2,
435                device_id: 0,
436            },
437            ndim: 0,
438            dtype: DLDataType {
439                code: DLDataTypeCode::Float as u8,
440                bits: 32,
441                lanes: 1,
442            },
443            shape: std::ptr::null_mut(),
444            strides: std::ptr::null_mut(),
445            byte_offset: 0,
446        };
447        let info = cuda_tensor_info_from_dltensor(&tensor).expect("zero-dim should succeed");
448        assert!(info.shape.is_empty(), "zero-dim shape should be empty");
449        assert_eq!(info.numel(), 1, "empty product is 1");
450    }
451
452    // ─── dlpack_auto_dispatch_f32 ────────────────────────────────────────────
453
454    #[test]
455    fn test_dlpack_auto_dispatch_cpu_f32_returns_array() {
456        let data = [1.0_f32, 2.0, 3.0, 4.0];
457        let shape = [2_i64, 2];
458        let tensor = crate::dlpack::DLTensor {
459            data: data.as_ptr() as *mut c_void,
460            device: crate::dlpack::DLDevice {
461                device_type: DLDeviceType::Cpu as i32,
462                device_id: 0,
463            },
464            ndim: 2,
465            dtype: crate::dlpack::DLDataType {
466                code: crate::dlpack::DLDataTypeCode::Float as u8,
467                bits: 32,
468                lanes: 1,
469            },
470            shape: shape.as_ptr() as *mut i64,
471            strides: std::ptr::null_mut(),
472            byte_offset: 0,
473        };
474        // SAFETY: tensor is valid; data and shape are alive.
475        let result = unsafe { dlpack_auto_dispatch_f32(&tensor as *const _) }
476            .expect("CPU dispatch should succeed");
477        assert!(
478            matches!(result, DLPackDispatchResult::Cpu(_)),
479            "CPU tensor should return Cpu variant"
480        );
481        if let DLPackDispatchResult::Cpu(view) = result {
482            assert_eq!(view.shape(), &[2, 2]);
483            assert_eq!(view[[0, 0]], 1.0_f32);
484        }
485    }
486
487    #[test]
488    fn test_dlpack_auto_dispatch_cuda_f32_returns_gpu_info() {
489        let shape = [8_i64];
490        let tensor = make_non_cpu_dltensor(DLDeviceType::Cuda as i32, 0, &shape);
491        // SAFETY: tensor is valid; shape is alive.
492        let result = unsafe { dlpack_auto_dispatch_f32(&tensor as *const _) }
493            .expect("CUDA dispatch should succeed");
494        assert!(
495            matches!(result, DLPackDispatchResult::Gpu(_)),
496            "CUDA tensor should return Gpu variant"
497        );
498        if let DLPackDispatchResult::Gpu(info) = result {
499            assert_eq!(info.shape, vec![8]);
500            assert_eq!(info.device_type_code, 2);
501        }
502    }
503
504    #[test]
505    fn test_dlpack_auto_dispatch_cpu_f64_returns_array() {
506        let data = [10.0_f64, 20.0, 30.0];
507        let shape = [3_i64];
508        let tensor = dlpack_from_slice(&data, &shape);
509        // SAFETY: tensor is valid; data and shape are alive.
510        let result = unsafe { dlpack_auto_dispatch_f64(&tensor as *const _) }
511            .expect("CPU f64 dispatch should succeed");
512        assert!(
513            matches!(result, DLPackDispatchResult::Cpu(_)),
514            "CPU f64 tensor should return Cpu variant"
515        );
516    }
517
518    #[test]
519    fn test_dlpack_auto_dispatch_cuda_f64_returns_gpu_info() {
520        let shape = [4_i64, 4];
521        let tensor = make_non_cpu_dltensor(DLDeviceType::Cuda as i32, 1, &shape);
522        // SAFETY: tensor is valid; shape is alive.
523        let result = unsafe { dlpack_auto_dispatch_f64(&tensor as *const _) }
524            .expect("CUDA f64 dispatch should succeed");
525        if let DLPackDispatchResult::Gpu(info) = result {
526            assert_eq!(info.shape, vec![4, 4]);
527            assert_eq!(info.device_id, 1);
528        } else {
529            panic!("expected Gpu variant");
530        }
531    }
532
533    #[test]
534    fn test_dlpack_other_device_passthrough() {
535        // Metal (device_type=8) should also return Gpu variant.
536        let shape = [16_i64];
537        let tensor = make_non_cpu_dltensor(DLDeviceType::Metal as i32, 0, &shape);
538        // SAFETY: tensor is valid; shape is alive.
539        let result = unsafe { dlpack_auto_dispatch_f32(&tensor as *const _) }
540            .expect("Metal dispatch should succeed");
541        assert!(
542            matches!(result, DLPackDispatchResult::Gpu(_)),
543            "Metal tensor should return Gpu variant"
544        );
545        if let DLPackDispatchResult::Gpu(info) = result {
546            assert_eq!(info.device_str(), "metal:0");
547        }
548    }
549
550    #[test]
551    fn test_dlpack_rocm_passthrough() {
552        let shape = [32_i64];
553        let tensor = make_non_cpu_dltensor(DLDeviceType::Rocm as i32, 2, &shape);
554        // SAFETY: tensor is valid; shape is alive.
555        let result = unsafe { dlpack_auto_dispatch_f64(&tensor as *const _) }
556            .expect("ROCm dispatch should succeed");
557        if let DLPackDispatchResult::Gpu(info) = result {
558            assert_eq!(info.device_type_code, 10);
559            assert_eq!(info.device_id, 2);
560        } else {
561            panic!("expected Gpu variant for ROCm device");
562        }
563    }
564
565    #[test]
566    fn test_cuda_tensor_numel_empty_shape() {
567        let shape: [i64; 0] = [];
568        let tensor = make_non_cpu_dltensor(2, 0, &shape);
569        // ndim=0, shape ptr points to empty slice.
570        let info = cuda_tensor_info_from_dltensor(&tensor).expect("should succeed");
571        assert_eq!(info.numel(), 1, "empty shape product is 1");
572    }
573
574    #[test]
575    fn test_cuda_tensor_dtype_bits() {
576        let shape = [4_i64];
577        let tensor = make_non_cpu_dltensor(2, 0, &shape);
578        let info = cuda_tensor_info_from_dltensor(&tensor).expect("should succeed");
579        assert_eq!(info.dtype_bits(), 32, "dtype bits should be 32");
580    }
581}