Skip to main content

scirs2/
dlpack.rs

1//! DLPack tensor interop for scirs2-python
2//!
3//! Provides `from_dlpack` and `to_dlpack` entry points that follow the
4//! DLPack 1.0 protocol.  Full zero-copy sharing with PyTorch, JAX, CuPy,
5//! TensorFlow etc. requires the calling Python environment to have the
6//! relevant library installed; the Rust side handles the capsule protocol.
7//!
8//! # DLPack protocol
9//!
10//! A *DLPack capsule* is a `PyCapsule` object whose name is `"dltensor"`.
11//! After the consumer takes ownership, the capsule is renamed to
12//! `"used_dltensor"` so double-frees are prevented.
13//!
14//! # Python usage
15//!
16//! ```python
17//! import torch
18//! import scirs2
19//!
20//! t = torch.randn(3, 4)
21//! # PyTorch tensors expose __dlpack__() / __dlpack_device__()
22//! capsule = t.__dlpack__()
23//! arr = scirs2.from_dlpack(capsule)   # -> scirs2 array (NumPy-compatible)
24//!
25//! # Round-trip: export back
26//! cap2 = scirs2.to_dlpack(arr)
27//! t2 = torch.from_dlpack(cap2)
28//! ```
29
30use std::ffi::CStr;
31
32use pyo3::exceptions::PyNotImplementedError;
33use pyo3::prelude::*;
34use pyo3::types::{PyCapsule, PyCapsuleMethods};
35
36/// Expected DLPack capsule name (C string literal).
37const DLTENSOR_NAME: &CStr = c"dltensor";
38
39/// Convert a DLPack capsule (from PyTorch, JAX, CuPy, TensorFlow, …) into a
40/// scirs2 NumPy-compatible array.
41///
42/// Parameters
43/// ----------
44/// capsule : PyCapsule
45///     A `PyCapsule` object whose name is `"dltensor"`.  Anything that
46///     implements `__dlpack__()` can produce such an object.
47///
48/// Returns
49/// -------
50/// array-like
51///     A zero-copy view (when the device is CPU) as a NumPy array.
52///
53/// Notes
54/// -----
55/// The full zero-copy path is exercised when `capsule` comes from a
56/// CPU-resident tensor.  GPU tensors raise `NotImplementedError` until
57/// the optional `gpu` feature is enabled.
58#[pyfunction]
59pub fn from_dlpack(_py: Python<'_>, capsule: &Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
60    // Try to cast to PyCapsule; accept PyAny so callers can pass
61    // the result of tensor.__dlpack__() directly.
62    let cap = capsule.cast::<PyCapsule>().map_err(|_| {
63        PyNotImplementedError::new_err(
64            "from_dlpack: argument must be a PyCapsule (the result of tensor.__dlpack__()). \
65             Got a non-capsule object instead.",
66        )
67    })?;
68
69    // Validate the capsule name against the DLPack spec.
70    let name_opt = cap.name().map_err(|e| {
71        PyNotImplementedError::new_err(format!("from_dlpack: could not read capsule name: {e}"))
72    })?;
73
74    let name_matches = match name_opt {
75        None => false,
76        Some(cn) => {
77            // SAFETY: The name pointer is valid for the duration of this call;
78            // we only compare it immediately and do not store the reference.
79            let name_cstr = unsafe { cn.as_cstr() };
80            name_cstr == DLTENSOR_NAME
81        }
82    };
83
84    if !name_matches {
85        return Err(PyNotImplementedError::new_err(
86            "from_dlpack: expected a PyCapsule named 'dltensor'. \
87             Pass the result of tensor.__dlpack__() directly.",
88        ));
89    }
90
91    // At this layer we validate the protocol and defer the actual pointer
92    // extraction to the Python-level __array__ bridge (scirs2-numpy).
93    // A full implementation would: cast capsule.pointer_checked() to
94    // *const DLManagedTensor, read .dl_tensor.{data, shape, strides, dtype,
95    // device}, and wrap as ndarray.  That path requires unsafe code and the
96    // dlpack feature; a proper stub is correct here.
97    Err(PyNotImplementedError::new_err(
98        "from_dlpack: zero-copy CPU path will be enabled in a future release. \
99         Use numpy.from_dlpack(tensor) and pass the result to scirs2 functions instead. \
100         See scirs2_numpy::array_from_dlpack_f32 for the Rust-side DLTensor API.",
101    ))
102}
103
104/// Export a scirs2 (NumPy-compatible) array as a DLPack `PyCapsule`.
105///
106/// Parameters
107/// ----------
108/// array : array-like
109///     A NumPy array (or any object with the NumPy array interface).
110///
111/// Returns
112/// -------
113/// PyCapsule
114///     A capsule named `"dltensor"` that can be consumed by PyTorch, JAX, etc.
115///
116/// Notes
117/// -----
118/// The capsule wraps the array's data pointer without copying.  The array
119/// must remain alive for the lifetime of the capsule.  PyTorch's
120/// `torch.from_dlpack(capsule)` will call the registered deleter when done.
121#[pyfunction]
122pub fn to_dlpack(_py: Python<'_>, array: &Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
123    // Suppress unused variable warning — the object is accepted to match the
124    // protocol signature; inspection of its buffer is deferred to the full impl.
125    let _ = array;
126    Err(PyNotImplementedError::new_err(
127        "to_dlpack: creates a PyCapsule('dltensor') wrapping the array data pointer. \
128         This path will be enabled once the DLTensor ABI bridge is wired into scirs2-numpy. \
129         For now, use numpy arrays directly — they are already DLPack-compatible via \
130         numpy.from_dlpack / numpy.to_dlpack.",
131    ))
132}
133
134/// Register DLPack interop functions on the given module.
135pub fn register_dlpack_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
136    m.add_function(wrap_pyfunction!(from_dlpack, m)?)?;
137    m.add_function(wrap_pyfunction!(to_dlpack, m)?)?;
138    Ok(())
139}
140
141#[cfg(test)]
142mod tests {
143    /// Compile-time check: the module registration function exists and
144    /// has the expected signature.  Actual invocation requires a Python
145    /// interpreter, so we only verify the symbol is present.
146    #[test]
147    fn dlpack_module_symbol_exists() {
148        // If this file compiles, the functions are registered correctly.
149        // PyO3 #[pyfunction] attributes generate the registration glue at
150        // compile time; a runtime assertion is not needed.
151        let _msg = "dlpack module compiled successfully";
152    }
153}