1use bon::bon;
2use snafu::ResultExt;
3use std::sync::Arc;
4
5use svod_device::{Buffer, registry};
6use svod_dtype::DType;
7use svod_dtype::ext::HasDType;
8use svod_ir::{DeviceSpec, SInt, UOp, shape::Shape};
9
10use crate::Tensor;
11use crate::error::*;
12use crate::tensor_registry;
13
14#[bon]
15impl Tensor {
16 pub fn from_slice<T: HasDType, C: AsRef<[T]>>(source: C) -> Self {
24 let source = source.as_ref();
25 Self::from_bytes_shaped(
26 unsafe { std::slice::from_raw_parts(source.as_ptr() as *const u8, source.len() * T::DTYPE.bytes()) },
27 &[source.len()],
28 T::DTYPE,
29 DeviceSpec::Cpu,
30 )
31 }
32
33 #[builder]
35 pub fn from_slice_with<T: HasDType, C: AsRef<[T]>>(
36 source: C,
37 #[builder(default = DeviceSpec::Cpu)] device: DeviceSpec,
38 ) -> Self {
39 let source = source.as_ref();
40 Self::from_bytes_shaped(
41 unsafe { std::slice::from_raw_parts(source.as_ptr() as *const u8, source.len() * T::DTYPE.bytes()) },
42 &[source.len()],
43 T::DTYPE,
44 device,
45 )
46 }
47}
48
49impl Tensor {
50 fn from_bytes_shaped(bytes: &[u8], shape: &[usize], dtype: DType, device: DeviceSpec) -> Self {
55 let numel: usize = shape.iter().product();
56 let ir_shape = Shape::from_iter(shape.iter().map(|&d| SInt::Const(d)));
57
58 let buffer_uop = UOp::new_buffer(device.clone(), numel, dtype.clone());
59 let buffer_uop_id = buffer_uop.id;
60
61 let allocator = match &device {
62 DeviceSpec::Cpu => registry::cpu().expect("CPU always should be accessible"),
63 _ => registry::cpu().expect("CPU fallback for unsupported device"),
64 };
65
66 let mut buffer = Buffer::new(allocator, dtype.clone(), shape.to_vec(), Default::default());
67 buffer.copyin(bytes).expect("Buffer write always successful");
68
69 let buffer_arc = Arc::new(buffer);
70 let uop = buffer_uop.try_reshape(&ir_shape).expect("shape matches element count");
71
72 let entry = tensor_registry::register_tensor_with_buffer(uop, buffer_arc.clone(), buffer_uop_id);
73 Self::with_buffer(entry, buffer_arc)
74 }
75
76 pub fn from_raw_bytes(data: &[u8], shape: &[usize], dtype: DType) -> Result<Self> {
82 let numel: usize = shape.iter().product();
83 let expected_bytes = numel * dtype.bytes();
84 if data.len() != expected_bytes {
85 return Err(Error::IrConstruction {
86 details: format!(
87 "from_raw_bytes: data length {} != expected {} ({} elements * {} bytes)",
88 data.len(),
89 expected_bytes,
90 numel,
91 dtype.bytes()
92 ),
93 });
94 }
95 Ok(Self::from_bytes_shaped(data, shape, dtype, DeviceSpec::Cpu))
96 }
97
98 pub fn from_ndarray<T, S, D>(array: &ndarray::ArrayBase<S, D>) -> Self
113 where
114 T: HasDType + Clone,
115 S: ndarray::Data<Elem = T>,
116 D: ndarray::Dimension,
117 {
118 let shape: Vec<usize> = array.shape().to_vec();
119 if array.is_empty() {
120 let t = Self::empty_zero(T::DTYPE);
121 if shape.len() <= 1 {
122 return t;
123 }
124 let isize_shape: Vec<isize> = shape.iter().map(|&d| d as isize).collect();
125 return t.try_reshape(&isize_shape).expect("empty reshape matches");
126 }
127 if let Some(slice) = array.as_slice() {
129 let bytes =
130 unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, slice.len() * T::DTYPE.bytes()) };
131 Self::from_bytes_shaped(bytes, &shape, T::DTYPE, DeviceSpec::Cpu)
132 } else {
133 let data: Vec<T> = array.iter().cloned().collect();
135 let bytes =
136 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * T::DTYPE.bytes()) };
137 Self::from_bytes_shaped(bytes, &shape, T::DTYPE, DeviceSpec::Cpu)
138 }
139 }
140
141 pub fn buffer(&self) -> Option<Buffer> {
146 if let Some(buf) = self.buffer.as_ref().or_else(|| self.entry.buffer()) {
148 return Some((**buf).clone());
149 }
150 crate::tensor_registry::get_buffer_arc(self.uop().base().id).map(|arc| (*arc).clone())
151 }
152
153 pub fn as_ndarray<T: HasDType + Default + Clone>(&self) -> Result<ndarray::ArrayD<T>> {
166 use ndarray::{ArrayD, IxDyn};
167
168 let uop = self.uop();
169 let shape = uop.shape().context(UOpSnafu)?.ok_or(Error::NoShape)?;
170
171 if shape.iter().any(|dim| dim.as_const().is_none()) {
173 return SymbolicShapeSnafu.fail();
174 }
175
176 let dims: Vec<usize> = shape.iter().map(|dim| dim.as_const().unwrap()).collect();
177
178 if dims.contains(&0) {
179 let arr = ArrayD::from_shape_vec(IxDyn(&dims), vec![]).context(NdarrayShapeSnafu)?;
180 return Ok(arr);
181 }
182
183 let buffer = self.buffer().ok_or(Error::NoBuffer)?;
184
185 if buffer.dtype() != T::DTYPE {
186 return TypeMismatchSnafu { expected: T::DTYPE, actual: buffer.dtype() }.fail();
187 }
188
189 let count = buffer.size() / T::DTYPE.bytes();
190 let mut data = vec![T::default(); count];
191 buffer
192 .copyout(unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, count * T::DTYPE.bytes()) })
193 .context(DeviceSnafu)?;
194
195 let arr = ArrayD::from_shape_vec(IxDyn(&dims), data).context(NdarrayShapeSnafu)?;
196 Ok(arr)
197 }
198
199 pub fn as_vec<T: HasDType + Default + Clone>(&self) -> Result<Vec<T>> {
212 let uop = self.uop();
213 if let Ok(Some(shape)) = uop.shape() {
214 if shape.iter().any(|dim| dim.as_const().is_none()) {
216 return SymbolicShapeSnafu.fail();
217 }
218 if shape.iter().any(|dim| dim.as_const() == Some(0)) {
219 return Ok(vec![]);
220 }
221 }
222
223 let buffer = self.buffer().ok_or(Error::NoBuffer)?;
224
225 if buffer.dtype() != T::DTYPE {
226 return TypeMismatchSnafu { expected: T::DTYPE, actual: buffer.dtype() }.fail();
227 }
228
229 let count = buffer.size() / T::DTYPE.bytes();
230 let mut data = vec![T::default(); count];
231 buffer
232 .copyout(unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, count * T::DTYPE.bytes()) })
233 .context(DeviceSnafu)?;
234
235 Ok(data)
236 }
237
238 pub fn array_view<T: HasDType>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
252 let buffer_arc = self.buffer.as_ref().or_else(|| self.entry.buffer()).ok_or(Error::NoBuffer)?;
253 let flat = buffer_arc.as_array::<T>().context(DeviceSnafu)?;
254 if let Ok(shape) = self.shape() {
256 let dims: Vec<usize> = shape.iter().filter_map(|d| d.as_const()).collect();
257 if dims.len() == shape.len() {
258 return flat.into_shape_with_order(ndarray::IxDyn(&dims)).context(NdarrayShapeSnafu);
259 }
260 }
261 Ok(flat)
262 }
263
264 pub fn array_view_mut<T: HasDType>(&self) -> Result<ndarray::ArrayViewMutD<'_, T>> {
275 let buffer_arc = self.buffer.as_ref().or_else(|| self.entry.buffer()).ok_or(Error::NoBuffer)?;
276 let flat = buffer_arc.as_array_mut::<T>().context(DeviceSnafu)?;
277 if let Ok(shape) = self.shape() {
278 let dims: Vec<usize> = shape.iter().filter_map(|d| d.as_const()).collect();
279 if dims.len() == shape.len() {
280 return flat.into_shape_with_order(ndarray::IxDyn(&dims)).context(NdarrayShapeSnafu);
281 }
282 }
283 Ok(flat)
284 }
285}