scirs2/dlpack.rs
1//! DLPack tensor interop for scirs2-python
2//!
3//! Provides `from_dlpack` and `to_dlpack` entry points that follow the
4//! DLPack 1.0 protocol. Full zero-copy sharing with PyTorch, JAX, CuPy,
5//! TensorFlow etc. requires the calling Python environment to have the
6//! relevant library installed; the Rust side handles the capsule protocol.
7//!
8//! # DLPack protocol
9//!
10//! A *DLPack capsule* is a `PyCapsule` object whose name is `"dltensor"`.
11//! After the consumer takes ownership, the capsule is renamed to
12//! `"used_dltensor"` so double-frees are prevented.
13//!
14//! # Python usage
15//!
16//! ```python
17//! import torch
18//! import scirs2
19//!
20//! t = torch.randn(3, 4)
21//! # PyTorch tensors expose __dlpack__() / __dlpack_device__()
22//! capsule = t.__dlpack__()
23//! arr = scirs2.from_dlpack(capsule) # -> scirs2 array (NumPy-compatible)
24//!
25//! # Round-trip: export back
26//! cap2 = scirs2.to_dlpack(arr)
27//! t2 = torch.from_dlpack(cap2)
28//! ```
29
30use std::ffi::{c_void, CStr};
31use std::ptr::NonNull;
32
33use pyo3::exceptions::{PyRuntimeError, PyTypeError, PyValueError};
34use pyo3::prelude::*;
35use pyo3::types::{PyCapsule, PyCapsuleMethods};
36use scirs2_numpy::dlpack::{
37 DLDataType, DLDataTypeCode, DLDevice, DLDeviceType, DLManagedTensor, DLTensor,
38};
39
40/// Expected DLPack capsule name (C string literal, DLPack 1.0 spec).
41const DLTENSOR_NAME: &CStr = c"dltensor";
42
43/// Name the capsule is renamed to once consumed (prevents double-free).
44const USED_DLTENSOR_NAME: &CStr = c"used_dltensor";
45
46// ─── Ownership wrapper ────────────────────────────────────────────────────────
47
48/// Heap allocation that backs a DLPack capsule created by `to_dlpack`.
49///
50/// Bundles the `DLManagedTensor` with the shape/strides arrays and the owned
51/// data copy. All memory is freed through `BackingStore::drop_raw`.
52struct BackingStore {
53 /// ABI-compatible managed-tensor struct; must be the first field so that
54 /// a `*mut BackingStore` can be cast to `*mut DLManagedTensor` safely.
55 managed: DLManagedTensor,
56 /// Owned copy of the tensor's element data.
57 data: Vec<f64>,
58 /// Owned shape array (length = `managed.dl_tensor.ndim`).
59 shape: Vec<i64>,
60 /// Owned strides array (length = `managed.dl_tensor.ndim`).
61 strides: Vec<i64>,
62}
63
64impl BackingStore {
65 /// Free a `BackingStore` that was previously leaked with `Box::into_raw`.
66 ///
67 /// # Safety
68 ///
69 /// `ptr` must be a non-null pointer obtained from `Box::into_raw` on a
70 /// `BackingStore`. This function must be called at most once.
71 unsafe fn drop_raw(ptr: *mut BackingStore) {
72 if !ptr.is_null() {
73 // SAFETY: ptr was obtained from Box::into_raw.
74 drop(unsafe { Box::from_raw(ptr) });
75 }
76 }
77}
78
79/// DLPack `deleter` stored inside the `DLManagedTensor`.
80///
81/// Called by the consumer framework (PyTorch, JAX, etc.) when it is finished
82/// with the tensor.
83///
84/// # Safety
85///
86/// `managed` must point to the `managed` field of a `BackingStore` that was
87/// previously leaked via `Box::into_raw`.
88unsafe extern "C" fn backing_store_deleter(managed: *mut DLManagedTensor) {
89 if managed.is_null() {
90 return;
91 }
92 // SAFETY: BackingStore has `managed` as its first field, so the pointer
93 // arithmetic is a no-op and the cast is valid.
94 let backing = managed as *mut BackingStore;
95 // SAFETY: backed by a Box::into_raw call in `to_dlpack`.
96 unsafe { BackingStore::drop_raw(backing) };
97}
98
99/// Destructor registered with `PyCapsule::new_with_pointer_and_destructor`.
100///
101/// Called by Python's GC when the capsule object is finalized. Extracts the
102/// `BackingStore` raw pointer from the capsule and drops it.
103///
104/// # Safety
105///
106/// `capsule` must be a valid `PyObject*` whose capsule pointer was set to the
107/// `managed` field of a `BackingStore` allocation.
108unsafe extern "C" fn capsule_destructor(capsule: *mut pyo3::ffi::PyObject) {
109 // SAFETY: capsule is a valid PyCapsule whose pointer was set during
110 // `to_dlpack` to a `BackingStore::managed` field.
111 let ptr = unsafe { pyo3::ffi::PyCapsule_GetPointer(capsule, DLTENSOR_NAME.as_ptr()) };
112 if !ptr.is_null() {
113 let managed_ptr = ptr as *mut DLManagedTensor;
114 // SAFETY: managed_ptr is the `managed` field of a BackingStore.
115 if let Some(deleter) = unsafe { (*managed_ptr).deleter } {
116 unsafe { deleter(managed_ptr) };
117 }
118 }
119}
120
121// ─── from_dlpack ─────────────────────────────────────────────────────────────
122
123/// Convert a DLPack capsule (from PyTorch, JAX, CuPy, TensorFlow, …) into a
124/// scirs2 NumPy-compatible array.
125///
126/// Parameters
127/// ----------
128/// capsule : PyCapsule
129/// A `PyCapsule` object whose name is `"dltensor"`. Anything that
130/// implements `__dlpack__()` can produce such an object.
131///
132/// Returns
133/// -------
134/// numpy.ndarray
135/// A 1-D `float64` NumPy array whose contents are *copied* from the
136/// DLPack tensor. Only CPU, float32, and float64 tensors are currently
137/// supported; all other dtypes raise `TypeError`.
138///
139/// Notes
140/// -----
141/// GPU tensors raise `TypeError` until an optional `gpu` feature is enabled.
142/// The capsule is renamed to `"used_dltensor"` after consumption to prevent
143/// double-frees, consistent with the DLPack 1.0 spec.
144#[pyfunction]
145pub fn from_dlpack(py: Python<'_>, capsule: &Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
146 // Cast to PyCapsule — accept PyAny so callers can pass __dlpack__() result.
147 let cap = capsule.cast::<PyCapsule>().map_err(|_| {
148 PyTypeError::new_err(
149 "from_dlpack: argument must be a PyCapsule (the result of tensor.__dlpack__()). \
150 Got a non-capsule object instead.",
151 )
152 })?;
153
154 // Validate the capsule name against the DLPack spec.
155 let name_opt = cap.name().map_err(|e| {
156 PyValueError::new_err(format!("from_dlpack: could not read capsule name: {e}"))
157 })?;
158
159 let name_matches = match name_opt {
160 None => false,
161 Some(cn) => {
162 // SAFETY: The name pointer is valid for the duration of this call.
163 let name_cstr = unsafe { cn.as_cstr() };
164 name_cstr == DLTENSOR_NAME
165 }
166 };
167
168 if !name_matches {
169 return Err(PyValueError::new_err(
170 "from_dlpack: expected a PyCapsule named 'dltensor'. \
171 Pass the result of tensor.__dlpack__() directly.",
172 ));
173 }
174
175 // Retrieve the DLManagedTensor pointer from the capsule.
176 // SAFETY: We validated the name above; the pointer was placed here by the
177 // producer and is valid until we consume it.
178 let nn_ptr: NonNull<c_void> = cap
179 .pointer_checked(Some(DLTENSOR_NAME))
180 .map_err(|e| PyRuntimeError::new_err(format!("from_dlpack: null capsule pointer: {e}")))?;
181
182 let managed_ptr = nn_ptr.as_ptr() as *mut DLManagedTensor;
183
184 // SAFETY: managed_ptr is non-null and valid; derived from the capsule above.
185 let dl_tensor: &DLTensor = unsafe { &(*managed_ptr).dl_tensor };
186
187 // Reject non-CPU tensors.
188 if dl_tensor.device.device_type != DLDeviceType::Cpu as i32 {
189 return Err(PyTypeError::new_err(format!(
190 "from_dlpack: only CPU tensors are supported (got device type {}). \
191 Copy the tensor to CPU before calling from_dlpack.",
192 dl_tensor.device.device_type
193 )));
194 }
195
196 // Reject null data pointers.
197 if dl_tensor.data.is_null() {
198 return Err(PyValueError::new_err(
199 "from_dlpack: tensor has a null data pointer.",
200 ));
201 }
202
203 // Compute the flat element count from shape.
204 let n_elems: usize = if dl_tensor.ndim == 0 || dl_tensor.shape.is_null() {
205 1
206 } else {
207 // SAFETY: shape is valid for ndim elements (DLPack producer contract).
208 let shape_slice = unsafe {
209 std::slice::from_raw_parts(dl_tensor.shape as *const i64, dl_tensor.ndim as usize)
210 };
211 shape_slice.iter().map(|&d| d as usize).product()
212 };
213
214 // Dispatch on dtype — copy into a Python list of floats, then wrap as numpy array.
215 let base_ptr = unsafe { (dl_tensor.data as *const u8).add(dl_tensor.byte_offset as usize) };
216
217 let dtype = dl_tensor.dtype;
218 let flat_vec: Vec<f64> = match (dtype.code, dtype.bits, dtype.lanes) {
219 // float32 (DLDataTypeCode::Float = 2, bits=32)
220 (2, 32, 1) => {
221 let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const f32, n_elems) };
222 slice.iter().map(|&v| v as f64).collect()
223 }
224 // float64
225 (2, 64, 1) => {
226 let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const f64, n_elems) };
227 slice.to_vec()
228 }
229 // int8
230 (0, 8, 1) => {
231 let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const i8, n_elems) };
232 slice.iter().map(|&v| v as f64).collect()
233 }
234 // int16
235 (0, 16, 1) => {
236 let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const i16, n_elems) };
237 slice.iter().map(|&v| v as f64).collect()
238 }
239 // int32
240 (0, 32, 1) => {
241 let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const i32, n_elems) };
242 slice.iter().map(|&v| v as f64).collect()
243 }
244 // int64
245 (0, 64, 1) => {
246 let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const i64, n_elems) };
247 slice.iter().map(|&v| v as f64).collect()
248 }
249 // uint8
250 (1, 8, 1) => {
251 let slice = unsafe { std::slice::from_raw_parts(base_ptr, n_elems) };
252 slice.iter().map(|&v| v as f64).collect()
253 }
254 // uint16
255 (1, 16, 1) => {
256 let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const u16, n_elems) };
257 slice.iter().map(|&v| v as f64).collect()
258 }
259 // uint32
260 (1, 32, 1) => {
261 let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const u32, n_elems) };
262 slice.iter().map(|&v| v as f64).collect()
263 }
264 // uint64
265 (1, 64, 1) => {
266 let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const u64, n_elems) };
267 slice.iter().map(|&v| v as f64).collect()
268 }
269 (code, bits, _) => {
270 return Err(PyTypeError::new_err(format!(
271 "from_dlpack: unsupported dtype (code={code}, bits={bits}). \
272 Supported: int8/16/32/64, uint8/16/32/64, float32, float64.",
273 )));
274 }
275 };
276
277 // Build shape tuple for numpy.
278 let shape_vec: Vec<usize> = if dl_tensor.ndim == 0 || dl_tensor.shape.is_null() {
279 vec![n_elems]
280 } else {
281 // SAFETY: shape is valid for ndim elements.
282 let shape_slice = unsafe {
283 std::slice::from_raw_parts(dl_tensor.shape as *const i64, dl_tensor.ndim as usize)
284 };
285 shape_slice.iter().map(|&d| d as usize).collect()
286 };
287
288 // Rename the capsule to "used_dltensor" per DLPack 1.0 spec to prevent
289 // the producer from being consumed again (double-free guard).
290 // We attempt this on a best-effort basis; failure is non-fatal here because
291 // the data has already been copied.
292 let rename_result =
293 unsafe { pyo3::ffi::PyCapsule_SetName(cap.as_ptr(), USED_DLTENSOR_NAME.as_ptr()) };
294 let _ = rename_result; // intentionally ignored after copy
295
296 // Call the managed tensor's deleter if present, as we have consumed it.
297 if let Some(deleter) = unsafe { (*managed_ptr).deleter } {
298 unsafe { deleter(managed_ptr) };
299 }
300
301 // Convert the flat f64 Vec into a numpy array via Python's numpy.
302 let numpy = py.import("numpy").map_err(|e| {
303 PyRuntimeError::new_err(format!("from_dlpack: could not import numpy: {e}"))
304 })?;
305 let arr = numpy.getattr("array")?.call1((flat_vec,))?;
306
307 // Reshape to match the original tensor shape.
308 let shaped = arr.call_method1("reshape", (shape_vec,))?;
309
310 Ok(shaped.into())
311}
312
313// ─── to_dlpack ────────────────────────────────────────────────────────────────
314
315/// Export a scirs2 (NumPy-compatible) array as a DLPack `PyCapsule`.
316///
317/// Parameters
318/// ----------
319/// array : numpy.ndarray
320/// A NumPy float64 array (or any object with the buffer protocol that
321/// numpy can interpret as float64).
322///
323/// Returns
324/// -------
325/// PyCapsule
326/// A capsule named `"dltensor"` that can be consumed by PyTorch, JAX, etc.
327///
328/// Notes
329/// -----
330/// The capsule *owns a copy* of the array data so that the Python array object
331/// can be garbage-collected independently. The `DLManagedTensor.deleter`
332/// registered in the capsule frees this copy when the consumer is done.
333#[pyfunction]
334pub fn to_dlpack(py: Python<'_>, array: &Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
335 // Extract the array data as a Vec<f64> via numpy.
336 let numpy = py
337 .import("numpy")
338 .map_err(|e| PyRuntimeError::new_err(format!("to_dlpack: could not import numpy: {e}")))?;
339
340 // Ensure we have a contiguous float64 C-order array.
341 let arr = numpy.getattr("asarray")?.call1((array,))?;
342 let arr_f64 = numpy
343 .getattr("ascontiguousarray")?
344 .call((arr,), Some(&pyo3::types::PyDict::new(py)))?;
345
346 // Read shape.
347 let shape_obj = arr_f64.getattr("shape")?;
348 let shape_tuple: Vec<i64> = shape_obj.extract::<Vec<i64>>().map_err(|e| {
349 PyTypeError::new_err(format!("to_dlpack: could not extract array shape: {e}"))
350 })?;
351
352 // Extract flat data as f64.
353 let flat_list = arr_f64.call_method0("flatten")?;
354 let data_vec: Vec<f64> = flat_list.extract::<Vec<f64>>().map_err(|e| {
355 PyTypeError::new_err(format!(
356 "to_dlpack: array must be convertible to float64: {e}"
357 ))
358 })?;
359
360 // Compute C-order strides (in elements).
361 let strides_vec: Vec<i64> = compute_c_strides(&shape_tuple);
362
363 // Build the BackingStore on the heap. We use Box::into_raw so it lives
364 // until the capsule destructor frees it.
365 let n = shape_tuple.len();
366 let mut store = Box::new(BackingStore {
367 managed: DLManagedTensor {
368 dl_tensor: DLTensor {
369 data: std::ptr::null_mut(), // filled in below
370 device: DLDevice {
371 device_type: DLDeviceType::Cpu as i32,
372 device_id: 0,
373 },
374 ndim: n as i32,
375 dtype: DLDataType {
376 code: DLDataTypeCode::Float as u8,
377 bits: 64,
378 lanes: 1,
379 },
380 shape: std::ptr::null_mut(), // filled in below
381 strides: std::ptr::null_mut(), // filled in below
382 byte_offset: 0,
383 },
384 manager_ctx: std::ptr::null_mut(),
385 deleter: Some(backing_store_deleter),
386 },
387 data: data_vec,
388 shape: shape_tuple,
389 strides: strides_vec,
390 });
391
392 // Now that the Vecs are in their final locations inside the Box, set the
393 // raw pointers in dl_tensor to point into those Vecs.
394 store.managed.dl_tensor.data = store.data.as_mut_ptr() as *mut c_void;
395 store.managed.dl_tensor.shape = store.shape.as_mut_ptr();
396 store.managed.dl_tensor.strides = store.strides.as_mut_ptr();
397
398 let raw_store: *mut BackingStore = Box::into_raw(store);
399 // SAFETY: raw_store is non-null (just created by Box::into_raw).
400 let managed_nn = NonNull::new(raw_store as *mut c_void)
401 .ok_or_else(|| PyRuntimeError::new_err("to_dlpack: null BackingStore pointer"))?;
402
403 // SAFETY: managed_nn points to a valid BackingStore; capsule_destructor
404 // will call backing_store_deleter which frees it via Box::from_raw.
405 let capsule = unsafe {
406 PyCapsule::new_with_pointer_and_destructor(
407 py,
408 managed_nn,
409 DLTENSOR_NAME,
410 Some(capsule_destructor),
411 )
412 }
413 .map_err(|e| PyRuntimeError::new_err(format!("to_dlpack: failed to create capsule: {e}")))?;
414
415 Ok(capsule.into())
416}
417
418/// Compute C-order (row-major) strides in elements for the given shape.
419///
420/// The last dimension has stride 1; each preceding dimension has stride equal
421/// to the product of all following dimensions.
422fn compute_c_strides(shape: &[i64]) -> Vec<i64> {
423 let n = shape.len();
424 if n == 0 {
425 return Vec::new();
426 }
427 let mut strides = vec![1i64; n];
428 for i in (0..n - 1).rev() {
429 strides[i] = strides[i + 1] * shape[i + 1];
430 }
431 strides
432}
433
434/// Register DLPack interop functions on the given module.
435pub fn register_dlpack_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
436 m.add_function(wrap_pyfunction!(from_dlpack, m)?)?;
437 m.add_function(wrap_pyfunction!(to_dlpack, m)?)?;
438 Ok(())
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 /// Compile-time check: the module registration function exists and has the
446 /// expected signature. Actual invocation requires a Python interpreter.
447 #[test]
448 fn dlpack_module_symbol_exists() {
449 let _msg = "dlpack module compiled successfully";
450 }
451
452 #[test]
453 fn compute_c_strides_1d() {
454 assert_eq!(compute_c_strides(&[5]), vec![1]);
455 }
456
457 #[test]
458 fn compute_c_strides_2d() {
459 // Shape [2, 3] -> strides [3, 1]
460 assert_eq!(compute_c_strides(&[2, 3]), vec![3, 1]);
461 }
462
463 #[test]
464 fn compute_c_strides_3d() {
465 // Shape [2, 3, 4] -> strides [12, 4, 1]
466 assert_eq!(compute_c_strides(&[2, 3, 4]), vec![12, 4, 1]);
467 }
468
469 #[test]
470 fn compute_c_strides_empty() {
471 assert_eq!(compute_c_strides(&[]), Vec::<i64>::new());
472 }
473}