1use serde::{Deserialize, Serialize};
7use serde::{Deserializer, Serializer};
8use std::sync::Arc;
9
10use crate::buffer::{BufferHandle, CpuBuffer};
11use crate::dtype::DType;
12use crate::error::{Result, SapientError};
13use crate::shape::Shape;
14
15#[derive(Debug, Clone)]
19pub struct Tensor {
20 shape: Shape,
21 dtype: DType,
22 strides: Vec<usize>, buffer: BufferHandle,
24 offset: usize,
26}
27
28impl Tensor {
29 pub fn zeros(shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
33 let shape = shape.into();
34 shape.validate()?;
35 let numel = shape.numel();
36 let strides = shape.strides();
37 let buffer = BufferHandle::new(CpuBuffer::zeros(numel, dtype)?);
38 Ok(Self {
39 shape,
40 dtype,
41 strides,
42 buffer,
43 offset: 0,
44 })
45 }
46
47 pub fn from_f32(data: &[f32], shape: impl Into<Shape>) -> Result<Self> {
49 let shape = shape.into();
50 shape.validate()?;
51 if data.len() != shape.numel() {
52 return Err(SapientError::ShapeMismatch {
53 expected: shape.dims().to_vec(),
54 got: vec![data.len()],
55 });
56 }
57 let strides = shape.strides();
58 let buffer = BufferHandle::new(CpuBuffer::from_f32_slice(data)?);
59 Ok(Self {
60 shape,
61 dtype: DType::F32,
62 strides,
63 buffer,
64 offset: 0,
65 })
66 }
67
68 pub fn from_bf16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
71 let shape = shape.into();
72 shape.validate()?;
73 let expected_bytes = shape.numel() * 2;
74 if data.len() != expected_bytes {
75 return Err(SapientError::ShapeMismatch {
76 expected: shape.dims().to_vec(),
77 got: vec![data.len() / 2],
78 });
79 }
80 let strides = shape.strides();
81 let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
82 Ok(Self {
83 shape,
84 dtype: DType::BF16,
85 strides,
86 buffer,
87 offset: 0,
88 })
89 }
90
91 pub fn from_f16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
93 let shape = shape.into();
94 shape.validate()?;
95 let expected_bytes = shape.numel() * 2;
96 if data.len() != expected_bytes {
97 return Err(SapientError::ShapeMismatch {
98 expected: shape.dims().to_vec(),
99 got: vec![data.len() / 2],
100 });
101 }
102 let strides = shape.strides();
103 let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
104 Ok(Self {
105 shape,
106 dtype: DType::F16,
107 strides,
108 buffer,
109 offset: 0,
110 })
111 }
112
113 pub fn scalar_f32(v: f32) -> Result<Self> {
115 Self::from_f32(&[v], Shape::scalar())
116 }
117
118 pub fn from_buffer(
120 shape: impl Into<Shape>,
121 dtype: DType,
122 buffer: BufferHandle,
123 offset: usize,
124 ) -> Result<Self> {
125 let shape = shape.into();
126 shape.validate()?;
127 let required = dtype.byte_count(shape.numel());
128 if buffer.len() < offset + required {
129 return Err(SapientError::BufferSizeMismatch {
130 expected: offset + required,
131 got: buffer.len(),
132 });
133 }
134 let strides = shape.strides();
135 Ok(Self {
136 shape,
137 dtype,
138 strides,
139 buffer,
140 offset,
141 })
142 }
143
144 pub fn shape(&self) -> &Shape {
147 &self.shape
148 }
149 pub fn dtype(&self) -> DType {
150 self.dtype
151 }
152 pub fn ndim(&self) -> usize {
153 self.shape.ndim()
154 }
155 pub fn numel(&self) -> usize {
156 self.shape.numel()
157 }
158 pub fn strides(&self) -> &[usize] {
159 &self.strides
160 }
161 pub fn buffer(&self) -> &BufferHandle {
162 &self.buffer
163 }
164 pub fn offset(&self) -> usize {
165 self.offset
166 }
167
168 pub fn is_scalar(&self) -> bool {
170 self.shape.is_scalar() || self.numel() == 1
171 }
172
173 pub fn is_contiguous(&self) -> bool {
175 self.strides == self.shape.strides() && self.offset == 0
176 }
177
178 pub fn as_bytes(&self) -> &[u8] {
182 let bytes = self.buffer.as_bytes();
183 &bytes[self.offset..]
184 }
185
186 pub fn as_f32_slice(&self) -> &[f32] {
188 assert_eq!(
189 self.dtype,
190 DType::F32,
191 "Tensor dtype is not F32 — call to_f32_vec() instead"
192 );
193 let bytes = self.as_bytes();
194 assert_eq!(bytes.len() % 4, 0);
195 unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4) }
197 }
198
199 pub fn to_f32_cow(&self) -> std::borrow::Cow<'_, [f32]> {
201 if self.dtype == DType::F32 {
202 std::borrow::Cow::Borrowed(self.as_f32_slice())
203 } else {
204 std::borrow::Cow::Owned(self.to_f32_vec())
205 }
206 }
207
208 pub fn to_f32_vec(&self) -> Vec<f32> {
211 match self.dtype {
212 DType::F32 => self.as_f32_slice().to_vec(),
213 DType::BF16 => {
214 let bytes = self.as_bytes();
215 bytes
216 .chunks_exact(2)
217 .map(|c| f32::from(half::bf16::from_le_bytes(c.try_into().unwrap())))
218 .collect()
219 }
220 DType::F16 => {
221 let bytes = self.as_bytes();
222 bytes
223 .chunks_exact(2)
224 .map(|c| half::f16::from_le_bytes(c.try_into().unwrap()).to_f32())
225 .collect()
226 }
227 _ => self.as_f32_slice().to_vec(), }
229 }
230
231 pub fn to_f32_tensor(&self) -> Result<Tensor> {
234 match self.dtype {
235 DType::F32 => Ok(self.clone()),
236 _ => Tensor::from_f32(&self.to_f32_vec(), self.shape.clone()),
237 }
238 }
239
240 pub fn as_f32_slice_mut(&mut self) -> Result<&mut [f32]> {
242 if self.dtype != DType::F32 {
243 return Err(SapientError::internal("Tensor dtype is not F32"));
244 }
245 let offset = self.offset;
246 let buf = Arc::get_mut(&mut self.buffer.0)
247 .ok_or_else(|| SapientError::internal("Cannot mutate shared tensor buffer"))?;
248 let bytes = buf.as_bytes_mut();
249 let bytes = &mut bytes[offset..];
250 if bytes.len() % 4 != 0 {
251 return Err(SapientError::internal("Buffer length not a multiple of 4"));
252 }
253 Ok(unsafe {
255 std::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut f32, bytes.len() / 4)
256 })
257 }
258
259 pub fn reshape(&self, new_shape: impl Into<Shape>) -> Result<Tensor> {
264 let new_shape = self.shape.reshape(new_shape.into().dims().to_vec())?;
265 let strides = new_shape.strides();
266 Ok(Tensor {
267 shape: new_shape,
268 dtype: self.dtype,
269 strides,
270 buffer: self.buffer.clone(),
271 offset: self.offset,
272 })
273 }
274
275 pub fn t(&self) -> Result<Tensor> {
277 if self.ndim() != 2 {
278 return Err(SapientError::internal("t() requires a 2-D tensor"));
279 }
280 let mut dims = self.shape.dims().to_vec();
281 let mut strides = self.strides.clone();
282 dims.swap(0, 1);
283 strides.swap(0, 1);
284 Ok(Tensor {
285 shape: Shape(dims),
286 dtype: self.dtype,
287 strides,
288 buffer: self.buffer.clone(),
289 offset: self.offset,
290 })
291 }
292
293 pub fn slice_axis(&self, axis: usize, start: usize, end: usize) -> Result<Tensor> {
295 let mut dims = self.shape.dims().to_vec();
296 if axis >= dims.len() {
297 return Err(SapientError::internal("slice axis out of bounds"));
298 }
299 if start > end || end > dims[axis] {
300 return Err(SapientError::internal("slice range out of bounds"));
301 }
302 dims[axis] = end - start;
303 let offset = self.offset + start * self.strides[axis] * self.dtype.element_size();
304 Ok(Tensor {
305 shape: Shape(dims),
306 dtype: self.dtype,
307 strides: self.strides.clone(),
308 buffer: self.buffer.clone(),
309 offset,
310 })
311 }
312
313 pub fn byte_size(&self) -> usize {
317 self.dtype.byte_count(self.numel())
318 }
319}
320
321impl std::fmt::Display for Tensor {
322 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323 write!(
324 f,
325 "Tensor(shape={}, dtype={}, device={})",
326 self.shape,
327 self.dtype,
328 self.buffer.0.device()
329 )
330 }
331}
332
333#[derive(Serialize, Deserialize)]
337struct TensorProxy {
338 shape: Shape,
339 dtype: DType,
340 data: Vec<f32>,
342}
343
344impl Serialize for Tensor {
345 fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
346 let data: Vec<f32> = if self.dtype == DType::F32 {
347 self.as_f32_slice().to_vec()
348 } else {
349 vec![] };
351 TensorProxy {
352 shape: self.shape.clone(),
353 dtype: self.dtype,
354 data,
355 }
356 .serialize(serializer)
357 }
358}
359
360impl<'de> Deserialize<'de> for Tensor {
361 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
362 let proxy = TensorProxy::deserialize(deserializer)?;
363 if proxy.data.is_empty() {
364 Tensor::zeros(proxy.shape, proxy.dtype).map_err(serde::de::Error::custom)
365 } else {
366 Tensor::from_f32(&proxy.data, proxy.shape).map_err(serde::de::Error::custom)
367 }
368 }
369}
370
371#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct TensorMeta {
374 pub shape: Shape,
375 pub dtype: DType,
376}
377
378impl From<&Tensor> for TensorMeta {
379 fn from(t: &Tensor) -> Self {
380 Self {
381 shape: t.shape.clone(),
382 dtype: t.dtype,
383 }
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn zeros_dtype_shape() {
393 let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
394 assert_eq!(t.shape().dims(), &[2, 3]);
395 assert_eq!(t.dtype(), DType::F32);
396 assert_eq!(t.numel(), 6);
397 }
398
399 #[test]
400 fn from_f32_roundtrip() {
401 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
402 let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
403 assert_eq!(t.as_f32_slice(), data.as_slice());
404 }
405
406 #[test]
407 fn reshape_preserves_data() {
408 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
409 let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
410 let r = t.reshape(vec![3, 2]).unwrap();
411 assert_eq!(r.shape().dims(), &[3, 2]);
412 assert_eq!(r.as_f32_slice(), data.as_slice());
413 }
414
415 #[test]
416 fn reshape_wrong_numel() {
417 let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
418 assert!(t.reshape(vec![5]).is_err());
419 }
420
421 #[test]
422 fn transpose_2d() {
423 let t = Tensor::zeros(vec![3, 4], DType::F32).unwrap();
424 let t2 = t.t().unwrap();
425 assert_eq!(t2.shape().dims(), &[4, 3]);
426 }
427
428 #[test]
429 fn byte_size() {
430 let t = Tensor::zeros(vec![4, 4], DType::F32).unwrap();
431 assert_eq!(t.byte_size(), 64);
432 }
433}