1use crate::TVec;
3use crate::blob::Blob;
4use crate::datum::{ClampCast, Datum, DatumType, QParams, round_ties_to_even, scale_by};
5use crate::dim::TDim;
6use crate::internal::*;
7use crate::opaque::Opaque;
8use half::f16;
9use itertools::{Itertools, izip};
10use ndarray::prelude::*;
11#[cfg(feature = "complex")]
12use num_complex::Complex;
13use num_traits::{Float, Zero};
14use std::borrow::Cow;
15use std::fmt;
16use std::hash::Hash;
17use std::ops::Range;
18use std::sync::Arc;
19
20pub mod litteral;
21pub mod view;
22
23#[derive(Copy, Clone, Default, Debug)]
24pub enum Approximation {
25 Exact,
26 #[default]
27 Close,
28 Approximate,
29 VeryApproximate,
30 SuperApproximate,
31 UltraApproximate,
32 Custom(f32, f32, f32),
33}
34
35impl PartialEq for Approximation {
36 fn eq(&self, other: &Self) -> bool {
37 use Approximation::Custom;
38 if let (Custom(aa, ar, ao), Custom(ba, br, bo)) = (self, other) {
39 aa == ba && ar == br && bo == ao
40 } else {
41 std::mem::discriminant(self) == std::mem::discriminant(other)
42 }
43 }
44}
45
46impl Eq for Approximation {}
47
48impl From<bool> for Approximation {
49 fn from(b: bool) -> Self {
50 if b { Self::Approximate } else { Self::Exact }
51 }
52}
53
54impl Approximation {
55 fn atol_rtol_outliers(&self, dt: &DatumType) -> (f64, f64, f64) {
56 use Approximation::*;
57 match (self, dt) {
58 (Exact, _) => (0.0, 0.0, 0.0),
59 (Close, DatumType::F16) => (1e-3, 1e-3, 0.0),
60 (Approximate, DatumType::F16) => (1e-3, 5e-3, 0.0),
61 (Approximate, qp) if qp.is_quantized() => (qp.zp_scale().1 as f64, 0., 0.0),
62 (Close, _) => (1e-7, 1e-7, 0.0),
63 (Approximate, _) => (1e-4, 5e-4, 0.0),
64 (VeryApproximate, _) => (5e-2, 1e-2, 0.0),
65 (SuperApproximate, _) => (0.1, 0.05, 0.0001),
66 (UltraApproximate, _) => (0.2, 0.1, 0.0005),
67 (Custom(atol, rtol, out), _) => (*atol as _, *rtol as _, *out as _),
68 }
69 }
70}
71
72#[derive(Eq)]
74pub struct Tensor {
75 dt: DatumType,
76 shape: TVec<usize>,
77 strides: TVec<isize>,
78 len: usize,
79 data: Blob,
80}
81
82unsafe impl Send for Tensor {}
83unsafe impl Sync for Tensor {}
84
85impl Hash for Tensor {
86 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
87 use DatumType::*;
88 self.dt.hash(state);
89 self.shape.hash(state);
90 self.data.layout().align().hash(state);
91 unsafe {
92 match self.dt {
93 Bool => self.as_slice_unchecked::<bool>().hash(state),
94 I8 => self.as_slice_unchecked::<i8>().hash(state),
95 I16 => self.as_slice_unchecked::<i16>().hash(state),
96 I32 => self.as_slice_unchecked::<i32>().hash(state),
97 I64 => self.as_slice_unchecked::<i64>().hash(state),
98 U8 => self.as_slice_unchecked::<u8>().hash(state),
99 U16 => self.as_slice_unchecked::<u16>().hash(state),
100 U32 => self.as_slice_unchecked::<u32>().hash(state),
101 U64 => self.as_slice_unchecked::<u64>().hash(state),
102 F16 => self.as_slice_unchecked::<i16>().hash(state),
103 F32 => self.as_slice_unchecked::<i32>().hash(state),
104 F64 => self.as_slice_unchecked::<i64>().hash(state),
105 TDim => self.as_slice_unchecked::<crate::dim::TDim>().hash(state),
106 String => self.as_slice_unchecked::<std::string::String>().hash(state),
107 Blob => self.as_slice_unchecked::<crate::blob::Blob>().hash(state),
108 Opaque => self.as_slice_unchecked::<crate::opaque::Opaque>().hash(state),
109 QI8(_) => self.as_slice_unchecked::<i8>().hash(state),
110 QU8(_) => self.as_slice_unchecked::<u8>().hash(state),
111 QI32(_) => self.as_slice_unchecked::<i32>().hash(state),
112 #[cfg(feature = "complex")]
113 ComplexI16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
114 #[cfg(feature = "complex")]
115 ComplexI32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
116 #[cfg(feature = "complex")]
117 ComplexI64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
118 #[cfg(feature = "complex")]
119 ComplexF16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
120 #[cfg(feature = "complex")]
121 ComplexF32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
122 #[cfg(feature = "complex")]
123 ComplexF64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
124 }
125 }
126 }
127}
128
129impl Clone for Tensor {
130 fn clone(&self) -> Tensor {
131 self.deep_clone()
132 }
133}
134
135impl Default for Tensor {
136 fn default() -> Tensor {
137 litteral::tensor0(0f32)
138 }
139}
140
141impl Drop for Tensor {
142 fn drop(&mut self) {
143 macro_rules! drop_in_place {
144 ($t: ty) => {
145 if self.dt == <$t>::datum_type() {
146 unsafe {
147 let slice = self.as_slice_mut::<$t>().unwrap();
148 std::ptr::drop_in_place(slice as *mut [$t]);
149 }
150 }
151 };
152 }
153 drop_in_place!(Blob);
154 drop_in_place!(String);
155 drop_in_place!(TDim);
156 drop_in_place!(Opaque);
157 }
158}
159
160#[allow(unreachable_code)]
161pub fn vector_size() -> usize {
162 #[cfg(target_arch = "x86_64")]
163 {
164 return if is_x86_feature_detected!("avx512f") { 512 / 8 } else { 256 / 8 };
165 }
166 128 / 8
167}
168
169impl Tensor {
170 #[inline]
172 pub unsafe fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<Tensor> {
173 unsafe { Self::uninitialized_dt(T::datum_type(), shape) }
174 }
175
176 #[inline]
178 pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
179 unsafe { Self::uninitialized_aligned_dt(dt, shape, vector_size()) }
180 }
181
182 #[inline]
184 pub unsafe fn uninitialized_aligned<T: Datum>(
185 shape: &[usize],
186 alignment: usize,
187 ) -> TractResult<Tensor> {
188 unsafe { Self::uninitialized_aligned_dt(T::datum_type(), shape, alignment) }
189 }
190
191 pub unsafe fn uninitialized_aligned_dt(
193 dt: DatumType,
194 shape: &[usize],
195 alignment: usize,
196 ) -> TractResult<Tensor> {
197 let bytes = shape.iter().cloned().product::<usize>() * dt.size_of();
198 let data = unsafe { Blob::new_for_size_and_align(bytes, alignment) };
199 let mut tensor = Tensor { strides: tvec!(), dt, shape: shape.into(), data, len: 0 };
200 if tensor.shape.len() == 0 {
201 tensor.len = 1;
202 } else {
203 tensor.update_strides_and_len();
204 }
205 if !tensor.data.is_empty() {
206 if dt == String::datum_type() || dt == Blob::datum_type() {
207 tensor.data.fill(0);
209 } else if dt == TDim::datum_type() {
210 unsafe {
211 tensor
212 .as_slice_mut_unchecked::<TDim>()
213 .iter_mut()
214 .for_each(|dim| std::ptr::write(dim, TDim::zero()))
215 }
216 } else if dt == Opaque::datum_type() {
217 unsafe {
218 tensor.as_slice_mut_unchecked::<Opaque>().iter_mut().for_each(|p| {
219 std::ptr::write(p, Opaque::default());
220 })
221 };
222 } else if cfg!(debug_assertions) {
223 assert!(dt.is_copy());
224 if dt == DatumType::F32 {
225 tensor.fill_t(f32::NAN).unwrap();
226 } else {
227 tensor.as_bytes_mut().iter_mut().for_each(|x| *x = (-1i8) as u8);
229 }
230 }
231 }
232 Ok(tensor)
233 }
234
235 pub fn stack_tensors(
236 axis: usize,
237 tensors: &[impl std::borrow::Borrow<Tensor>],
238 ) -> TractResult<Tensor> {
239 ensure!(tensors.len() > 0);
240 let rank = tensors[0].borrow().rank();
241 ensure!(axis < rank);
242 ensure!(tensors.iter().all(|t| t.borrow().rank() == rank));
243 let dt = tensors[0].borrow().datum_type();
244 ensure!(tensors.iter().all(|t| t.borrow().datum_type() == dt));
245 let mut shape: TVec<usize> = tensors[0].borrow().shape().into();
246 for ax in 0..rank {
247 if ax != axis {
248 ensure!(tensors.iter().all(|t| t.borrow().shape()[ax] == shape[ax]));
249 }
250 }
251 shape[axis] = tensors.iter().map(|v| v.borrow().shape()[axis]).sum();
252 unsafe {
253 let mut result = Tensor::uninitialized_dt(dt, &shape)?;
254 if dt.is_copy() && shape[..axis].iter().all(|d| *d == 1) {
255 let mut offset = 0isize;
256 for v in tensors {
257 let v = v.borrow();
258 let len = v.data.len();
259 std::ptr::copy_nonoverlapping(
260 v.data.as_ptr(),
261 result.data.as_mut_ptr().offset(offset),
262 len,
263 );
264 offset += len as isize;
265 }
266 } else {
267 let mut offset = 0;
268 for t in tensors {
269 let t = t.borrow();
270 let len = t.shape()[axis];
271 result.assign_slice_from_resolved(offset..offset + len, t, 0..len, axis);
272 offset += len;
273 }
274 }
275
276 Ok(result)
277 }
278 }
279
280 pub fn clear<T: Datum + num_traits::Zero + Clone>(&mut self) -> TractResult<()> {
281 self.fill_t(T::zero())
282 }
283
284 pub fn zero<T: Datum + num_traits::Zero>(shape: &[usize]) -> TractResult<Tensor> {
285 unsafe {
286 let mut t = Tensor::uninitialized::<T>(shape)?;
287 t.clear::<T>()?;
288 Ok(t)
289 }
290 }
291
292 pub fn zero_scalar<T: Datum + num_traits::Zero>() -> TractResult<Tensor> {
293 Tensor::zero::<T>(&[])
294 }
295
296 pub fn zero_scalar_dt(dt: DatumType) -> TractResult<Tensor> {
297 Tensor::zero_dt(dt, &[])
298 }
299
300 pub fn zero_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
301 Tensor::zero_aligned_dt(dt, shape, vector_size())
302 }
303
304 pub fn fill_t<T: Datum + Clone>(&mut self, value: T) -> TractResult<()> {
305 self.as_slice_mut::<T>()?.iter_mut().for_each(|item| *item = value.clone());
306 Ok(())
307 }
308
309 pub fn zero_aligned_dt(
310 dt: DatumType,
311 shape: &[usize],
312 alignment: usize,
313 ) -> TractResult<Tensor> {
314 if shape.iter().product::<usize>() == 0 {
315 unsafe { return Tensor::uninitialized_dt(dt, shape) };
316 }
317 if dt.is_quantized() {
318 unsafe {
319 let mut t = Tensor::uninitialized_dt(dt, shape)?;
320 let zp = dt.zp_scale().0;
321 match dt.unquantized() {
322 DatumType::I8 => {
323 t.as_slice_mut::<i8>()?.iter_mut().for_each(|item| *item = zp as _)
324 }
325 DatumType::U8 => {
326 t.as_slice_mut::<u8>()?.iter_mut().for_each(|item| *item = zp as _)
327 }
328 DatumType::I32 => {
329 t.as_slice_mut::<i32>()?.iter_mut().for_each(|item| *item = zp as _)
330 }
331 _ => unreachable!(),
332 }
333 Ok(t)
334 }
335 } else {
336 dispatch_zerolike!(Self::zero_aligned(dt)(shape, alignment))
337 }
338 }
339
340 pub fn zero_aligned<T: Datum + num_traits::Zero>(
341 shape: &[usize],
342 alignment: usize,
343 ) -> TractResult<Tensor> {
344 unsafe {
345 let mut tensor = Self::uninitialized_aligned::<T>(shape, alignment)?;
346 tensor.clear::<T>()?;
347 Ok(tensor)
348 }
349 }
350
351 pub fn from_shape<T: Datum + Copy>(shape: &[usize], data: &[T]) -> TractResult<Tensor> {
354 Self::from_shape_align(shape, data, vector_size())
355 }
356
357 pub fn from_shape_align<T: Datum + Copy>(
360 shape: &[usize],
361 data: &[T],
362 align: usize,
363 ) -> TractResult<Tensor> {
364 ensure!(
365 data.len() == shape.iter().product::<usize>(),
366 "Shape product must be equal to data length"
367 );
368 unsafe {
369 let bytes = std::slice::from_raw_parts(
370 data.as_ptr() as *const u8,
371 data.len() * T::datum_type().size_of(),
372 );
373 let dt = T::datum_type();
374 Self::from_raw_dt_align(dt, shape, bytes, align)
375 }
376 }
377
378 pub unsafe fn from_raw<T: Datum>(shape: &[usize], content: &[u8]) -> TractResult<Tensor> {
382 unsafe { Tensor::from_raw_dt(T::datum_type(), shape, content) }
383 }
384
385 pub unsafe fn from_raw_aligned<T: Datum>(
386 shape: &[usize],
387 content: &[u8],
388 align: usize,
389 ) -> TractResult<Tensor> {
390 unsafe { Tensor::from_raw_dt_align(T::datum_type(), shape, content, align) }
391 }
392
393 pub unsafe fn from_raw_dt(
394 dt: DatumType,
395 shape: &[usize],
396 content: &[u8],
397 ) -> TractResult<Tensor> {
398 unsafe { Self::from_raw_dt_align(dt, shape, content, vector_size()) }
399 }
400
401 pub unsafe fn from_raw_dt_align(
402 dt: DatumType,
403 shape: &[usize],
404 content: &[u8],
405 align: usize,
406 ) -> TractResult<Tensor> {
407 let mut tensor = unsafe { Tensor::uninitialized_aligned_dt(dt, shape, align) }?;
408 tensor.as_bytes_mut().copy_from_slice(content);
409 Ok(tensor)
410 }
411
412 pub unsafe fn from_slice_align<T: Datum>(content: &[T], align: usize) -> TractResult<Tensor> {
413 let bytes = if content.len() == 0 {
414 &[]
415 } else {
416 unsafe {
417 std::slice::from_raw_parts(
418 content.as_ptr() as *const u8,
419 content.len() * T::datum_type().size_of(),
420 )
421 }
422 };
423 unsafe { Self::from_raw_dt_align(T::datum_type(), &[content.len()], bytes, align) }
424 }
425
426 #[inline]
428 pub fn rank(&self) -> usize {
429 self.shape.len()
430 }
431
432 #[inline]
434 pub fn shape(&self) -> &[usize] {
435 &self.shape
436 }
437
438 #[inline]
440 #[allow(clippy::len_without_is_empty)]
441 pub fn len(&self) -> usize {
442 self.len
443 }
444
445 #[inline]
447 #[allow(clippy::len_without_is_empty)]
448 pub fn volume(&self) -> usize {
449 self.len
450 }
451
452 #[inline]
454 pub fn strides(&self) -> &[isize] {
455 &self.strides
456 }
457
458 fn update_strides_and_len(&mut self) {
459 self.strides.clear();
460 if self.shape.len() == 0 {
461 self.len = 1;
462 return;
463 }
464 compute_natural_stride_to(&mut self.strides, &self.shape);
465 self.len = unsafe { *self.strides.get_unchecked(0) as usize * self.shape.get_unchecked(0) };
466 }
467
468 pub unsafe fn set_shape_unchecked(&mut self, shape: &[usize]) {
470 if shape != &*self.shape {
471 self.shape.clear();
472 self.shape.extend_from_slice(shape);
473 self.update_strides_and_len();
474 }
475 }
476
477 pub unsafe fn set_geometry_unchecked(&mut self, shape: &[usize], strides: &[isize]) {
479 self.shape.clear();
480 self.shape.extend_from_slice(shape);
481 self.strides.clear();
482 self.strides.extend_from_slice(strides);
483 }
484
485 pub fn set_shape(&mut self, shape: &[usize]) -> TractResult<()> {
487 if self.len() != shape.iter().product::<usize>() {
488 bail!("Invalid reshape {:?} to {:?}", self.shape, shape);
489 }
490 unsafe { self.set_shape_unchecked(shape) }
491 Ok(())
492 }
493
494 pub fn permute_axes(self, axes: &[usize]) -> TractResult<Tensor> {
495 ensure!(axes.iter().duplicates().next().is_none());
496 ensure!(axes.iter().all(|a| *a < self.rank()));
497 unsafe {
498 #[inline]
499 unsafe fn permute<T: Datum>(axes: &[usize], input: Tensor) -> Tensor {
500 unsafe { input.into_array_unchecked::<T>().permuted_axes(axes).into_tensor() }
501 }
502 let dt = self.datum_type();
503 let mut t = dispatch_datum_by_size!(permute(self.datum_type())(axes, self));
504 t.set_datum_type(dt);
505 Ok(t)
506 }
507 }
508
509 pub fn move_axis(self, from: usize, to: usize) -> TractResult<Tensor> {
510 let mut permutation: Vec<usize> = (0..self.rank()).collect();
511 permutation.remove(from);
512 permutation.insert(to, from);
513 self.permute_axes(&permutation)
514 }
515
516 pub fn collapse_axis_with_next(mut self, axis: usize) -> Tensor {
517 let removed = self.shape.remove(axis + 1);
518 self.shape[axis] *= removed;
519 self.update_strides_and_len();
520 self
521 }
522
523 pub fn split_axis(mut self, axis: usize, outer_dim: usize) -> TractResult<Tensor> {
524 if self.shape[axis] % outer_dim != 0 {
525 bail!(
526 "Invalid axis split, shape is {:?}, axis split at {}, outer {}",
527 self.shape,
528 axis,
529 outer_dim
530 );
531 }
532 self.shape.insert(axis + 1, self.shape[axis] / outer_dim);
533 self.shape[axis] = outer_dim;
534 self.update_strides_and_len();
535 Ok(self)
536 }
537
538 pub fn into_shape(mut self, shape: &[usize]) -> TractResult<Tensor> {
540 self.set_shape(shape)?;
541 Ok(self)
542 }
543
544 pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
545 self.shape.insert(axis, 1);
546 self.strides.insert(axis, self.strides.get(axis).copied().unwrap_or(1));
547 Ok(())
548 }
549
550 pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
551 ensure!(self.shape[axis] == 1, "Remove a non-1 axis: axis {} in {:?}", axis, self);
552 self.shape.remove(axis);
553 self.strides.remove(axis);
554 Ok(())
555 }
556
557 pub fn broadcast_into_rank(mut self, rank: usize) -> TractResult<Tensor> {
558 self.broadcast_to_rank(rank)?;
559 self.update_strides_and_len();
560 Ok(self)
561 }
562
563 pub fn broadcast_to_rank(&mut self, rank: usize) -> TractResult<()> {
564 if rank < self.rank() {
565 bail!("Can only broadcast to higher rank")
566 }
567 while self.shape.len() < rank {
568 self.shape.insert(0, 1)
569 }
570 self.update_strides_and_len();
571 Ok(())
572 }
573
574 pub fn broadcast_scalar_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
575 if self.rank() > 0 {
576 bail!("broadcast_scalar_to_shape called on {:?}, which is not a salar", self);
577 }
578 unsafe fn make<T: Datum>(src: &Tensor, dst: &mut Tensor) {
579 unsafe {
580 let value: &T = src.to_scalar_unchecked::<T>();
581 dst.as_slice_mut_unchecked::<T>().iter_mut().for_each(|item| *item = value.clone())
582 };
583 }
584 unsafe {
585 let mut t = Tensor::uninitialized_dt(self.datum_type(), shape)?;
586 dispatch_datum_by_size!(make(self.datum_type())(self, &mut t));
587 Ok(t)
588 }
589 }
590
591 fn broadcast_to_shape_t<T: Datum>(&self, shape: &[usize]) -> TractResult<Tensor> {
592 unsafe {
593 let view = self.to_array_view_unchecked::<T>();
594 let mut output = view
595 .broadcast(shape)
596 .with_context(|| format!("Broadcasting {view:?} to {shape:?}"))?
597 .into_owned()
598 .into_tensor();
599 output.set_datum_type(self.datum_type());
600 Ok(output)
601 }
602 }
603
604 pub fn broadcast_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
605 dispatch_datum!(Self::broadcast_to_shape_t(self.dt)(self, shape))
606 }
607
608 pub fn broadcast_vector_to_shape(&self, shape: &[usize], axis: usize) -> TractResult<Tensor> {
609 ensure!(self.rank() == 1);
610 ensure!(shape[axis] == self.len());
611 if !self.datum_type().is_copy() {
612 let mut vec_shape = vec![1; shape.len()];
613 vec_shape[axis] = self.len();
614 return self.clone().into_shape(&vec_shape)?.broadcast_to_shape(shape);
615 }
616 unsafe {
617 let mut output = Tensor::uninitialized_dt(self.datum_type(), shape)?;
618 if output.len() == 0 {
619 return Ok(output);
620 }
621 let inner_len = shape[axis + 1..].iter().product::<usize>();
622
623 unsafe fn splat<T>(input: &Tensor, output: &mut Tensor, inner_len: usize)
624 where
625 T: Datum + Copy,
626 {
627 unsafe {
628 for ix in 0..input.len() {
629 let value: T = input.as_slice_unchecked()[ix];
630 output.as_slice_mut_unchecked::<T>()[ix * inner_len..(ix + 1) * inner_len]
631 .iter_mut()
632 .for_each(|item| *item = value);
633 }
634 }
635 }
636 dispatch_copy_by_size!(splat(self.datum_type())(&self, &mut output, inner_len));
637
638 let outer_len = shape[0..axis].iter().product::<usize>();
639 let repeat_bytes_len = inner_len * self.as_bytes().len();
640 let bytes = output.as_bytes_mut();
641 for ix in 1..outer_len {
642 bytes.copy_within(0..repeat_bytes_len, ix * repeat_bytes_len);
643 }
644
645 Ok(output)
646 }
647 }
648
649 pub fn assign_slice(
650 &mut self,
651 range: impl std::ops::RangeBounds<usize>,
652 src: &Tensor,
653 src_range: impl std::ops::RangeBounds<usize>,
654 axis: usize,
655 ) -> TractResult<()> {
656 ensure!(self.rank() == src.rank());
657 ensure!(axis < self.rank());
658 let range = clip_range_bounds(self.shape[axis], range);
659 let src_range = clip_range_bounds(src.shape[axis], src_range);
660 ensure!(
661 src.datum_type() == self.datum_type(),
662 "Attempt to assign into {:?} from {:?}, datum type mismatch",
663 self.datum_type(),
664 src.datum_type()
665 );
666 ensure!(
667 src_range.len() == range.len(),
668 "Attempt to assign a range of {:?} from a range of {:?}",
669 range,
670 src_range,
671 );
672 ensure!(
673 itertools::izip!(0.., self.shape(), src.shape())
674 .all(|(ix, dst, src)| ix == axis || src == dst),
675 "Attempt to assign a {}-axis range of {:?} from a range of {:?}",
676 axis,
677 self,
678 src
679 );
680 ensure!(
681 src_range.end <= src.shape()[axis],
682 "Assigning from invalid slice (axis {}, {:?}) of {:?}",
683 axis,
684 src_range,
685 src
686 );
687 ensure!(
688 range.end <= self.shape()[axis],
689 "Assigning to invalid slice (axis {}, {:?}) of {:?}",
690 axis,
691 range,
692 self
693 );
694 unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
695 Ok(())
696 }
697
698 pub unsafe fn assign_slice_unchecked(
699 &mut self,
700 range: impl std::ops::RangeBounds<usize>,
701 src: &Tensor,
702 src_range: impl std::ops::RangeBounds<usize>,
703 axis: usize,
704 ) {
705 let range = clip_range_bounds(self.shape[axis], range);
706 let src_range = clip_range_bounds(src.shape[axis], src_range);
707 unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
708 }
709
710 #[allow(clippy::ptr_eq)]
711 unsafe fn assign_slice_from_resolved(
712 &mut self,
713 range: std::ops::Range<usize>,
714 src: &Tensor,
715 src_range: std::ops::Range<usize>,
716 axis: usize,
717 ) {
718 unsafe {
719 use ndarray::Slice;
720 unsafe fn assign_slice_t<T: Datum>(
721 to: &mut Tensor,
722 to_range: Range<usize>,
723 from: &Tensor,
724 from_range: Range<usize>,
725 axis: usize,
726 ) {
727 unsafe {
728 to.to_array_view_mut_unchecked::<T>()
729 .slice_axis_mut(Axis(axis), Slice::from(to_range))
730 .assign(
731 &from
732 .to_array_view_unchecked::<T>()
733 .slice_axis(Axis(axis), Slice::from(from_range)),
734 )
735 }
736 }
737 if self.datum_type().is_copy() && self.shape[..axis].iter().all(|d| *d == 1) {
738 let stride = self.strides[axis] as usize * self.datum_type().size_of();
739 let dst_start = (stride * range.start) as isize;
740 let src_start = (stride * src_range.start) as isize;
741 let len = stride * range.len();
742 if len > 0 {
743 if self.data.as_ptr() != src.data.as_ptr() {
744 std::ptr::copy_nonoverlapping(
745 src.data.as_ptr().offset(src_start),
746 self.data.as_mut_ptr().offset(dst_start),
747 len,
748 );
749 } else {
750 std::ptr::copy(
751 src.data.as_ptr().offset(src_start),
752 self.data.as_mut_ptr().offset(dst_start),
753 len,
754 );
755 }
756 }
757 } else {
758 dispatch_datum!(assign_slice_t(self.datum_type())(
759 self, range, src, src_range, axis
760 ));
761 }
762 }
763 }
764
765 #[inline]
767 pub fn datum_type(&self) -> DatumType {
768 self.dt
769 }
770
771 #[inline]
773 pub unsafe fn set_datum_type(&mut self, dt: DatumType) {
774 self.dt = dt
775 }
776
777 pub fn dump(&self, force_full: bool) -> TractResult<String> {
781 unsafe fn dump_t<D: Datum>(tensor: &Tensor, n: usize) -> String {
782 unsafe {
783 if let Some(qp) = tensor.datum_type().qparams() {
784 let integers = tensor.cast_to::<i32>().unwrap();
785 integers.as_slice_unchecked::<i32>()[0..n]
786 .iter()
787 .map(|x| format!("[{}]({})", x, qp.dq(*x)))
788 .join(", ")
789 } else {
790 tensor.as_slice_unchecked::<D>()[0..n].iter().join(", ")
791 }
792 }
793 }
794 unsafe {
795 let trunc = self.len() > 12 && !force_full;
796 let data = dispatch_datum!(dump_t(self.datum_type())(
797 self,
798 if trunc { 12 } else { self.len() }
799 ));
800 Ok(format!(
801 "{},{:?} {}{}",
802 self.shape.iter().join(","),
803 self.dt,
804 data,
805 if trunc { "..." } else { "" }
806 ))
807 }
808 }
809
810 pub fn close_enough(
812 &self,
813 other: &Self,
814 approx: impl Into<Approximation> + std::fmt::Debug,
815 ) -> TractResult<()> {
816 let approx = approx.into();
817 if self.shape() != other.shape() {
818 bail!("Shape mismatch {:?} != {:?}", self.shape(), other.shape())
819 }
820 let (atol, rtol, outliers) = approx.atol_rtol_outliers(&self.datum_type());
821 let ma = self.cast_to::<f32>()?;
822 let ma = ma.to_array_view::<f32>()?;
823 let mb = other.cast_to::<f32>()?;
824 let mb = mb.to_array_view::<f32>()?;
825 let mut first_outlier = None;
826 let mut outliers_count = 0;
827 ndarray::indices_of(&ma).into_iter().for_each(|indices| {
828 let a = ma[&indices];
829 let b = mb[&indices];
830 if !((a.is_nan() && b.is_nan())
831 || (a.is_infinite() && b.is_infinite() && a.signum() == b.signum())
832 || (a - b).abs() <= atol as f32 + rtol as f32 * b.abs())
833 {
834 if outliers_count == 0 {
835 first_outlier = Some(indices.as_array_view().to_vec());
836 }
837 outliers_count += 1;
838 }
839 });
840 if self.volume() > 0 && outliers_count as f64 / self.volume() as f64 > outliers {
841 let indices = first_outlier.unwrap();
842 let a = ma[&*indices];
843 let b = mb[&*indices];
844 bail!(
845 "Mismatch. First outlier: {:?} for {:?}) at {:?} {} != {}. Outliers: {} / {} = {:0.5} > {:0.5}.",
846 approx,
847 self.datum_type(),
848 indices,
849 a,
850 b,
851 outliers_count,
852 self.volume(),
853 outliers_count as f64 / self.volume() as f64,
854 outliers
855 );
856 }
857 Ok(())
858 }
859
860 pub fn into_array<D: Datum>(self) -> TractResult<ArrayD<D>> {
862 Ok(self.to_array_view::<D>()?.to_owned())
863 }
864
865 pub unsafe fn into_array_unchecked<D: Datum>(self) -> ArrayD<D> {
867 unsafe { self.to_array_view_unchecked::<D>().to_owned() }
868 }
869
870 fn check_for_access<D: Datum>(&self) -> TractResult<()> {
871 ensure!(
872 self.datum_type().unquantized() == D::datum_type().unquantized(),
873 "Tensor datum type error: tensor is {:?}, accessed as {:?}",
874 self.datum_type(),
875 D::datum_type(),
876 );
877 Ok(())
878 }
879
880 pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'_, D>> {
882 self.check_for_access::<D>()?;
883 unsafe { Ok(self.to_array_view_unchecked()) }
884 }
885
886 pub fn to_array_view_mut<D: Datum>(&mut self) -> TractResult<ArrayViewMutD<'_, D>> {
888 self.check_for_access::<D>()?;
889 unsafe { Ok(self.to_array_view_mut_unchecked()) }
890 }
891
892 pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'_, D> {
894 if self.len() != 0 {
895 unsafe { ArrayViewD::from_shape_ptr(&*self.shape, self.data.as_ptr() as *const D) }
896 } else {
897 ArrayViewD::from_shape(&*self.shape, &[]).unwrap()
898 }
899 }
900
901 pub unsafe fn to_array_view_mut_unchecked<D: Datum>(&mut self) -> ArrayViewMutD<'_, D> {
903 if self.len() != 0 {
904 unsafe { ArrayViewMutD::from_shape_ptr(&*self.shape, self.data.as_mut_ptr() as *mut D) }
905 } else {
906 ArrayViewMutD::from_shape(&*self.shape, &mut []).unwrap()
907 }
908 }
909
910 pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
912 self.check_for_access::<D>()?;
913 Ok(self.data.as_ptr() as *const D)
914 }
915
916 pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
918 self.data.as_ptr() as *const D
919 }
920
921 pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
923 self.data.as_mut_ptr() as *mut D
924 }
925
926 pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
928 self.as_ptr::<D>().map(|p| p as *mut D)
929 }
930
931 pub fn as_slice<D: Datum>(&self) -> TractResult<&[D]> {
933 let ptr: *const D = self.as_ptr()?;
934 if self.data.len() == 0 {
935 Ok(&[])
936 } else {
937 unsafe { Ok(std::slice::from_raw_parts::<D>(ptr, self.len())) }
938 }
939 }
940
941 pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
943 let ptr: *mut D = self.as_ptr_mut()?;
944 if self.data.len() == 0 {
945 Ok(&mut [])
946 } else {
947 unsafe { Ok(std::slice::from_raw_parts_mut::<D>(ptr, self.len())) }
948 }
949 }
950
951 pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &[D] {
953 if self.data.len() == 0 {
954 &[]
955 } else {
956 unsafe { std::slice::from_raw_parts::<D>(self.as_ptr_unchecked(), self.len()) }
957 }
958 }
959
960 pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
962 if self.data.len() == 0 {
963 &mut []
964 } else {
965 unsafe { std::slice::from_raw_parts_mut::<D>(self.as_ptr_mut_unchecked(), self.len()) }
966 }
967 }
968
969 pub fn to_scalar<D: Datum>(&self) -> TractResult<&D> {
971 self.check_for_access::<D>()?;
972 if self.len() == 0 {
973 bail!("to_scalar called on empty tensor ({:?})", self)
974 }
975 if self.len() > 1 {
976 bail!("to_scalar called on a tensor with multiple values ({:?})", self)
977 }
978 unsafe { Ok(self.to_scalar_unchecked()) }
979 }
980
981 pub fn to_scalar_tensor(&self) -> TractResult<Tensor> {
983 fn to_scalar_tensor_t<D: Datum>(t: &Tensor) -> TractResult<Tensor> {
984 Ok(litteral::tensor0(t.to_scalar::<D>()?.clone()))
985 }
986 dispatch_datum!(to_scalar_tensor_t(self.datum_type())(self))
987 }
988
989 pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &D {
991 unsafe { &*(self.data.as_ptr() as *const D) }
992 }
993
994 pub fn to_scalar_mut<D: Datum>(&mut self) -> TractResult<&mut D> {
996 self.check_for_access::<D>()?;
997 if self.len() == 0 {
998 bail!("to_scalar_mut called on empty tensor ({:?})", self)
999 }
1000 if self.len() > 1 {
1001 bail!("to_scalar called on a tensor with multiple values ({:?})", self)
1002 }
1003 unsafe { Ok(self.to_scalar_mut_unchecked()) }
1004 }
1005
1006 pub unsafe fn to_scalar_mut_unchecked<D: Datum>(&mut self) -> &mut D {
1008 unsafe { &mut *(self.data.as_mut_ptr() as *mut D) }
1009 }
1010
1011 pub fn as_bytes(&self) -> &[u8] {
1012 self.data.as_bytes()
1013 }
1014
1015 pub fn as_bytes_mut(&mut self) -> &mut [u8] {
1016 self.data.as_bytes_mut()
1017 }
1018
1019 unsafe fn is_uniform_t<T: Datum>(&self) -> bool {
1020 let slice = unsafe { self.as_slice_unchecked::<T>() };
1021 slice[1..].iter().all(|x| x == &slice[0])
1022 }
1023
1024 pub fn is_uniform(&self) -> bool {
1025 if self.len() <= 1 {
1026 return true;
1027 }
1028 unsafe { dispatch_datum!(Tensor::is_uniform_t(self.datum_type())(self)) }
1029 }
1030
1031 unsafe fn as_uniform_t<T: Datum>(&self) -> Tensor {
1032 let v: T = unsafe { self.as_slice_unchecked::<T>() }[0].clone();
1033 litteral::tensor0(v)
1034 }
1035
1036 pub fn as_uniform(&self) -> Option<Tensor> {
1037 if self.len() >= 1 && self.is_uniform() {
1038 unsafe {
1039 let mut t = dispatch_datum!(Tensor::as_uniform_t(self.datum_type())(self));
1040 t.set_datum_type(self.datum_type());
1041 Some(t)
1042 }
1043 } else {
1044 None
1045 }
1046 }
1047
1048 pub fn is_all_zero(&self) -> TractResult<bool> {
1049 Ok(self.len() == 0 || self.as_uniform().map(|t| t.is_zero().unwrap()).unwrap_or(false))
1050 }
1051
1052 pub fn is_zero(&self) -> TractResult<bool> {
1053 Ok(self == &Tensor::zero_scalar_dt(self.dt)?)
1054 }
1055
1056 unsafe fn natural_cast<
1057 Source: Datum + num_traits::AsPrimitive<Target>,
1058 Target: Datum + Copy,
1059 >(
1060 &self,
1061 other: &mut Tensor,
1062 ) {
1063 unsafe {
1064 self.as_slice_unchecked::<Source>()
1065 .iter()
1066 .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1067 .for_each(|(s, d)| *d = s.as_())
1068 };
1069 }
1070
1071 unsafe fn cast_number_to_bool<Source: Datum + num_traits::Zero>(&self, other: &mut Tensor) {
1072 unsafe {
1073 self.as_slice_unchecked::<Source>()
1074 .iter()
1075 .zip(other.as_slice_mut_unchecked::<bool>().iter_mut())
1076 .for_each(|(s, d)| *d = !s.is_zero());
1077 }
1078 }
1079
1080 unsafe fn cast_from_string<Target: Datum + core::str::FromStr>(
1081 &self,
1082 other: &mut Tensor,
1083 ) -> TractResult<()> {
1084 unsafe {
1085 for (s, d) in self
1086 .as_slice_unchecked::<String>()
1087 .iter()
1088 .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1089 {
1090 *d = s
1091 .parse()
1092 .map_err(|_| format_err!("Can not parse as {:?}", Target::datum_type()))?;
1093 }
1094 Ok(())
1095 }
1096 }
1097
1098 unsafe fn cast_to_string<Source: Datum>(&self, other: &mut Tensor) {
1099 unsafe {
1100 for (s, d) in self
1101 .as_slice_unchecked::<Source>()
1102 .iter()
1103 .zip(other.as_slice_mut_unchecked::<String>().iter_mut())
1104 {
1105 *d = s.to_string()
1106 }
1107 }
1108 }
1109
1110 pub fn cast_to<D: Datum>(&self) -> TractResult<Cow<'_, Tensor>> {
1112 self.cast_to_dt(D::datum_type())
1113 }
1114
1115 #[allow(clippy::redundant_closure_call)]
1117 pub fn cast_to_dt(&self, dst_dt: DatumType) -> TractResult<Cow<'_, Tensor>> {
1118 unsafe {
1119 if self.dt == dst_dt {
1120 return Ok(Cow::Borrowed(self));
1121 }
1122 if self.dt == TDim::datum_type() && (dst_dt.is_integer() || dst_dt.is_float()) {
1123 let slice = self.as_slice_unchecked::<TDim>();
1124 let mut ints = Self::uninitialized::<i64>(&self.shape)?;
1125 let ints_slice = ints.as_slice_mut_unchecked::<i64>();
1126 for i in 0..self.len() {
1127 ints_slice[i] = slice[i].to_i64()?;
1128 }
1129 return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1130 }
1131 if self.dt == bool::datum_type()
1132 && (dst_dt.is_integer() || dst_dt.is_float() || dst_dt == TDim::datum_type())
1133 {
1134 let slice = self.as_slice_unchecked::<bool>();
1135 let mut ints = Self::uninitialized::<i8>(&self.shape)?;
1136 let ints_slice = ints.as_slice_mut_unchecked::<i8>();
1137 for i in 0..self.len() {
1138 ints_slice[i] = slice[i] as usize as i8;
1139 }
1140 return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1141 }
1142 let mut result = Self::uninitialized_dt(dst_dt, &self.shape)?;
1143 if self.dt == DatumType::String {
1144 dispatch_numbers!(Self::cast_from_string(dst_dt)(self, &mut result))?;
1145 return Ok(Cow::Owned(result));
1146 }
1147 if dst_dt == DatumType::String {
1148 dispatch_datum!(Self::cast_to_string(self.dt)(self, &mut result));
1149 return Ok(Cow::Owned(result));
1150 }
1151 macro_rules! n {
1152 ($source:ty) => {
1153 if <$source>::datum_type() == self.datum_type() {
1154 match dst_dt {
1155 DatumType::I8 => self.natural_cast::<$source, i8>(&mut result),
1156 DatumType::I16 => self.natural_cast::<$source, i16>(&mut result),
1157 DatumType::I32 => self.natural_cast::<$source, i32>(&mut result),
1158 DatumType::I64 => self.natural_cast::<$source, i64>(&mut result),
1159 DatumType::U8 => self.natural_cast::<$source, u8>(&mut result),
1160 DatumType::U16 => self.natural_cast::<$source, u16>(&mut result),
1161 DatumType::U32 => self.natural_cast::<$source, u32>(&mut result),
1162 DatumType::U64 => self.natural_cast::<$source, u64>(&mut result),
1163 DatumType::F16 => self.natural_cast::<$source, f16>(&mut result),
1164 DatumType::F32 => self.natural_cast::<$source, f32>(&mut result),
1165 DatumType::F64 => self.natural_cast::<$source, f64>(&mut result),
1166 DatumType::TDim => {
1167 let ints = self.cast_to::<i32>()?;
1168 let slice = ints.as_slice_unchecked::<i32>();
1169 let result = result.as_slice_mut_unchecked::<TDim>();
1170 for i in 0..self.len() {
1171 result[i] = slice[i].into();
1172 }
1173 }
1174 DatumType::Bool => self.cast_number_to_bool::<$source>(&mut result),
1175 _ => todo!(),
1176 }
1177 return Ok(Cow::Owned(result));
1178 };
1179 };
1180 }
1181 if !dst_dt.is_quantized() && !self.datum_type().is_quantized() {
1183 n!(u8);
1184 n!(u16);
1185 n!(u32);
1186 n!(u64);
1187 n!(i8);
1188 n!(i16);
1189 n!(i32);
1190 n!(i64);
1191 n!(f16);
1192 n!(f32);
1193 n!(f64);
1194 } else {
1195 let (s_zp, s_scale) = self.datum_type().zp_scale();
1196 let (d_zp, d_scale) = dst_dt.zp_scale();
1197 if self.datum_type().is_quantized() && dst_dt.is_float() {
1198 macro_rules! q_to_fp {
1199 ($source:ty, $dest:ty) => {
1200 if <$source>::datum_type().unquantized()
1201 == self.datum_type().unquantized()
1202 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1203 {
1204 self.as_slice_unchecked::<$source>()
1205 .iter()
1206 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1207 .for_each(|(&s, d)| {
1208 *d = (s as $dest - s_zp as $dest) * s_scale as $dest;
1209 });
1210 return Ok(Cow::Owned(result));
1211 }
1212 };
1213 }
1214 q_to_fp!(i8, f64);
1215 q_to_fp!(i8, f32);
1216 q_to_fp!(u8, f64);
1217 q_to_fp!(u8, f32);
1218 }
1219 macro_rules! q8_to_q8 {
1221 ($typ:ty) => {
1222 if dst_dt.unquantized() == <$typ>::datum_type() {
1223 self.as_slice_unchecked::<$typ>()
1224 .iter()
1225 .zip(result.as_slice_mut_unchecked::<$typ>().iter_mut())
1226 .for_each(|(&s, d)| {
1227 *d = (d_zp as i32
1228 + scale_by(s as i32 - s_zp as i32, s_scale / d_scale))
1229 .clamp_cast()
1230 });
1231 return Ok(Cow::Owned(result));
1232 }
1233 };
1234 }
1235
1236 macro_rules! q_via_f32 {
1237 ($source:ty, $dest:ty, $round:expr) => {
1238 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1239 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1240 {
1241 self.as_slice_unchecked::<$source>()
1242 .iter()
1243 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1244 .for_each(|(&s, d)| {
1245 let s_float = (s as f32 - s_zp as f32) * s_scale as f32;
1246 let d_float = s_float as f32 / d_scale as f32 + d_zp as f32;
1247 *d = $round(d_float);
1248 });
1249 return Ok(Cow::Owned(result));
1250 }
1251 };
1252 }
1253
1254 macro_rules! q_n {
1255 (clamp $source:ty, $dest:ty) => {{
1256 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1257 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1258 {
1259 self.as_slice_unchecked::<$source>()
1260 .iter()
1261 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1262 .for_each(|(&s, d)| {
1263 *d = s.clamp_cast();
1264 });
1265 return Ok(Cow::Owned(result));
1266 }
1267 }};
1268 ($source:ty, $dest:ty) => {{
1269 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1270 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1271 {
1272 self.as_slice_unchecked::<$source>()
1273 .iter()
1274 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1275 .for_each(|(&s, d)| {
1276 *d = s as $dest;
1277 });
1278 return Ok(Cow::Owned(result));
1279 }
1280 }};
1281 }
1282
1283 if dst_dt.unquantized() == self.datum_type().unquantized()
1284 && dst_dt.is_quantized()
1285 && self.datum_type().is_quantized()
1286 {
1287 q8_to_q8!(i8);
1288 q8_to_q8!(u8);
1289 }
1290
1291 q_via_f32!(f32, i8, |f| round_ties_to_even(f).clamp_cast());
1292 q_via_f32!(f32, u8, |f| round_ties_to_even(f).clamp_cast());
1293 q_via_f32!(f32, i32, |f| round_ties_to_even(f).clamp_cast());
1294 q_via_f32!(i8, f32, |f| f);
1295 q_via_f32!(u8, f32, |f| f);
1296 q_via_f32!(i32, f32, |f| f);
1297
1298 if dst_dt.is_quantized() && self.datum_type().is_quantized() {
1299 q_via_f32!(u8, i8, |f| round_ties_to_even(f).clamp_cast());
1300 q_via_f32!(i8, u8, |f| round_ties_to_even(f).clamp_cast());
1301 q_via_f32!(i32, u8, |f| round_ties_to_even(f).clamp_cast());
1302 q_via_f32!(i32, i8, |f| round_ties_to_even(f).clamp_cast());
1303 q_via_f32!(u8, i32, |f| round_ties_to_even(f).clamp_cast());
1304 q_via_f32!(i8, i32, |f| round_ties_to_even(f).clamp_cast());
1305
1306 q_via_f32!(i8, i8, |f| round_ties_to_even(f).clamp_cast());
1308 q_via_f32!(u8, u8, |f| round_ties_to_even(f).clamp_cast());
1309 }
1310
1311 q_n!(i8, i32);
1312 q_n!(i8, u32);
1313 q_n!(u8, i32);
1314 q_n!(u8, u32);
1315 q_n!(clamp i32, i8);
1316 q_n!(clamp i32, u8);
1317 q_n!(clamp u32, i8);
1318 q_n!(clamp u32, u8);
1319 q_n!(i8, i8);
1320 q_n!(u8, u8);
1321 q_n!(i32, i32);
1322 q_n!(u32, u32);
1323 }
1324
1325 bail!("Unsupported cast from {:?} to {:?}", self.dt, dst_dt)
1326 }
1327 }
1328
1329 pub fn cast_to_scalar<D: Datum + Copy>(&self) -> TractResult<D> {
1331 let casted = self.cast_to::<D>()?;
1332 casted.to_scalar::<D>().copied()
1333 }
1334
1335 pub fn nth(&self, nth: usize) -> TractResult<Tensor> {
1337 if nth >= self.len() {
1338 bail!(
1339 "nth called with {}th element on a tensor of len {} ({:?}",
1340 nth,
1341 self.len(),
1342 self
1343 );
1344 }
1345 unsafe fn nth_t<T: Datum>(me: &Tensor, nth: usize, output: &mut Tensor) {
1346 unsafe {
1347 let value = me.as_slice_unchecked::<T>()[nth].clone();
1348 output.as_slice_mut_unchecked::<T>()[0] = value;
1349 }
1350 }
1351 unsafe {
1352 let mut output = Tensor::uninitialized_dt(self.datum_type(), &[])?;
1353 dispatch_datum_by_size!(nth_t(self.datum_type())(self, nth, &mut output));
1354 Ok(output)
1355 }
1356 }
1357
1358 fn eq_dt(&self, other: &Tensor) -> TractResult<bool> {
1360 unsafe fn eq_t<D: Datum>(me: &Tensor, other: &Tensor) -> TractResult<bool> {
1361 unsafe {
1362 if D::datum_type().is_float() {
1363 return dispatch_floatlike!(float_eq_t(D::datum_type())(me, other));
1364 }
1365 Ok(izip!(me.as_slice_unchecked::<D>(), other.as_slice_unchecked::<D>())
1366 .all(|(a, b)| a == b))
1367 }
1368 }
1369
1370 unsafe fn float_eq_t<D: Datum + Float>(me: &Tensor, other: &Tensor) -> TractResult<bool> {
1371 unsafe {
1372 Ok(izip!(me.as_slice_unchecked::<D>(), other.as_slice_unchecked::<D>())
1373 .all(|(a, b)| (a.is_nan() && b.is_nan()) || a == b))
1374 }
1375 }
1376
1377 unsafe {
1378 Ok(self.datum_type() == other.datum_type()
1379 && self.shape() == other.shape()
1380 && dispatch_datum!(eq_t(self.dt)(self, other))?)
1381 }
1382 }
1383
1384 fn from_datum<T: Datum>(mut it: ArrayD<T>) -> Tensor {
1385 unsafe {
1386 let mut t = Self::uninitialized::<T>(it.shape()).unwrap();
1387 if let Some(slice) = it.as_slice_mut() {
1388 if t.datum_type().is_copy() {
1389 std::ptr::copy_nonoverlapping(
1390 slice.as_ptr() as *const i8,
1391 t.as_ptr_mut_unchecked(),
1392 t.data.layout().size(),
1393 );
1394 } else {
1395 t.as_slice_mut_unchecked::<T>()
1396 .iter_mut()
1397 .zip(slice.iter_mut())
1398 .for_each(|(t, s)| *t = std::mem::take(s));
1399 }
1400 return t;
1401 }
1402 if it.strides().iter().all(|&s| s > 0) && it.as_slice_memory_order().is_some() {
1403 let mut len_and_strides: TVec<(usize, usize)> = tvec!();
1404 for (len, stride) in itertools::izip!(it.shape(), it.strides(), t.strides())
1405 .sorted_by_key(|(_, src, _)| *src)
1406 .map(|(l, _, dst)| (*l as isize, *dst))
1407 {
1408 if !len_and_strides.is_empty()
1409 && len_and_strides.last().unwrap().1 * len_and_strides.last().unwrap().0
1410 == stride as usize
1411 {
1412 len_and_strides.last_mut().unwrap().0 *= len as usize;
1413 } else {
1414 len_and_strides.push((len as usize, stride as usize));
1415 }
1416 }
1417 len_and_strides.reverse();
1418 crate::scatter::scatter_contig_data(
1419 it.as_ptr(),
1420 t.as_ptr_mut_unchecked(),
1421 &len_and_strides,
1422 );
1423 return t;
1424 }
1425 t.as_slice_mut_unchecked().iter_mut().zip(it).for_each(|(t, a)| *t = a);
1427 t
1428 }
1429 }
1430
1431 pub fn deep_clone(&self) -> Tensor {
1432 unsafe {
1433 let mut tensor = Tensor::uninitialized_dt(self.datum_type(), self.shape()).unwrap();
1434 if self.len() > 0 {
1435 if self.dt.is_copy() {
1436 self.data.as_ptr().copy_to_nonoverlapping(
1437 tensor.as_bytes_mut().as_mut_ptr(),
1438 self.data.layout().size(),
1439 )
1440 } else if self.dt == DatumType::String {
1441 tensor
1442 .as_slice_mut_unchecked::<String>()
1443 .clone_from_slice(self.as_slice_unchecked());
1444 } else if self.dt == DatumType::Blob {
1445 tensor
1446 .as_slice_mut_unchecked::<Blob>()
1447 .clone_from_slice(self.as_slice_unchecked());
1448 } else if self.dt == DatumType::Opaque {
1449 tensor
1450 .as_slice_mut_unchecked::<Opaque>()
1451 .clone_from_slice(self.as_slice_unchecked());
1452 } else if self.dt == DatumType::TDim {
1453 tensor
1454 .as_slice_mut_unchecked::<TDim>()
1455 .clone_from_slice(self.as_slice_unchecked());
1456 }
1457 }
1458 tensor
1459 }
1460 }
1461
1462 pub fn slice(&self, axis: usize, start: usize, end: usize) -> TractResult<Tensor> {
1463 if axis >= self.rank() {
1464 bail!("Can not slice at axis {} tensor {:?}", axis, self);
1465 }
1466 if start > self.shape[axis] || end > self.shape[axis] || start >= end {
1467 bail!("Invalid slicing range {start}..{end} on axis {axis} for {self:?}");
1468 }
1469 fn slice_t<T: Datum>(
1470 t: &Tensor,
1471 axis: usize,
1472 start: usize,
1473 end: usize,
1474 ) -> TractResult<Tensor> {
1475 Ok(t.to_array_view::<T>()?
1476 .slice_axis(ndarray::Axis(axis), (start..end).into())
1477 .into_owned()
1478 .into_tensor())
1479 }
1480 dispatch_datum!(slice_t(self.datum_type())(self, axis, start, end))
1481 }
1482
1483 #[inline]
1484 pub fn view(&self) -> view::TensorView<'_> {
1485 unsafe { view::TensorView::view(self) }
1486 }
1487
1488 #[inline]
1489 pub fn view_at_prefix(&self, prefix: &[usize]) -> TractResult<view::TensorView<'_>> {
1490 view::TensorView::at_prefix(self, prefix)
1491 }
1492
1493 #[inline]
1494 pub fn view_offsetting(&self, coords: &[usize]) -> TractResult<view::TensorView<'_>> {
1495 view::TensorView::offsetting(self, coords)
1496 }
1497
1498 #[inline]
1499 pub unsafe fn view_offsetting_unchecked(&self, coords: &[usize]) -> view::TensorView<'_> {
1500 unsafe { view::TensorView::offsetting_unchecked(self, coords) }
1501 }
1502
1503 #[inline]
1504 pub fn view_mut(&mut self) -> view::TensorView<'_> {
1505 unsafe { view::TensorView::view(self) }
1506 }
1507
1508 #[inline]
1509 pub fn view_at_prefix_mut(&mut self, prefix: &[usize]) -> TractResult<view::TensorView<'_>> {
1510 view::TensorView::at_prefix(self, prefix)
1511 }
1512
1513 #[inline]
1514 pub fn view_offsetting_mut(&mut self, coords: &[usize]) -> TractResult<view::TensorView<'_>> {
1515 view::TensorView::offsetting(self, coords)
1516 }
1517
1518 pub fn offset_u8_as_i8(self: &Arc<Self>) -> Arc<Self> {
1520 let mut t = if let DatumType::U8 = self.dt.unquantized() {
1521 self.to_array_view::<u8>().unwrap().mapv(|v| v.wrapping_sub(128) as i8).into_tensor()
1522 } else {
1523 return self.clone();
1524 };
1525
1526 if let DatumType::QU8(qp) = self.dt {
1527 if let QParams::ZpScale { zero_point, scale } = qp {
1528 t.dt = DatumType::QI8(QParams::ZpScale { zero_point: zero_point - 128, scale });
1529 } else {
1530 t.dt = DatumType::QI8(qp);
1531 }
1532 }
1533
1534 t.into_arc_tensor()
1535 }
1536
1537 pub fn offset_i8_as_u8(self: &Arc<Self>) -> Arc<Self> {
1539 let mut t = if let DatumType::I8 = self.dt.unquantized() {
1540 self.to_array_view::<i8>().unwrap().mapv(|v| (v as u8).wrapping_add(128)).into_tensor()
1541 } else {
1542 return self.clone();
1543 };
1544
1545 if let DatumType::QI8(qp) = self.dt {
1546 if let QParams::ZpScale { zero_point, scale } = qp {
1547 t.dt = DatumType::QU8(QParams::ZpScale { zero_point: zero_point + 128, scale });
1548 } else {
1549 t.dt = DatumType::QU8(qp);
1550 }
1551 }
1552 t.into_arc_tensor()
1553 }
1554
1555 pub fn to_aligned_default(&self) -> TractResult<Self> {
1556 if self.dt.is_copy() {
1557 unsafe {
1558 let mut t = Self::uninitialized_dt(self.dt, &self.shape)?;
1559 t.as_bytes_mut().copy_from_slice(self.as_bytes());
1560 Ok(t)
1561 }
1562 } else {
1563 let mut t = Self::zero_dt(self.dt, &self.shape)?;
1564 if self.dt == String::datum_type() {
1565 t.as_slice_mut::<String>()?.clone_from_slice(self.as_slice()?);
1566 } else if self.dt == Blob::datum_type() {
1567 t.as_slice_mut::<Blob>()?.clone_from_slice(self.as_slice()?);
1568 } else if self.dt == TDim::datum_type() {
1569 t.as_slice_mut::<TDim>()?.clone_from_slice(self.as_slice()?);
1570 }
1571 Ok(t)
1572 }
1573 }
1574
1575 pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1576 let mut strides = tvec!();
1577 compute_natural_stride_to(&mut strides, shape);
1578 strides
1579 }
1580
1581 pub fn into_blob(mut self) -> TractResult<Blob> {
1582 ensure!(self.dt.is_copy());
1583 Ok(std::mem::take(&mut self.data))
1584 }
1585}
1586
1587impl PartialEq for Tensor {
1588 fn eq(&self, other: &Tensor) -> bool {
1589 if self.dt != other.dt || self.shape != other.shape {
1590 return false;
1591 }
1592 self.eq_dt(other).unwrap_or(false)
1593 }
1594}
1595
1596impl fmt::Debug for Tensor {
1597 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1598 let content = self.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
1599 write!(formatter, "{content}")
1600 }
1601}
1602
1603#[cfg(feature = "complex")]
1604pub fn reinterpret_inner_dim_as_complex(mut t: Tensor) -> TractResult<Tensor> {
1605 ensure!(
1606 t.shape().last() == Some(&2),
1607 "The last dimension in the tensor shape {:?} must be 2",
1608 t.shape()
1609 );
1610 unsafe {
1611 t.shape.pop();
1612 t.set_datum_type(t.datum_type().complexify()?);
1613 t.update_strides_and_len();
1614 Ok(t)
1615 }
1616}
1617
1618#[cfg(feature = "complex")]
1619pub fn reinterpret_complex_as_inner_dim(mut t: Tensor) -> TractResult<Tensor> {
1620 unsafe {
1621 t.shape.push(2);
1622 t.set_datum_type(t.datum_type().decomplexify()?);
1623 t.update_strides_and_len();
1624 Ok(t)
1625 }
1626}
1627
1628pub fn clip_range_bounds(len: usize, range: impl std::ops::RangeBounds<usize>) -> Range<usize> {
1629 use std::ops::Bound;
1630 let start = match range.start_bound() {
1631 Bound::Included(ix) => *ix,
1632 Bound::Excluded(ix) => ix + 1,
1633 Bound::Unbounded => 0,
1634 };
1635 let end = match range.end_bound() {
1636 Bound::Included(ix) => *ix + 1,
1637 Bound::Excluded(ix) => *ix,
1638 Bound::Unbounded => len,
1639 };
1640 start..end
1641}
1642
1643pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1644 let mut strides = tvec!();
1645 compute_natural_stride_to(&mut strides, shape);
1646 strides
1647}
1648
1649fn compute_natural_stride_to(strides: &mut TVec<isize>, shape: &[usize]) {
1650 match shape.len() {
1651 0 => (),
1652 1 => strides.push(1),
1653 2 => strides.extend_from_slice(&[shape[1] as isize, 1]),
1654 3 => strides.extend_from_slice(&[(shape[1] * shape[2]) as isize, shape[2] as _, 1]),
1655 4 => strides.extend_from_slice(&[
1656 (shape[1] * shape[2] * shape[3]) as isize,
1657 (shape[2] * shape[3]) as _,
1658 shape[3] as _,
1659 1,
1660 ]),
1661 _ => {
1662 strides.push(1);
1663 for dim in shape.as_ref().iter().skip(1).rev() {
1664 let previous = *strides.last().unwrap();
1665 strides.push(previous * *dim as isize)
1666 }
1667 strides.reverse();
1668 }
1669 }
1670}
1671
1672impl<D: ::ndarray::Dimension, T: Datum> From<Array<T, D>> for Tensor {
1673 fn from(it: Array<T, D>) -> Tensor {
1674 Tensor::from_datum(it.into_dyn())
1675 }
1676}
1677
1678pub trait IntoTensor: Sized {
1680 fn into_tensor(self) -> Tensor;
1684}
1685
1686pub trait IntoArcTensor: Sized {
1688 fn into_arc_tensor(self) -> Arc<Tensor>;
1692}
1693
1694impl<D: ::ndarray::Dimension, T: Datum> IntoTensor for Array<T, D> {
1695 fn into_tensor(self) -> Tensor {
1696 Tensor::from(self)
1697 }
1698}
1699
1700impl<D: ::ndarray::Dimension, T: Datum> IntoArcTensor for Array<T, D> {
1701 fn into_arc_tensor(self) -> Arc<Tensor> {
1702 Arc::new(Tensor::from(self))
1703 }
1704}
1705
1706impl IntoTensor for Tensor {
1707 fn into_tensor(self) -> Tensor {
1708 self
1709 }
1710}
1711
1712impl IntoTensor for Arc<Tensor> {
1713 fn into_tensor(self) -> Tensor {
1714 Arc::try_unwrap(self).unwrap_or_else(|t| (*t).clone())
1715 }
1716}
1717
1718impl IntoArcTensor for Tensor {
1719 fn into_arc_tensor(self) -> Arc<Tensor> {
1720 Arc::new(self)
1721 }
1722}
1723
1724impl IntoArcTensor for Arc<Tensor> {
1725 fn into_arc_tensor(self) -> Arc<Tensor> {
1726 self
1727 }
1728}
1729
1730#[cfg(test)]
1731mod tests {
1732 use crate::dim::SymbolScope;
1733 use crate::prelude::tensor1;
1734
1735 use super::*;
1736 use litteral::tensor0;
1737 use proptest::collection::vec;
1738 use proptest::prelude::*;
1739
1740 #[derive(Debug)]
1741 struct PermuteAxisProblem {
1742 shape: Vec<usize>,
1743 permutation: Vec<usize>,
1744 }
1745
1746 impl Arbitrary for PermuteAxisProblem {
1747 type Strategy = BoxedStrategy<PermuteAxisProblem>;
1748 type Parameters = ();
1749
1750 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1751 (0..8usize)
1752 .prop_flat_map(|rank| {
1753 let permute: Vec<usize> = (0..rank).collect();
1754 (proptest::collection::vec(1..5usize, rank), Just(permute).prop_shuffle())
1755 })
1756 .prop_map(|(shape, permutation)| PermuteAxisProblem { shape, permutation })
1757 .boxed()
1758 }
1759 }
1760
1761 impl PermuteAxisProblem {
1762 fn input(&self) -> ArrayD<i32> {
1763 let mut i = 0;
1764 ArrayD::from_shape_simple_fn(&*self.shape, || {
1765 i += 1;
1766 i
1767 })
1768 .permuted_axes(&*self.permutation)
1769 }
1770
1771 fn reference(&self) -> Tensor {
1772 let values: Vec<i32> = self.input().iter().copied().collect();
1773 let shape = self.permutation.iter().map(|ix| self.shape[*ix]).collect::<TVec<usize>>();
1774 super::litteral::tensor1(&values).into_shape(&shape).unwrap()
1775 }
1776
1777 fn tract(&self) -> Tensor {
1778 Tensor::from(self.input())
1779 }
1780
1781 fn check(&self) -> proptest::test_runner::TestCaseResult {
1782 prop_assert_eq!(self.tract(), self.reference());
1783 Ok(())
1784 }
1785 }
1786
1787 proptest::proptest! {
1788 #[test]
1789 fn prop(pb: PermuteAxisProblem) {
1790 pb.check().unwrap();
1791 }
1792 }
1793
1794 #[test]
1795 fn t_1_2() {
1796 PermuteAxisProblem { shape: vec![2, 1], permutation: vec![1, 0] }.check().unwrap();
1797 }
1798
1799 #[test]
1800 fn t_2_2() {
1801 PermuteAxisProblem { shape: vec![2, 2], permutation: vec![1, 0] }.check().unwrap();
1802 }
1803
1804 #[derive(Debug)]
1805 struct BroadcastVecToShape {
1806 vec: Vec<f32>,
1807 axis: usize,
1808 shape: TVec<usize>,
1809 }
1810
1811 impl BroadcastVecToShape {
1812 fn check(&self) -> proptest::test_runner::TestCaseResult {
1813 let input = tensor1(&self.vec);
1814 let mut intermediate = tvec![1usize; self.shape.len()];
1815 intermediate[self.axis] = self.vec.len();
1816 let reference = input
1817 .clone()
1818 .into_shape(&intermediate)
1819 .unwrap()
1820 .broadcast_to_shape(&self.shape)
1821 .unwrap();
1822 prop_assert_eq!(
1823 reference,
1824 input.broadcast_vector_to_shape(&self.shape, self.axis).unwrap()
1825 );
1826 Ok(())
1827 }
1828 }
1829
1830 impl Arbitrary for BroadcastVecToShape {
1831 type Strategy = BoxedStrategy<BroadcastVecToShape>;
1832 type Parameters = ();
1833
1834 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1835 vec(0usize..5, 0usize..4)
1836 .prop_flat_map(|shape| {
1837 (vec(-10f32..10f32, 0usize..5), Just(shape.clone()), 0..shape.len() + 1)
1838 })
1839 .prop_map(|(vec, mut shape, axis)| {
1840 shape.insert(axis, vec.len());
1841 BroadcastVecToShape { vec, shape: shape.into(), axis }
1842 })
1843 .boxed()
1844 }
1845 }
1846
1847 proptest::proptest! {
1848 #[test]
1849 fn broadcast_vector_to_shape_prop(pb: BroadcastVecToShape) {
1850 pb.check().unwrap()
1851 }
1852 }
1853
1854 #[test]
1855 #[cfg(feature = "complex")]
1856 fn test_reinterpret_inner_dim_as_complex() -> TractResult<()> {
1857 let input = crate::internal::tensor2(&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]]);
1858 let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1859 let expected = crate::internal::tensor1(&[
1860 Complex::new(1.0f32, 2.0),
1861 Complex::new(3.0, 4.0),
1862 Complex::new(5.0, 6.0),
1863 ]);
1864 assert_eq!(expected, cplx_input);
1865 Ok(())
1866 }
1867
1868 #[test]
1869 #[cfg(feature = "complex")]
1870 fn test_reinterpret_inner_dim_as_complex_2() -> TractResult<()> {
1871 let input =
1872 crate::internal::tensor3(&[[[1i32, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]);
1873 let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1874 let expected = crate::internal::tensor2(&[
1875 [Complex::new(1i32, 2), Complex::new(1, 2)],
1876 [Complex::new(3, 4), Complex::new(3, 4)],
1877 [Complex::new(5, 6), Complex::new(5, 6)],
1878 ]);
1879 assert_eq!(expected, cplx_input);
1880 Ok(())
1881 }
1882
1883 #[test]
1884 fn clone_tdim_tensor() {
1885 let symbols = SymbolScope::default();
1886 let a = symbols.sym("a");
1887 let t = tensor0(TDim::from(a));
1888 let _ = t.clone();
1889 }
1890}