1use crate::TVec;
3use crate::blob::Blob;
4use crate::datum::{ClampCast, Datum, DatumType, QParams, round_ties_to_even, scale_by};
5use crate::dim::TDim;
6use crate::internal::*;
7use half::f16;
8use itertools::{Itertools, izip};
9use ndarray::prelude::*;
10#[cfg(feature = "complex")]
11use num_complex::Complex;
12use num_traits::{Float, Zero};
13use std::borrow::Cow;
14use std::fmt;
15use std::hash::Hash;
16use std::ops::Range;
17use std::sync::Arc;
18
19pub mod litteral;
20pub mod plain_view;
21pub mod storage;
22pub mod view;
23
24pub use plain_view::{PlainView, PlainViewMut};
25use storage::{PlainStorage, StorageKind, TensorStorage};
26
27#[derive(Copy, Clone, Default, Debug)]
28pub enum Approximation {
29 Exact,
30 #[default]
31 Close,
32 Approximate,
33 VeryApproximate,
34 SuperApproximate,
35 UltraApproximate,
36 Custom(f32, f32, f32),
37}
38
39impl PartialEq for Approximation {
40 fn eq(&self, other: &Self) -> bool {
41 use Approximation::Custom;
42 if let (Custom(aa, ar, ao), Custom(ba, br, bo)) = (self, other) {
43 aa == ba && ar == br && bo == ao
44 } else {
45 std::mem::discriminant(self) == std::mem::discriminant(other)
46 }
47 }
48}
49
50impl Eq for Approximation {}
51
52impl From<bool> for Approximation {
53 fn from(b: bool) -> Self {
54 if b { Self::Approximate } else { Self::Exact }
55 }
56}
57
58impl Approximation {
59 fn atol_rtol_outliers(&self, dt: &DatumType) -> (f64, f64, f64) {
60 use Approximation::*;
61 match (self, dt) {
62 (Exact, _) => (0.0, 0.0, 0.0),
63 (Close, DatumType::F16) => (1e-3, 1e-3, 0.0),
64 (Approximate, DatumType::F16) => (1e-3, 5e-3, 0.0),
65 (Approximate, qp) if qp.is_quantized() => (qp.zp_scale().1 as f64, 0., 0.0),
66 (Close, _) => (1e-7, 1e-7, 0.0),
67 (Approximate, _) => (1e-4, 5e-4, 0.0),
68 (VeryApproximate, _) => (5e-2, 1e-2, 0.0),
69 (SuperApproximate, _) => (0.1, 0.05, 0.0001),
70 (UltraApproximate, _) => (0.2, 0.1, 0.0005),
71 (Custom(atol, rtol, out), _) => (*atol as _, *rtol as _, *out as _),
72 }
73 }
74}
75
76pub struct Tensor {
78 dt: DatumType,
79 shape: TVec<usize>,
80 strides: TVec<isize>,
81 len: usize,
82 storage: StorageKind,
83}
84
85unsafe impl Send for Tensor {}
86unsafe impl Sync for Tensor {}
87
88impl Hash for Tensor {
89 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
90 use DatumType::*;
91 self.dt.hash(state);
92 self.shape.hash(state);
93 if let Some(plain) = self.storage.as_plain() {
94 plain.layout().align().hash(state);
95 unsafe {
96 match self.dt {
97 Bool => self.as_slice_unchecked::<bool>().hash(state),
98 I8 => self.as_slice_unchecked::<i8>().hash(state),
99 I16 => self.as_slice_unchecked::<i16>().hash(state),
100 I32 => self.as_slice_unchecked::<i32>().hash(state),
101 I64 => self.as_slice_unchecked::<i64>().hash(state),
102 U8 => self.as_slice_unchecked::<u8>().hash(state),
103 U16 => self.as_slice_unchecked::<u16>().hash(state),
104 U32 => self.as_slice_unchecked::<u32>().hash(state),
105 U64 => self.as_slice_unchecked::<u64>().hash(state),
106 F16 => self.as_slice_unchecked::<i16>().hash(state),
107 F32 => self.as_slice_unchecked::<i32>().hash(state),
108 F64 => self.as_slice_unchecked::<i64>().hash(state),
109 TDim => self.as_slice_unchecked::<crate::dim::TDim>().hash(state),
110 String => self.as_slice_unchecked::<std::string::String>().hash(state),
111 Blob => self.as_slice_unchecked::<crate::blob::Blob>().hash(state),
112 QI8(_) => self.as_slice_unchecked::<i8>().hash(state),
113 QU8(_) => self.as_slice_unchecked::<u8>().hash(state),
114 QI32(_) => self.as_slice_unchecked::<i32>().hash(state),
115 #[cfg(feature = "complex")]
116 ComplexI16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
117 #[cfg(feature = "complex")]
118 ComplexI32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
119 #[cfg(feature = "complex")]
120 ComplexI64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
121 #[cfg(feature = "complex")]
122 ComplexF16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
123 #[cfg(feature = "complex")]
124 ComplexF32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
125 #[cfg(feature = "complex")]
126 ComplexF64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
127 }
128 }
129 } else {
130 self.storage.dyn_hash(state);
131 }
132 }
133}
134
135impl Clone for Tensor {
136 fn clone(&self) -> Tensor {
137 self.deep_clone()
138 }
139}
140
141impl Default for Tensor {
142 fn default() -> Tensor {
143 litteral::tensor0(0f32)
144 }
145}
146
147impl Drop for Tensor {
148 fn drop(&mut self) {
149 if self.is_plain() {
150 macro_rules! drop_in_place {
151 ($t: ty) => {
152 if self.dt == <$t>::datum_type() {
153 unsafe {
154 let slice = self.as_slice_mut_unchecked::<$t>();
155 std::ptr::drop_in_place(slice as *mut [$t]);
156 }
157 }
158 };
159 }
160 drop_in_place!(Blob);
161 drop_in_place!(String);
162 drop_in_place!(TDim);
163 }
164 }
166}
167
168#[allow(unreachable_code)]
169pub fn vector_size() -> usize {
170 #[cfg(target_arch = "x86_64")]
171 {
172 return if is_x86_feature_detected!("avx512f") { 512 / 8 } else { 256 / 8 };
173 }
174 128 / 8
175}
176
177impl Tensor {
178 #[inline]
179 fn plain_storage(&self) -> &PlainStorage {
180 self.storage.as_plain().expect("Non-plain storage")
181 }
182
183 #[inline]
184 fn plain_storage_mut(&mut self) -> &mut PlainStorage {
185 self.storage.as_plain_mut().expect("Non-plain storage")
186 }
187
188 pub fn storage_as<T: TensorStorage>(&self) -> Option<&T> {
189 self.storage.as_storage().downcast_ref::<T>()
190 }
191
192 pub fn try_storage_as<T: TensorStorage>(&self) -> TractResult<&T> {
193 self.storage_as::<T>().context("Unexpected tensor storage type")
194 }
195
196 pub fn from_storage(
197 dt: DatumType,
198 shape: &[usize],
199 storage: impl TensorStorage + 'static,
200 ) -> Tensor {
201 let len = shape.iter().product::<usize>();
202 let strides = Self::natural_strides(shape);
203 Tensor {
204 dt,
205 shape: shape.into(),
206 strides,
207 len,
208 storage: StorageKind::Exotic(Box::new(storage)),
209 }
210 }
211
212 #[inline]
214 pub fn as_plain(&self) -> Option<PlainView<'_>> {
215 let storage = self.storage.as_plain()?;
216 Some(PlainView::new(self, storage))
217 }
218
219 #[inline]
221 pub fn try_as_plain(&self) -> TractResult<PlainView<'_>> {
222 self.as_plain().context("Tensor storage is not plain")
223 }
224
225 #[inline]
227 pub fn is_plain(&self) -> bool {
228 self.storage.as_plain().is_some()
229 }
230
231 #[inline]
233 pub fn is_exotic(&self) -> bool {
234 !self.is_plain()
235 }
236
237 pub fn exotic_fact(&self) -> TractResult<Option<Box<dyn crate::exotic::ExoticFact>>> {
239 self.storage.as_storage().exotic_fact(&self.shape)
240 }
241
242 #[inline]
244 pub fn as_plain_mut(&mut self) -> Option<PlainViewMut<'_>> {
245 let storage = self.storage.as_plain_mut()?;
246 Some(PlainViewMut::new(self.dt, &self.shape, &self.strides, self.len, storage))
247 }
248
249 #[inline]
251 pub fn try_as_plain_mut(&mut self) -> TractResult<PlainViewMut<'_>> {
252 self.as_plain_mut().context("Tensor storage is not plain")
253 }
254
255 #[inline]
257 pub unsafe fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<Tensor> {
258 unsafe { Self::uninitialized_dt(T::datum_type(), shape) }
259 }
260
261 #[inline]
263 pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
264 unsafe { Self::uninitialized_aligned_dt(dt, shape, vector_size()) }
265 }
266
267 #[inline]
269 pub unsafe fn uninitialized_aligned<T: Datum>(
270 shape: &[usize],
271 alignment: usize,
272 ) -> TractResult<Tensor> {
273 unsafe { Self::uninitialized_aligned_dt(T::datum_type(), shape, alignment) }
274 }
275
276 pub unsafe fn uninitialized_aligned_dt(
278 dt: DatumType,
279 shape: &[usize],
280 alignment: usize,
281 ) -> TractResult<Tensor> {
282 let bytes = shape.iter().cloned().product::<usize>() * dt.size_of();
283 let storage = StorageKind::Plain(PlainStorage::from(unsafe {
284 Blob::new_for_size_and_align(bytes, alignment)
285 }));
286 let mut tensor = Tensor { strides: tvec!(), dt, shape: shape.into(), storage, len: 0 };
287 if tensor.shape.len() == 0 {
288 tensor.len = 1;
289 } else {
290 tensor.update_strides_and_len();
291 }
292 if !tensor.storage.is_empty() {
293 if dt == String::datum_type() || dt == Blob::datum_type() {
294 tensor.plain_storage_mut().as_bytes_mut().fill(0);
296 } else if dt == TDim::datum_type() {
297 unsafe {
298 tensor
299 .as_slice_mut_unchecked::<TDim>()
300 .iter_mut()
301 .for_each(|dim| std::ptr::write(dim, TDim::zero()))
302 }
303 } else if cfg!(debug_assertions) {
304 assert!(dt.is_copy());
305 if dt == DatumType::F32 {
306 tensor.fill_t(f32::NAN).unwrap();
307 } else {
308 tensor.as_bytes_mut().iter_mut().for_each(|x| *x = (-1i8) as u8);
310 }
311 }
312 }
313 Ok(tensor)
314 }
315
316 pub fn stack_tensors(
317 axis: usize,
318 tensors: &[impl std::borrow::Borrow<Tensor>],
319 ) -> TractResult<Tensor> {
320 ensure!(tensors.len() > 0);
321 let rank = tensors[0].borrow().rank();
322 ensure!(axis < rank);
323 ensure!(tensors.iter().all(|t| t.borrow().rank() == rank));
324 let dt = tensors[0].borrow().datum_type();
325 ensure!(tensors.iter().all(|t| t.borrow().datum_type() == dt));
326 let mut shape: TVec<usize> = tensors[0].borrow().shape().into();
327 for ax in 0..rank {
328 if ax != axis {
329 ensure!(tensors.iter().all(|t| t.borrow().shape()[ax] == shape[ax]));
330 }
331 }
332 shape[axis] = tensors.iter().map(|v| v.borrow().shape()[axis]).sum();
333 unsafe {
334 let mut result = Tensor::uninitialized_dt(dt, &shape)?;
335 if dt.is_copy() && shape[..axis].iter().all(|d| *d == 1) {
336 let mut offset = 0isize;
337 for v in tensors {
338 let v = v.borrow();
339 let len = v.storage.byte_len();
340 std::ptr::copy_nonoverlapping(
341 v.plain_storage().as_ptr(),
342 result.plain_storage_mut().as_mut_ptr().offset(offset),
343 len,
344 );
345 offset += len as isize;
346 }
347 } else {
348 let mut offset = 0;
349 for t in tensors {
350 let t = t.borrow();
351 let len = t.shape()[axis];
352 result.assign_slice_from_resolved(offset..offset + len, t, 0..len, axis);
353 offset += len;
354 }
355 }
356
357 Ok(result)
358 }
359 }
360
361 pub fn clear<T: Datum + num_traits::Zero + Clone>(&mut self) -> TractResult<()> {
362 self.fill_t(T::zero())
363 }
364
365 pub fn zero<T: Datum + num_traits::Zero>(shape: &[usize]) -> TractResult<Tensor> {
366 unsafe {
367 let mut t = Tensor::uninitialized::<T>(shape)?;
368 t.clear::<T>()?;
369 Ok(t)
370 }
371 }
372
373 pub fn zero_scalar<T: Datum + num_traits::Zero>() -> TractResult<Tensor> {
374 Tensor::zero::<T>(&[])
375 }
376
377 pub fn zero_scalar_dt(dt: DatumType) -> TractResult<Tensor> {
378 Tensor::zero_dt(dt, &[])
379 }
380
381 pub fn zero_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
382 Tensor::zero_aligned_dt(dt, shape, vector_size())
383 }
384
385 pub fn fill_t<T: Datum + Clone>(&mut self, value: T) -> TractResult<()> {
386 self.try_as_plain_mut()?
387 .as_slice_mut::<T>()?
388 .iter_mut()
389 .for_each(|item| *item = value.clone());
390 Ok(())
391 }
392
393 pub fn zero_aligned_dt(
394 dt: DatumType,
395 shape: &[usize],
396 alignment: usize,
397 ) -> TractResult<Tensor> {
398 if shape.iter().product::<usize>() == 0 {
399 unsafe { return Tensor::uninitialized_dt(dt, shape) };
400 }
401 if dt.is_quantized() {
402 unsafe {
403 let mut t = Tensor::uninitialized_dt(dt, shape)?;
404 let zp = dt.zp_scale().0;
405 match dt.unquantized() {
406 DatumType::I8 => t
407 .try_as_plain_mut()?
408 .as_slice_mut::<i8>()?
409 .iter_mut()
410 .for_each(|item| *item = zp as _),
411 DatumType::U8 => t
412 .try_as_plain_mut()?
413 .as_slice_mut::<u8>()?
414 .iter_mut()
415 .for_each(|item| *item = zp as _),
416 DatumType::I32 => t
417 .try_as_plain_mut()?
418 .as_slice_mut::<i32>()?
419 .iter_mut()
420 .for_each(|item| *item = zp as _),
421 _ => unreachable!(),
422 }
423 Ok(t)
424 }
425 } else if dt == DatumType::Bool {
426 let mut t = unsafe { Tensor::uninitialized_dt(dt, shape)? };
427 t.fill_t::<bool>(false)?;
428 Ok(t)
429 } else {
430 dispatch_zerolike!(Self::zero_aligned(dt)(shape, alignment))
431 }
432 }
433
434 pub fn zero_aligned<T: Datum + num_traits::Zero>(
435 shape: &[usize],
436 alignment: usize,
437 ) -> TractResult<Tensor> {
438 unsafe {
439 let mut tensor = Self::uninitialized_aligned::<T>(shape, alignment)?;
440 tensor.clear::<T>()?;
441 Ok(tensor)
442 }
443 }
444
445 pub fn from_shape<T: Datum + Copy>(shape: &[usize], data: &[T]) -> TractResult<Tensor> {
448 Self::from_shape_align(shape, data, vector_size())
449 }
450
451 pub fn from_shape_align<T: Datum + Copy>(
454 shape: &[usize],
455 data: &[T],
456 align: usize,
457 ) -> TractResult<Tensor> {
458 ensure!(
459 data.len() == shape.iter().product::<usize>(),
460 "Shape product must be equal to data length"
461 );
462 unsafe {
463 let bytes = std::slice::from_raw_parts(
464 data.as_ptr() as *const u8,
465 data.len() * T::datum_type().size_of(),
466 );
467 let dt = T::datum_type();
468 Self::from_raw_dt_align(dt, shape, bytes, align)
469 }
470 }
471
472 pub unsafe fn from_raw<T: Datum>(shape: &[usize], content: &[u8]) -> TractResult<Tensor> {
476 unsafe { Tensor::from_raw_dt(T::datum_type(), shape, content) }
477 }
478
479 pub unsafe fn from_raw_aligned<T: Datum>(
480 shape: &[usize],
481 content: &[u8],
482 align: usize,
483 ) -> TractResult<Tensor> {
484 unsafe { Tensor::from_raw_dt_align(T::datum_type(), shape, content, align) }
485 }
486
487 pub unsafe fn from_raw_dt(
488 dt: DatumType,
489 shape: &[usize],
490 content: &[u8],
491 ) -> TractResult<Tensor> {
492 unsafe { Self::from_raw_dt_align(dt, shape, content, vector_size()) }
493 }
494
495 pub unsafe fn from_raw_dt_align(
496 dt: DatumType,
497 shape: &[usize],
498 content: &[u8],
499 align: usize,
500 ) -> TractResult<Tensor> {
501 let mut tensor = unsafe { Tensor::uninitialized_aligned_dt(dt, shape, align) }?;
502 tensor.as_bytes_mut().copy_from_slice(content);
503 Ok(tensor)
504 }
505
506 pub unsafe fn from_slice_align<T: Datum>(content: &[T], align: usize) -> TractResult<Tensor> {
507 let bytes = if content.len() == 0 {
508 &[]
509 } else {
510 unsafe {
511 std::slice::from_raw_parts(
512 content.as_ptr() as *const u8,
513 content.len() * T::datum_type().size_of(),
514 )
515 }
516 };
517 unsafe { Self::from_raw_dt_align(T::datum_type(), &[content.len()], bytes, align) }
518 }
519
520 #[inline]
522 pub fn rank(&self) -> usize {
523 self.shape.len()
524 }
525
526 #[inline]
528 pub fn shape(&self) -> &[usize] {
529 &self.shape
530 }
531
532 #[inline]
534 #[allow(clippy::len_without_is_empty)]
535 pub fn len(&self) -> usize {
536 self.len
537 }
538
539 #[inline]
541 #[allow(clippy::len_without_is_empty)]
542 pub fn volume(&self) -> usize {
543 self.len
544 }
545
546 #[inline]
548 pub fn strides(&self) -> &[isize] {
549 &self.strides
550 }
551
552 fn update_strides_and_len(&mut self) {
553 self.strides.clear();
554 if self.shape.len() == 0 {
555 self.len = 1;
556 return;
557 }
558 compute_natural_stride_to(&mut self.strides, &self.shape);
559 self.len = unsafe { *self.strides.get_unchecked(0) as usize * self.shape.get_unchecked(0) };
560 }
561
562 pub unsafe fn set_shape_unchecked(&mut self, shape: &[usize]) {
564 if shape != &*self.shape {
565 self.shape.clear();
566 self.shape.extend_from_slice(shape);
567 self.update_strides_and_len();
568 }
569 }
570
571 pub unsafe fn set_geometry_unchecked(&mut self, shape: &[usize], strides: &[isize]) {
573 self.shape.clear();
574 self.shape.extend_from_slice(shape);
575 self.strides.clear();
576 self.strides.extend_from_slice(strides);
577 }
578
579 pub fn set_shape(&mut self, shape: &[usize]) -> TractResult<()> {
581 if self.len() != shape.iter().product::<usize>() {
582 bail!("Invalid reshape {:?} to {:?}", self.shape, shape);
583 }
584 unsafe { self.set_shape_unchecked(shape) }
585 Ok(())
586 }
587
588 pub fn permute_axes(self, axes: &[usize]) -> TractResult<Tensor> {
589 ensure!(axes.iter().duplicates().next().is_none());
590 ensure!(axes.iter().all(|a| *a < self.rank()));
591 unsafe {
592 #[inline]
593 unsafe fn permute<T: Datum>(axes: &[usize], input: Tensor) -> Tensor {
594 unsafe { input.into_array_unchecked::<T>().permuted_axes(axes).into_tensor() }
595 }
596 let dt = self.datum_type();
597 let mut t = dispatch_datum_by_size!(permute(self.datum_type())(axes, self));
598 t.set_datum_type(dt);
599 Ok(t)
600 }
601 }
602
603 pub fn move_axis(self, from: usize, to: usize) -> TractResult<Tensor> {
604 let mut permutation: Vec<usize> = (0..self.rank()).collect();
605 permutation.remove(from);
606 permutation.insert(to, from);
607 self.permute_axes(&permutation)
608 }
609
610 pub fn collapse_axis_with_next(mut self, axis: usize) -> Tensor {
611 let removed = self.shape.remove(axis + 1);
612 self.shape[axis] *= removed;
613 self.update_strides_and_len();
614 self
615 }
616
617 pub fn split_axis(mut self, axis: usize, outer_dim: usize) -> TractResult<Tensor> {
618 if self.shape[axis] % outer_dim != 0 {
619 bail!(
620 "Invalid axis split, shape is {:?}, axis split at {}, outer {}",
621 self.shape,
622 axis,
623 outer_dim
624 );
625 }
626 self.shape.insert(axis + 1, self.shape[axis] / outer_dim);
627 self.shape[axis] = outer_dim;
628 self.update_strides_and_len();
629 Ok(self)
630 }
631
632 pub fn into_shape(mut self, shape: &[usize]) -> TractResult<Tensor> {
634 self.set_shape(shape)?;
635 Ok(self)
636 }
637
638 pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
639 self.shape.insert(axis, 1);
640 self.strides.insert(axis, self.strides.get(axis).copied().unwrap_or(1));
641 Ok(())
642 }
643
644 pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
645 ensure!(self.shape[axis] == 1, "Remove a non-1 axis: axis {} in {:?}", axis, self);
646 self.shape.remove(axis);
647 self.strides.remove(axis);
648 Ok(())
649 }
650
651 pub fn broadcast_into_rank(mut self, rank: usize) -> TractResult<Tensor> {
652 self.broadcast_to_rank(rank)?;
653 self.update_strides_and_len();
654 Ok(self)
655 }
656
657 pub fn broadcast_to_rank(&mut self, rank: usize) -> TractResult<()> {
658 if rank < self.rank() {
659 bail!("Can only broadcast to higher rank")
660 }
661 while self.shape.len() < rank {
662 self.shape.insert(0, 1)
663 }
664 self.update_strides_and_len();
665 Ok(())
666 }
667
668 pub fn broadcast_scalar_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
669 if self.rank() > 0 {
670 bail!("broadcast_scalar_to_shape called on {:?}, which is not a salar", self);
671 }
672 unsafe fn make<T: Datum>(src: &Tensor, dst: &mut Tensor) {
673 unsafe {
674 let value: &T = src.to_scalar_unchecked::<T>();
675 dst.as_slice_mut_unchecked::<T>().iter_mut().for_each(|item| *item = value.clone())
676 };
677 }
678 unsafe {
679 let mut t = Tensor::uninitialized_dt(self.datum_type(), shape)?;
680 dispatch_datum_by_size!(make(self.datum_type())(self, &mut t));
681 Ok(t)
682 }
683 }
684
685 fn broadcast_to_shape_t<T: Datum>(&self, shape: &[usize]) -> TractResult<Tensor> {
686 unsafe {
687 let view = self.to_array_view_unchecked::<T>();
688 let mut output = view
689 .broadcast(shape)
690 .with_context(|| format!("Broadcasting {view:?} to {shape:?}"))?
691 .into_owned()
692 .into_tensor();
693 output.set_datum_type(self.datum_type());
694 Ok(output)
695 }
696 }
697
698 pub fn broadcast_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
699 dispatch_datum!(Self::broadcast_to_shape_t(self.dt)(self, shape))
700 }
701
702 pub fn broadcast_vector_to_shape(&self, shape: &[usize], axis: usize) -> TractResult<Tensor> {
703 ensure!(self.rank() == 1);
704 ensure!(shape[axis] == self.len());
705 if !self.datum_type().is_copy() {
706 let mut vec_shape = vec![1; shape.len()];
707 vec_shape[axis] = self.len();
708 return self.clone().into_shape(&vec_shape)?.broadcast_to_shape(shape);
709 }
710 unsafe {
711 let mut output = Tensor::uninitialized_dt(self.datum_type(), shape)?;
712 if output.len() == 0 {
713 return Ok(output);
714 }
715 let inner_len = shape[axis + 1..].iter().product::<usize>();
716
717 unsafe fn splat<T>(input: &Tensor, output: &mut Tensor, inner_len: usize)
718 where
719 T: Datum + Copy,
720 {
721 unsafe {
722 for ix in 0..input.len() {
723 let value: T = input.as_slice_unchecked()[ix];
724 output.as_slice_mut_unchecked::<T>()[ix * inner_len..(ix + 1) * inner_len]
725 .iter_mut()
726 .for_each(|item| *item = value);
727 }
728 }
729 }
730 dispatch_copy_by_size!(splat(self.datum_type())(&self, &mut output, inner_len));
731
732 let outer_len = shape[0..axis].iter().product::<usize>();
733 let repeat_bytes_len = inner_len * self.as_bytes().len();
734 let bytes = output.as_bytes_mut();
735 for ix in 1..outer_len {
736 bytes.copy_within(0..repeat_bytes_len, ix * repeat_bytes_len);
737 }
738
739 Ok(output)
740 }
741 }
742
743 pub fn assign_slice(
744 &mut self,
745 range: impl std::ops::RangeBounds<usize>,
746 src: &Tensor,
747 src_range: impl std::ops::RangeBounds<usize>,
748 axis: usize,
749 ) -> TractResult<()> {
750 ensure!(self.rank() == src.rank());
751 ensure!(axis < self.rank());
752 let range = clip_range_bounds(self.shape[axis], range);
753 let src_range = clip_range_bounds(src.shape[axis], src_range);
754 ensure!(
755 src.datum_type() == self.datum_type(),
756 "Attempt to assign into {:?} from {:?}, datum type mismatch",
757 self.datum_type(),
758 src.datum_type()
759 );
760 ensure!(
761 src_range.len() == range.len(),
762 "Attempt to assign a range of {:?} from a range of {:?}",
763 range,
764 src_range,
765 );
766 ensure!(
767 itertools::izip!(0.., self.shape(), src.shape())
768 .all(|(ix, dst, src)| ix == axis || src == dst),
769 "Attempt to assign a {}-axis range of {:?} from a range of {:?}",
770 axis,
771 self,
772 src
773 );
774 ensure!(
775 src_range.end <= src.shape()[axis],
776 "Assigning from invalid slice (axis {}, {:?}) of {:?}",
777 axis,
778 src_range,
779 src
780 );
781 ensure!(
782 range.end <= self.shape()[axis],
783 "Assigning to invalid slice (axis {}, {:?}) of {:?}",
784 axis,
785 range,
786 self
787 );
788 unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
789 Ok(())
790 }
791
792 pub unsafe fn assign_slice_unchecked(
793 &mut self,
794 range: impl std::ops::RangeBounds<usize>,
795 src: &Tensor,
796 src_range: impl std::ops::RangeBounds<usize>,
797 axis: usize,
798 ) {
799 let range = clip_range_bounds(self.shape[axis], range);
800 let src_range = clip_range_bounds(src.shape[axis], src_range);
801 unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
802 }
803
804 #[allow(clippy::ptr_eq)]
805 unsafe fn assign_slice_from_resolved(
806 &mut self,
807 range: std::ops::Range<usize>,
808 src: &Tensor,
809 src_range: std::ops::Range<usize>,
810 axis: usize,
811 ) {
812 unsafe {
813 use ndarray::Slice;
814 unsafe fn assign_slice_t<T: Datum>(
815 to: &mut Tensor,
816 to_range: Range<usize>,
817 from: &Tensor,
818 from_range: Range<usize>,
819 axis: usize,
820 ) {
821 unsafe {
822 to.to_array_view_mut_unchecked::<T>()
823 .slice_axis_mut(Axis(axis), Slice::from(to_range))
824 .assign(
825 &from
826 .to_array_view_unchecked::<T>()
827 .slice_axis(Axis(axis), Slice::from(from_range)),
828 )
829 }
830 }
831 if self.datum_type().is_copy() && self.shape[..axis].iter().all(|d| *d == 1) {
832 let stride = self.strides[axis] as usize * self.datum_type().size_of();
833 let dst_start = (stride * range.start) as isize;
834 let src_start = (stride * src_range.start) as isize;
835 let len = stride * range.len();
836 if len > 0 {
837 if self.plain_storage().as_ptr() != src.plain_storage().as_ptr() {
838 std::ptr::copy_nonoverlapping(
839 src.plain_storage().as_ptr().offset(src_start),
840 self.plain_storage_mut().as_mut_ptr().offset(dst_start),
841 len,
842 );
843 } else {
844 std::ptr::copy(
845 src.plain_storage().as_ptr().offset(src_start),
846 self.plain_storage_mut().as_mut_ptr().offset(dst_start),
847 len,
848 );
849 }
850 }
851 } else {
852 dispatch_datum!(assign_slice_t(self.datum_type())(
853 self, range, src, src_range, axis
854 ));
855 }
856 }
857 }
858
859 #[inline]
861 pub fn datum_type(&self) -> DatumType {
862 self.dt
863 }
864
865 #[inline]
867 pub unsafe fn set_datum_type(&mut self, dt: DatumType) {
868 self.dt = dt
869 }
870
871 pub fn dump(&self, force_full: bool) -> TractResult<String> {
875 if self.is_exotic() {
876 return Ok(format!(
877 "{},{:?} (non-plain storage)",
878 self.shape.iter().join(","),
879 self.dt,
880 ));
881 }
882 unsafe fn dump_t<D: Datum>(tensor: &Tensor, n: usize) -> String {
883 unsafe {
884 if let Some(qp) = tensor.datum_type().qparams() {
885 let integers = tensor.cast_to::<i32>().unwrap();
886 integers.as_slice_unchecked::<i32>()[0..n]
887 .iter()
888 .map(|x| format!("[{}]({})", x, qp.dq(*x)))
889 .join(", ")
890 } else {
891 tensor.as_slice_unchecked::<D>()[0..n].iter().join(", ")
892 }
893 }
894 }
895 unsafe {
896 let trunc = self.len() > 12 && !force_full;
897 let data = dispatch_datum!(dump_t(self.datum_type())(
898 self,
899 if trunc { 12 } else { self.len() }
900 ));
901 Ok(format!(
902 "{},{:?} {}{}",
903 self.shape.iter().join(","),
904 self.dt,
905 data,
906 if trunc { "..." } else { "" }
907 ))
908 }
909 }
910
911 pub fn close_enough(
913 &self,
914 other: &Self,
915 approx: impl Into<Approximation> + std::fmt::Debug,
916 ) -> TractResult<()> {
917 let approx = approx.into();
918 if self.shape() != other.shape() {
919 bail!("Shape mismatch {:?} != {:?}", self.shape(), other.shape())
920 }
921 let (atol, rtol, outliers) = approx.atol_rtol_outliers(&self.datum_type());
922 let ma = self.cast_to::<f32>()?;
923 let ma = ma.to_plain_array_view::<f32>()?;
924 let mb = other.cast_to::<f32>()?;
925 let mb = mb.to_plain_array_view::<f32>()?;
926 let mut first_outlier = None;
927 let mut outliers_count = 0;
928 ndarray::indices_of(&ma).into_iter().for_each(|indices| {
929 let a = ma[&indices];
930 let b = mb[&indices];
931 if !((a.is_nan() && b.is_nan())
932 || (a.is_infinite() && b.is_infinite() && a.signum() == b.signum())
933 || (a - b).abs() <= atol as f32 + rtol as f32 * b.abs())
934 {
935 if outliers_count == 0 {
936 first_outlier = Some(indices.as_array_view().to_vec());
937 }
938 outliers_count += 1;
939 }
940 });
941 if self.volume() > 0 && outliers_count as f64 / self.volume() as f64 > outliers {
942 let indices = first_outlier.unwrap();
943 let a = ma[&*indices];
944 let b = mb[&*indices];
945 bail!(
946 "Mismatch. First outlier: {:?} for {:?}) at {:?} {} != {}. Outliers: {} / {} = {:0.5} > {:0.5}.",
947 approx,
948 self.datum_type(),
949 indices,
950 a,
951 b,
952 outliers_count,
953 self.volume(),
954 outliers_count as f64 / self.volume() as f64,
955 outliers
956 );
957 }
958 Ok(())
959 }
960
961 pub fn into_plain_array<D: Datum>(self) -> TractResult<ArrayD<D>> {
963 Ok(self.to_plain_array_view::<D>()?.to_owned())
964 }
965
966 pub unsafe fn into_array_unchecked<D: Datum>(self) -> ArrayD<D> {
968 unsafe { self.to_array_view_unchecked::<D>().to_owned() }
969 }
970
971 #[inline]
975 pub fn to_plain_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'_, D>> {
976 self.try_as_plain()?.to_array_view::<D>()
977 }
978
979 #[inline]
983 pub fn to_plain_array_view_mut<D: Datum>(&mut self) -> TractResult<ArrayViewMutD<'_, D>> {
984 self.check_for_access::<D>()?;
985 ensure!(self.storage.as_plain_mut().is_some(), "Tensor storage is not plain");
986 unsafe { Ok(self.to_array_view_mut_unchecked()) }
987 }
988
989 fn check_for_access<D: Datum>(&self) -> TractResult<()> {
990 ensure!(
991 self.datum_type().unquantized() == D::datum_type().unquantized(),
992 "Tensor datum type error: tensor is {:?}, accessed as {:?}",
993 self.datum_type(),
994 D::datum_type(),
995 );
996 Ok(())
997 }
998
999 pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'_, D> {
1001 if self.len() != 0 {
1002 unsafe {
1003 ArrayViewD::from_shape_ptr(&*self.shape, self.plain_storage().as_ptr() as *const D)
1004 }
1005 } else {
1006 ArrayViewD::from_shape(&*self.shape, &[]).unwrap()
1007 }
1008 }
1009
1010 pub unsafe fn to_array_view_mut_unchecked<D: Datum>(&mut self) -> ArrayViewMutD<'_, D> {
1012 if self.len() != 0 {
1013 unsafe {
1014 let ptr = self.plain_storage_mut().as_mut_ptr() as *mut D;
1015 ArrayViewMutD::from_shape_ptr(&*self.shape, ptr)
1016 }
1017 } else {
1018 ArrayViewMutD::from_shape(&*self.shape, &mut []).unwrap()
1019 }
1020 }
1021
1022 pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
1024 self.check_for_access::<D>()?;
1025 Ok(self.plain_storage().as_ptr() as *const D)
1026 }
1027
1028 pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
1030 self.plain_storage().as_ptr() as *const D
1031 }
1032
1033 pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
1035 self.plain_storage_mut().as_mut_ptr() as *mut D
1036 }
1037
1038 pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
1040 self.as_ptr::<D>().map(|p| p as *mut D)
1041 }
1042
1043 pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &[D] {
1045 if self.storage.byte_len() == 0 {
1046 &[]
1047 } else {
1048 unsafe { std::slice::from_raw_parts::<D>(self.as_ptr_unchecked(), self.len()) }
1049 }
1050 }
1051
1052 pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
1054 if self.storage.byte_len() == 0 {
1055 &mut []
1056 } else {
1057 unsafe { std::slice::from_raw_parts_mut::<D>(self.as_ptr_mut_unchecked(), self.len()) }
1058 }
1059 }
1060
1061 pub fn to_scalar_tensor(&self) -> TractResult<Tensor> {
1063 fn to_scalar_tensor_t<D: Datum>(t: &Tensor) -> TractResult<Tensor> {
1064 Ok(litteral::tensor0(t.try_as_plain()?.to_scalar::<D>()?.clone()))
1065 }
1066 dispatch_datum!(to_scalar_tensor_t(self.datum_type())(self))
1067 }
1068
1069 pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &D {
1071 unsafe { &*(self.plain_storage().as_ptr() as *const D) }
1072 }
1073
1074 pub fn to_scalar_mut<D: Datum>(&mut self) -> TractResult<&mut D> {
1076 self.check_for_access::<D>()?;
1077 if self.len() == 0 {
1078 bail!("to_scalar_mut called on empty tensor ({:?})", self)
1079 }
1080 if self.len() > 1 {
1081 bail!("to_scalar called on a tensor with multiple values ({:?})", self)
1082 }
1083 unsafe { Ok(self.to_scalar_mut_unchecked()) }
1084 }
1085
1086 pub unsafe fn to_scalar_mut_unchecked<D: Datum>(&mut self) -> &mut D {
1088 unsafe { &mut *(self.plain_storage_mut().as_mut_ptr() as *mut D) }
1089 }
1090
1091 pub fn as_bytes(&self) -> &[u8] {
1092 self.plain_storage().as_bytes()
1093 }
1094
1095 pub fn as_bytes_mut(&mut self) -> &mut [u8] {
1096 self.plain_storage_mut().as_bytes_mut()
1097 }
1098
1099 unsafe fn is_uniform_t<T: Datum>(&self) -> bool {
1100 let slice = unsafe { self.as_slice_unchecked::<T>() };
1101 slice[1..].iter().all(|x| x == &slice[0])
1102 }
1103
1104 pub fn is_uniform(&self) -> bool {
1105 if self.is_exotic() {
1106 return false;
1107 }
1108 if self.len() <= 1 {
1109 return true;
1110 }
1111 unsafe { dispatch_datum!(Tensor::is_uniform_t(self.datum_type())(self)) }
1112 }
1113
1114 unsafe fn as_uniform_t<T: Datum>(&self) -> Tensor {
1115 let v: T = unsafe { self.as_slice_unchecked::<T>() }[0].clone();
1116 litteral::tensor0(v)
1117 }
1118
1119 pub fn as_uniform(&self) -> Option<Tensor> {
1120 if self.len() >= 1 && self.is_uniform() {
1121 unsafe {
1122 let mut t = dispatch_datum!(Tensor::as_uniform_t(self.datum_type())(self));
1123 t.set_datum_type(self.datum_type());
1124 Some(t)
1125 }
1126 } else {
1127 None
1128 }
1129 }
1130
1131 pub fn is_all_zero(&self) -> TractResult<bool> {
1132 Ok(self.len() == 0 || self.as_uniform().map(|t| t.is_zero().unwrap()).unwrap_or(false))
1133 }
1134
1135 pub fn is_zero(&self) -> TractResult<bool> {
1136 Ok(self == &Tensor::zero_scalar_dt(self.dt)?)
1137 }
1138
1139 unsafe fn natural_cast<
1140 Source: Datum + num_traits::AsPrimitive<Target>,
1141 Target: Datum + Copy,
1142 >(
1143 &self,
1144 other: &mut Tensor,
1145 ) {
1146 unsafe {
1147 self.as_slice_unchecked::<Source>()
1148 .iter()
1149 .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1150 .for_each(|(s, d)| *d = s.as_())
1151 };
1152 }
1153
1154 unsafe fn cast_number_to_bool<Source: Datum + num_traits::Zero>(&self, other: &mut Tensor) {
1155 unsafe {
1156 self.as_slice_unchecked::<Source>()
1157 .iter()
1158 .zip(other.as_slice_mut_unchecked::<bool>().iter_mut())
1159 .for_each(|(s, d)| *d = !s.is_zero());
1160 }
1161 }
1162
1163 unsafe fn cast_from_string<Target: Datum + core::str::FromStr>(
1164 &self,
1165 other: &mut Tensor,
1166 ) -> TractResult<()> {
1167 unsafe {
1168 for (s, d) in self
1169 .as_slice_unchecked::<String>()
1170 .iter()
1171 .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1172 {
1173 *d = s
1174 .parse()
1175 .map_err(|_| format_err!("Can not parse as {:?}", Target::datum_type()))?;
1176 }
1177 Ok(())
1178 }
1179 }
1180
1181 unsafe fn cast_to_string<Source: Datum>(&self, other: &mut Tensor) {
1182 unsafe {
1183 for (s, d) in self
1184 .as_slice_unchecked::<Source>()
1185 .iter()
1186 .zip(other.as_slice_mut_unchecked::<String>().iter_mut())
1187 {
1188 *d = s.to_string()
1189 }
1190 }
1191 }
1192
1193 pub fn cast_to<D: Datum>(&self) -> TractResult<Cow<'_, Tensor>> {
1195 self.cast_to_dt(D::datum_type())
1196 }
1197
1198 #[allow(clippy::redundant_closure_call)]
1200 pub fn cast_to_dt(&self, dst_dt: DatumType) -> TractResult<Cow<'_, Tensor>> {
1201 unsafe {
1202 if self.dt == dst_dt {
1203 return Ok(Cow::Borrowed(self));
1204 }
1205 if self.dt == TDim::datum_type() && (dst_dt.is_integer() || dst_dt.is_float()) {
1206 let slice = self.as_slice_unchecked::<TDim>();
1207 let mut ints = Self::uninitialized::<i64>(&self.shape)?;
1208 let ints_slice = ints.as_slice_mut_unchecked::<i64>();
1209 for i in 0..self.len() {
1210 ints_slice[i] = slice[i].to_i64()?;
1211 }
1212 return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1213 }
1214 if self.dt == bool::datum_type()
1215 && (dst_dt.is_integer() || dst_dt.is_float() || dst_dt == TDim::datum_type())
1216 {
1217 let slice = self.as_slice_unchecked::<bool>();
1218 let mut ints = Self::uninitialized::<i8>(&self.shape)?;
1219 let ints_slice = ints.as_slice_mut_unchecked::<i8>();
1220 for i in 0..self.len() {
1221 ints_slice[i] = slice[i] as usize as i8;
1222 }
1223 return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1224 }
1225 let mut result = Self::uninitialized_dt(dst_dt, &self.shape)?;
1226 if self.dt == DatumType::String {
1227 dispatch_numbers!(Self::cast_from_string(dst_dt)(self, &mut result))?;
1228 return Ok(Cow::Owned(result));
1229 }
1230 if dst_dt == DatumType::String {
1231 dispatch_datum!(Self::cast_to_string(self.dt)(self, &mut result));
1232 return Ok(Cow::Owned(result));
1233 }
1234 macro_rules! n {
1235 ($source:ty) => {
1236 if <$source>::datum_type() == self.datum_type() {
1237 match dst_dt {
1238 DatumType::I8 => self.natural_cast::<$source, i8>(&mut result),
1239 DatumType::I16 => self.natural_cast::<$source, i16>(&mut result),
1240 DatumType::I32 => self.natural_cast::<$source, i32>(&mut result),
1241 DatumType::I64 => self.natural_cast::<$source, i64>(&mut result),
1242 DatumType::U8 => self.natural_cast::<$source, u8>(&mut result),
1243 DatumType::U16 => self.natural_cast::<$source, u16>(&mut result),
1244 DatumType::U32 => self.natural_cast::<$source, u32>(&mut result),
1245 DatumType::U64 => self.natural_cast::<$source, u64>(&mut result),
1246 DatumType::F16 => self.natural_cast::<$source, f16>(&mut result),
1247 DatumType::F32 => self.natural_cast::<$source, f32>(&mut result),
1248 DatumType::F64 => self.natural_cast::<$source, f64>(&mut result),
1249 DatumType::TDim => {
1250 let ints = self.cast_to::<i32>()?;
1251 let slice = ints.as_slice_unchecked::<i32>();
1252 let result = result.as_slice_mut_unchecked::<TDim>();
1253 for i in 0..self.len() {
1254 result[i] = slice[i].into();
1255 }
1256 }
1257 DatumType::Bool => self.cast_number_to_bool::<$source>(&mut result),
1258 _ => todo!(),
1259 }
1260 return Ok(Cow::Owned(result));
1261 };
1262 };
1263 }
1264 if !dst_dt.is_quantized() && !self.datum_type().is_quantized() {
1266 n!(u8);
1267 n!(u16);
1268 n!(u32);
1269 n!(u64);
1270 n!(i8);
1271 n!(i16);
1272 n!(i32);
1273 n!(i64);
1274 n!(f16);
1275 n!(f32);
1276 n!(f64);
1277 } else {
1278 let (s_zp, s_scale) = self.datum_type().zp_scale();
1279 let (d_zp, d_scale) = dst_dt.zp_scale();
1280 if self.datum_type().is_quantized() && dst_dt.is_float() {
1281 macro_rules! q_to_fp {
1282 ($source:ty, $dest:ty) => {
1283 if <$source>::datum_type().unquantized()
1284 == self.datum_type().unquantized()
1285 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1286 {
1287 self.as_slice_unchecked::<$source>()
1288 .iter()
1289 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1290 .for_each(|(&s, d)| {
1291 *d = (s as $dest - s_zp as $dest) * s_scale as $dest;
1292 });
1293 return Ok(Cow::Owned(result));
1294 }
1295 };
1296 }
1297 q_to_fp!(i8, f64);
1298 q_to_fp!(i8, f32);
1299 q_to_fp!(u8, f64);
1300 q_to_fp!(u8, f32);
1301 }
1302 macro_rules! q8_to_q8 {
1304 ($typ:ty) => {
1305 if dst_dt.unquantized() == <$typ>::datum_type() {
1306 self.as_slice_unchecked::<$typ>()
1307 .iter()
1308 .zip(result.as_slice_mut_unchecked::<$typ>().iter_mut())
1309 .for_each(|(&s, d)| {
1310 *d = (d_zp as i32
1311 + scale_by(s as i32 - s_zp as i32, s_scale / d_scale))
1312 .clamp_cast()
1313 });
1314 return Ok(Cow::Owned(result));
1315 }
1316 };
1317 }
1318
1319 macro_rules! q_via_f32 {
1320 ($source:ty, $dest:ty, $round:expr) => {
1321 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1322 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1323 {
1324 self.as_slice_unchecked::<$source>()
1325 .iter()
1326 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1327 .for_each(|(&s, d)| {
1328 let s_float = (s as f32 - s_zp as f32) * s_scale as f32;
1329 let d_float = s_float as f32 / d_scale as f32 + d_zp as f32;
1330 *d = $round(d_float);
1331 });
1332 return Ok(Cow::Owned(result));
1333 }
1334 };
1335 }
1336
1337 macro_rules! q_n {
1338 (clamp $source:ty, $dest:ty) => {{
1339 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1340 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1341 {
1342 self.as_slice_unchecked::<$source>()
1343 .iter()
1344 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1345 .for_each(|(&s, d)| {
1346 *d = s.clamp_cast();
1347 });
1348 return Ok(Cow::Owned(result));
1349 }
1350 }};
1351 ($source:ty, $dest:ty) => {{
1352 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1353 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1354 {
1355 self.as_slice_unchecked::<$source>()
1356 .iter()
1357 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1358 .for_each(|(&s, d)| {
1359 *d = s as $dest;
1360 });
1361 return Ok(Cow::Owned(result));
1362 }
1363 }};
1364 }
1365
1366 if dst_dt.unquantized() == self.datum_type().unquantized()
1367 && dst_dt.is_quantized()
1368 && self.datum_type().is_quantized()
1369 {
1370 q8_to_q8!(i8);
1371 q8_to_q8!(u8);
1372 }
1373
1374 q_via_f32!(f32, i8, |f| round_ties_to_even(f).clamp_cast());
1375 q_via_f32!(f32, u8, |f| round_ties_to_even(f).clamp_cast());
1376 q_via_f32!(f32, i32, |f| round_ties_to_even(f).clamp_cast());
1377 q_via_f32!(i8, f32, |f| f);
1378 q_via_f32!(u8, f32, |f| f);
1379 q_via_f32!(i32, f32, |f| f);
1380
1381 if dst_dt.is_quantized() && self.datum_type().is_quantized() {
1382 q_via_f32!(u8, i8, |f| round_ties_to_even(f).clamp_cast());
1383 q_via_f32!(i8, u8, |f| round_ties_to_even(f).clamp_cast());
1384 q_via_f32!(i32, u8, |f| round_ties_to_even(f).clamp_cast());
1385 q_via_f32!(i32, i8, |f| round_ties_to_even(f).clamp_cast());
1386 q_via_f32!(u8, i32, |f| round_ties_to_even(f).clamp_cast());
1387 q_via_f32!(i8, i32, |f| round_ties_to_even(f).clamp_cast());
1388
1389 q_via_f32!(i8, i8, |f| round_ties_to_even(f).clamp_cast());
1391 q_via_f32!(u8, u8, |f| round_ties_to_even(f).clamp_cast());
1392 }
1393
1394 q_n!(i8, i32);
1395 q_n!(i8, u32);
1396 q_n!(u8, i32);
1397 q_n!(u8, u32);
1398 q_n!(clamp i32, i8);
1399 q_n!(clamp i32, u8);
1400 q_n!(clamp u32, i8);
1401 q_n!(clamp u32, u8);
1402 q_n!(i8, i8);
1403 q_n!(u8, u8);
1404 q_n!(i32, i32);
1405 q_n!(u32, u32);
1406 }
1407
1408 bail!("Unsupported cast from {:?} to {:?}", self.dt, dst_dt)
1409 }
1410 }
1411
1412 pub fn cast_to_scalar<D: Datum + Copy>(&self) -> TractResult<D> {
1414 let casted = self.cast_to::<D>()?;
1415 casted.try_as_plain()?.to_scalar::<D>().copied()
1416 }
1417
1418 pub fn nth(&self, nth: usize) -> TractResult<Tensor> {
1420 if nth >= self.len() {
1421 bail!(
1422 "nth called with {}th element on a tensor of len {} ({:?}",
1423 nth,
1424 self.len(),
1425 self
1426 );
1427 }
1428 unsafe fn nth_t<T: Datum>(me: &Tensor, nth: usize, output: &mut Tensor) {
1429 unsafe {
1430 let value = me.as_slice_unchecked::<T>()[nth].clone();
1431 output.as_slice_mut_unchecked::<T>()[0] = value;
1432 }
1433 }
1434 unsafe {
1435 let mut output = Tensor::uninitialized_dt(self.datum_type(), &[])?;
1436 dispatch_datum_by_size!(nth_t(self.datum_type())(self, nth, &mut output));
1437 Ok(output)
1438 }
1439 }
1440
1441 fn eq_dt(&self, other: &Tensor) -> TractResult<bool> {
1443 unsafe fn eq_t<D: Datum>(me: &Tensor, other: &Tensor) -> TractResult<bool> {
1444 unsafe {
1445 if D::datum_type().is_float() {
1446 return dispatch_floatlike!(float_eq_t(D::datum_type())(me, other));
1447 }
1448 Ok(izip!(me.as_slice_unchecked::<D>(), other.as_slice_unchecked::<D>())
1449 .all(|(a, b)| a == b))
1450 }
1451 }
1452
1453 unsafe fn float_eq_t<D: Datum + Float>(me: &Tensor, other: &Tensor) -> TractResult<bool> {
1454 unsafe {
1455 Ok(izip!(me.as_slice_unchecked::<D>(), other.as_slice_unchecked::<D>())
1456 .all(|(a, b)| (a.is_nan() && b.is_nan()) || a == b))
1457 }
1458 }
1459
1460 unsafe {
1461 Ok(self.datum_type() == other.datum_type()
1462 && self.shape() == other.shape()
1463 && dispatch_datum!(eq_t(self.dt)(self, other))?)
1464 }
1465 }
1466
1467 fn from_datum<T: Datum>(mut it: ArrayD<T>) -> Tensor {
1468 unsafe {
1469 let mut t = Self::uninitialized::<T>(it.shape()).unwrap();
1470 if let Some(slice) = it.as_slice_mut() {
1471 if t.datum_type().is_copy() {
1472 std::ptr::copy_nonoverlapping(
1473 slice.as_ptr() as *const i8,
1474 t.as_ptr_mut_unchecked(),
1475 t.plain_storage().layout().size(),
1476 );
1477 } else {
1478 t.as_slice_mut_unchecked::<T>()
1479 .iter_mut()
1480 .zip(slice.iter_mut())
1481 .for_each(|(t, s)| *t = std::mem::take(s));
1482 }
1483 return t;
1484 }
1485 if it.strides().iter().all(|&s| s > 0) && it.as_slice_memory_order().is_some() {
1486 let mut len_and_strides: TVec<(usize, usize)> = tvec!();
1487 for (len, stride) in itertools::izip!(it.shape(), it.strides(), t.strides())
1488 .sorted_by_key(|(_, src, _)| *src)
1489 .map(|(l, _, dst)| (*l as isize, *dst))
1490 {
1491 if !len_and_strides.is_empty()
1492 && len_and_strides.last().unwrap().1 * len_and_strides.last().unwrap().0
1493 == stride as usize
1494 {
1495 len_and_strides.last_mut().unwrap().0 *= len as usize;
1496 } else {
1497 len_and_strides.push((len as usize, stride as usize));
1498 }
1499 }
1500 len_and_strides.reverse();
1501 crate::scatter::scatter_contig_data(
1502 it.as_ptr(),
1503 t.as_ptr_mut_unchecked(),
1504 &len_and_strides,
1505 );
1506 return t;
1507 }
1508 t.as_slice_mut_unchecked().iter_mut().zip(it).for_each(|(t, a)| *t = a);
1510 t
1511 }
1512 }
1513
1514 pub fn deep_clone(&self) -> Tensor {
1515 if self.is_exotic() {
1516 return Tensor {
1517 dt: self.dt,
1518 shape: self.shape.clone(),
1519 strides: self.strides.clone(),
1520 len: self.len,
1521 storage: self.storage.deep_clone(),
1522 };
1523 }
1524 unsafe {
1525 let mut tensor = Tensor::uninitialized_dt(self.datum_type(), self.shape()).unwrap();
1526 if self.len() > 0 {
1527 if self.dt.is_copy() {
1528 self.plain_storage().as_ptr().copy_to_nonoverlapping(
1529 tensor.as_bytes_mut().as_mut_ptr(),
1530 self.plain_storage().layout().size(),
1531 )
1532 } else if self.dt == DatumType::String {
1533 tensor
1534 .as_slice_mut_unchecked::<String>()
1535 .clone_from_slice(self.as_slice_unchecked());
1536 } else if self.dt == DatumType::Blob {
1537 tensor
1538 .as_slice_mut_unchecked::<Blob>()
1539 .clone_from_slice(self.as_slice_unchecked());
1540 } else if self.dt == DatumType::TDim {
1541 tensor
1542 .as_slice_mut_unchecked::<TDim>()
1543 .clone_from_slice(self.as_slice_unchecked());
1544 }
1545 }
1546 tensor
1547 }
1548 }
1549
1550 pub fn slice(&self, axis: usize, start: usize, end: usize) -> TractResult<Tensor> {
1551 if axis >= self.rank() {
1552 bail!("Can not slice at axis {} tensor {:?}", axis, self);
1553 }
1554 if start > self.shape[axis] || end > self.shape[axis] || start >= end {
1555 bail!("Invalid slicing range {start}..{end} on axis {axis} for {self:?}");
1556 }
1557 fn slice_t<T: Datum>(
1558 t: &Tensor,
1559 axis: usize,
1560 start: usize,
1561 end: usize,
1562 ) -> TractResult<Tensor> {
1563 Ok(t.to_plain_array_view::<T>()?
1564 .slice_axis(ndarray::Axis(axis), (start..end).into())
1565 .into_owned()
1566 .into_tensor())
1567 }
1568 dispatch_datum!(slice_t(self.datum_type())(self, axis, start, end))
1569 }
1570
1571 #[inline]
1572 pub fn view(&self) -> view::TensorView<'_> {
1573 unsafe { view::TensorView::view(self) }
1574 }
1575
1576 #[inline]
1577 pub fn view_at_prefix(&self, prefix: &[usize]) -> TractResult<view::TensorView<'_>> {
1578 view::TensorView::at_prefix(self, prefix)
1579 }
1580
1581 #[inline]
1582 pub fn view_offsetting(&self, coords: &[usize]) -> TractResult<view::TensorView<'_>> {
1583 view::TensorView::offsetting(self, coords)
1584 }
1585
1586 #[inline]
1587 pub unsafe fn view_offsetting_unchecked(&self, coords: &[usize]) -> view::TensorView<'_> {
1588 unsafe { view::TensorView::offsetting_unchecked(self, coords) }
1589 }
1590
1591 #[inline]
1592 pub fn view_mut(&mut self) -> view::TensorView<'_> {
1593 unsafe { view::TensorView::view(self) }
1594 }
1595
1596 #[inline]
1597 pub fn view_at_prefix_mut(&mut self, prefix: &[usize]) -> TractResult<view::TensorView<'_>> {
1598 view::TensorView::at_prefix(self, prefix)
1599 }
1600
1601 #[inline]
1602 pub fn view_offsetting_mut(&mut self, coords: &[usize]) -> TractResult<view::TensorView<'_>> {
1603 view::TensorView::offsetting(self, coords)
1604 }
1605
1606 pub fn offset_u8_as_i8(self: &Arc<Self>) -> Arc<Self> {
1608 let mut t = if let DatumType::U8 = self.dt.unquantized() {
1609 self.try_as_plain()
1610 .unwrap()
1611 .to_array_view::<u8>()
1612 .unwrap()
1613 .mapv(|v| v.wrapping_sub(128) as i8)
1614 .into_tensor()
1615 } else {
1616 return self.clone();
1617 };
1618
1619 if let DatumType::QU8(qp) = self.dt {
1620 if let QParams::ZpScale { zero_point, scale } = qp {
1621 t.dt = DatumType::QI8(QParams::ZpScale { zero_point: zero_point - 128, scale });
1622 } else {
1623 t.dt = DatumType::QI8(qp);
1624 }
1625 }
1626
1627 t.into_arc_tensor()
1628 }
1629
1630 pub fn offset_i8_as_u8(self: &Arc<Self>) -> Arc<Self> {
1632 let mut t = if let DatumType::I8 = self.dt.unquantized() {
1633 self.try_as_plain()
1634 .unwrap()
1635 .to_array_view::<i8>()
1636 .unwrap()
1637 .mapv(|v| (v as u8).wrapping_add(128))
1638 .into_tensor()
1639 } else {
1640 return self.clone();
1641 };
1642
1643 if let DatumType::QI8(qp) = self.dt {
1644 if let QParams::ZpScale { zero_point, scale } = qp {
1645 t.dt = DatumType::QU8(QParams::ZpScale { zero_point: zero_point + 128, scale });
1646 } else {
1647 t.dt = DatumType::QU8(qp);
1648 }
1649 }
1650 t.into_arc_tensor()
1651 }
1652
1653 pub fn to_aligned_default(&self) -> TractResult<Self> {
1654 if self.dt.is_copy() {
1655 unsafe {
1656 let mut t = Self::uninitialized_dt(self.dt, &self.shape)?;
1657 t.as_bytes_mut().copy_from_slice(self.as_bytes());
1658 Ok(t)
1659 }
1660 } else {
1661 let mut t = Self::zero_dt(self.dt, &self.shape)?;
1662 if self.dt == String::datum_type() {
1663 t.try_as_plain_mut()?
1664 .as_slice_mut::<String>()?
1665 .clone_from_slice(self.try_as_plain()?.as_slice()?);
1666 } else if self.dt == Blob::datum_type() {
1667 t.try_as_plain_mut()?
1668 .as_slice_mut::<Blob>()?
1669 .clone_from_slice(self.try_as_plain()?.as_slice()?);
1670 } else if self.dt == TDim::datum_type() {
1671 t.try_as_plain_mut()?
1672 .as_slice_mut::<TDim>()?
1673 .clone_from_slice(self.try_as_plain()?.as_slice()?);
1674 }
1675 Ok(t)
1676 }
1677 }
1678
1679 pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1680 let mut strides = tvec!();
1681 compute_natural_stride_to(&mut strides, shape);
1682 strides
1683 }
1684
1685 pub fn into_blob(mut self) -> TractResult<Blob> {
1686 ensure!(self.dt.is_copy());
1687 let storage =
1688 std::mem::replace(&mut self.storage, StorageKind::Plain(PlainStorage::default()));
1689 Ok(storage.into_plain().context("Storage is not plain")?.into_blob())
1690 }
1691}
1692
1693impl PartialEq for Tensor {
1694 fn eq(&self, other: &Tensor) -> bool {
1695 if self.dt != other.dt || self.shape != other.shape {
1696 return false;
1697 }
1698 match (self.storage.as_plain(), other.storage.as_plain()) {
1699 (Some(_), Some(_)) => self.eq_dt(other).unwrap_or(false),
1700 (None, None) => self.storage == other.storage,
1701 _ => false,
1702 }
1703 }
1704}
1705
1706impl Eq for Tensor {}
1707
1708impl fmt::Debug for Tensor {
1709 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1710 let content = self.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
1711 write!(formatter, "{content}")
1712 }
1713}
1714
1715#[cfg(feature = "complex")]
1716pub fn reinterpret_inner_dim_as_complex(mut t: Tensor) -> TractResult<Tensor> {
1717 ensure!(
1718 t.shape().last() == Some(&2),
1719 "The last dimension in the tensor shape {:?} must be 2",
1720 t.shape()
1721 );
1722 unsafe {
1723 t.shape.pop();
1724 t.set_datum_type(t.datum_type().complexify()?);
1725 t.update_strides_and_len();
1726 Ok(t)
1727 }
1728}
1729
1730#[cfg(feature = "complex")]
1731pub fn reinterpret_complex_as_inner_dim(mut t: Tensor) -> TractResult<Tensor> {
1732 unsafe {
1733 t.shape.push(2);
1734 t.set_datum_type(t.datum_type().decomplexify()?);
1735 t.update_strides_and_len();
1736 Ok(t)
1737 }
1738}
1739
1740pub fn clip_range_bounds(len: usize, range: impl std::ops::RangeBounds<usize>) -> Range<usize> {
1741 use std::ops::Bound;
1742 let start = match range.start_bound() {
1743 Bound::Included(ix) => *ix,
1744 Bound::Excluded(ix) => ix + 1,
1745 Bound::Unbounded => 0,
1746 };
1747 let end = match range.end_bound() {
1748 Bound::Included(ix) => *ix + 1,
1749 Bound::Excluded(ix) => *ix,
1750 Bound::Unbounded => len,
1751 };
1752 start..end
1753}
1754
1755pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1756 let mut strides = tvec!();
1757 compute_natural_stride_to(&mut strides, shape);
1758 strides
1759}
1760
1761fn compute_natural_stride_to(strides: &mut TVec<isize>, shape: &[usize]) {
1762 match shape.len() {
1763 0 => (),
1764 1 => strides.push(1),
1765 2 => strides.extend_from_slice(&[shape[1] as isize, 1]),
1766 3 => strides.extend_from_slice(&[(shape[1] * shape[2]) as isize, shape[2] as _, 1]),
1767 4 => strides.extend_from_slice(&[
1768 (shape[1] * shape[2] * shape[3]) as isize,
1769 (shape[2] * shape[3]) as _,
1770 shape[3] as _,
1771 1,
1772 ]),
1773 _ => {
1774 strides.push(1);
1775 for dim in shape.as_ref().iter().skip(1).rev() {
1776 let previous = *strides.last().unwrap();
1777 strides.push(previous * *dim as isize)
1778 }
1779 strides.reverse();
1780 }
1781 }
1782}
1783
1784impl<D: ::ndarray::Dimension, T: Datum> From<Array<T, D>> for Tensor {
1785 fn from(it: Array<T, D>) -> Tensor {
1786 Tensor::from_datum(it.into_dyn())
1787 }
1788}
1789
1790pub trait IntoTensor: Sized {
1792 fn into_tensor(self) -> Tensor;
1796}
1797
1798pub trait IntoArcTensor: Sized {
1800 fn into_arc_tensor(self) -> Arc<Tensor>;
1804}
1805
1806impl<D: ::ndarray::Dimension, T: Datum> IntoTensor for Array<T, D> {
1807 fn into_tensor(self) -> Tensor {
1808 Tensor::from(self)
1809 }
1810}
1811
1812impl<D: ::ndarray::Dimension, T: Datum> IntoArcTensor for Array<T, D> {
1813 fn into_arc_tensor(self) -> Arc<Tensor> {
1814 Arc::new(Tensor::from(self))
1815 }
1816}
1817
1818impl IntoTensor for Tensor {
1819 fn into_tensor(self) -> Tensor {
1820 self
1821 }
1822}
1823
1824impl IntoTensor for Arc<Tensor> {
1825 fn into_tensor(self) -> Tensor {
1826 Arc::try_unwrap(self).unwrap_or_else(|t| (*t).clone())
1827 }
1828}
1829
1830impl IntoArcTensor for Tensor {
1831 fn into_arc_tensor(self) -> Arc<Tensor> {
1832 Arc::new(self)
1833 }
1834}
1835
1836impl IntoArcTensor for Arc<Tensor> {
1837 fn into_arc_tensor(self) -> Arc<Tensor> {
1838 self
1839 }
1840}
1841
1842#[cfg(test)]
1843mod tests {
1844 use crate::dim::SymbolScope;
1845 use crate::prelude::tensor1;
1846
1847 use super::*;
1848 use litteral::tensor0;
1849 use proptest::collection::vec;
1850 use proptest::prelude::*;
1851
1852 #[derive(Debug)]
1853 struct PermuteAxisProblem {
1854 shape: Vec<usize>,
1855 permutation: Vec<usize>,
1856 }
1857
1858 impl Arbitrary for PermuteAxisProblem {
1859 type Strategy = BoxedStrategy<PermuteAxisProblem>;
1860 type Parameters = ();
1861
1862 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1863 (0..8usize)
1864 .prop_flat_map(|rank| {
1865 let permute: Vec<usize> = (0..rank).collect();
1866 (proptest::collection::vec(1..5usize, rank), Just(permute).prop_shuffle())
1867 })
1868 .prop_map(|(shape, permutation)| PermuteAxisProblem { shape, permutation })
1869 .boxed()
1870 }
1871 }
1872
1873 impl PermuteAxisProblem {
1874 fn input(&self) -> ArrayD<i32> {
1875 let mut i = 0;
1876 ArrayD::from_shape_simple_fn(&*self.shape, || {
1877 i += 1;
1878 i
1879 })
1880 .permuted_axes(&*self.permutation)
1881 }
1882
1883 fn reference(&self) -> Tensor {
1884 let values: Vec<i32> = self.input().iter().copied().collect();
1885 let shape = self.permutation.iter().map(|ix| self.shape[*ix]).collect::<TVec<usize>>();
1886 super::litteral::tensor1(&values).into_shape(&shape).unwrap()
1887 }
1888
1889 fn tract(&self) -> Tensor {
1890 Tensor::from(self.input())
1891 }
1892
1893 fn check(&self) -> proptest::test_runner::TestCaseResult {
1894 prop_assert_eq!(self.tract(), self.reference());
1895 Ok(())
1896 }
1897 }
1898
1899 proptest::proptest! {
1900 #[test]
1901 fn prop(pb: PermuteAxisProblem) {
1902 pb.check().unwrap();
1903 }
1904 }
1905
1906 #[test]
1907 fn t_1_2() {
1908 PermuteAxisProblem { shape: vec![2, 1], permutation: vec![1, 0] }.check().unwrap();
1909 }
1910
1911 #[test]
1912 fn t_2_2() {
1913 PermuteAxisProblem { shape: vec![2, 2], permutation: vec![1, 0] }.check().unwrap();
1914 }
1915
1916 #[derive(Debug)]
1917 struct BroadcastVecToShape {
1918 vec: Vec<f32>,
1919 axis: usize,
1920 shape: TVec<usize>,
1921 }
1922
1923 impl BroadcastVecToShape {
1924 fn check(&self) -> proptest::test_runner::TestCaseResult {
1925 let input = tensor1(&self.vec);
1926 let mut intermediate = tvec![1usize; self.shape.len()];
1927 intermediate[self.axis] = self.vec.len();
1928 let reference = input
1929 .clone()
1930 .into_shape(&intermediate)
1931 .unwrap()
1932 .broadcast_to_shape(&self.shape)
1933 .unwrap();
1934 prop_assert_eq!(
1935 reference,
1936 input.broadcast_vector_to_shape(&self.shape, self.axis).unwrap()
1937 );
1938 Ok(())
1939 }
1940 }
1941
1942 impl Arbitrary for BroadcastVecToShape {
1943 type Strategy = BoxedStrategy<BroadcastVecToShape>;
1944 type Parameters = ();
1945
1946 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1947 vec(0usize..5, 0usize..4)
1948 .prop_flat_map(|shape| {
1949 (vec(-10f32..10f32, 0usize..5), Just(shape.clone()), 0..shape.len() + 1)
1950 })
1951 .prop_map(|(vec, mut shape, axis)| {
1952 shape.insert(axis, vec.len());
1953 BroadcastVecToShape { vec, shape: shape.into(), axis }
1954 })
1955 .boxed()
1956 }
1957 }
1958
1959 proptest::proptest! {
1960 #[test]
1961 fn broadcast_vector_to_shape_prop(pb: BroadcastVecToShape) {
1962 pb.check().unwrap()
1963 }
1964 }
1965
1966 #[test]
1967 #[cfg(feature = "complex")]
1968 fn test_reinterpret_inner_dim_as_complex() -> TractResult<()> {
1969 let input = crate::internal::tensor2(&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]]);
1970 let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1971 let expected = crate::internal::tensor1(&[
1972 Complex::new(1.0f32, 2.0),
1973 Complex::new(3.0, 4.0),
1974 Complex::new(5.0, 6.0),
1975 ]);
1976 assert_eq!(expected, cplx_input);
1977 Ok(())
1978 }
1979
1980 #[test]
1981 #[cfg(feature = "complex")]
1982 fn test_reinterpret_inner_dim_as_complex_2() -> TractResult<()> {
1983 let input =
1984 crate::internal::tensor3(&[[[1i32, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]);
1985 let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1986 let expected = crate::internal::tensor2(&[
1987 [Complex::new(1i32, 2), Complex::new(1, 2)],
1988 [Complex::new(3, 4), Complex::new(3, 4)],
1989 [Complex::new(5, 6), Complex::new(5, 6)],
1990 ]);
1991 assert_eq!(expected, cplx_input);
1992 Ok(())
1993 }
1994
1995 #[test]
1996 fn clone_tdim_tensor() {
1997 let symbols = SymbolScope::default();
1998 let a = symbols.sym("a");
1999 let t = tensor0(TDim::from(a));
2000 let _ = t.clone();
2001 }
2002}