1use crate::blob::Blob;
3use crate::datum::{round_ties_to_even, scale_by, ClampCast, Datum, DatumType, QParams};
4use crate::dim::TDim;
5use crate::internal::*;
6use crate::opaque::Opaque;
7use crate::TVec;
8use half::f16;
9use itertools::{izip, Itertools};
10use ndarray::prelude::*;
11#[cfg(feature = "complex")]
12use num_complex::Complex;
13use num_traits::{Float, Zero};
14use std::borrow::Cow;
15use std::fmt;
16use std::hash::Hash;
17use std::ops::Range;
18use std::sync::Arc;
19
20pub mod litteral;
21pub mod view;
22
23#[derive(Copy, Clone, Default, Debug)]
24pub enum Approximation {
25 Exact,
26 #[default]
27 Close,
28 Approximate,
29 VeryApproximate,
30 SuperApproximate,
31 UltraApproximate,
32 Custom(f32, f32, f32),
33}
34
35impl PartialEq for Approximation {
36 fn eq(&self, other: &Self) -> bool {
37 use Approximation::Custom;
38 if let (Custom(aa, ar, ao), Custom(ba, br, bo)) = (self, other) {
39 aa == ba && ar == br && bo == ao
40 } else {
41 std::mem::discriminant(self) == std::mem::discriminant(other)
42 }
43 }
44}
45
46impl Eq for Approximation {}
47
48impl From<bool> for Approximation {
49 fn from(b: bool) -> Self {
50 if b {
51 Self::Approximate
52 } else {
53 Self::Exact
54 }
55 }
56}
57
58impl Approximation {
59 fn atol_rtol_outliers(&self, dt: &DatumType) -> (f64, f64, f64) {
60 use Approximation::*;
61 match (self, dt) {
62 (Exact, _) => (0.0, 0.0, 0.0),
63 (Close, DatumType::F16) => (1e-3, 1e-3, 0.0),
64 (Approximate, DatumType::F16) => (1e-3, 5e-3, 0.0),
65 (Approximate, qp) if qp.is_quantized() => (qp.zp_scale().1 as f64, 0., 0.0),
66 (Close, _) => (1e-7, 1e-7, 0.0),
67 (Approximate, _) => (1e-4, 5e-4, 0.0),
68 (VeryApproximate, _) => (5e-2, 1e-2, 0.0),
69 (SuperApproximate, _) => (0.1, 0.05, 0.0001),
70 (UltraApproximate, _) => (0.2, 0.1, 0.0005),
71 (Custom(atol, rtol, out), _) => (*atol as _, *rtol as _, *out as _),
72 }
73 }
74}
75
76#[derive(Eq)]
78pub struct Tensor {
79 dt: DatumType,
80 shape: TVec<usize>,
81 strides: TVec<isize>,
82 len: usize,
83 data: Blob,
84}
85
86unsafe impl Send for Tensor {}
87unsafe impl Sync for Tensor {}
88
89impl Hash for Tensor {
90 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
91 use DatumType::*;
92 self.dt.hash(state);
93 self.shape.hash(state);
94 self.data.layout().align().hash(state);
95 unsafe {
96 match self.dt {
97 Bool => self.as_slice_unchecked::<bool>().hash(state),
98 I8 => self.as_slice_unchecked::<i8>().hash(state),
99 I16 => self.as_slice_unchecked::<i16>().hash(state),
100 I32 => self.as_slice_unchecked::<i32>().hash(state),
101 I64 => self.as_slice_unchecked::<i64>().hash(state),
102 U8 => self.as_slice_unchecked::<u8>().hash(state),
103 U16 => self.as_slice_unchecked::<u16>().hash(state),
104 U32 => self.as_slice_unchecked::<u32>().hash(state),
105 U64 => self.as_slice_unchecked::<u64>().hash(state),
106 F16 => self.as_slice_unchecked::<i16>().hash(state),
107 F32 => self.as_slice_unchecked::<i32>().hash(state),
108 F64 => self.as_slice_unchecked::<i64>().hash(state),
109 TDim => self.as_slice_unchecked::<crate::dim::TDim>().hash(state),
110 String => self.as_slice_unchecked::<std::string::String>().hash(state),
111 Blob => self.as_slice_unchecked::<crate::blob::Blob>().hash(state),
112 Opaque => self.as_slice_unchecked::<crate::opaque::Opaque>().hash(state),
113 QI8(_) => self.as_slice_unchecked::<i8>().hash(state),
114 QU8(_) => self.as_slice_unchecked::<u8>().hash(state),
115 QI32(_) => self.as_slice_unchecked::<i32>().hash(state),
116 #[cfg(feature = "complex")]
117 ComplexI16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
118 #[cfg(feature = "complex")]
119 ComplexI32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
120 #[cfg(feature = "complex")]
121 ComplexI64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
122 #[cfg(feature = "complex")]
123 ComplexF16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
124 #[cfg(feature = "complex")]
125 ComplexF32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
126 #[cfg(feature = "complex")]
127 ComplexF64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
128 }
129 }
130 }
131}
132
133impl Clone for Tensor {
134 fn clone(&self) -> Tensor {
135 self.deep_clone()
136 }
137}
138
139impl Default for Tensor {
140 fn default() -> Tensor {
141 litteral::tensor0(0f32)
142 }
143}
144
145impl Drop for Tensor {
146 fn drop(&mut self) {
147 macro_rules! drop_in_place {
148 ($t: ty) => {
149 if self.dt == <$t>::datum_type() {
150 unsafe {
151 self.as_slice_mut::<$t>()
152 .unwrap()
153 .iter_mut()
154 .for_each(|s| std::ptr::drop_in_place(s as *mut $t));
155 }
156 }
157 };
158 }
159 drop_in_place!(Blob);
160 drop_in_place!(String);
161 drop_in_place!(TDim);
162 drop_in_place!(Opaque);
163 }
164}
165
166#[allow(unreachable_code)]
167pub fn vector_size() -> usize {
168 #[cfg(target_arch = "x86_64")]
169 {
170 return if is_x86_feature_detected!("avx512f") { 512 / 8 } else { 256 / 8 };
171 }
172 128 / 8
173}
174
175impl Tensor {
176 #[inline]
178 pub unsafe fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<Tensor> {
179 unsafe { Self::uninitialized_dt(T::datum_type(), shape) }
180 }
181
182 #[inline]
184 pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
185 unsafe { Self::uninitialized_aligned_dt(dt, shape, vector_size()) }
186 }
187
188 #[inline]
190 pub unsafe fn uninitialized_aligned<T: Datum>(
191 shape: &[usize],
192 alignment: usize,
193 ) -> TractResult<Tensor> {
194 unsafe { Self::uninitialized_aligned_dt(T::datum_type(), shape, alignment) }
195 }
196
197 pub unsafe fn uninitialized_aligned_dt(
199 dt: DatumType,
200 shape: &[usize],
201 alignment: usize,
202 ) -> TractResult<Tensor> {
203 let bytes = shape.iter().cloned().product::<usize>() * dt.size_of();
204 let data = unsafe { Blob::new_for_size_and_align(bytes, alignment) };
205 let mut tensor = Tensor { strides: tvec!(), dt, shape: shape.into(), data, len: 0 };
206 if tensor.shape.len() == 0 {
207 tensor.len = 1;
208 } else {
209 tensor.update_strides_and_len();
210 }
211 if !tensor.data.is_empty() {
212 if dt == String::datum_type() || dt == Blob::datum_type() {
213 tensor.data.fill(0);
215 } else if dt == TDim::datum_type() {
216 unsafe {
217 tensor
218 .as_slice_mut_unchecked::<TDim>()
219 .iter_mut()
220 .for_each(|dim| std::ptr::write(dim, TDim::zero()))
221 }
222 } else if dt == Opaque::datum_type() {
223 unsafe {
224 tensor.as_slice_mut_unchecked::<Opaque>().iter_mut().for_each(|p| {
225 std::ptr::write(p, Opaque::default());
226 })
227 };
228 } else if cfg!(debug_assertions) {
229 assert!(dt.is_copy());
230 if dt == DatumType::F32 {
231 tensor.fill_t(f32::NAN).unwrap();
232 } else {
233 tensor.as_bytes_mut().iter_mut().for_each(|x| *x = (-1i8) as u8);
235 }
236 }
237 }
238 Ok(tensor)
239 }
240
241 pub fn stack_tensors(
242 axis: usize,
243 tensors: &[impl std::borrow::Borrow<Tensor>],
244 ) -> TractResult<Tensor> {
245 ensure!(tensors.len() > 0);
246 let rank = tensors[0].borrow().rank();
247 ensure!(axis < rank);
248 ensure!(tensors.iter().all(|t| t.borrow().rank() == rank));
249 let dt = tensors[0].borrow().datum_type();
250 ensure!(tensors.iter().all(|t| t.borrow().datum_type() == dt));
251 let mut shape: TVec<usize> = tensors[0].borrow().shape().into();
252 for ax in 0..rank {
253 if ax != axis {
254 ensure!(tensors.iter().all(|t| t.borrow().shape()[ax] == shape[ax]));
255 }
256 }
257 shape[axis] = tensors.iter().map(|v| v.borrow().shape()[axis]).sum();
258 unsafe {
259 let mut result = Tensor::uninitialized_dt(dt, &shape)?;
260 if dt.is_copy() && shape[..axis].iter().all(|d| *d == 1) {
261 let mut offset = 0isize;
262 for v in tensors {
263 let v = v.borrow();
264 let len = v.data.len();
265 std::ptr::copy_nonoverlapping(
266 v.data.as_ptr(),
267 result.data.as_mut_ptr().offset(offset),
268 len,
269 );
270 offset += len as isize;
271 }
272 } else {
273 let mut offset = 0;
274 for t in tensors {
275 let t = t.borrow();
276 let len = t.shape()[axis];
277 result.assign_slice_from_resolved(offset..offset + len, t, 0..len, axis);
278 offset += len;
279 }
280 }
281
282 Ok(result)
283 }
284 }
285
286 pub fn clear<T: Datum + num_traits::Zero + Clone>(&mut self) -> TractResult<()> {
287 self.fill_t(T::zero())
288 }
289
290 pub fn zero<T: Datum + num_traits::Zero>(shape: &[usize]) -> TractResult<Tensor> {
291 unsafe {
292 let mut t = Tensor::uninitialized::<T>(shape)?;
293 t.clear::<T>()?;
294 Ok(t)
295 }
296 }
297
298 pub fn zero_scalar<T: Datum + num_traits::Zero>() -> TractResult<Tensor> {
299 Tensor::zero::<T>(&[])
300 }
301
302 pub fn zero_scalar_dt(dt: DatumType) -> TractResult<Tensor> {
303 Tensor::zero_dt(dt, &[])
304 }
305
306 pub fn zero_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
307 Tensor::zero_aligned_dt(dt, shape, vector_size())
308 }
309
310 pub fn fill_t<T: Datum + Clone>(&mut self, value: T) -> TractResult<()> {
311 self.as_slice_mut::<T>()?.iter_mut().for_each(|item| *item = value.clone());
312 Ok(())
313 }
314
315 pub fn zero_aligned_dt(
316 dt: DatumType,
317 shape: &[usize],
318 alignment: usize,
319 ) -> TractResult<Tensor> {
320 if shape.iter().product::<usize>() == 0 {
321 unsafe { return Tensor::uninitialized_dt(dt, shape) };
322 }
323 if dt.is_quantized() {
324 unsafe {
325 let mut t = Tensor::uninitialized_dt(dt, shape)?;
326 let zp = dt.zp_scale().0;
327 match dt.unquantized() {
328 DatumType::I8 => {
329 t.as_slice_mut::<i8>()?.iter_mut().for_each(|item| *item = zp as _)
330 }
331 DatumType::U8 => {
332 t.as_slice_mut::<u8>()?.iter_mut().for_each(|item| *item = zp as _)
333 }
334 DatumType::I32 => {
335 t.as_slice_mut::<i32>()?.iter_mut().for_each(|item| *item = zp as _)
336 }
337 _ => unreachable!(),
338 }
339 Ok(t)
340 }
341 } else {
342 dispatch_zerolike!(Self::zero_aligned(dt)(shape, alignment))
343 }
344 }
345
346 pub fn zero_aligned<T: Datum + num_traits::Zero>(
347 shape: &[usize],
348 alignment: usize,
349 ) -> TractResult<Tensor> {
350 unsafe {
351 let mut tensor = Self::uninitialized_aligned::<T>(shape, alignment)?;
352 tensor.clear::<T>()?;
353 Ok(tensor)
354 }
355 }
356
357 pub fn from_shape<T: Datum + Copy>(shape: &[usize], data: &[T]) -> TractResult<Tensor> {
360 Self::from_shape_align(shape, data, vector_size())
361 }
362
363 pub fn from_shape_align<T: Datum + Copy>(
366 shape: &[usize],
367 data: &[T],
368 align: usize,
369 ) -> TractResult<Tensor> {
370 ensure!(
371 data.len() == shape.iter().product::<usize>(),
372 "Shape product must be equal to data length"
373 );
374 unsafe {
375 let bytes = std::slice::from_raw_parts(
376 data.as_ptr() as *const u8,
377 data.len() * T::datum_type().size_of(),
378 );
379 let dt = T::datum_type();
380 Self::from_raw_dt_align(dt, shape, bytes, align)
381 }
382 }
383
384 pub unsafe fn from_raw<T: Datum>(shape: &[usize], content: &[u8]) -> TractResult<Tensor> {
388 unsafe { Tensor::from_raw_dt(T::datum_type(), shape, content) }
389 }
390
391 pub unsafe fn from_raw_aligned<T: Datum>(
392 shape: &[usize],
393 content: &[u8],
394 align: usize,
395 ) -> TractResult<Tensor> {
396 unsafe { Tensor::from_raw_dt_align(T::datum_type(), shape, content, align) }
397 }
398
399 pub unsafe fn from_raw_dt(
400 dt: DatumType,
401 shape: &[usize],
402 content: &[u8],
403 ) -> TractResult<Tensor> {
404 unsafe { Self::from_raw_dt_align(dt, shape, content, vector_size()) }
405 }
406
407 pub unsafe fn from_raw_dt_align(
408 dt: DatumType,
409 shape: &[usize],
410 content: &[u8],
411 align: usize,
412 ) -> TractResult<Tensor> {
413 let mut tensor = unsafe { Tensor::uninitialized_aligned_dt(dt, shape, align) }?;
414 tensor.as_bytes_mut().copy_from_slice(content);
415 Ok(tensor)
416 }
417
418 pub unsafe fn from_slice_align<T: Datum>(content: &[T], align: usize) -> TractResult<Tensor> {
419 let bytes = if content.len() == 0 {
420 &[]
421 } else {
422 unsafe {
423 std::slice::from_raw_parts(
424 content.as_ptr() as *const u8,
425 content.len() * T::datum_type().size_of(),
426 )
427 }
428 };
429 unsafe { Self::from_raw_dt_align(T::datum_type(), &[content.len()], bytes, align) }
430 }
431
432 #[inline]
434 pub fn rank(&self) -> usize {
435 self.shape.len()
436 }
437
438 #[inline]
440 pub fn shape(&self) -> &[usize] {
441 &self.shape
442 }
443
444 #[inline]
446 #[allow(clippy::len_without_is_empty)]
447 pub fn len(&self) -> usize {
448 self.len
449 }
450
451 #[inline]
453 #[allow(clippy::len_without_is_empty)]
454 pub fn volume(&self) -> usize {
455 self.len
456 }
457
458 #[inline]
460 pub fn strides(&self) -> &[isize] {
461 &self.strides
462 }
463
464 fn update_strides_and_len(&mut self) {
465 self.strides.clear();
466 if self.shape.len() == 0 {
467 self.len = 1;
468 return;
469 }
470 compute_natural_stride_to(&mut self.strides, &self.shape);
471 self.len = unsafe { *self.strides.get_unchecked(0) as usize * self.shape.get_unchecked(0) };
472 }
473
474 pub unsafe fn set_shape_unchecked(&mut self, shape: &[usize]) {
476 if shape != &*self.shape {
477 self.shape.clear();
478 self.shape.extend_from_slice(shape);
479 self.update_strides_and_len();
480 }
481 }
482
483 pub unsafe fn set_geometry_unchecked(&mut self, shape: &[usize], strides: &[isize]) {
485 self.shape.clear();
486 self.shape.extend_from_slice(shape);
487 self.strides.clear();
488 self.strides.extend_from_slice(strides);
489 }
490
491 pub fn set_shape(&mut self, shape: &[usize]) -> TractResult<()> {
493 if self.len() != shape.iter().product::<usize>() {
494 bail!("Invalid reshape {:?} to {:?}", self.shape, shape);
495 }
496 unsafe { self.set_shape_unchecked(shape) }
497 Ok(())
498 }
499
500 pub fn permute_axes(self, axes: &[usize]) -> TractResult<Tensor> {
501 ensure!(axes.iter().duplicates().next().is_none());
502 ensure!(axes.iter().all(|a| *a < self.rank()));
503 unsafe {
504 #[inline]
505 unsafe fn permute<T: Datum>(axes: &[usize], input: Tensor) -> Tensor {
506 unsafe { input.into_array_unchecked::<T>().permuted_axes(axes).into_tensor() }
507 }
508 let dt = self.datum_type();
509 let mut t = dispatch_datum_by_size!(permute(self.datum_type())(axes, self));
510 t.set_datum_type(dt);
511 Ok(t)
512 }
513 }
514
515 pub fn move_axis(self, from: usize, to: usize) -> TractResult<Tensor> {
516 let mut permutation: Vec<usize> = (0..self.rank()).collect();
517 permutation.remove(from);
518 permutation.insert(to, from);
519 self.permute_axes(&permutation)
520 }
521
522 pub fn collapse_axis_with_next(mut self, axis: usize) -> Tensor {
523 let removed = self.shape.remove(axis + 1);
524 self.shape[axis] *= removed;
525 self.update_strides_and_len();
526 self
527 }
528
529 pub fn split_axis(mut self, axis: usize, outer_dim: usize) -> TractResult<Tensor> {
530 if self.shape[axis] % outer_dim != 0 {
531 bail!(
532 "Invalid axis split, shape is {:?}, axis split at {}, outer {}",
533 self.shape,
534 axis,
535 outer_dim
536 );
537 }
538 self.shape.insert(axis + 1, self.shape[axis] / outer_dim);
539 self.shape[axis] = outer_dim;
540 self.update_strides_and_len();
541 Ok(self)
542 }
543
544 pub fn into_shape(mut self, shape: &[usize]) -> TractResult<Tensor> {
546 self.set_shape(shape)?;
547 Ok(self)
548 }
549
550 pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
551 self.shape.insert(axis, 1);
552 self.strides.insert(axis, self.strides.get(axis).copied().unwrap_or(1));
553 Ok(())
554 }
555
556 pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
557 ensure!(self.shape[axis] == 1, "Remove a non-1 axis: axis {} in {:?}", axis, self);
558 self.shape.remove(axis);
559 self.strides.remove(axis);
560 Ok(())
561 }
562
563 pub fn broadcast_into_rank(mut self, rank: usize) -> TractResult<Tensor> {
564 self.broadcast_to_rank(rank)?;
565 self.update_strides_and_len();
566 Ok(self)
567 }
568
569 pub fn broadcast_to_rank(&mut self, rank: usize) -> TractResult<()> {
570 if rank < self.rank() {
571 bail!("Can only broadcast to higher rank")
572 }
573 while self.shape.len() < rank {
574 self.shape.insert(0, 1)
575 }
576 self.update_strides_and_len();
577 Ok(())
578 }
579
580 pub fn broadcast_scalar_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
581 if self.rank() > 0 {
582 bail!("broadcast_scalar_to_shape called on {:?}, which is not a salar", self);
583 }
584 unsafe fn make<T: Datum>(src: &Tensor, dst: &mut Tensor) {
585 unsafe {
586 let value: &T = src.to_scalar_unchecked::<T>();
587 dst.as_slice_mut_unchecked::<T>().iter_mut().for_each(|item| *item = value.clone())
588 };
589 }
590 unsafe {
591 let mut t = Tensor::uninitialized_dt(self.datum_type(), shape)?;
592 dispatch_datum_by_size!(make(self.datum_type())(self, &mut t));
593 Ok(t)
594 }
595 }
596
597 fn broadcast_to_shape_t<T: Datum>(&self, shape: &[usize]) -> TractResult<Tensor> {
598 unsafe {
599 let view = self.to_array_view_unchecked::<T>();
600 let mut output = view
601 .broadcast(shape)
602 .with_context(|| format!("Broadcasting {view:?} to {shape:?}"))?
603 .into_owned()
604 .into_tensor();
605 output.set_datum_type(self.datum_type());
606 Ok(output)
607 }
608 }
609
610 pub fn broadcast_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
611 dispatch_datum!(Self::broadcast_to_shape_t(self.dt)(self, shape))
612 }
613
614 pub fn broadcast_vector_to_shape(&self, shape: &[usize], axis: usize) -> TractResult<Tensor> {
615 ensure!(self.rank() == 1);
616 ensure!(shape[axis] == self.len());
617 if !self.datum_type().is_copy() {
618 let mut vec_shape = vec![1; shape.len()];
619 vec_shape[axis] = self.len();
620 return self.clone().into_shape(&vec_shape)?.broadcast_to_shape(shape);
621 }
622 unsafe {
623 let mut output = Tensor::uninitialized_dt(self.datum_type(), shape)?;
624 if output.len() == 0 {
625 return Ok(output);
626 }
627 let inner_len = shape[axis + 1..].iter().product::<usize>();
628
629 unsafe fn splat<T>(input: &Tensor, output: &mut Tensor, inner_len: usize)
630 where
631 T: Datum + Copy,
632 {
633 unsafe {
634 for ix in 0..input.len() {
635 let value: T = input.as_slice_unchecked()[ix];
636 output.as_slice_mut_unchecked::<T>()[ix * inner_len..(ix + 1) * inner_len]
637 .iter_mut()
638 .for_each(|item| *item = value);
639 }
640 }
641 }
642 dispatch_copy_by_size!(splat(self.datum_type())(&self, &mut output, inner_len));
643
644 let outer_len = shape[0..axis].iter().product::<usize>();
645 let repeat_bytes_len = inner_len * self.as_bytes().len();
646 let bytes = output.as_bytes_mut();
647 for ix in 1..outer_len {
648 bytes.copy_within(0..repeat_bytes_len, ix * repeat_bytes_len);
649 }
650
651 Ok(output)
652 }
653 }
654
655 fn clip_range_bounds(
656 &self,
657 axis: usize,
658 range: impl std::ops::RangeBounds<usize>,
659 ) -> Range<usize> {
660 use std::ops::Bound;
661 let start = match range.start_bound() {
662 Bound::Included(ix) => *ix,
663 Bound::Excluded(ix) => ix + 1,
664 Bound::Unbounded => 0,
665 };
666 let end = match range.end_bound() {
667 Bound::Included(ix) => *ix + 1,
668 Bound::Excluded(ix) => *ix,
669 Bound::Unbounded => self.shape()[axis],
670 };
671 start..end
672 }
673
674 pub fn assign_slice(
675 &mut self,
676 range: impl std::ops::RangeBounds<usize>,
677 src: &Tensor,
678 src_range: impl std::ops::RangeBounds<usize>,
679 axis: usize,
680 ) -> TractResult<()> {
681 let range = self.clip_range_bounds(axis, range);
682 let src_range = src.clip_range_bounds(axis, src_range);
683 ensure!(
684 src.datum_type() == self.datum_type(),
685 "Attempt to assign into {:?} from {:?}, datum type mismatch",
686 self.datum_type(),
687 src.datum_type()
688 );
689 ensure!(
690 src_range.len() == range.len(),
691 "Attempt to assign a range of {:?} from a range of {:?}",
692 range,
693 src_range,
694 );
695 ensure!(
696 self.rank() == src.rank()
697 && itertools::izip!(0.., self.shape(), src.shape())
698 .all(|(ix, dst, src)| ix == axis || src == dst),
699 "Attempt to assign a {}-axis range of {:?} from a range of {:?}",
700 axis,
701 self,
702 src
703 );
704 ensure!(
705 src_range.end <= src.shape()[axis],
706 "Assigning from invalid slice (axis {}, {:?}) of {:?}",
707 axis,
708 src_range,
709 src
710 );
711 ensure!(
712 range.end <= self.shape()[axis],
713 "Assigning to invalid slice (axis {}, {:?}) of {:?}",
714 axis,
715 range,
716 self
717 );
718 unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
719 Ok(())
720 }
721
722 pub unsafe fn assign_slice_unchecked(
723 &mut self,
724 range: impl std::ops::RangeBounds<usize>,
725 src: &Tensor,
726 src_range: impl std::ops::RangeBounds<usize>,
727 axis: usize,
728 ) {
729 let range = self.clip_range_bounds(axis, range);
730 let src_range = src.clip_range_bounds(axis, src_range);
731 unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
732 }
733
734 #[allow(clippy::ptr_eq)]
735 unsafe fn assign_slice_from_resolved(
736 &mut self,
737 range: std::ops::Range<usize>,
738 src: &Tensor,
739 src_range: std::ops::Range<usize>,
740 axis: usize,
741 ) {
742 unsafe {
743 use ndarray::Slice;
744 unsafe fn assign_slice_t<T: Datum>(
745 to: &mut Tensor,
746 to_range: Range<usize>,
747 from: &Tensor,
748 from_range: Range<usize>,
749 axis: usize,
750 ) {
751 unsafe {
752 to.to_array_view_mut_unchecked::<T>()
753 .slice_axis_mut(Axis(axis), Slice::from(to_range))
754 .assign(
755 &from
756 .to_array_view_unchecked::<T>()
757 .slice_axis(Axis(axis), Slice::from(from_range)),
758 )
759 }
760 }
761 if self.datum_type().is_copy() && self.shape[..axis].iter().all(|d| *d == 1) {
762 let stride = self.strides[axis] as usize * self.datum_type().size_of();
763 let dst_start = (stride * range.start) as isize;
764 let src_start = (stride * src_range.start) as isize;
765 let len = stride * range.len();
766 if len > 0 {
767 if self.data.as_ptr() != src.data.as_ptr() {
768 std::ptr::copy_nonoverlapping(
769 src.data.as_ptr().offset(src_start),
770 self.data.as_mut_ptr().offset(dst_start),
771 len,
772 );
773 } else {
774 std::ptr::copy(
775 src.data.as_ptr().offset(src_start),
776 self.data.as_mut_ptr().offset(dst_start),
777 len,
778 );
779 }
780 }
781 } else {
782 dispatch_datum!(assign_slice_t(self.datum_type())(
783 self, range, src, src_range, axis
784 ));
785 }
786 }
787 }
788
789 #[inline]
791 pub fn datum_type(&self) -> DatumType {
792 self.dt
793 }
794
795 #[inline]
797 pub unsafe fn set_datum_type(&mut self, dt: DatumType) {
798 self.dt = dt
799 }
800
801 pub fn dump(&self, force_full: bool) -> TractResult<String> {
805 unsafe fn dump_t<D: Datum>(tensor: &Tensor, n: usize) -> String {
806 unsafe {
807 if let Some(qp) = tensor.datum_type().qparams() {
808 let integers = tensor.cast_to::<i32>().unwrap();
809 integers.as_slice_unchecked::<i32>()[0..n]
810 .iter()
811 .map(|x| format!("[{}]({})", x, qp.dq(*x)))
812 .join(", ")
813 } else {
814 tensor.as_slice_unchecked::<D>()[0..n].iter().join(", ")
815 }
816 }
817 }
818 unsafe {
819 let trunc = self.len() > 12 && !force_full;
820 let data = dispatch_datum!(dump_t(self.datum_type())(
821 self,
822 if trunc { 12 } else { self.len() }
823 ));
824 Ok(format!(
825 "{},{:?} {}{}",
826 self.shape.iter().join(","),
827 self.dt,
828 data,
829 if trunc { "..." } else { "" }
830 ))
831 }
832 }
833
834 pub fn close_enough(
836 &self,
837 other: &Self,
838 approx: impl Into<Approximation> + std::fmt::Debug,
839 ) -> TractResult<()> {
840 let approx = approx.into();
841 if self.shape() != other.shape() {
842 bail!("Shape mismatch {:?} != {:?}", self.shape(), other.shape())
843 }
844 let (atol, rtol, outliers) = approx.atol_rtol_outliers(&self.datum_type());
845 let ma = self.cast_to::<f32>()?;
846 let ma = ma.to_array_view::<f32>()?;
847 let mb = other.cast_to::<f32>()?;
848 let mb = mb.to_array_view::<f32>()?;
849 let mut first_outlier = None;
850 let mut outliers_count = 0;
851 ndarray::indices_of(&ma).into_iter().for_each(|indices| {
852 let a = ma[&indices];
853 let b = mb[&indices];
854 if !((a.is_nan() && b.is_nan())
855 || (a.is_infinite() && b.is_infinite() && a.signum() == b.signum())
856 || (a - b).abs() <= atol as f32 + rtol as f32 * b.abs())
857 {
858 if outliers_count == 0 {
859 first_outlier = Some(indices.as_array_view().to_vec());
860 }
861 outliers_count += 1;
862 }
863 });
864 if self.volume() > 0 && outliers_count as f64 / self.volume() as f64 > outliers {
865 let indices = first_outlier.unwrap();
866 let a = ma[&*indices];
867 let b = mb[&*indices];
868 bail!(
869 "Mismatch. First outlier: {:?} for {:?}) at {:?} {} != {}. Outliers: {} / {} = {:0.5} > {:0.5}.",
870 approx,
871 self.datum_type(),
872 indices,
873 a,
874 b,
875 outliers_count,
876 self.volume(),
877 outliers_count as f64 / self.volume() as f64,
878 outliers
879 );
880 }
881 Ok(())
882 }
883
884 pub fn into_array<D: Datum>(self) -> TractResult<ArrayD<D>> {
886 Ok(self.to_array_view::<D>()?.to_owned())
887 }
888
889 pub unsafe fn into_array_unchecked<D: Datum>(self) -> ArrayD<D> {
891 unsafe { self.to_array_view_unchecked::<D>().to_owned() }
892 }
893
894 fn check_for_access<D: Datum>(&self) -> TractResult<()> {
895 ensure!(
896 self.datum_type().unquantized() == D::datum_type().unquantized(),
897 "Tensor datum type error: tensor is {:?}, accessed as {:?}",
898 self.datum_type(),
899 D::datum_type(),
900 );
901 Ok(())
902 }
903
904 pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'_, D>> {
906 self.check_for_access::<D>()?;
907 unsafe { Ok(self.to_array_view_unchecked()) }
908 }
909
910 pub fn to_array_view_mut<D: Datum>(&mut self) -> TractResult<ArrayViewMutD<'_, D>> {
912 self.check_for_access::<D>()?;
913 unsafe { Ok(self.to_array_view_mut_unchecked()) }
914 }
915
916 pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'_, D> {
918 if self.len() != 0 {
919 unsafe { ArrayViewD::from_shape_ptr(&*self.shape, self.data.as_ptr() as *const D) }
920 } else {
921 ArrayViewD::from_shape(&*self.shape, &[]).unwrap()
922 }
923 }
924
925 pub unsafe fn to_array_view_mut_unchecked<D: Datum>(&mut self) -> ArrayViewMutD<'_, D> {
927 if self.len() != 0 {
928 unsafe { ArrayViewMutD::from_shape_ptr(&*self.shape, self.data.as_mut_ptr() as *mut D) }
929 } else {
930 ArrayViewMutD::from_shape(&*self.shape, &mut []).unwrap()
931 }
932 }
933
934 pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
936 self.check_for_access::<D>()?;
937 Ok(self.data.as_ptr() as *const D)
938 }
939
940 pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
942 self.data.as_ptr() as *const D
943 }
944
945 pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
947 self.data.as_mut_ptr() as *mut D
948 }
949
950 pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
952 self.as_ptr::<D>().map(|p| p as *mut D)
953 }
954
955 pub fn as_slice<D: Datum>(&self) -> TractResult<&[D]> {
957 let ptr: *const D = self.as_ptr()?;
958 if self.data.len() == 0 {
959 Ok(&[])
960 } else {
961 unsafe { Ok(std::slice::from_raw_parts::<D>(ptr, self.len())) }
962 }
963 }
964
965 pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
967 let ptr: *mut D = self.as_ptr_mut()?;
968 if self.data.len() == 0 {
969 Ok(&mut [])
970 } else {
971 unsafe { Ok(std::slice::from_raw_parts_mut::<D>(ptr, self.len())) }
972 }
973 }
974
975 pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &[D] {
977 if self.data.len() == 0 {
978 &[]
979 } else {
980 unsafe { std::slice::from_raw_parts::<D>(self.as_ptr_unchecked(), self.len()) }
981 }
982 }
983
984 pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
986 if self.data.len() == 0 {
987 &mut []
988 } else {
989 unsafe { std::slice::from_raw_parts_mut::<D>(self.as_ptr_mut_unchecked(), self.len()) }
990 }
991 }
992
993 pub fn to_scalar<D: Datum>(&self) -> TractResult<&D> {
995 self.check_for_access::<D>()?;
996 if self.len() == 0 {
997 bail!("to_scalar called on empty tensor ({:?})", self)
998 }
999 if self.len() > 1 {
1000 bail!("to_scalar called on a tensor with multiple values ({:?})", self)
1001 }
1002 unsafe { Ok(self.to_scalar_unchecked()) }
1003 }
1004
1005 pub fn to_scalar_tensor(&self) -> TractResult<Tensor> {
1007 fn to_scalar_tensor_t<D: Datum>(t: &Tensor) -> TractResult<Tensor> {
1008 Ok(litteral::tensor0(t.to_scalar::<D>()?.clone()))
1009 }
1010 dispatch_datum!(to_scalar_tensor_t(self.datum_type())(self))
1011 }
1012
1013 pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &D {
1015 unsafe { &*(self.data.as_ptr() as *const D) }
1016 }
1017
1018 pub fn to_scalar_mut<D: Datum>(&mut self) -> TractResult<&mut D> {
1020 self.check_for_access::<D>()?;
1021 if self.len() == 0 {
1022 bail!("to_scalar_mut called on empty tensor ({:?})", self)
1023 }
1024 if self.len() > 1 {
1025 bail!("to_scalar called on a tensor with multiple values ({:?})", self)
1026 }
1027 unsafe { Ok(self.to_scalar_mut_unchecked()) }
1028 }
1029
1030 pub unsafe fn to_scalar_mut_unchecked<D: Datum>(&mut self) -> &mut D {
1032 unsafe { &mut *(self.data.as_mut_ptr() as *mut D) }
1033 }
1034
1035 pub fn as_bytes(&self) -> &[u8] {
1036 self.data.as_bytes()
1037 }
1038
1039 pub fn as_bytes_mut(&mut self) -> &mut [u8] {
1040 self.data.as_bytes_mut()
1041 }
1042
1043 unsafe fn is_uniform_t<T: Datum>(&self) -> bool {
1044 let slice = unsafe { self.as_slice_unchecked::<T>() };
1045 slice[1..].iter().all(|x| x == &slice[0])
1046 }
1047
1048 pub fn is_uniform(&self) -> bool {
1049 if self.len() <= 1 {
1050 return true;
1051 }
1052 unsafe { dispatch_datum!(Tensor::is_uniform_t(self.datum_type())(self)) }
1053 }
1054
1055 unsafe fn as_uniform_t<T: Datum>(&self) -> Tensor {
1056 let v: T = unsafe { self.as_slice_unchecked::<T>() }[0].clone();
1057 litteral::tensor0(v)
1058 }
1059
1060 pub fn as_uniform(&self) -> Option<Tensor> {
1061 if self.len() >= 1 && self.is_uniform() {
1062 unsafe {
1063 let mut t = dispatch_datum!(Tensor::as_uniform_t(self.datum_type())(self));
1064 t.set_datum_type(self.datum_type());
1065 Some(t)
1066 }
1067 } else {
1068 None
1069 }
1070 }
1071
1072 pub fn is_all_zero(&self) -> TractResult<bool> {
1073 Ok(self.len() == 0 || self.as_uniform().map(|t| t.is_zero().unwrap()).unwrap_or(false))
1074 }
1075
1076 pub fn is_zero(&self) -> TractResult<bool> {
1077 Ok(self == &Tensor::zero_scalar_dt(self.dt)?)
1078 }
1079
1080 unsafe fn natural_cast<
1081 Source: Datum + num_traits::AsPrimitive<Target>,
1082 Target: Datum + Copy,
1083 >(
1084 &self,
1085 other: &mut Tensor,
1086 ) {
1087 unsafe {
1088 self.as_slice_unchecked::<Source>()
1089 .iter()
1090 .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1091 .for_each(|(s, d)| *d = s.as_())
1092 };
1093 }
1094
1095 unsafe fn cast_number_to_bool<Source: Datum + num_traits::Zero>(&self, other: &mut Tensor) {
1096 unsafe {
1097 self.as_slice_unchecked::<Source>()
1098 .iter()
1099 .zip(other.as_slice_mut_unchecked::<bool>().iter_mut())
1100 .for_each(|(s, d)| *d = !s.is_zero());
1101 }
1102 }
1103
1104 unsafe fn cast_from_string<Target: Datum + core::str::FromStr>(
1105 &self,
1106 other: &mut Tensor,
1107 ) -> TractResult<()> {
1108 unsafe {
1109 for (s, d) in self
1110 .as_slice_unchecked::<String>()
1111 .iter()
1112 .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1113 {
1114 *d = s
1115 .parse()
1116 .map_err(|_| format_err!("Can not parse as {:?}", Target::datum_type()))?;
1117 }
1118 Ok(())
1119 }
1120 }
1121
1122 unsafe fn cast_to_string<Source: Datum>(&self, other: &mut Tensor) {
1123 unsafe {
1124 for (s, d) in self
1125 .as_slice_unchecked::<Source>()
1126 .iter()
1127 .zip(other.as_slice_mut_unchecked::<String>().iter_mut())
1128 {
1129 *d = s.to_string()
1130 }
1131 }
1132 }
1133
1134 pub fn cast_to<D: Datum>(&self) -> TractResult<Cow<'_, Tensor>> {
1136 self.cast_to_dt(D::datum_type())
1137 }
1138
1139 #[allow(clippy::redundant_closure_call)]
1141 pub fn cast_to_dt(&self, dst_dt: DatumType) -> TractResult<Cow<'_, Tensor>> {
1142 unsafe {
1143 if self.dt == dst_dt {
1144 return Ok(Cow::Borrowed(self));
1145 }
1146 if self.dt == TDim::datum_type() && (dst_dt.is_integer() || dst_dt.is_float()) {
1147 let slice = self.as_slice_unchecked::<TDim>();
1148 let mut ints = Self::uninitialized::<i64>(&self.shape)?;
1149 let ints_slice = ints.as_slice_mut_unchecked::<i64>();
1150 for i in 0..self.len() {
1151 ints_slice[i] = slice[i].to_i64()?;
1152 }
1153 return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1154 }
1155 if self.dt == bool::datum_type()
1156 && (dst_dt.is_integer() || dst_dt.is_float() || dst_dt == TDim::datum_type())
1157 {
1158 let slice = self.as_slice_unchecked::<bool>();
1159 let mut ints = Self::uninitialized::<i8>(&self.shape)?;
1160 let ints_slice = ints.as_slice_mut_unchecked::<i8>();
1161 for i in 0..self.len() {
1162 ints_slice[i] = slice[i] as usize as i8;
1163 }
1164 return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1165 }
1166 let mut result = Self::uninitialized_dt(dst_dt, &self.shape)?;
1167 if self.dt == DatumType::String {
1168 dispatch_numbers!(Self::cast_from_string(dst_dt)(self, &mut result))?;
1169 return Ok(Cow::Owned(result));
1170 }
1171 if dst_dt == DatumType::String {
1172 dispatch_datum!(Self::cast_to_string(self.dt)(self, &mut result));
1173 return Ok(Cow::Owned(result));
1174 }
1175 macro_rules! n {
1176 ($source:ty) => {
1177 if <$source>::datum_type() == self.datum_type() {
1178 match dst_dt {
1179 DatumType::I8 => self.natural_cast::<$source, i8>(&mut result),
1180 DatumType::I16 => self.natural_cast::<$source, i16>(&mut result),
1181 DatumType::I32 => self.natural_cast::<$source, i32>(&mut result),
1182 DatumType::I64 => self.natural_cast::<$source, i64>(&mut result),
1183 DatumType::U8 => self.natural_cast::<$source, u8>(&mut result),
1184 DatumType::U16 => self.natural_cast::<$source, u16>(&mut result),
1185 DatumType::U32 => self.natural_cast::<$source, u32>(&mut result),
1186 DatumType::U64 => self.natural_cast::<$source, u64>(&mut result),
1187 DatumType::F16 => self.natural_cast::<$source, f16>(&mut result),
1188 DatumType::F32 => self.natural_cast::<$source, f32>(&mut result),
1189 DatumType::F64 => self.natural_cast::<$source, f64>(&mut result),
1190 DatumType::TDim => {
1191 let ints = self.cast_to::<i32>()?;
1192 let slice = ints.as_slice_unchecked::<i32>();
1193 let result = result.as_slice_mut_unchecked::<TDim>();
1194 for i in 0..self.len() {
1195 result[i] = slice[i].into();
1196 }
1197 }
1198 DatumType::Bool => self.cast_number_to_bool::<$source>(&mut result),
1199 _ => todo!(),
1200 }
1201 return Ok(Cow::Owned(result));
1202 };
1203 };
1204 }
1205 if !dst_dt.is_quantized() && !self.datum_type().is_quantized() {
1207 n!(u8);
1208 n!(u16);
1209 n!(u32);
1210 n!(u64);
1211 n!(i8);
1212 n!(i16);
1213 n!(i32);
1214 n!(i64);
1215 n!(f16);
1216 n!(f32);
1217 n!(f64);
1218 } else {
1219 let (s_zp, s_scale) = self.datum_type().zp_scale();
1220 let (d_zp, d_scale) = dst_dt.zp_scale();
1221 if self.datum_type().is_quantized() && dst_dt.is_float() {
1222 macro_rules! q_to_fp {
1223 ($source:ty, $dest:ty) => {
1224 if <$source>::datum_type().unquantized()
1225 == self.datum_type().unquantized()
1226 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1227 {
1228 self.as_slice_unchecked::<$source>()
1229 .iter()
1230 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1231 .for_each(|(&s, d)| {
1232 *d = (s as $dest - s_zp as $dest) * s_scale as $dest;
1233 });
1234 return Ok(Cow::Owned(result));
1235 }
1236 };
1237 }
1238 q_to_fp!(i8, f64);
1239 q_to_fp!(i8, f32);
1240 q_to_fp!(u8, f64);
1241 q_to_fp!(u8, f32);
1242 }
1243 macro_rules! q8_to_q8 {
1245 ($typ:ty) => {
1246 if dst_dt.unquantized() == <$typ>::datum_type() {
1247 self.as_slice_unchecked::<$typ>()
1248 .iter()
1249 .zip(result.as_slice_mut_unchecked::<$typ>().iter_mut())
1250 .for_each(|(&s, d)| {
1251 *d = (d_zp as i32
1252 + scale_by(s as i32 - s_zp as i32, s_scale / d_scale))
1253 .clamp_cast()
1254 });
1255 return Ok(Cow::Owned(result));
1256 }
1257 };
1258 }
1259
1260 macro_rules! q_via_f32 {
1261 ($source:ty, $dest:ty, $round:expr) => {
1262 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1263 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1264 {
1265 self.as_slice_unchecked::<$source>()
1266 .iter()
1267 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1268 .for_each(|(&s, d)| {
1269 let s_float = (s as f32 - s_zp as f32) * s_scale as f32;
1270 let d_float = s_float as f32 / d_scale as f32 + d_zp as f32;
1271 *d = $round(d_float);
1272 });
1273 return Ok(Cow::Owned(result));
1274 }
1275 };
1276 }
1277
1278 macro_rules! q_n {
1279 (clamp $source:ty, $dest:ty) => {{
1280 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1281 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1282 {
1283 self.as_slice_unchecked::<$source>()
1284 .iter()
1285 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1286 .for_each(|(&s, d)| {
1287 *d = s.clamp_cast();
1288 });
1289 return Ok(Cow::Owned(result));
1290 }
1291 }};
1292 ($source:ty, $dest:ty) => {{
1293 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1294 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1295 {
1296 self.as_slice_unchecked::<$source>()
1297 .iter()
1298 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1299 .for_each(|(&s, d)| {
1300 *d = s as $dest;
1301 });
1302 return Ok(Cow::Owned(result));
1303 }
1304 }};
1305 }
1306
1307 if dst_dt.unquantized() == self.datum_type().unquantized()
1308 && dst_dt.is_quantized()
1309 && self.datum_type().is_quantized()
1310 {
1311 q8_to_q8!(i8);
1312 q8_to_q8!(u8);
1313 }
1314
1315 q_via_f32!(f32, i8, |f| round_ties_to_even(f).clamp_cast());
1316 q_via_f32!(f32, u8, |f| round_ties_to_even(f).clamp_cast());
1317 q_via_f32!(f32, i32, |f| round_ties_to_even(f).clamp_cast());
1318 q_via_f32!(i8, f32, |f| f);
1319 q_via_f32!(u8, f32, |f| f);
1320 q_via_f32!(i32, f32, |f| f);
1321
1322 if dst_dt.is_quantized() && self.datum_type().is_quantized() {
1323 q_via_f32!(u8, i8, |f| round_ties_to_even(f).clamp_cast());
1324 q_via_f32!(i8, u8, |f| round_ties_to_even(f).clamp_cast());
1325 q_via_f32!(i32, u8, |f| round_ties_to_even(f).clamp_cast());
1326 q_via_f32!(i32, i8, |f| round_ties_to_even(f).clamp_cast());
1327 q_via_f32!(u8, i32, |f| round_ties_to_even(f).clamp_cast());
1328 q_via_f32!(i8, i32, |f| round_ties_to_even(f).clamp_cast());
1329
1330 q_via_f32!(i8, i8, |f| round_ties_to_even(f).clamp_cast());
1332 q_via_f32!(u8, u8, |f| round_ties_to_even(f).clamp_cast());
1333 }
1334
1335 q_n!(i8, i32);
1336 q_n!(i8, u32);
1337 q_n!(u8, i32);
1338 q_n!(u8, u32);
1339 q_n!(clamp i32, i8);
1340 q_n!(clamp i32, u8);
1341 q_n!(clamp u32, i8);
1342 q_n!(clamp u32, u8);
1343 q_n!(i8, i8);
1344 q_n!(u8, u8);
1345 q_n!(i32, i32);
1346 q_n!(u32, u32);
1347 }
1348
1349 bail!("Unsupported cast from {:?} to {:?}", self.dt, dst_dt)
1350 }
1351 }
1352
1353 pub fn cast_to_scalar<D: Datum + Copy>(&self) -> TractResult<D> {
1355 let casted = self.cast_to::<D>()?;
1356 casted.to_scalar::<D>().copied()
1357 }
1358
1359 pub fn nth(&self, nth: usize) -> TractResult<Tensor> {
1361 if nth >= self.len() {
1362 bail!(
1363 "nth called with {}th element on a tensor of len {} ({:?}",
1364 nth,
1365 self.len(),
1366 self
1367 );
1368 }
1369 unsafe fn nth_t<T: Datum>(me: &Tensor, nth: usize, output: &mut Tensor) {
1370 unsafe {
1371 let value = me.as_slice_unchecked::<T>()[nth].clone();
1372 output.as_slice_mut_unchecked::<T>()[0] = value;
1373 }
1374 }
1375 unsafe {
1376 let mut output = Tensor::uninitialized_dt(self.datum_type(), &[])?;
1377 dispatch_datum_by_size!(nth_t(self.datum_type())(self, nth, &mut output));
1378 Ok(output)
1379 }
1380 }
1381
1382 fn eq_dt(&self, other: &Tensor) -> TractResult<bool> {
1384 unsafe fn eq_t<D: Datum>(me: &Tensor, other: &Tensor) -> TractResult<bool> {
1385 unsafe {
1386 if D::datum_type().is_float() {
1387 return dispatch_floatlike!(float_eq_t(D::datum_type())(me, other));
1388 }
1389 Ok(izip!(me.as_slice_unchecked::<D>(), other.as_slice_unchecked::<D>())
1390 .all(|(a, b)| a == b))
1391 }
1392 }
1393
1394 unsafe fn float_eq_t<D: Datum + Float>(me: &Tensor, other: &Tensor) -> TractResult<bool> {
1395 unsafe {
1396 Ok(izip!(me.as_slice_unchecked::<D>(), other.as_slice_unchecked::<D>())
1397 .all(|(a, b)| (a.is_nan() && b.is_nan()) || a == b))
1398 }
1399 }
1400
1401 unsafe {
1402 Ok(self.datum_type() == other.datum_type()
1403 && self.shape() == other.shape()
1404 && dispatch_datum!(eq_t(self.dt)(self, other))?)
1405 }
1406 }
1407
1408 fn from_datum<T: Datum>(mut it: ArrayD<T>) -> Tensor {
1409 unsafe {
1410 let mut t = Self::uninitialized::<T>(it.shape()).unwrap();
1411 if let Some(slice) = it.as_slice_mut() {
1412 if t.datum_type().is_copy() {
1413 std::ptr::copy_nonoverlapping(
1414 slice.as_ptr() as *const i8,
1415 t.as_ptr_mut_unchecked(),
1416 t.data.layout().size(),
1417 );
1418 } else {
1419 t.as_slice_mut_unchecked::<T>()
1420 .iter_mut()
1421 .zip(slice.iter_mut())
1422 .for_each(|(t, s)| *t = std::mem::take(s));
1423 }
1424 return t;
1425 }
1426 if it.strides().iter().all(|&s| s > 0) && it.as_slice_memory_order().is_some() {
1427 let mut len_and_strides: TVec<(usize, usize)> = tvec!();
1428 for (len, stride) in itertools::izip!(it.shape(), it.strides(), t.strides())
1429 .sorted_by_key(|(_, src, _)| *src)
1430 .map(|(l, _, dst)| (*l as isize, *dst))
1431 {
1432 if !len_and_strides.is_empty()
1433 && len_and_strides.last().unwrap().1 * len_and_strides.last().unwrap().0
1434 == stride as usize
1435 {
1436 len_and_strides.last_mut().unwrap().0 *= len as usize;
1437 } else {
1438 len_and_strides.push((len as usize, stride as usize));
1439 }
1440 }
1441 len_and_strides.reverse();
1442 crate::scatter::scatter_contig_data(
1443 it.as_ptr(),
1444 t.as_ptr_mut_unchecked(),
1445 &len_and_strides,
1446 );
1447 return t;
1448 }
1449 t.as_slice_mut_unchecked().iter_mut().zip(it).for_each(|(t, a)| *t = a);
1451 t
1452 }
1453 }
1454
1455 pub fn deep_clone(&self) -> Tensor {
1456 unsafe {
1457 let mut tensor = Tensor::uninitialized_dt(self.datum_type(), self.shape()).unwrap();
1458 if self.len() > 0 {
1459 if self.dt.is_copy() {
1460 self.data.as_ptr().copy_to_nonoverlapping(
1461 tensor.as_bytes_mut().as_mut_ptr(),
1462 self.data.layout().size(),
1463 )
1464 } else if self.dt == DatumType::String {
1465 tensor
1466 .as_slice_mut_unchecked::<String>()
1467 .clone_from_slice(self.as_slice_unchecked());
1468 } else if self.dt == DatumType::Blob {
1469 tensor
1470 .as_slice_mut_unchecked::<Blob>()
1471 .clone_from_slice(self.as_slice_unchecked());
1472 } else if self.dt == DatumType::Opaque {
1473 tensor
1474 .as_slice_mut_unchecked::<Opaque>()
1475 .clone_from_slice(self.as_slice_unchecked());
1476 } else if self.dt == DatumType::TDim {
1477 tensor
1478 .as_slice_mut_unchecked::<TDim>()
1479 .clone_from_slice(self.as_slice_unchecked());
1480 }
1481 }
1482 tensor
1483 }
1484 }
1485
1486 pub fn slice(&self, axis: usize, start: usize, end: usize) -> TractResult<Tensor> {
1487 if axis >= self.rank() {
1488 bail!("Can not slice at axis {} tensor {:?}", axis, self);
1489 }
1490 if start > self.shape[axis] || end > self.shape[axis] || start >= end {
1491 bail!("Invalid slicing range {start}..{end} on axis {axis} for {self:?}");
1492 }
1493 fn slice_t<T: Datum>(
1494 t: &Tensor,
1495 axis: usize,
1496 start: usize,
1497 end: usize,
1498 ) -> TractResult<Tensor> {
1499 Ok(t.to_array_view::<T>()?
1500 .slice_axis(ndarray::Axis(axis), (start..end).into())
1501 .into_owned()
1502 .into_tensor())
1503 }
1504 dispatch_datum!(slice_t(self.datum_type())(self, axis, start, end))
1505 }
1506
1507 #[inline]
1508 pub fn view(&self) -> view::TensorView<'_> {
1509 unsafe { view::TensorView::view(self) }
1510 }
1511
1512 #[inline]
1513 pub fn view_at_prefix(&self, prefix: &[usize]) -> TractResult<view::TensorView<'_>> {
1514 view::TensorView::at_prefix(self, prefix)
1515 }
1516
1517 #[inline]
1518 pub fn view_offsetting(&self, coords: &[usize]) -> TractResult<view::TensorView<'_>> {
1519 view::TensorView::offsetting(self, coords)
1520 }
1521
1522 #[inline]
1523 pub unsafe fn view_offsetting_unchecked(&self, coords: &[usize]) -> view::TensorView<'_> {
1524 unsafe { view::TensorView::offsetting_unchecked(self, coords) }
1525 }
1526
1527 #[inline]
1528 pub fn view_mut(&mut self) -> view::TensorView<'_> {
1529 unsafe { view::TensorView::view(self) }
1530 }
1531
1532 #[inline]
1533 pub fn view_at_prefix_mut(&mut self, prefix: &[usize]) -> TractResult<view::TensorView<'_>> {
1534 view::TensorView::at_prefix(self, prefix)
1535 }
1536
1537 #[inline]
1538 pub fn view_offsetting_mut(&mut self, coords: &[usize]) -> TractResult<view::TensorView<'_>> {
1539 view::TensorView::offsetting(self, coords)
1540 }
1541
1542 pub fn offset_u8_as_i8(self: &Arc<Self>) -> Arc<Self> {
1544 let mut t = if let DatumType::U8 = self.dt.unquantized() {
1545 self.to_array_view::<u8>().unwrap().mapv(|v| v.wrapping_sub(128) as i8).into_tensor()
1546 } else {
1547 return self.clone();
1548 };
1549
1550 if let DatumType::QU8(qp) = self.dt {
1551 if let QParams::ZpScale { zero_point, scale } = qp {
1552 t.dt = DatumType::QI8(QParams::ZpScale { zero_point: zero_point - 128, scale });
1553 } else {
1554 t.dt = DatumType::QI8(qp);
1555 }
1556 }
1557
1558 t.into_arc_tensor()
1559 }
1560
1561 pub fn offset_i8_as_u8(self: &Arc<Self>) -> Arc<Self> {
1563 let mut t = if let DatumType::I8 = self.dt.unquantized() {
1564 self.to_array_view::<i8>().unwrap().mapv(|v| (v as u8).wrapping_add(128)).into_tensor()
1565 } else {
1566 return self.clone();
1567 };
1568
1569 if let DatumType::QI8(qp) = self.dt {
1570 if let QParams::ZpScale { zero_point, scale } = qp {
1571 t.dt = DatumType::QU8(QParams::ZpScale { zero_point: zero_point + 128, scale });
1572 } else {
1573 t.dt = DatumType::QU8(qp);
1574 }
1575 }
1576 t.into_arc_tensor()
1577 }
1578
1579 pub fn to_aligned_default(&self) -> TractResult<Self> {
1580 if self.dt.is_copy() {
1581 unsafe {
1582 let mut t = Self::uninitialized_dt(self.dt, &self.shape)?;
1583 t.as_bytes_mut().copy_from_slice(self.as_bytes());
1584 Ok(t)
1585 }
1586 } else {
1587 let mut t = Self::zero_dt(self.dt, &self.shape)?;
1588 if self.dt == String::datum_type() {
1589 t.as_slice_mut::<String>()?.clone_from_slice(self.as_slice()?);
1590 } else if self.dt == Blob::datum_type() {
1591 t.as_slice_mut::<Blob>()?.clone_from_slice(self.as_slice()?);
1592 } else if self.dt == TDim::datum_type() {
1593 t.as_slice_mut::<TDim>()?.clone_from_slice(self.as_slice()?);
1594 }
1595 Ok(t)
1596 }
1597 }
1598
1599 pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1600 let mut strides = tvec!();
1601 compute_natural_stride_to(&mut strides, shape);
1602 strides
1603 }
1604
1605 pub fn into_blob(mut self) -> TractResult<Blob> {
1606 ensure!(self.dt.is_copy());
1607 Ok(std::mem::take(&mut self.data))
1608 }
1609}
1610
1611impl PartialEq for Tensor {
1612 fn eq(&self, other: &Tensor) -> bool {
1613 if self.dt != other.dt || self.shape != other.shape {
1614 return false;
1615 }
1616 self.eq_dt(other).unwrap_or(false)
1617 }
1618}
1619
1620impl fmt::Debug for Tensor {
1621 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1622 let content = self.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
1623 write!(formatter, "{content}")
1624 }
1625}
1626
1627#[cfg(feature = "complex")]
1628pub fn reinterpret_inner_dim_as_complex(mut t: Tensor) -> TractResult<Tensor> {
1629 ensure!(
1630 t.shape().last() == Some(&2),
1631 "The last dimension in the tensor shape {:?} must be 2",
1632 t.shape()
1633 );
1634 unsafe {
1635 t.shape.pop();
1636 t.set_datum_type(t.datum_type().complexify()?);
1637 t.update_strides_and_len();
1638 Ok(t)
1639 }
1640}
1641
1642#[cfg(feature = "complex")]
1643pub fn reinterpret_complex_as_inner_dim(mut t: Tensor) -> TractResult<Tensor> {
1644 unsafe {
1645 t.shape.push(2);
1646 t.set_datum_type(t.datum_type().decomplexify()?);
1647 t.update_strides_and_len();
1648 Ok(t)
1649 }
1650}
1651
1652pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1653 let mut strides = tvec!();
1654 compute_natural_stride_to(&mut strides, shape);
1655 strides
1656}
1657
1658fn compute_natural_stride_to(strides: &mut TVec<isize>, shape: &[usize]) {
1659 match shape.len() {
1660 0 => (),
1661 1 => strides.push(1),
1662 2 => strides.extend_from_slice(&[shape[1] as isize, 1]),
1663 3 => strides.extend_from_slice(&[(shape[1] * shape[2]) as isize, shape[2] as _, 1]),
1664 4 => strides.extend_from_slice(&[
1665 (shape[1] * shape[2] * shape[3]) as isize,
1666 (shape[2] * shape[3]) as _,
1667 shape[3] as _,
1668 1,
1669 ]),
1670 _ => {
1671 strides.push(1);
1672 for dim in shape.as_ref().iter().skip(1).rev() {
1673 let previous = *strides.last().unwrap();
1674 strides.push(previous * *dim as isize)
1675 }
1676 strides.reverse();
1677 }
1678 }
1679}
1680
1681impl<D: ::ndarray::Dimension, T: Datum> From<Array<T, D>> for Tensor {
1682 fn from(it: Array<T, D>) -> Tensor {
1683 Tensor::from_datum(it.into_dyn())
1684 }
1685}
1686
1687pub trait IntoTensor: Sized {
1689 fn into_tensor(self) -> Tensor;
1693}
1694
1695pub trait IntoArcTensor: Sized {
1697 fn into_arc_tensor(self) -> Arc<Tensor>;
1701}
1702
1703impl<D: ::ndarray::Dimension, T: Datum> IntoTensor for Array<T, D> {
1704 fn into_tensor(self) -> Tensor {
1705 Tensor::from(self)
1706 }
1707}
1708
1709impl<D: ::ndarray::Dimension, T: Datum> IntoArcTensor for Array<T, D> {
1710 fn into_arc_tensor(self) -> Arc<Tensor> {
1711 Arc::new(Tensor::from(self))
1712 }
1713}
1714
1715impl IntoTensor for Tensor {
1716 fn into_tensor(self) -> Tensor {
1717 self
1718 }
1719}
1720
1721impl IntoTensor for Arc<Tensor> {
1722 fn into_tensor(self) -> Tensor {
1723 Arc::try_unwrap(self).unwrap_or_else(|t| (*t).clone())
1724 }
1725}
1726
1727impl IntoArcTensor for Tensor {
1728 fn into_arc_tensor(self) -> Arc<Tensor> {
1729 Arc::new(self)
1730 }
1731}
1732
1733impl IntoArcTensor for Arc<Tensor> {
1734 fn into_arc_tensor(self) -> Arc<Tensor> {
1735 self
1736 }
1737}
1738
1739#[cfg(test)]
1740mod tests {
1741 use crate::dim::SymbolScope;
1742 use crate::prelude::tensor1;
1743
1744 use super::*;
1745 use litteral::tensor0;
1746 use proptest::collection::vec;
1747 use proptest::prelude::*;
1748
1749 #[derive(Debug)]
1750 struct PermuteAxisProblem {
1751 shape: Vec<usize>,
1752 permutation: Vec<usize>,
1753 }
1754
1755 impl Arbitrary for PermuteAxisProblem {
1756 type Strategy = BoxedStrategy<PermuteAxisProblem>;
1757 type Parameters = ();
1758
1759 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1760 (0..8usize)
1761 .prop_flat_map(|rank| {
1762 let permute: Vec<usize> = (0..rank).collect();
1763 (proptest::collection::vec(1..5usize, rank), Just(permute).prop_shuffle())
1764 })
1765 .prop_map(|(shape, permutation)| PermuteAxisProblem { shape, permutation })
1766 .boxed()
1767 }
1768 }
1769
1770 impl PermuteAxisProblem {
1771 fn input(&self) -> ArrayD<i32> {
1772 let mut i = 0;
1773 ArrayD::from_shape_simple_fn(&*self.shape, || {
1774 i += 1;
1775 i
1776 })
1777 .permuted_axes(&*self.permutation)
1778 }
1779
1780 fn reference(&self) -> Tensor {
1781 let values: Vec<i32> = self.input().iter().copied().collect();
1782 let shape = self.permutation.iter().map(|ix| self.shape[*ix]).collect::<TVec<usize>>();
1783 super::litteral::tensor1(&values).into_shape(&shape).unwrap()
1784 }
1785
1786 fn tract(&self) -> Tensor {
1787 Tensor::from(self.input())
1788 }
1789
1790 fn check(&self) -> proptest::test_runner::TestCaseResult {
1791 prop_assert_eq!(self.tract(), self.reference());
1792 Ok(())
1793 }
1794 }
1795
1796 proptest::proptest! {
1797 #[test]
1798 fn prop(pb: PermuteAxisProblem) {
1799 pb.check().unwrap();
1800 }
1801 }
1802
1803 #[test]
1804 fn t_1_2() {
1805 PermuteAxisProblem { shape: vec![2, 1], permutation: vec![1, 0] }.check().unwrap();
1806 }
1807
1808 #[test]
1809 fn t_2_2() {
1810 PermuteAxisProblem { shape: vec![2, 2], permutation: vec![1, 0] }.check().unwrap();
1811 }
1812
1813 #[derive(Debug)]
1814 struct BroadcastVecToShape {
1815 vec: Vec<f32>,
1816 axis: usize,
1817 shape: TVec<usize>,
1818 }
1819
1820 impl BroadcastVecToShape {
1821 fn check(&self) -> proptest::test_runner::TestCaseResult {
1822 let input = tensor1(&self.vec);
1823 let mut intermediate = tvec![1usize; self.shape.len()];
1824 intermediate[self.axis] = self.vec.len();
1825 let reference = input
1826 .clone()
1827 .into_shape(&intermediate)
1828 .unwrap()
1829 .broadcast_to_shape(&self.shape)
1830 .unwrap();
1831 prop_assert_eq!(
1832 reference,
1833 input.broadcast_vector_to_shape(&self.shape, self.axis).unwrap()
1834 );
1835 Ok(())
1836 }
1837 }
1838
1839 impl Arbitrary for BroadcastVecToShape {
1840 type Strategy = BoxedStrategy<BroadcastVecToShape>;
1841 type Parameters = ();
1842
1843 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1844 vec(0usize..5, 0usize..4)
1845 .prop_flat_map(|shape| {
1846 (vec(-10f32..10f32, 0usize..5), Just(shape.clone()), 0..shape.len() + 1)
1847 })
1848 .prop_map(|(vec, mut shape, axis)| {
1849 shape.insert(axis, vec.len());
1850 BroadcastVecToShape { vec, shape: shape.into(), axis }
1851 })
1852 .boxed()
1853 }
1854 }
1855
1856 proptest::proptest! {
1857 #[test]
1858 fn broadcast_vector_to_shape_prop(pb: BroadcastVecToShape) {
1859 pb.check().unwrap()
1860 }
1861 }
1862
1863 #[test]
1864 #[cfg(feature = "complex")]
1865 fn test_reinterpret_inner_dim_as_complex() -> TractResult<()> {
1866 let input = crate::internal::tensor2(&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]]);
1867 let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1868 let expected = crate::internal::tensor1(&[
1869 Complex::new(1.0f32, 2.0),
1870 Complex::new(3.0, 4.0),
1871 Complex::new(5.0, 6.0),
1872 ]);
1873 assert_eq!(expected, cplx_input);
1874 Ok(())
1875 }
1876
1877 #[test]
1878 #[cfg(feature = "complex")]
1879 fn test_reinterpret_inner_dim_as_complex_2() -> TractResult<()> {
1880 let input =
1881 crate::internal::tensor3(&[[[1i32, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]);
1882 let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1883 let expected = crate::internal::tensor2(&[
1884 [Complex::new(1i32, 2), Complex::new(1, 2)],
1885 [Complex::new(3, 4), Complex::new(3, 4)],
1886 [Complex::new(5, 6), Complex::new(5, 6)],
1887 ]);
1888 assert_eq!(expected, cplx_input);
1889 Ok(())
1890 }
1891
1892 #[test]
1893 fn clone_tdim_tensor() {
1894 let symbols = SymbolScope::default();
1895 let a = symbols.sym("a");
1896 let t = tensor0(TDim::from(a));
1897 let _ = t.clone();
1898 }
1899}