torsh_python/utils/
conversion.rs

1//! Data conversion utilities for Python-Rust interop
2
3use crate::{device::PyDevice, dtype::PyDType, error::PyResult};
4use numpy::{PyArray1, PyArray2, PyArrayDyn, PyArrayMethods, PyUntypedArrayMethods};
5use pyo3::prelude::*;
6use pyo3::types::PyList;
7use torsh_core::{device::DeviceType, dtype::DType};
8
9/// Convert Python list to Vec<f32>
10pub fn python_list_to_vec(list: &Bound<'_, PyList>) -> PyResult<Vec<f32>> {
11    let mut data = Vec::new();
12    let len = list.len();
13
14    for i in 0..len {
15        let item = list.get_item(i)?;
16        if let Ok(val) = item.extract::<f32>() {
17            data.push(val);
18        } else {
19            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
20                "Cannot convert item at index {} to f32",
21                i
22            )));
23        }
24    }
25
26    Ok(data)
27}
28
29/// Convert device string to DeviceType
30pub fn parse_device_string(device_str: &str) -> PyResult<DeviceType> {
31    match device_str {
32        "cpu" => Ok(DeviceType::Cpu),
33        "cuda" | "cuda:0" => Ok(DeviceType::Cuda(0)),
34        "metal" | "metal:0" => Ok(DeviceType::Metal(0)),
35        s if s.starts_with("cuda:") => {
36            let id: usize = s[5..].parse().map_err(|_| {
37                PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
38                    "Invalid CUDA device ID: {}",
39                    &s[5..]
40                ))
41            })?;
42            Ok(DeviceType::Cuda(id))
43        }
44        s if s.starts_with("metal:") => {
45            let id: usize = s[6..].parse().map_err(|_| {
46                PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
47                    "Invalid Metal device ID: {}",
48                    &s[6..]
49                ))
50            })?;
51            Ok(DeviceType::Metal(id))
52        }
53        _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
54            "Unknown device: {}",
55            device_str
56        ))),
57    }
58}
59
60/// Convert dtype string to DType
61pub fn parse_dtype_string(dtype_str: &str) -> PyResult<DType> {
62    match dtype_str {
63        "float32" | "f32" => Ok(DType::F32),
64        "float64" | "f64" => Ok(DType::F64),
65        "int8" | "i8" => Ok(DType::I8),
66        "int16" | "i16" => Ok(DType::I16),
67        "int32" | "i32" => Ok(DType::I32),
68        "int64" | "i64" => Ok(DType::I64),
69        "uint8" | "u8" => Ok(DType::U8),
70        "uint32" | "u32" => Ok(DType::U32),
71        "uint64" | "u64" => Ok(DType::U64),
72        "bool" => Ok(DType::Bool),
73        "float16" | "f16" => Ok(DType::F16),
74        "bfloat16" | "bf16" => Ok(DType::BF16),
75        _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
76            "Unknown dtype: {}",
77            dtype_str
78        ))),
79    }
80}
81
82/// Convert Python objects to shape vector
83pub fn extract_shape(shape_obj: &Bound<'_, PyAny>) -> PyResult<Vec<usize>> {
84    if let Ok(shape_list) = shape_obj.extract::<Vec<i32>>() {
85        Ok(shape_list.into_iter().map(|s| s as usize).collect())
86    } else if let Ok(shape_tuple) = shape_obj.extract::<(i32,)>() {
87        Ok(vec![shape_tuple.0 as usize])
88    } else if let Ok(single_dim) = shape_obj.extract::<i32>() {
89        Ok(vec![single_dim as usize])
90    } else {
91        Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
92            "Shape must be an integer, tuple, or list of integers",
93        ))
94    }
95}