torsh_python/utils/
conversion.rs1use crate::{device::PyDevice, dtype::PyDType, error::PyResult};
4use numpy::{PyArray1, PyArray2, PyArrayDyn, PyArrayMethods, PyUntypedArrayMethods};
5use pyo3::prelude::*;
6use pyo3::types::PyList;
7use torsh_core::{device::DeviceType, dtype::DType};
8
9pub fn python_list_to_vec(list: &Bound<'_, PyList>) -> PyResult<Vec<f32>> {
11 let mut data = Vec::new();
12 let len = list.len();
13
14 for i in 0..len {
15 let item = list.get_item(i)?;
16 if let Ok(val) = item.extract::<f32>() {
17 data.push(val);
18 } else {
19 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
20 "Cannot convert item at index {} to f32",
21 i
22 )));
23 }
24 }
25
26 Ok(data)
27}
28
29pub fn parse_device_string(device_str: &str) -> PyResult<DeviceType> {
31 match device_str {
32 "cpu" => Ok(DeviceType::Cpu),
33 "cuda" | "cuda:0" => Ok(DeviceType::Cuda(0)),
34 "metal" | "metal:0" => Ok(DeviceType::Metal(0)),
35 s if s.starts_with("cuda:") => {
36 let id: usize = s[5..].parse().map_err(|_| {
37 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
38 "Invalid CUDA device ID: {}",
39 &s[5..]
40 ))
41 })?;
42 Ok(DeviceType::Cuda(id))
43 }
44 s if s.starts_with("metal:") => {
45 let id: usize = s[6..].parse().map_err(|_| {
46 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
47 "Invalid Metal device ID: {}",
48 &s[6..]
49 ))
50 })?;
51 Ok(DeviceType::Metal(id))
52 }
53 _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
54 "Unknown device: {}",
55 device_str
56 ))),
57 }
58}
59
60pub fn parse_dtype_string(dtype_str: &str) -> PyResult<DType> {
62 match dtype_str {
63 "float32" | "f32" => Ok(DType::F32),
64 "float64" | "f64" => Ok(DType::F64),
65 "int8" | "i8" => Ok(DType::I8),
66 "int16" | "i16" => Ok(DType::I16),
67 "int32" | "i32" => Ok(DType::I32),
68 "int64" | "i64" => Ok(DType::I64),
69 "uint8" | "u8" => Ok(DType::U8),
70 "uint32" | "u32" => Ok(DType::U32),
71 "uint64" | "u64" => Ok(DType::U64),
72 "bool" => Ok(DType::Bool),
73 "float16" | "f16" => Ok(DType::F16),
74 "bfloat16" | "bf16" => Ok(DType::BF16),
75 _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
76 "Unknown dtype: {}",
77 dtype_str
78 ))),
79 }
80}
81
82pub fn extract_shape(shape_obj: &Bound<'_, PyAny>) -> PyResult<Vec<usize>> {
84 if let Ok(shape_list) = shape_obj.extract::<Vec<i32>>() {
85 Ok(shape_list.into_iter().map(|s| s as usize).collect())
86 } else if let Ok(shape_tuple) = shape_obj.extract::<(i32,)>() {
87 Ok(vec![shape_tuple.0 as usize])
88 } else if let Ok(single_dim) = shape_obj.extract::<i32>() {
89 Ok(vec![single_dim as usize])
90 } else {
91 Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
92 "Shape must be an integer, tuple, or list of integers",
93 ))
94 }
95}