Skip to main content

torsh_python/utils/
conversion.rs

1//! Data conversion utilities for Python-Rust interop
2
3use crate::error::PyResult;
4use pyo3::prelude::*;
5use pyo3::types::PyList;
6use torsh_core::{device::DeviceType, dtype::DType};
7
8/// Convert Python list to Vec<f32>
9pub fn python_list_to_vec(list: &Bound<'_, PyList>) -> PyResult<Vec<f32>> {
10    let mut data = Vec::new();
11    let len = list.len();
12
13    for i in 0..len {
14        let item = list.get_item(i)?;
15        if let Ok(val) = item.extract::<f32>() {
16            data.push(val);
17        } else {
18            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
19                "Cannot convert item at index {} to f32",
20                i
21            )));
22        }
23    }
24
25    Ok(data)
26}
27
28/// Convert device string to DeviceType
29pub fn parse_device_string(device_str: &str) -> PyResult<DeviceType> {
30    match device_str {
31        "cpu" => Ok(DeviceType::Cpu),
32        "cuda" | "cuda:0" => Ok(DeviceType::Cuda(0)),
33        "metal" | "metal:0" => Ok(DeviceType::Metal(0)),
34        s if s.starts_with("cuda:") => {
35            let id: usize = s[5..].parse().map_err(|_| {
36                PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
37                    "Invalid CUDA device ID: {}",
38                    &s[5..]
39                ))
40            })?;
41            Ok(DeviceType::Cuda(id))
42        }
43        s if s.starts_with("metal:") => {
44            let id: usize = s[6..].parse().map_err(|_| {
45                PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
46                    "Invalid Metal device ID: {}",
47                    &s[6..]
48                ))
49            })?;
50            Ok(DeviceType::Metal(id))
51        }
52        _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
53            "Unknown device: {}",
54            device_str
55        ))),
56    }
57}
58
59/// Convert dtype string to DType
60pub fn parse_dtype_string(dtype_str: &str) -> PyResult<DType> {
61    match dtype_str {
62        "float32" | "f32" => Ok(DType::F32),
63        "float64" | "f64" => Ok(DType::F64),
64        "int8" | "i8" => Ok(DType::I8),
65        "int16" | "i16" => Ok(DType::I16),
66        "int32" | "i32" => Ok(DType::I32),
67        "int64" | "i64" => Ok(DType::I64),
68        "uint8" | "u8" => Ok(DType::U8),
69        "uint32" | "u32" => Ok(DType::U32),
70        "uint64" | "u64" => Ok(DType::U64),
71        "bool" => Ok(DType::Bool),
72        "float16" | "f16" => Ok(DType::F16),
73        "bfloat16" | "bf16" => Ok(DType::BF16),
74        _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
75            "Unknown dtype: {}",
76            dtype_str
77        ))),
78    }
79}
80
81/// Convert Python objects to shape vector
82pub fn extract_shape(shape_obj: &Bound<'_, PyAny>) -> PyResult<Vec<usize>> {
83    if let Ok(shape_list) = shape_obj.extract::<Vec<i32>>() {
84        Ok(shape_list.into_iter().map(|s| s as usize).collect())
85    } else if let Ok(shape_tuple) = shape_obj.extract::<(i32,)>() {
86        Ok(vec![shape_tuple.0 as usize])
87    } else if let Ok(single_dim) = shape_obj.extract::<i32>() {
88        Ok(vec![single_dim as usize])
89    } else {
90        Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
91            "Shape must be an integer, tuple, or list of integers",
92        ))
93    }
94}