scirs2/dlpack.rs
1//! DLPack tensor interop for scirs2-python
2//!
3//! Provides `from_dlpack` and `to_dlpack` entry points that follow the
4//! DLPack 1.0 protocol. Full zero-copy sharing with PyTorch, JAX, CuPy,
5//! TensorFlow etc. requires the calling Python environment to have the
6//! relevant library installed; the Rust side handles the capsule protocol.
7//!
8//! # DLPack protocol
9//!
10//! A *DLPack capsule* is a `PyCapsule` object whose name is `"dltensor"`.
11//! After the consumer takes ownership, the capsule is renamed to
12//! `"used_dltensor"` so double-frees are prevented.
13//!
14//! # Python usage
15//!
16//! ```python
17//! import torch
18//! import scirs2
19//!
20//! t = torch.randn(3, 4)
21//! # PyTorch tensors expose __dlpack__() / __dlpack_device__()
22//! capsule = t.__dlpack__()
23//! arr = scirs2.from_dlpack(capsule) # -> scirs2 array (NumPy-compatible)
24//!
25//! # Round-trip: export back
26//! cap2 = scirs2.to_dlpack(arr)
27//! t2 = torch.from_dlpack(cap2)
28//! ```
29
30use std::ffi::CStr;
31
32use pyo3::exceptions::PyNotImplementedError;
33use pyo3::prelude::*;
34use pyo3::types::{PyCapsule, PyCapsuleMethods};
35
36/// Expected DLPack capsule name (C string literal).
37const DLTENSOR_NAME: &CStr = c"dltensor";
38
39/// Convert a DLPack capsule (from PyTorch, JAX, CuPy, TensorFlow, …) into a
40/// scirs2 NumPy-compatible array.
41///
42/// Parameters
43/// ----------
44/// capsule : PyCapsule
45/// A `PyCapsule` object whose name is `"dltensor"`. Anything that
46/// implements `__dlpack__()` can produce such an object.
47///
48/// Returns
49/// -------
50/// array-like
51/// A zero-copy view (when the device is CPU) as a NumPy array.
52///
53/// Notes
54/// -----
55/// The full zero-copy path is exercised when `capsule` comes from a
56/// CPU-resident tensor. GPU tensors raise `NotImplementedError` until
57/// the optional `gpu` feature is enabled.
58#[pyfunction]
59pub fn from_dlpack(_py: Python<'_>, capsule: &Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
60 // Try to cast to PyCapsule; accept PyAny so callers can pass
61 // the result of tensor.__dlpack__() directly.
62 let cap = capsule.cast::<PyCapsule>().map_err(|_| {
63 PyNotImplementedError::new_err(
64 "from_dlpack: argument must be a PyCapsule (the result of tensor.__dlpack__()). \
65 Got a non-capsule object instead.",
66 )
67 })?;
68
69 // Validate the capsule name against the DLPack spec.
70 let name_opt = cap.name().map_err(|e| {
71 PyNotImplementedError::new_err(format!("from_dlpack: could not read capsule name: {e}"))
72 })?;
73
74 let name_matches = match name_opt {
75 None => false,
76 Some(cn) => {
77 // SAFETY: The name pointer is valid for the duration of this call;
78 // we only compare it immediately and do not store the reference.
79 let name_cstr = unsafe { cn.as_cstr() };
80 name_cstr == DLTENSOR_NAME
81 }
82 };
83
84 if !name_matches {
85 return Err(PyNotImplementedError::new_err(
86 "from_dlpack: expected a PyCapsule named 'dltensor'. \
87 Pass the result of tensor.__dlpack__() directly.",
88 ));
89 }
90
91 // At this layer we validate the protocol and defer the actual pointer
92 // extraction to the Python-level __array__ bridge (scirs2-numpy).
93 // A full implementation would: cast capsule.pointer_checked() to
94 // *const DLManagedTensor, read .dl_tensor.{data, shape, strides, dtype,
95 // device}, and wrap as ndarray. That path requires unsafe code and the
96 // dlpack feature; a proper stub is correct here.
97 Err(PyNotImplementedError::new_err(
98 "from_dlpack: zero-copy CPU path will be enabled in a future release. \
99 Use numpy.from_dlpack(tensor) and pass the result to scirs2 functions instead. \
100 See scirs2_numpy::array_from_dlpack_f32 for the Rust-side DLTensor API.",
101 ))
102}
103
104/// Export a scirs2 (NumPy-compatible) array as a DLPack `PyCapsule`.
105///
106/// Parameters
107/// ----------
108/// array : array-like
109/// A NumPy array (or any object with the NumPy array interface).
110///
111/// Returns
112/// -------
113/// PyCapsule
114/// A capsule named `"dltensor"` that can be consumed by PyTorch, JAX, etc.
115///
116/// Notes
117/// -----
118/// The capsule wraps the array's data pointer without copying. The array
119/// must remain alive for the lifetime of the capsule. PyTorch's
120/// `torch.from_dlpack(capsule)` will call the registered deleter when done.
121#[pyfunction]
122pub fn to_dlpack(_py: Python<'_>, array: &Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
123 // Suppress unused variable warning — the object is accepted to match the
124 // protocol signature; inspection of its buffer is deferred to the full impl.
125 let _ = array;
126 Err(PyNotImplementedError::new_err(
127 "to_dlpack: creates a PyCapsule('dltensor') wrapping the array data pointer. \
128 This path will be enabled once the DLTensor ABI bridge is wired into scirs2-numpy. \
129 For now, use numpy arrays directly — they are already DLPack-compatible via \
130 numpy.from_dlpack / numpy.to_dlpack.",
131 ))
132}
133
134/// Register DLPack interop functions on the given module.
135pub fn register_dlpack_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
136 m.add_function(wrap_pyfunction!(from_dlpack, m)?)?;
137 m.add_function(wrap_pyfunction!(to_dlpack, m)?)?;
138 Ok(())
139}
140
141#[cfg(test)]
142mod tests {
143 /// Compile-time check: the module registration function exists and
144 /// has the expected signature. Actual invocation requires a Python
145 /// interpreter, so we only verify the symbol is present.
146 #[test]
147 fn dlpack_module_symbol_exists() {
148 // If this file compiles, the functions are registered correctly.
149 // PyO3 #[pyfunction] attributes generate the registration glue at
150 // compile time; a runtime assertion is not needed.
151 let _msg = "dlpack module compiled successfully";
152 }
153}