Skip to main content

scirs2/
dlpack.rs

1//! DLPack tensor interop for scirs2-python
2//!
3//! Provides `from_dlpack` and `to_dlpack` entry points that follow the
4//! DLPack 1.0 protocol.  Full zero-copy sharing with PyTorch, JAX, CuPy,
5//! TensorFlow etc. requires the calling Python environment to have the
6//! relevant library installed; the Rust side handles the capsule protocol.
7//!
8//! # DLPack protocol
9//!
10//! A *DLPack capsule* is a `PyCapsule` object whose name is `"dltensor"`.
11//! After the consumer takes ownership, the capsule is renamed to
12//! `"used_dltensor"` so double-frees are prevented.
13//!
14//! # Python usage
15//!
16//! ```python
17//! import torch
18//! import scirs2
19//!
20//! t = torch.randn(3, 4)
21//! # PyTorch tensors expose __dlpack__() / __dlpack_device__()
22//! capsule = t.__dlpack__()
23//! arr = scirs2.from_dlpack(capsule)   # -> scirs2 array (NumPy-compatible)
24//!
25//! # Round-trip: export back
26//! cap2 = scirs2.to_dlpack(arr)
27//! t2 = torch.from_dlpack(cap2)
28//! ```
29
30use std::ffi::{c_void, CStr};
31use std::ptr::NonNull;
32
33use pyo3::exceptions::{PyRuntimeError, PyTypeError, PyValueError};
34use pyo3::prelude::*;
35use pyo3::types::{PyCapsule, PyCapsuleMethods};
36use scirs2_numpy::dlpack::{
37    DLDataType, DLDataTypeCode, DLDevice, DLDeviceType, DLManagedTensor, DLTensor,
38};
39
40/// Expected DLPack capsule name (C string literal, DLPack 1.0 spec).
41const DLTENSOR_NAME: &CStr = c"dltensor";
42
43/// Name the capsule is renamed to once consumed (prevents double-free).
44const USED_DLTENSOR_NAME: &CStr = c"used_dltensor";
45
46// ─── Ownership wrapper ────────────────────────────────────────────────────────
47
48/// Heap allocation that backs a DLPack capsule created by `to_dlpack`.
49///
50/// Bundles the `DLManagedTensor` with the shape/strides arrays and the owned
51/// data copy.  All memory is freed through `BackingStore::drop_raw`.
52struct BackingStore {
53    /// ABI-compatible managed-tensor struct; must be the first field so that
54    /// a `*mut BackingStore` can be cast to `*mut DLManagedTensor` safely.
55    managed: DLManagedTensor,
56    /// Owned copy of the tensor's element data.
57    data: Vec<f64>,
58    /// Owned shape array (length = `managed.dl_tensor.ndim`).
59    shape: Vec<i64>,
60    /// Owned strides array (length = `managed.dl_tensor.ndim`).
61    strides: Vec<i64>,
62}
63
64impl BackingStore {
65    /// Free a `BackingStore` that was previously leaked with `Box::into_raw`.
66    ///
67    /// # Safety
68    ///
69    /// `ptr` must be a non-null pointer obtained from `Box::into_raw` on a
70    /// `BackingStore`.  This function must be called at most once.
71    unsafe fn drop_raw(ptr: *mut BackingStore) {
72        if !ptr.is_null() {
73            // SAFETY: ptr was obtained from Box::into_raw.
74            drop(unsafe { Box::from_raw(ptr) });
75        }
76    }
77}
78
79/// DLPack `deleter` stored inside the `DLManagedTensor`.
80///
81/// Called by the consumer framework (PyTorch, JAX, etc.) when it is finished
82/// with the tensor.
83///
84/// # Safety
85///
86/// `managed` must point to the `managed` field of a `BackingStore` that was
87/// previously leaked via `Box::into_raw`.
88unsafe extern "C" fn backing_store_deleter(managed: *mut DLManagedTensor) {
89    if managed.is_null() {
90        return;
91    }
92    // SAFETY: BackingStore has `managed` as its first field, so the pointer
93    // arithmetic is a no-op and the cast is valid.
94    let backing = managed as *mut BackingStore;
95    // SAFETY: backed by a Box::into_raw call in `to_dlpack`.
96    unsafe { BackingStore::drop_raw(backing) };
97}
98
99/// Destructor registered with `PyCapsule::new_with_pointer_and_destructor`.
100///
101/// Called by Python's GC when the capsule object is finalized.  Extracts the
102/// `BackingStore` raw pointer from the capsule and drops it.
103///
104/// # Safety
105///
106/// `capsule` must be a valid `PyObject*` whose capsule pointer was set to the
107/// `managed` field of a `BackingStore` allocation.
108unsafe extern "C" fn capsule_destructor(capsule: *mut pyo3::ffi::PyObject) {
109    // SAFETY: capsule is a valid PyCapsule whose pointer was set during
110    // `to_dlpack` to a `BackingStore::managed` field.
111    let ptr = unsafe { pyo3::ffi::PyCapsule_GetPointer(capsule, DLTENSOR_NAME.as_ptr()) };
112    if !ptr.is_null() {
113        let managed_ptr = ptr as *mut DLManagedTensor;
114        // SAFETY: managed_ptr is the `managed` field of a BackingStore.
115        if let Some(deleter) = unsafe { (*managed_ptr).deleter } {
116            unsafe { deleter(managed_ptr) };
117        }
118    }
119}
120
121// ─── from_dlpack ─────────────────────────────────────────────────────────────
122
123/// Convert a DLPack capsule (from PyTorch, JAX, CuPy, TensorFlow, …) into a
124/// scirs2 NumPy-compatible array.
125///
126/// Parameters
127/// ----------
128/// capsule : PyCapsule
129///     A `PyCapsule` object whose name is `"dltensor"`.  Anything that
130///     implements `__dlpack__()` can produce such an object.
131///
132/// Returns
133/// -------
134/// numpy.ndarray
135///     A 1-D `float64` NumPy array whose contents are *copied* from the
136///     DLPack tensor.  Only CPU, float32, and float64 tensors are currently
137///     supported; all other dtypes raise `TypeError`.
138///
139/// Notes
140/// -----
141/// GPU tensors raise `TypeError` until an optional `gpu` feature is enabled.
142/// The capsule is renamed to `"used_dltensor"` after consumption to prevent
143/// double-frees, consistent with the DLPack 1.0 spec.
144#[pyfunction]
145pub fn from_dlpack(py: Python<'_>, capsule: &Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
146    // Cast to PyCapsule — accept PyAny so callers can pass __dlpack__() result.
147    let cap = capsule.cast::<PyCapsule>().map_err(|_| {
148        PyTypeError::new_err(
149            "from_dlpack: argument must be a PyCapsule (the result of tensor.__dlpack__()). \
150             Got a non-capsule object instead.",
151        )
152    })?;
153
154    // Validate the capsule name against the DLPack spec.
155    let name_opt = cap.name().map_err(|e| {
156        PyValueError::new_err(format!("from_dlpack: could not read capsule name: {e}"))
157    })?;
158
159    let name_matches = match name_opt {
160        None => false,
161        Some(cn) => {
162            // SAFETY: The name pointer is valid for the duration of this call.
163            let name_cstr = unsafe { cn.as_cstr() };
164            name_cstr == DLTENSOR_NAME
165        }
166    };
167
168    if !name_matches {
169        return Err(PyValueError::new_err(
170            "from_dlpack: expected a PyCapsule named 'dltensor'. \
171             Pass the result of tensor.__dlpack__() directly.",
172        ));
173    }
174
175    // Retrieve the DLManagedTensor pointer from the capsule.
176    // SAFETY: We validated the name above; the pointer was placed here by the
177    // producer and is valid until we consume it.
178    let nn_ptr: NonNull<c_void> = cap
179        .pointer_checked(Some(DLTENSOR_NAME))
180        .map_err(|e| PyRuntimeError::new_err(format!("from_dlpack: null capsule pointer: {e}")))?;
181
182    let managed_ptr = nn_ptr.as_ptr() as *mut DLManagedTensor;
183
184    // SAFETY: managed_ptr is non-null and valid; derived from the capsule above.
185    let dl_tensor: &DLTensor = unsafe { &(*managed_ptr).dl_tensor };
186
187    // Reject non-CPU tensors.
188    if dl_tensor.device.device_type != DLDeviceType::Cpu as i32 {
189        return Err(PyTypeError::new_err(format!(
190            "from_dlpack: only CPU tensors are supported (got device type {}). \
191             Copy the tensor to CPU before calling from_dlpack.",
192            dl_tensor.device.device_type
193        )));
194    }
195
196    // Reject null data pointers.
197    if dl_tensor.data.is_null() {
198        return Err(PyValueError::new_err(
199            "from_dlpack: tensor has a null data pointer.",
200        ));
201    }
202
203    // Compute the flat element count from shape.
204    let n_elems: usize = if dl_tensor.ndim == 0 || dl_tensor.shape.is_null() {
205        1
206    } else {
207        // SAFETY: shape is valid for ndim elements (DLPack producer contract).
208        let shape_slice = unsafe {
209            std::slice::from_raw_parts(dl_tensor.shape as *const i64, dl_tensor.ndim as usize)
210        };
211        shape_slice.iter().map(|&d| d as usize).product()
212    };
213
214    // Dispatch on dtype — copy into a Python list of floats, then wrap as numpy array.
215    let base_ptr = unsafe { (dl_tensor.data as *const u8).add(dl_tensor.byte_offset as usize) };
216
217    let dtype = dl_tensor.dtype;
218    let flat_vec: Vec<f64> = match (dtype.code, dtype.bits, dtype.lanes) {
219        // float32 (DLDataTypeCode::Float = 2, bits=32)
220        (2, 32, 1) => {
221            let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const f32, n_elems) };
222            slice.iter().map(|&v| v as f64).collect()
223        }
224        // float64
225        (2, 64, 1) => {
226            let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const f64, n_elems) };
227            slice.to_vec()
228        }
229        // int8
230        (0, 8, 1) => {
231            let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const i8, n_elems) };
232            slice.iter().map(|&v| v as f64).collect()
233        }
234        // int16
235        (0, 16, 1) => {
236            let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const i16, n_elems) };
237            slice.iter().map(|&v| v as f64).collect()
238        }
239        // int32
240        (0, 32, 1) => {
241            let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const i32, n_elems) };
242            slice.iter().map(|&v| v as f64).collect()
243        }
244        // int64
245        (0, 64, 1) => {
246            let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const i64, n_elems) };
247            slice.iter().map(|&v| v as f64).collect()
248        }
249        // uint8
250        (1, 8, 1) => {
251            let slice = unsafe { std::slice::from_raw_parts(base_ptr, n_elems) };
252            slice.iter().map(|&v| v as f64).collect()
253        }
254        // uint16
255        (1, 16, 1) => {
256            let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const u16, n_elems) };
257            slice.iter().map(|&v| v as f64).collect()
258        }
259        // uint32
260        (1, 32, 1) => {
261            let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const u32, n_elems) };
262            slice.iter().map(|&v| v as f64).collect()
263        }
264        // uint64
265        (1, 64, 1) => {
266            let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const u64, n_elems) };
267            slice.iter().map(|&v| v as f64).collect()
268        }
269        (code, bits, _) => {
270            return Err(PyTypeError::new_err(format!(
271                "from_dlpack: unsupported dtype (code={code}, bits={bits}). \
272                 Supported: int8/16/32/64, uint8/16/32/64, float32, float64.",
273            )));
274        }
275    };
276
277    // Build shape tuple for numpy.
278    let shape_vec: Vec<usize> = if dl_tensor.ndim == 0 || dl_tensor.shape.is_null() {
279        vec![n_elems]
280    } else {
281        // SAFETY: shape is valid for ndim elements.
282        let shape_slice = unsafe {
283            std::slice::from_raw_parts(dl_tensor.shape as *const i64, dl_tensor.ndim as usize)
284        };
285        shape_slice.iter().map(|&d| d as usize).collect()
286    };
287
288    // Rename the capsule to "used_dltensor" per DLPack 1.0 spec to prevent
289    // the producer from being consumed again (double-free guard).
290    // We attempt this on a best-effort basis; failure is non-fatal here because
291    // the data has already been copied.
292    let rename_result =
293        unsafe { pyo3::ffi::PyCapsule_SetName(cap.as_ptr(), USED_DLTENSOR_NAME.as_ptr()) };
294    let _ = rename_result; // intentionally ignored after copy
295
296    // Call the managed tensor's deleter if present, as we have consumed it.
297    if let Some(deleter) = unsafe { (*managed_ptr).deleter } {
298        unsafe { deleter(managed_ptr) };
299    }
300
301    // Convert the flat f64 Vec into a numpy array via Python's numpy.
302    let numpy = py.import("numpy").map_err(|e| {
303        PyRuntimeError::new_err(format!("from_dlpack: could not import numpy: {e}"))
304    })?;
305    let arr = numpy.getattr("array")?.call1((flat_vec,))?;
306
307    // Reshape to match the original tensor shape.
308    let shaped = arr.call_method1("reshape", (shape_vec,))?;
309
310    Ok(shaped.into())
311}
312
313// ─── to_dlpack ────────────────────────────────────────────────────────────────
314
315/// Export a scirs2 (NumPy-compatible) array as a DLPack `PyCapsule`.
316///
317/// Parameters
318/// ----------
319/// array : numpy.ndarray
320///     A NumPy float64 array (or any object with the buffer protocol that
321///     numpy can interpret as float64).
322///
323/// Returns
324/// -------
325/// PyCapsule
326///     A capsule named `"dltensor"` that can be consumed by PyTorch, JAX, etc.
327///
328/// Notes
329/// -----
330/// The capsule *owns a copy* of the array data so that the Python array object
331/// can be garbage-collected independently.  The `DLManagedTensor.deleter`
332/// registered in the capsule frees this copy when the consumer is done.
333#[pyfunction]
334pub fn to_dlpack(py: Python<'_>, array: &Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
335    // Extract the array data as a Vec<f64> via numpy.
336    let numpy = py
337        .import("numpy")
338        .map_err(|e| PyRuntimeError::new_err(format!("to_dlpack: could not import numpy: {e}")))?;
339
340    // Ensure we have a contiguous float64 C-order array.
341    let arr = numpy.getattr("asarray")?.call1((array,))?;
342    let arr_f64 = numpy
343        .getattr("ascontiguousarray")?
344        .call((arr,), Some(&pyo3::types::PyDict::new(py)))?;
345
346    // Read shape.
347    let shape_obj = arr_f64.getattr("shape")?;
348    let shape_tuple: Vec<i64> = shape_obj.extract::<Vec<i64>>().map_err(|e| {
349        PyTypeError::new_err(format!("to_dlpack: could not extract array shape: {e}"))
350    })?;
351
352    // Extract flat data as f64.
353    let flat_list = arr_f64.call_method0("flatten")?;
354    let data_vec: Vec<f64> = flat_list.extract::<Vec<f64>>().map_err(|e| {
355        PyTypeError::new_err(format!(
356            "to_dlpack: array must be convertible to float64: {e}"
357        ))
358    })?;
359
360    // Compute C-order strides (in elements).
361    let strides_vec: Vec<i64> = compute_c_strides(&shape_tuple);
362
363    // Build the BackingStore on the heap.  We use Box::into_raw so it lives
364    // until the capsule destructor frees it.
365    let n = shape_tuple.len();
366    let mut store = Box::new(BackingStore {
367        managed: DLManagedTensor {
368            dl_tensor: DLTensor {
369                data: std::ptr::null_mut(), // filled in below
370                device: DLDevice {
371                    device_type: DLDeviceType::Cpu as i32,
372                    device_id: 0,
373                },
374                ndim: n as i32,
375                dtype: DLDataType {
376                    code: DLDataTypeCode::Float as u8,
377                    bits: 64,
378                    lanes: 1,
379                },
380                shape: std::ptr::null_mut(),   // filled in below
381                strides: std::ptr::null_mut(), // filled in below
382                byte_offset: 0,
383            },
384            manager_ctx: std::ptr::null_mut(),
385            deleter: Some(backing_store_deleter),
386        },
387        data: data_vec,
388        shape: shape_tuple,
389        strides: strides_vec,
390    });
391
392    // Now that the Vecs are in their final locations inside the Box, set the
393    // raw pointers in dl_tensor to point into those Vecs.
394    store.managed.dl_tensor.data = store.data.as_mut_ptr() as *mut c_void;
395    store.managed.dl_tensor.shape = store.shape.as_mut_ptr();
396    store.managed.dl_tensor.strides = store.strides.as_mut_ptr();
397
398    let raw_store: *mut BackingStore = Box::into_raw(store);
399    // SAFETY: raw_store is non-null (just created by Box::into_raw).
400    let managed_nn = NonNull::new(raw_store as *mut c_void)
401        .ok_or_else(|| PyRuntimeError::new_err("to_dlpack: null BackingStore pointer"))?;
402
403    // SAFETY: managed_nn points to a valid BackingStore; capsule_destructor
404    // will call backing_store_deleter which frees it via Box::from_raw.
405    let capsule = unsafe {
406        PyCapsule::new_with_pointer_and_destructor(
407            py,
408            managed_nn,
409            DLTENSOR_NAME,
410            Some(capsule_destructor),
411        )
412    }
413    .map_err(|e| PyRuntimeError::new_err(format!("to_dlpack: failed to create capsule: {e}")))?;
414
415    Ok(capsule.into())
416}
417
418/// Compute C-order (row-major) strides in elements for the given shape.
419///
420/// The last dimension has stride 1; each preceding dimension has stride equal
421/// to the product of all following dimensions.
422fn compute_c_strides(shape: &[i64]) -> Vec<i64> {
423    let n = shape.len();
424    if n == 0 {
425        return Vec::new();
426    }
427    let mut strides = vec![1i64; n];
428    for i in (0..n - 1).rev() {
429        strides[i] = strides[i + 1] * shape[i + 1];
430    }
431    strides
432}
433
434/// Register DLPack interop functions on the given module.
435pub fn register_dlpack_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
436    m.add_function(wrap_pyfunction!(from_dlpack, m)?)?;
437    m.add_function(wrap_pyfunction!(to_dlpack, m)?)?;
438    Ok(())
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    /// Compile-time check: the module registration function exists and has the
446    /// expected signature.  Actual invocation requires a Python interpreter.
447    #[test]
448    fn dlpack_module_symbol_exists() {
449        let _msg = "dlpack module compiled successfully";
450    }
451
452    #[test]
453    fn compute_c_strides_1d() {
454        assert_eq!(compute_c_strides(&[5]), vec![1]);
455    }
456
457    #[test]
458    fn compute_c_strides_2d() {
459        // Shape [2, 3] -> strides [3, 1]
460        assert_eq!(compute_c_strides(&[2, 3]), vec![3, 1]);
461    }
462
463    #[test]
464    fn compute_c_strides_3d() {
465        // Shape [2, 3, 4] -> strides [12, 4, 1]
466        assert_eq!(compute_c_strides(&[2, 3, 4]), vec![12, 4, 1]);
467    }
468
469    #[test]
470    fn compute_c_strides_empty() {
471        assert_eq!(compute_c_strides(&[]), Vec::<i64>::new());
472    }
473}