torsh_python/utils/
conversion.rs1use crate::error::PyResult;
4use pyo3::prelude::*;
5use pyo3::types::PyList;
6use torsh_core::{device::DeviceType, dtype::DType};
7
8pub fn python_list_to_vec(list: &Bound<'_, PyList>) -> PyResult<Vec<f32>> {
10 let mut data = Vec::new();
11 let len = list.len();
12
13 for i in 0..len {
14 let item = list.get_item(i)?;
15 if let Ok(val) = item.extract::<f32>() {
16 data.push(val);
17 } else {
18 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
19 "Cannot convert item at index {} to f32",
20 i
21 )));
22 }
23 }
24
25 Ok(data)
26}
27
28pub fn parse_device_string(device_str: &str) -> PyResult<DeviceType> {
30 match device_str {
31 "cpu" => Ok(DeviceType::Cpu),
32 "cuda" | "cuda:0" => Ok(DeviceType::Cuda(0)),
33 "metal" | "metal:0" => Ok(DeviceType::Metal(0)),
34 s if s.starts_with("cuda:") => {
35 let id: usize = s[5..].parse().map_err(|_| {
36 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
37 "Invalid CUDA device ID: {}",
38 &s[5..]
39 ))
40 })?;
41 Ok(DeviceType::Cuda(id))
42 }
43 s if s.starts_with("metal:") => {
44 let id: usize = s[6..].parse().map_err(|_| {
45 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
46 "Invalid Metal device ID: {}",
47 &s[6..]
48 ))
49 })?;
50 Ok(DeviceType::Metal(id))
51 }
52 _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
53 "Unknown device: {}",
54 device_str
55 ))),
56 }
57}
58
59pub fn parse_dtype_string(dtype_str: &str) -> PyResult<DType> {
61 match dtype_str {
62 "float32" | "f32" => Ok(DType::F32),
63 "float64" | "f64" => Ok(DType::F64),
64 "int8" | "i8" => Ok(DType::I8),
65 "int16" | "i16" => Ok(DType::I16),
66 "int32" | "i32" => Ok(DType::I32),
67 "int64" | "i64" => Ok(DType::I64),
68 "uint8" | "u8" => Ok(DType::U8),
69 "uint32" | "u32" => Ok(DType::U32),
70 "uint64" | "u64" => Ok(DType::U64),
71 "bool" => Ok(DType::Bool),
72 "float16" | "f16" => Ok(DType::F16),
73 "bfloat16" | "bf16" => Ok(DType::BF16),
74 _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
75 "Unknown dtype: {}",
76 dtype_str
77 ))),
78 }
79}
80
81pub fn extract_shape(shape_obj: &Bound<'_, PyAny>) -> PyResult<Vec<usize>> {
83 if let Ok(shape_list) = shape_obj.extract::<Vec<i32>>() {
84 Ok(shape_list.into_iter().map(|s| s as usize).collect())
85 } else if let Ok(shape_tuple) = shape_obj.extract::<(i32,)>() {
86 Ok(vec![shape_tuple.0 as usize])
87 } else if let Ok(single_dim) = shape_obj.extract::<i32>() {
88 Ok(vec![single_dim as usize])
89 } else {
90 Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
91 "Shape must be an integer, tuple, or list of integers",
92 ))
93 }
94}