1use std::ffi::CStr;
44
45use pyo3::prelude::*;
46use pyo3::types::PyDict;
47
48use crate::dlpack::{
49 array_from_dlpack_f32, array_from_dlpack_f64, DLDataType, DLDeviceType, DLManagedTensor,
50 DLTensor, DlpackError,
51};
52
53const DLTENSOR_NAME: &CStr = c"dltensor";
55
56const USED_DLTENSOR_NAME: &CStr = c"used_dltensor";
61
62#[derive(Debug, Clone)]
68pub struct CudaTensorInfo {
69 pub device_id: i32,
71 pub shape: Vec<usize>,
73 pub dtype: DLDataType,
75 pub byte_offset: u64,
77 pub device_type_code: i32,
79}
80
81impl CudaTensorInfo {
82 pub fn numel(&self) -> usize {
84 self.shape.iter().product()
85 }
86
87 pub fn dtype_bits(&self) -> u8 {
89 self.dtype.bits
90 }
91
92 pub fn device_str(&self) -> String {
94 let name = match self.device_type_code {
95 2 => "cuda",
96 3 => "cuda_host",
97 4 => "opencl",
98 7 => "vulkan",
99 8 => "metal",
100 10 => "rocm",
101 _ => "unknown",
102 };
103 format!("{}:{}", name, self.device_id)
104 }
105}
106
107pub enum DLPackDispatchResult<'a, T> {
111 Cpu(ndarray::ArrayViewD<'a, T>),
113 Gpu(CudaTensorInfo),
116 OtherDevice {
118 device_type: i32,
120 device_id: i32,
122 },
123}
124
125pub fn cuda_tensor_info_from_dltensor(tensor: &DLTensor) -> Result<CudaTensorInfo, DlpackError> {
142 if tensor.data.is_null() {
143 return Err(DlpackError::NullPointer);
144 }
145 if tensor.device.device_type == DLDeviceType::Cpu as i32 {
147 return Err(DlpackError::NonCpuDevice);
148 }
149 let ndim = tensor.ndim.max(0) as usize;
150 let shape = if ndim == 0 || tensor.shape.is_null() {
151 Vec::new()
152 } else {
153 unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, ndim) }
155 .iter()
156 .map(|&d| d as usize)
157 .collect()
158 };
159 Ok(CudaTensorInfo {
160 device_id: tensor.device.device_id,
161 shape,
162 dtype: tensor.dtype,
163 byte_offset: tensor.byte_offset,
164 device_type_code: tensor.device.device_type,
165 })
166}
167
168pub unsafe fn dlpack_auto_dispatch_f32<'a>(
179 tensor: *const DLTensor,
180) -> Result<DLPackDispatchResult<'a, f32>, DlpackError> {
181 let t = unsafe { &*tensor };
183 match t.device.device_type {
184 dt if dt == DLDeviceType::Cpu as i32 => {
185 let view = unsafe { array_from_dlpack_f32(tensor)? };
187 Ok(DLPackDispatchResult::Cpu(view))
188 }
189 _ => {
190 let info = cuda_tensor_info_from_dltensor(t)?;
191 Ok(DLPackDispatchResult::Gpu(info))
192 }
193 }
194}
195
196pub unsafe fn dlpack_auto_dispatch_f64<'a>(
204 tensor: *const DLTensor,
205) -> Result<DLPackDispatchResult<'a, f64>, DlpackError> {
206 let t = unsafe { &*tensor };
208 match t.device.device_type {
209 dt if dt == DLDeviceType::Cpu as i32 => {
210 let view = unsafe { array_from_dlpack_f64(tensor)? };
212 Ok(DLPackDispatchResult::Cpu(view))
213 }
214 _ => {
215 let info = cuda_tensor_info_from_dltensor(t)?;
216 Ok(DLPackDispatchResult::Gpu(info))
217 }
218 }
219}
220
221pub fn cuda_tensor_info(capsule: &Bound<'_, PyAny>) -> PyResult<CudaTensorInfo> {
247 let raw_obj: *mut pyo3::ffi::PyObject = capsule.as_ptr();
249
250 let is_used =
254 unsafe { pyo3::ffi::PyCapsule_IsValid(raw_obj, USED_DLTENSOR_NAME.as_ptr()) == 1 };
255 if is_used {
256 return Err(pyo3::exceptions::PyValueError::new_err(
257 "DLPack capsule has already been consumed ('used_dltensor'). \
258 Call __dlpack__() again on the original tensor.",
259 ));
260 }
261
262 let raw_ptr = unsafe { pyo3::ffi::PyCapsule_GetPointer(raw_obj, DLTENSOR_NAME.as_ptr()) };
264
265 if raw_ptr.is_null() {
266 return Err(PyErr::fetch(capsule.py()));
269 }
270
271 let managed = unsafe { &*(raw_ptr as *const DLManagedTensor) };
277 let dl_tensor = &managed.dl_tensor;
278
279 cuda_tensor_info_from_dltensor(dl_tensor).map_err(|e| match e {
281 DlpackError::NonCpuDevice => pyo3::exceptions::PyValueError::new_err(
282 "cuda_tensor_info requires a non-CPU DLPack tensor. \
283 Use the standard DLPack CPU path for CPU tensors.",
284 ),
285 DlpackError::NullPointer => {
286 pyo3::exceptions::PyValueError::new_err("DLPack tensor has a null data pointer.")
287 }
288 other => pyo3::exceptions::PyValueError::new_err(format!("DLPack error: {other}")),
289 })
290}
291
292#[pyfunction]
307pub fn get_cuda_tensor_info(py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult<Py<PyDict>> {
308 let info = cuda_tensor_info(obj)?;
309
310 let dict = PyDict::new(py);
311 dict.set_item("device_id", info.device_id)?;
312 dict.set_item("shape", info.shape.clone())?;
313 dict.set_item("device_type", info.device_type_code)?;
314 dict.set_item("device_str", info.device_str())?;
315 Ok(dict.into())
316}
317
318pub fn register_dlpack_cuda_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
322 m.add_function(wrap_pyfunction!(get_cuda_tensor_info, m)?)?;
323 Ok(())
324}
325
326#[cfg(test)]
332fn make_non_cpu_dltensor(device_type: i32, device_id: i32, shape: &[i64]) -> DLTensor {
333 use crate::dlpack::{DLDataTypeCode, DLDevice};
334 use std::ffi::c_void;
335 static SENTINEL: u8 = 0;
337 DLTensor {
338 data: &SENTINEL as *const u8 as *mut c_void,
339 device: DLDevice {
340 device_type,
341 device_id,
342 },
343 ndim: shape.len() as i32,
344 dtype: DLDataType {
345 code: DLDataTypeCode::Float as u8,
346 bits: 32,
347 lanes: 1,
348 },
349 shape: shape.as_ptr() as *mut i64,
350 strides: std::ptr::null_mut(),
351 byte_offset: 0,
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use crate::dlpack::{dlpack_from_slice, DLDeviceType};
359 use std::ffi::c_void;
360
361 #[test]
364 fn test_cuda_tensor_info_rejects_cpu_device() {
365 let data = [1.0_f64, 2.0, 3.0];
366 let shape = [3_i64];
367 let tensor = dlpack_from_slice(&data, &shape);
368 let result = cuda_tensor_info_from_dltensor(&tensor);
370 assert!(
371 matches!(result, Err(DlpackError::NonCpuDevice)),
372 "CPU tensor should be rejected by cuda_tensor_info_from_dltensor"
373 );
374 }
375
376 #[test]
377 fn test_cuda_tensor_info_rejects_null_data() {
378 let shape = [4_i64, 4];
379 let mut tensor = make_non_cpu_dltensor(2, 0, &shape);
380 tensor.data = std::ptr::null_mut();
381 let result = cuda_tensor_info_from_dltensor(&tensor);
382 assert!(
383 matches!(result, Err(DlpackError::NullPointer)),
384 "null data pointer should be rejected"
385 );
386 }
387
388 #[test]
389 fn test_cuda_tensor_info_extracts_shape() {
390 let shape = [3_i64, 4, 5];
391 let tensor = make_non_cpu_dltensor(2, 0, &shape);
392 let info = cuda_tensor_info_from_dltensor(&tensor)
393 .expect("CUDA tensor should produce CudaTensorInfo");
394 assert_eq!(info.shape, vec![3, 4, 5], "shape mismatch");
395 assert_eq!(info.numel(), 60, "numel mismatch");
396 }
397
398 #[test]
399 fn test_cuda_tensor_info_extracts_device_id() {
400 let shape = [8_i64];
401 let tensor = make_non_cpu_dltensor(2, 3, &shape); let info = cuda_tensor_info_from_dltensor(&tensor).expect("should produce CudaTensorInfo");
403 assert_eq!(info.device_id, 3, "device_id mismatch");
404 assert_eq!(
405 info.device_type_code, 2,
406 "device_type_code should be CUDA (2)"
407 );
408 }
409
410 #[test]
411 fn test_cuda_tensor_info_device_str() {
412 let shape = [1_i64];
413 let tensor = make_non_cpu_dltensor(2, 0, &shape);
414 let info = cuda_tensor_info_from_dltensor(&tensor).expect("should succeed");
415 assert_eq!(info.device_str(), "cuda:0");
416 }
417
418 #[test]
419 fn test_rocm_tensor_info_device_str() {
420 let shape = [1_i64];
421 let tensor = make_non_cpu_dltensor(10, 1, &shape); let info = cuda_tensor_info_from_dltensor(&tensor).expect("should succeed");
423 assert_eq!(info.device_str(), "rocm:1");
424 }
425
426 #[test]
427 fn test_cuda_tensor_info_zero_dim_tensor() {
428 use crate::dlpack::{DLDataType, DLDataTypeCode};
430 static SENTINEL: u8 = 0;
431 let tensor = DLTensor {
432 data: &SENTINEL as *const u8 as *mut c_void,
433 device: crate::dlpack::DLDevice {
434 device_type: 2,
435 device_id: 0,
436 },
437 ndim: 0,
438 dtype: DLDataType {
439 code: DLDataTypeCode::Float as u8,
440 bits: 32,
441 lanes: 1,
442 },
443 shape: std::ptr::null_mut(),
444 strides: std::ptr::null_mut(),
445 byte_offset: 0,
446 };
447 let info = cuda_tensor_info_from_dltensor(&tensor).expect("zero-dim should succeed");
448 assert!(info.shape.is_empty(), "zero-dim shape should be empty");
449 assert_eq!(info.numel(), 1, "empty product is 1");
450 }
451
452 #[test]
455 fn test_dlpack_auto_dispatch_cpu_f32_returns_array() {
456 let data = [1.0_f32, 2.0, 3.0, 4.0];
457 let shape = [2_i64, 2];
458 let tensor = crate::dlpack::DLTensor {
459 data: data.as_ptr() as *mut c_void,
460 device: crate::dlpack::DLDevice {
461 device_type: DLDeviceType::Cpu as i32,
462 device_id: 0,
463 },
464 ndim: 2,
465 dtype: crate::dlpack::DLDataType {
466 code: crate::dlpack::DLDataTypeCode::Float as u8,
467 bits: 32,
468 lanes: 1,
469 },
470 shape: shape.as_ptr() as *mut i64,
471 strides: std::ptr::null_mut(),
472 byte_offset: 0,
473 };
474 let result = unsafe { dlpack_auto_dispatch_f32(&tensor as *const _) }
476 .expect("CPU dispatch should succeed");
477 assert!(
478 matches!(result, DLPackDispatchResult::Cpu(_)),
479 "CPU tensor should return Cpu variant"
480 );
481 if let DLPackDispatchResult::Cpu(view) = result {
482 assert_eq!(view.shape(), &[2, 2]);
483 assert_eq!(view[[0, 0]], 1.0_f32);
484 }
485 }
486
487 #[test]
488 fn test_dlpack_auto_dispatch_cuda_f32_returns_gpu_info() {
489 let shape = [8_i64];
490 let tensor = make_non_cpu_dltensor(DLDeviceType::Cuda as i32, 0, &shape);
491 let result = unsafe { dlpack_auto_dispatch_f32(&tensor as *const _) }
493 .expect("CUDA dispatch should succeed");
494 assert!(
495 matches!(result, DLPackDispatchResult::Gpu(_)),
496 "CUDA tensor should return Gpu variant"
497 );
498 if let DLPackDispatchResult::Gpu(info) = result {
499 assert_eq!(info.shape, vec![8]);
500 assert_eq!(info.device_type_code, 2);
501 }
502 }
503
504 #[test]
505 fn test_dlpack_auto_dispatch_cpu_f64_returns_array() {
506 let data = [10.0_f64, 20.0, 30.0];
507 let shape = [3_i64];
508 let tensor = dlpack_from_slice(&data, &shape);
509 let result = unsafe { dlpack_auto_dispatch_f64(&tensor as *const _) }
511 .expect("CPU f64 dispatch should succeed");
512 assert!(
513 matches!(result, DLPackDispatchResult::Cpu(_)),
514 "CPU f64 tensor should return Cpu variant"
515 );
516 }
517
518 #[test]
519 fn test_dlpack_auto_dispatch_cuda_f64_returns_gpu_info() {
520 let shape = [4_i64, 4];
521 let tensor = make_non_cpu_dltensor(DLDeviceType::Cuda as i32, 1, &shape);
522 let result = unsafe { dlpack_auto_dispatch_f64(&tensor as *const _) }
524 .expect("CUDA f64 dispatch should succeed");
525 if let DLPackDispatchResult::Gpu(info) = result {
526 assert_eq!(info.shape, vec![4, 4]);
527 assert_eq!(info.device_id, 1);
528 } else {
529 panic!("expected Gpu variant");
530 }
531 }
532
533 #[test]
534 fn test_dlpack_other_device_passthrough() {
535 let shape = [16_i64];
537 let tensor = make_non_cpu_dltensor(DLDeviceType::Metal as i32, 0, &shape);
538 let result = unsafe { dlpack_auto_dispatch_f32(&tensor as *const _) }
540 .expect("Metal dispatch should succeed");
541 assert!(
542 matches!(result, DLPackDispatchResult::Gpu(_)),
543 "Metal tensor should return Gpu variant"
544 );
545 if let DLPackDispatchResult::Gpu(info) = result {
546 assert_eq!(info.device_str(), "metal:0");
547 }
548 }
549
550 #[test]
551 fn test_dlpack_rocm_passthrough() {
552 let shape = [32_i64];
553 let tensor = make_non_cpu_dltensor(DLDeviceType::Rocm as i32, 2, &shape);
554 let result = unsafe { dlpack_auto_dispatch_f64(&tensor as *const _) }
556 .expect("ROCm dispatch should succeed");
557 if let DLPackDispatchResult::Gpu(info) = result {
558 assert_eq!(info.device_type_code, 10);
559 assert_eq!(info.device_id, 2);
560 } else {
561 panic!("expected Gpu variant for ROCm device");
562 }
563 }
564
565 #[test]
566 fn test_cuda_tensor_numel_empty_shape() {
567 let shape: [i64; 0] = [];
568 let tensor = make_non_cpu_dltensor(2, 0, &shape);
569 let info = cuda_tensor_info_from_dltensor(&tensor).expect("should succeed");
571 assert_eq!(info.numel(), 1, "empty shape product is 1");
572 }
573
574 #[test]
575 fn test_cuda_tensor_dtype_bits() {
576 let shape = [4_i64];
577 let tensor = make_non_cpu_dltensor(2, 0, &shape);
578 let info = cuda_tensor_info_from_dltensor(&tensor).expect("should succeed");
579 assert_eq!(info.dtype_bits(), 32, "dtype bits should be 32");
580 }
581}