1use crate::TVec;
3use crate::blob::Blob;
4use crate::datum::{ClampCast, Datum, DatumType, QParams, round_ties_to_even, scale_by};
5use crate::dim::TDim;
6use crate::internal::*;
7use crate::opaque::Opaque;
8use half::f16;
9use itertools::{Itertools, izip};
10use ndarray::prelude::*;
11#[cfg(feature = "complex")]
12use num_complex::Complex;
13use num_traits::{Float, Zero};
14use std::borrow::Cow;
15use std::fmt;
16use std::hash::Hash;
17use std::ops::Range;
18use std::sync::Arc;
19
20pub mod dense_view;
21pub mod litteral;
22pub mod storage;
23pub mod view;
24
25pub use dense_view::{DenseView, DenseViewMut};
26use storage::{DenseStorage, StorageKind};
27
28#[derive(Copy, Clone, Default, Debug)]
29pub enum Approximation {
30 Exact,
31 #[default]
32 Close,
33 Approximate,
34 VeryApproximate,
35 SuperApproximate,
36 UltraApproximate,
37 Custom(f32, f32, f32),
38}
39
40impl PartialEq for Approximation {
41 fn eq(&self, other: &Self) -> bool {
42 use Approximation::Custom;
43 if let (Custom(aa, ar, ao), Custom(ba, br, bo)) = (self, other) {
44 aa == ba && ar == br && bo == ao
45 } else {
46 std::mem::discriminant(self) == std::mem::discriminant(other)
47 }
48 }
49}
50
51impl Eq for Approximation {}
52
53impl From<bool> for Approximation {
54 fn from(b: bool) -> Self {
55 if b { Self::Approximate } else { Self::Exact }
56 }
57}
58
59impl Approximation {
60 fn atol_rtol_outliers(&self, dt: &DatumType) -> (f64, f64, f64) {
61 use Approximation::*;
62 match (self, dt) {
63 (Exact, _) => (0.0, 0.0, 0.0),
64 (Close, DatumType::F16) => (1e-3, 1e-3, 0.0),
65 (Approximate, DatumType::F16) => (1e-3, 5e-3, 0.0),
66 (Approximate, qp) if qp.is_quantized() => (qp.zp_scale().1 as f64, 0., 0.0),
67 (Close, _) => (1e-7, 1e-7, 0.0),
68 (Approximate, _) => (1e-4, 5e-4, 0.0),
69 (VeryApproximate, _) => (5e-2, 1e-2, 0.0),
70 (SuperApproximate, _) => (0.1, 0.05, 0.0001),
71 (UltraApproximate, _) => (0.2, 0.1, 0.0005),
72 (Custom(atol, rtol, out), _) => (*atol as _, *rtol as _, *out as _),
73 }
74 }
75}
76
77pub struct Tensor {
79 dt: DatumType,
80 shape: TVec<usize>,
81 strides: TVec<isize>,
82 len: usize,
83 storage: StorageKind,
84}
85
86unsafe impl Send for Tensor {}
87unsafe impl Sync for Tensor {}
88
89impl Hash for Tensor {
90 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
91 use DatumType::*;
92 self.dt.hash(state);
93 self.shape.hash(state);
94 self.dense_storage().layout().align().hash(state);
95 unsafe {
96 match self.dt {
97 Bool => self.as_slice_unchecked::<bool>().hash(state),
98 I8 => self.as_slice_unchecked::<i8>().hash(state),
99 I16 => self.as_slice_unchecked::<i16>().hash(state),
100 I32 => self.as_slice_unchecked::<i32>().hash(state),
101 I64 => self.as_slice_unchecked::<i64>().hash(state),
102 U8 => self.as_slice_unchecked::<u8>().hash(state),
103 U16 => self.as_slice_unchecked::<u16>().hash(state),
104 U32 => self.as_slice_unchecked::<u32>().hash(state),
105 U64 => self.as_slice_unchecked::<u64>().hash(state),
106 F16 => self.as_slice_unchecked::<i16>().hash(state),
107 F32 => self.as_slice_unchecked::<i32>().hash(state),
108 F64 => self.as_slice_unchecked::<i64>().hash(state),
109 TDim => self.as_slice_unchecked::<crate::dim::TDim>().hash(state),
110 String => self.as_slice_unchecked::<std::string::String>().hash(state),
111 Blob => self.as_slice_unchecked::<crate::blob::Blob>().hash(state),
112 Opaque => self.as_slice_unchecked::<crate::opaque::Opaque>().hash(state),
113 QI8(_) => self.as_slice_unchecked::<i8>().hash(state),
114 QU8(_) => self.as_slice_unchecked::<u8>().hash(state),
115 QI32(_) => self.as_slice_unchecked::<i32>().hash(state),
116 #[cfg(feature = "complex")]
117 ComplexI16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
118 #[cfg(feature = "complex")]
119 ComplexI32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
120 #[cfg(feature = "complex")]
121 ComplexI64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
122 #[cfg(feature = "complex")]
123 ComplexF16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
124 #[cfg(feature = "complex")]
125 ComplexF32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
126 #[cfg(feature = "complex")]
127 ComplexF64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
128 }
129 }
130 }
131}
132
133impl Clone for Tensor {
134 fn clone(&self) -> Tensor {
135 self.deep_clone()
136 }
137}
138
139impl Default for Tensor {
140 fn default() -> Tensor {
141 litteral::tensor0(0f32)
142 }
143}
144
145impl Drop for Tensor {
146 fn drop(&mut self) {
147 macro_rules! drop_in_place {
148 ($t: ty) => {
149 if self.dt == <$t>::datum_type() {
150 unsafe {
151 let slice = self.as_slice_mut_unchecked::<$t>();
152 std::ptr::drop_in_place(slice as *mut [$t]);
153 }
154 }
155 };
156 }
157 drop_in_place!(Blob);
158 drop_in_place!(String);
159 drop_in_place!(TDim);
160 drop_in_place!(Opaque);
161 }
162}
163
164#[allow(unreachable_code)]
165pub fn vector_size() -> usize {
166 #[cfg(target_arch = "x86_64")]
167 {
168 return if is_x86_feature_detected!("avx512f") { 512 / 8 } else { 256 / 8 };
169 }
170 128 / 8
171}
172
173impl Tensor {
174 #[inline]
175 fn dense_storage(&self) -> &DenseStorage {
176 self.storage.as_dense().expect("Non-dense storage")
177 }
178
179 #[inline]
180 fn dense_storage_mut(&mut self) -> &mut DenseStorage {
181 self.storage.as_dense_mut().expect("Non-dense storage")
182 }
183
184 #[inline]
186 pub fn as_dense(&self) -> Option<DenseView<'_>> {
187 let storage = self.storage.as_dense()?;
188 Some(DenseView::new(self, storage))
189 }
190
191 #[inline]
193 pub fn try_as_dense(&self) -> TractResult<DenseView<'_>> {
194 self.as_dense().context("Tensor storage is not dense")
195 }
196
197 #[inline]
199 pub fn as_dense_mut(&mut self) -> Option<DenseViewMut<'_>> {
200 let storage = self.storage.as_dense_mut()?;
201 Some(DenseViewMut::new(self.dt, &self.shape, &self.strides, self.len, storage))
202 }
203
204 #[inline]
206 pub fn try_as_dense_mut(&mut self) -> TractResult<DenseViewMut<'_>> {
207 self.as_dense_mut().context("Tensor storage is not dense")
208 }
209
210 #[inline]
212 pub unsafe fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<Tensor> {
213 unsafe { Self::uninitialized_dt(T::datum_type(), shape) }
214 }
215
216 #[inline]
218 pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
219 unsafe { Self::uninitialized_aligned_dt(dt, shape, vector_size()) }
220 }
221
222 #[inline]
224 pub unsafe fn uninitialized_aligned<T: Datum>(
225 shape: &[usize],
226 alignment: usize,
227 ) -> TractResult<Tensor> {
228 unsafe { Self::uninitialized_aligned_dt(T::datum_type(), shape, alignment) }
229 }
230
231 pub unsafe fn uninitialized_aligned_dt(
233 dt: DatumType,
234 shape: &[usize],
235 alignment: usize,
236 ) -> TractResult<Tensor> {
237 let bytes = shape.iter().cloned().product::<usize>() * dt.size_of();
238 let storage = StorageKind::Dense(DenseStorage::from(unsafe {
239 Blob::new_for_size_and_align(bytes, alignment)
240 }));
241 let mut tensor = Tensor { strides: tvec!(), dt, shape: shape.into(), storage, len: 0 };
242 if tensor.shape.len() == 0 {
243 tensor.len = 1;
244 } else {
245 tensor.update_strides_and_len();
246 }
247 if !tensor.storage.is_empty() {
248 if dt == String::datum_type() || dt == Blob::datum_type() {
249 tensor.dense_storage_mut().as_bytes_mut().fill(0);
251 } else if dt == TDim::datum_type() {
252 unsafe {
253 tensor
254 .as_slice_mut_unchecked::<TDim>()
255 .iter_mut()
256 .for_each(|dim| std::ptr::write(dim, TDim::zero()))
257 }
258 } else if dt == Opaque::datum_type() {
259 unsafe {
260 tensor.as_slice_mut_unchecked::<Opaque>().iter_mut().for_each(|p| {
261 std::ptr::write(p, Opaque::default());
262 })
263 };
264 } else if cfg!(debug_assertions) {
265 assert!(dt.is_copy());
266 if dt == DatumType::F32 {
267 tensor.fill_t(f32::NAN).unwrap();
268 } else {
269 tensor.as_bytes_mut().iter_mut().for_each(|x| *x = (-1i8) as u8);
271 }
272 }
273 }
274 Ok(tensor)
275 }
276
277 pub fn stack_tensors(
278 axis: usize,
279 tensors: &[impl std::borrow::Borrow<Tensor>],
280 ) -> TractResult<Tensor> {
281 ensure!(tensors.len() > 0);
282 let rank = tensors[0].borrow().rank();
283 ensure!(axis < rank);
284 ensure!(tensors.iter().all(|t| t.borrow().rank() == rank));
285 let dt = tensors[0].borrow().datum_type();
286 ensure!(tensors.iter().all(|t| t.borrow().datum_type() == dt));
287 let mut shape: TVec<usize> = tensors[0].borrow().shape().into();
288 for ax in 0..rank {
289 if ax != axis {
290 ensure!(tensors.iter().all(|t| t.borrow().shape()[ax] == shape[ax]));
291 }
292 }
293 shape[axis] = tensors.iter().map(|v| v.borrow().shape()[axis]).sum();
294 unsafe {
295 let mut result = Tensor::uninitialized_dt(dt, &shape)?;
296 if dt.is_copy() && shape[..axis].iter().all(|d| *d == 1) {
297 let mut offset = 0isize;
298 for v in tensors {
299 let v = v.borrow();
300 let len = v.storage.byte_len();
301 std::ptr::copy_nonoverlapping(
302 v.dense_storage().as_ptr(),
303 result.dense_storage_mut().as_mut_ptr().offset(offset),
304 len,
305 );
306 offset += len as isize;
307 }
308 } else {
309 let mut offset = 0;
310 for t in tensors {
311 let t = t.borrow();
312 let len = t.shape()[axis];
313 result.assign_slice_from_resolved(offset..offset + len, t, 0..len, axis);
314 offset += len;
315 }
316 }
317
318 Ok(result)
319 }
320 }
321
322 pub fn clear<T: Datum + num_traits::Zero + Clone>(&mut self) -> TractResult<()> {
323 self.fill_t(T::zero())
324 }
325
326 pub fn zero<T: Datum + num_traits::Zero>(shape: &[usize]) -> TractResult<Tensor> {
327 unsafe {
328 let mut t = Tensor::uninitialized::<T>(shape)?;
329 t.clear::<T>()?;
330 Ok(t)
331 }
332 }
333
334 pub fn zero_scalar<T: Datum + num_traits::Zero>() -> TractResult<Tensor> {
335 Tensor::zero::<T>(&[])
336 }
337
338 pub fn zero_scalar_dt(dt: DatumType) -> TractResult<Tensor> {
339 Tensor::zero_dt(dt, &[])
340 }
341
342 pub fn zero_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
343 Tensor::zero_aligned_dt(dt, shape, vector_size())
344 }
345
346 pub fn fill_t<T: Datum + Clone>(&mut self, value: T) -> TractResult<()> {
347 self.try_as_dense_mut()?
348 .as_slice_mut::<T>()?
349 .iter_mut()
350 .for_each(|item| *item = value.clone());
351 Ok(())
352 }
353
354 pub fn zero_aligned_dt(
355 dt: DatumType,
356 shape: &[usize],
357 alignment: usize,
358 ) -> TractResult<Tensor> {
359 if shape.iter().product::<usize>() == 0 {
360 unsafe { return Tensor::uninitialized_dt(dt, shape) };
361 }
362 if dt.is_quantized() {
363 unsafe {
364 let mut t = Tensor::uninitialized_dt(dt, shape)?;
365 let zp = dt.zp_scale().0;
366 match dt.unquantized() {
367 DatumType::I8 => t
368 .try_as_dense_mut()?
369 .as_slice_mut::<i8>()?
370 .iter_mut()
371 .for_each(|item| *item = zp as _),
372 DatumType::U8 => t
373 .try_as_dense_mut()?
374 .as_slice_mut::<u8>()?
375 .iter_mut()
376 .for_each(|item| *item = zp as _),
377 DatumType::I32 => t
378 .try_as_dense_mut()?
379 .as_slice_mut::<i32>()?
380 .iter_mut()
381 .for_each(|item| *item = zp as _),
382 _ => unreachable!(),
383 }
384 Ok(t)
385 }
386 } else {
387 dispatch_zerolike!(Self::zero_aligned(dt)(shape, alignment))
388 }
389 }
390
391 pub fn zero_aligned<T: Datum + num_traits::Zero>(
392 shape: &[usize],
393 alignment: usize,
394 ) -> TractResult<Tensor> {
395 unsafe {
396 let mut tensor = Self::uninitialized_aligned::<T>(shape, alignment)?;
397 tensor.clear::<T>()?;
398 Ok(tensor)
399 }
400 }
401
402 pub fn from_shape<T: Datum + Copy>(shape: &[usize], data: &[T]) -> TractResult<Tensor> {
405 Self::from_shape_align(shape, data, vector_size())
406 }
407
408 pub fn from_shape_align<T: Datum + Copy>(
411 shape: &[usize],
412 data: &[T],
413 align: usize,
414 ) -> TractResult<Tensor> {
415 ensure!(
416 data.len() == shape.iter().product::<usize>(),
417 "Shape product must be equal to data length"
418 );
419 unsafe {
420 let bytes = std::slice::from_raw_parts(
421 data.as_ptr() as *const u8,
422 data.len() * T::datum_type().size_of(),
423 );
424 let dt = T::datum_type();
425 Self::from_raw_dt_align(dt, shape, bytes, align)
426 }
427 }
428
429 pub unsafe fn from_raw<T: Datum>(shape: &[usize], content: &[u8]) -> TractResult<Tensor> {
433 unsafe { Tensor::from_raw_dt(T::datum_type(), shape, content) }
434 }
435
436 pub unsafe fn from_raw_aligned<T: Datum>(
437 shape: &[usize],
438 content: &[u8],
439 align: usize,
440 ) -> TractResult<Tensor> {
441 unsafe { Tensor::from_raw_dt_align(T::datum_type(), shape, content, align) }
442 }
443
444 pub unsafe fn from_raw_dt(
445 dt: DatumType,
446 shape: &[usize],
447 content: &[u8],
448 ) -> TractResult<Tensor> {
449 unsafe { Self::from_raw_dt_align(dt, shape, content, vector_size()) }
450 }
451
452 pub unsafe fn from_raw_dt_align(
453 dt: DatumType,
454 shape: &[usize],
455 content: &[u8],
456 align: usize,
457 ) -> TractResult<Tensor> {
458 let mut tensor = unsafe { Tensor::uninitialized_aligned_dt(dt, shape, align) }?;
459 tensor.as_bytes_mut().copy_from_slice(content);
460 Ok(tensor)
461 }
462
463 pub unsafe fn from_slice_align<T: Datum>(content: &[T], align: usize) -> TractResult<Tensor> {
464 let bytes = if content.len() == 0 {
465 &[]
466 } else {
467 unsafe {
468 std::slice::from_raw_parts(
469 content.as_ptr() as *const u8,
470 content.len() * T::datum_type().size_of(),
471 )
472 }
473 };
474 unsafe { Self::from_raw_dt_align(T::datum_type(), &[content.len()], bytes, align) }
475 }
476
477 #[inline]
479 pub fn rank(&self) -> usize {
480 self.shape.len()
481 }
482
483 #[inline]
485 pub fn shape(&self) -> &[usize] {
486 &self.shape
487 }
488
489 #[inline]
491 #[allow(clippy::len_without_is_empty)]
492 pub fn len(&self) -> usize {
493 self.len
494 }
495
496 #[inline]
498 #[allow(clippy::len_without_is_empty)]
499 pub fn volume(&self) -> usize {
500 self.len
501 }
502
503 #[inline]
505 pub fn strides(&self) -> &[isize] {
506 &self.strides
507 }
508
509 fn update_strides_and_len(&mut self) {
510 self.strides.clear();
511 if self.shape.len() == 0 {
512 self.len = 1;
513 return;
514 }
515 compute_natural_stride_to(&mut self.strides, &self.shape);
516 self.len = unsafe { *self.strides.get_unchecked(0) as usize * self.shape.get_unchecked(0) };
517 }
518
519 pub unsafe fn set_shape_unchecked(&mut self, shape: &[usize]) {
521 if shape != &*self.shape {
522 self.shape.clear();
523 self.shape.extend_from_slice(shape);
524 self.update_strides_and_len();
525 }
526 }
527
528 pub unsafe fn set_geometry_unchecked(&mut self, shape: &[usize], strides: &[isize]) {
530 self.shape.clear();
531 self.shape.extend_from_slice(shape);
532 self.strides.clear();
533 self.strides.extend_from_slice(strides);
534 }
535
536 pub fn set_shape(&mut self, shape: &[usize]) -> TractResult<()> {
538 if self.len() != shape.iter().product::<usize>() {
539 bail!("Invalid reshape {:?} to {:?}", self.shape, shape);
540 }
541 unsafe { self.set_shape_unchecked(shape) }
542 Ok(())
543 }
544
545 pub fn permute_axes(self, axes: &[usize]) -> TractResult<Tensor> {
546 ensure!(axes.iter().duplicates().next().is_none());
547 ensure!(axes.iter().all(|a| *a < self.rank()));
548 unsafe {
549 #[inline]
550 unsafe fn permute<T: Datum>(axes: &[usize], input: Tensor) -> Tensor {
551 unsafe { input.into_array_unchecked::<T>().permuted_axes(axes).into_tensor() }
552 }
553 let dt = self.datum_type();
554 let mut t = dispatch_datum_by_size!(permute(self.datum_type())(axes, self));
555 t.set_datum_type(dt);
556 Ok(t)
557 }
558 }
559
560 pub fn move_axis(self, from: usize, to: usize) -> TractResult<Tensor> {
561 let mut permutation: Vec<usize> = (0..self.rank()).collect();
562 permutation.remove(from);
563 permutation.insert(to, from);
564 self.permute_axes(&permutation)
565 }
566
567 pub fn collapse_axis_with_next(mut self, axis: usize) -> Tensor {
568 let removed = self.shape.remove(axis + 1);
569 self.shape[axis] *= removed;
570 self.update_strides_and_len();
571 self
572 }
573
574 pub fn split_axis(mut self, axis: usize, outer_dim: usize) -> TractResult<Tensor> {
575 if self.shape[axis] % outer_dim != 0 {
576 bail!(
577 "Invalid axis split, shape is {:?}, axis split at {}, outer {}",
578 self.shape,
579 axis,
580 outer_dim
581 );
582 }
583 self.shape.insert(axis + 1, self.shape[axis] / outer_dim);
584 self.shape[axis] = outer_dim;
585 self.update_strides_and_len();
586 Ok(self)
587 }
588
589 pub fn into_shape(mut self, shape: &[usize]) -> TractResult<Tensor> {
591 self.set_shape(shape)?;
592 Ok(self)
593 }
594
595 pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
596 self.shape.insert(axis, 1);
597 self.strides.insert(axis, self.strides.get(axis).copied().unwrap_or(1));
598 Ok(())
599 }
600
601 pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
602 ensure!(self.shape[axis] == 1, "Remove a non-1 axis: axis {} in {:?}", axis, self);
603 self.shape.remove(axis);
604 self.strides.remove(axis);
605 Ok(())
606 }
607
608 pub fn broadcast_into_rank(mut self, rank: usize) -> TractResult<Tensor> {
609 self.broadcast_to_rank(rank)?;
610 self.update_strides_and_len();
611 Ok(self)
612 }
613
614 pub fn broadcast_to_rank(&mut self, rank: usize) -> TractResult<()> {
615 if rank < self.rank() {
616 bail!("Can only broadcast to higher rank")
617 }
618 while self.shape.len() < rank {
619 self.shape.insert(0, 1)
620 }
621 self.update_strides_and_len();
622 Ok(())
623 }
624
625 pub fn broadcast_scalar_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
626 if self.rank() > 0 {
627 bail!("broadcast_scalar_to_shape called on {:?}, which is not a salar", self);
628 }
629 unsafe fn make<T: Datum>(src: &Tensor, dst: &mut Tensor) {
630 unsafe {
631 let value: &T = src.to_scalar_unchecked::<T>();
632 dst.as_slice_mut_unchecked::<T>().iter_mut().for_each(|item| *item = value.clone())
633 };
634 }
635 unsafe {
636 let mut t = Tensor::uninitialized_dt(self.datum_type(), shape)?;
637 dispatch_datum_by_size!(make(self.datum_type())(self, &mut t));
638 Ok(t)
639 }
640 }
641
642 fn broadcast_to_shape_t<T: Datum>(&self, shape: &[usize]) -> TractResult<Tensor> {
643 unsafe {
644 let view = self.to_array_view_unchecked::<T>();
645 let mut output = view
646 .broadcast(shape)
647 .with_context(|| format!("Broadcasting {view:?} to {shape:?}"))?
648 .into_owned()
649 .into_tensor();
650 output.set_datum_type(self.datum_type());
651 Ok(output)
652 }
653 }
654
655 pub fn broadcast_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
656 dispatch_datum!(Self::broadcast_to_shape_t(self.dt)(self, shape))
657 }
658
659 pub fn broadcast_vector_to_shape(&self, shape: &[usize], axis: usize) -> TractResult<Tensor> {
660 ensure!(self.rank() == 1);
661 ensure!(shape[axis] == self.len());
662 if !self.datum_type().is_copy() {
663 let mut vec_shape = vec![1; shape.len()];
664 vec_shape[axis] = self.len();
665 return self.clone().into_shape(&vec_shape)?.broadcast_to_shape(shape);
666 }
667 unsafe {
668 let mut output = Tensor::uninitialized_dt(self.datum_type(), shape)?;
669 if output.len() == 0 {
670 return Ok(output);
671 }
672 let inner_len = shape[axis + 1..].iter().product::<usize>();
673
674 unsafe fn splat<T>(input: &Tensor, output: &mut Tensor, inner_len: usize)
675 where
676 T: Datum + Copy,
677 {
678 unsafe {
679 for ix in 0..input.len() {
680 let value: T = input.as_slice_unchecked()[ix];
681 output.as_slice_mut_unchecked::<T>()[ix * inner_len..(ix + 1) * inner_len]
682 .iter_mut()
683 .for_each(|item| *item = value);
684 }
685 }
686 }
687 dispatch_copy_by_size!(splat(self.datum_type())(&self, &mut output, inner_len));
688
689 let outer_len = shape[0..axis].iter().product::<usize>();
690 let repeat_bytes_len = inner_len * self.as_bytes().len();
691 let bytes = output.as_bytes_mut();
692 for ix in 1..outer_len {
693 bytes.copy_within(0..repeat_bytes_len, ix * repeat_bytes_len);
694 }
695
696 Ok(output)
697 }
698 }
699
700 pub fn assign_slice(
701 &mut self,
702 range: impl std::ops::RangeBounds<usize>,
703 src: &Tensor,
704 src_range: impl std::ops::RangeBounds<usize>,
705 axis: usize,
706 ) -> TractResult<()> {
707 ensure!(self.rank() == src.rank());
708 ensure!(axis < self.rank());
709 let range = clip_range_bounds(self.shape[axis], range);
710 let src_range = clip_range_bounds(src.shape[axis], src_range);
711 ensure!(
712 src.datum_type() == self.datum_type(),
713 "Attempt to assign into {:?} from {:?}, datum type mismatch",
714 self.datum_type(),
715 src.datum_type()
716 );
717 ensure!(
718 src_range.len() == range.len(),
719 "Attempt to assign a range of {:?} from a range of {:?}",
720 range,
721 src_range,
722 );
723 ensure!(
724 itertools::izip!(0.., self.shape(), src.shape())
725 .all(|(ix, dst, src)| ix == axis || src == dst),
726 "Attempt to assign a {}-axis range of {:?} from a range of {:?}",
727 axis,
728 self,
729 src
730 );
731 ensure!(
732 src_range.end <= src.shape()[axis],
733 "Assigning from invalid slice (axis {}, {:?}) of {:?}",
734 axis,
735 src_range,
736 src
737 );
738 ensure!(
739 range.end <= self.shape()[axis],
740 "Assigning to invalid slice (axis {}, {:?}) of {:?}",
741 axis,
742 range,
743 self
744 );
745 unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
746 Ok(())
747 }
748
749 pub unsafe fn assign_slice_unchecked(
750 &mut self,
751 range: impl std::ops::RangeBounds<usize>,
752 src: &Tensor,
753 src_range: impl std::ops::RangeBounds<usize>,
754 axis: usize,
755 ) {
756 let range = clip_range_bounds(self.shape[axis], range);
757 let src_range = clip_range_bounds(src.shape[axis], src_range);
758 unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
759 }
760
761 #[allow(clippy::ptr_eq)]
762 unsafe fn assign_slice_from_resolved(
763 &mut self,
764 range: std::ops::Range<usize>,
765 src: &Tensor,
766 src_range: std::ops::Range<usize>,
767 axis: usize,
768 ) {
769 unsafe {
770 use ndarray::Slice;
771 unsafe fn assign_slice_t<T: Datum>(
772 to: &mut Tensor,
773 to_range: Range<usize>,
774 from: &Tensor,
775 from_range: Range<usize>,
776 axis: usize,
777 ) {
778 unsafe {
779 to.to_array_view_mut_unchecked::<T>()
780 .slice_axis_mut(Axis(axis), Slice::from(to_range))
781 .assign(
782 &from
783 .to_array_view_unchecked::<T>()
784 .slice_axis(Axis(axis), Slice::from(from_range)),
785 )
786 }
787 }
788 if self.datum_type().is_copy() && self.shape[..axis].iter().all(|d| *d == 1) {
789 let stride = self.strides[axis] as usize * self.datum_type().size_of();
790 let dst_start = (stride * range.start) as isize;
791 let src_start = (stride * src_range.start) as isize;
792 let len = stride * range.len();
793 if len > 0 {
794 if self.dense_storage().as_ptr() != src.dense_storage().as_ptr() {
795 std::ptr::copy_nonoverlapping(
796 src.dense_storage().as_ptr().offset(src_start),
797 self.dense_storage_mut().as_mut_ptr().offset(dst_start),
798 len,
799 );
800 } else {
801 std::ptr::copy(
802 src.dense_storage().as_ptr().offset(src_start),
803 self.dense_storage_mut().as_mut_ptr().offset(dst_start),
804 len,
805 );
806 }
807 }
808 } else {
809 dispatch_datum!(assign_slice_t(self.datum_type())(
810 self, range, src, src_range, axis
811 ));
812 }
813 }
814 }
815
816 #[inline]
818 pub fn datum_type(&self) -> DatumType {
819 self.dt
820 }
821
822 #[inline]
824 pub unsafe fn set_datum_type(&mut self, dt: DatumType) {
825 self.dt = dt
826 }
827
828 pub fn dump(&self, force_full: bool) -> TractResult<String> {
832 unsafe fn dump_t<D: Datum>(tensor: &Tensor, n: usize) -> String {
833 unsafe {
834 if let Some(qp) = tensor.datum_type().qparams() {
835 let integers = tensor.cast_to::<i32>().unwrap();
836 integers.as_slice_unchecked::<i32>()[0..n]
837 .iter()
838 .map(|x| format!("[{}]({})", x, qp.dq(*x)))
839 .join(", ")
840 } else {
841 tensor.as_slice_unchecked::<D>()[0..n].iter().join(", ")
842 }
843 }
844 }
845 unsafe {
846 let trunc = self.len() > 12 && !force_full;
847 let data = dispatch_datum!(dump_t(self.datum_type())(
848 self,
849 if trunc { 12 } else { self.len() }
850 ));
851 Ok(format!(
852 "{},{:?} {}{}",
853 self.shape.iter().join(","),
854 self.dt,
855 data,
856 if trunc { "..." } else { "" }
857 ))
858 }
859 }
860
861 pub fn close_enough(
863 &self,
864 other: &Self,
865 approx: impl Into<Approximation> + std::fmt::Debug,
866 ) -> TractResult<()> {
867 let approx = approx.into();
868 if self.shape() != other.shape() {
869 bail!("Shape mismatch {:?} != {:?}", self.shape(), other.shape())
870 }
871 let (atol, rtol, outliers) = approx.atol_rtol_outliers(&self.datum_type());
872 let ma = self.cast_to::<f32>()?;
873 let ma = ma.to_dense_array_view::<f32>()?;
874 let mb = other.cast_to::<f32>()?;
875 let mb = mb.to_dense_array_view::<f32>()?;
876 let mut first_outlier = None;
877 let mut outliers_count = 0;
878 ndarray::indices_of(&ma).into_iter().for_each(|indices| {
879 let a = ma[&indices];
880 let b = mb[&indices];
881 if !((a.is_nan() && b.is_nan())
882 || (a.is_infinite() && b.is_infinite() && a.signum() == b.signum())
883 || (a - b).abs() <= atol as f32 + rtol as f32 * b.abs())
884 {
885 if outliers_count == 0 {
886 first_outlier = Some(indices.as_array_view().to_vec());
887 }
888 outliers_count += 1;
889 }
890 });
891 if self.volume() > 0 && outliers_count as f64 / self.volume() as f64 > outliers {
892 let indices = first_outlier.unwrap();
893 let a = ma[&*indices];
894 let b = mb[&*indices];
895 bail!(
896 "Mismatch. First outlier: {:?} for {:?}) at {:?} {} != {}. Outliers: {} / {} = {:0.5} > {:0.5}.",
897 approx,
898 self.datum_type(),
899 indices,
900 a,
901 b,
902 outliers_count,
903 self.volume(),
904 outliers_count as f64 / self.volume() as f64,
905 outliers
906 );
907 }
908 Ok(())
909 }
910
911 pub fn into_dense_array<D: Datum>(self) -> TractResult<ArrayD<D>> {
913 Ok(self.to_dense_array_view::<D>()?.to_owned())
914 }
915
916 pub unsafe fn into_array_unchecked<D: Datum>(self) -> ArrayD<D> {
918 unsafe { self.to_array_view_unchecked::<D>().to_owned() }
919 }
920
921 #[inline]
925 pub fn to_dense_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'_, D>> {
926 self.try_as_dense()?.to_array_view::<D>()
927 }
928
929 #[inline]
933 pub fn to_dense_array_view_mut<D: Datum>(&mut self) -> TractResult<ArrayViewMutD<'_, D>> {
934 self.check_for_access::<D>()?;
935 ensure!(self.storage.as_dense_mut().is_some(), "Tensor storage is not dense");
936 unsafe { Ok(self.to_array_view_mut_unchecked()) }
937 }
938
939 fn check_for_access<D: Datum>(&self) -> TractResult<()> {
940 ensure!(
941 self.datum_type().unquantized() == D::datum_type().unquantized(),
942 "Tensor datum type error: tensor is {:?}, accessed as {:?}",
943 self.datum_type(),
944 D::datum_type(),
945 );
946 Ok(())
947 }
948
949 pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'_, D> {
951 if self.len() != 0 {
952 unsafe {
953 ArrayViewD::from_shape_ptr(&*self.shape, self.dense_storage().as_ptr() as *const D)
954 }
955 } else {
956 ArrayViewD::from_shape(&*self.shape, &[]).unwrap()
957 }
958 }
959
960 pub unsafe fn to_array_view_mut_unchecked<D: Datum>(&mut self) -> ArrayViewMutD<'_, D> {
962 if self.len() != 0 {
963 unsafe {
964 let ptr = self.dense_storage_mut().as_mut_ptr() as *mut D;
965 ArrayViewMutD::from_shape_ptr(&*self.shape, ptr)
966 }
967 } else {
968 ArrayViewMutD::from_shape(&*self.shape, &mut []).unwrap()
969 }
970 }
971
972 pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
974 self.check_for_access::<D>()?;
975 Ok(self.dense_storage().as_ptr() as *const D)
976 }
977
978 pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
980 self.dense_storage().as_ptr() as *const D
981 }
982
983 pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
985 self.dense_storage_mut().as_mut_ptr() as *mut D
986 }
987
988 pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
990 self.as_ptr::<D>().map(|p| p as *mut D)
991 }
992
993 pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &[D] {
995 if self.storage.byte_len() == 0 {
996 &[]
997 } else {
998 unsafe { std::slice::from_raw_parts::<D>(self.as_ptr_unchecked(), self.len()) }
999 }
1000 }
1001
1002 pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
1004 if self.storage.byte_len() == 0 {
1005 &mut []
1006 } else {
1007 unsafe { std::slice::from_raw_parts_mut::<D>(self.as_ptr_mut_unchecked(), self.len()) }
1008 }
1009 }
1010
1011 pub fn to_scalar_tensor(&self) -> TractResult<Tensor> {
1013 fn to_scalar_tensor_t<D: Datum>(t: &Tensor) -> TractResult<Tensor> {
1014 Ok(litteral::tensor0(t.try_as_dense()?.to_scalar::<D>()?.clone()))
1015 }
1016 dispatch_datum!(to_scalar_tensor_t(self.datum_type())(self))
1017 }
1018
1019 pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &D {
1021 unsafe { &*(self.dense_storage().as_ptr() as *const D) }
1022 }
1023
1024 pub fn to_scalar_mut<D: Datum>(&mut self) -> TractResult<&mut D> {
1026 self.check_for_access::<D>()?;
1027 if self.len() == 0 {
1028 bail!("to_scalar_mut called on empty tensor ({:?})", self)
1029 }
1030 if self.len() > 1 {
1031 bail!("to_scalar called on a tensor with multiple values ({:?})", self)
1032 }
1033 unsafe { Ok(self.to_scalar_mut_unchecked()) }
1034 }
1035
1036 pub unsafe fn to_scalar_mut_unchecked<D: Datum>(&mut self) -> &mut D {
1038 unsafe { &mut *(self.dense_storage_mut().as_mut_ptr() as *mut D) }
1039 }
1040
1041 pub fn as_bytes(&self) -> &[u8] {
1042 self.dense_storage().as_bytes()
1043 }
1044
1045 pub fn as_bytes_mut(&mut self) -> &mut [u8] {
1046 self.dense_storage_mut().as_bytes_mut()
1047 }
1048
1049 unsafe fn is_uniform_t<T: Datum>(&self) -> bool {
1050 let slice = unsafe { self.as_slice_unchecked::<T>() };
1051 slice[1..].iter().all(|x| x == &slice[0])
1052 }
1053
1054 pub fn is_uniform(&self) -> bool {
1055 if self.len() <= 1 {
1056 return true;
1057 }
1058 unsafe { dispatch_datum!(Tensor::is_uniform_t(self.datum_type())(self)) }
1059 }
1060
1061 unsafe fn as_uniform_t<T: Datum>(&self) -> Tensor {
1062 let v: T = unsafe { self.as_slice_unchecked::<T>() }[0].clone();
1063 litteral::tensor0(v)
1064 }
1065
1066 pub fn as_uniform(&self) -> Option<Tensor> {
1067 if self.len() >= 1 && self.is_uniform() {
1068 unsafe {
1069 let mut t = dispatch_datum!(Tensor::as_uniform_t(self.datum_type())(self));
1070 t.set_datum_type(self.datum_type());
1071 Some(t)
1072 }
1073 } else {
1074 None
1075 }
1076 }
1077
1078 pub fn is_all_zero(&self) -> TractResult<bool> {
1079 Ok(self.len() == 0 || self.as_uniform().map(|t| t.is_zero().unwrap()).unwrap_or(false))
1080 }
1081
1082 pub fn is_zero(&self) -> TractResult<bool> {
1083 Ok(self == &Tensor::zero_scalar_dt(self.dt)?)
1084 }
1085
1086 unsafe fn natural_cast<
1087 Source: Datum + num_traits::AsPrimitive<Target>,
1088 Target: Datum + Copy,
1089 >(
1090 &self,
1091 other: &mut Tensor,
1092 ) {
1093 unsafe {
1094 self.as_slice_unchecked::<Source>()
1095 .iter()
1096 .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1097 .for_each(|(s, d)| *d = s.as_())
1098 };
1099 }
1100
1101 unsafe fn cast_number_to_bool<Source: Datum + num_traits::Zero>(&self, other: &mut Tensor) {
1102 unsafe {
1103 self.as_slice_unchecked::<Source>()
1104 .iter()
1105 .zip(other.as_slice_mut_unchecked::<bool>().iter_mut())
1106 .for_each(|(s, d)| *d = !s.is_zero());
1107 }
1108 }
1109
1110 unsafe fn cast_from_string<Target: Datum + core::str::FromStr>(
1111 &self,
1112 other: &mut Tensor,
1113 ) -> TractResult<()> {
1114 unsafe {
1115 for (s, d) in self
1116 .as_slice_unchecked::<String>()
1117 .iter()
1118 .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1119 {
1120 *d = s
1121 .parse()
1122 .map_err(|_| format_err!("Can not parse as {:?}", Target::datum_type()))?;
1123 }
1124 Ok(())
1125 }
1126 }
1127
1128 unsafe fn cast_to_string<Source: Datum>(&self, other: &mut Tensor) {
1129 unsafe {
1130 for (s, d) in self
1131 .as_slice_unchecked::<Source>()
1132 .iter()
1133 .zip(other.as_slice_mut_unchecked::<String>().iter_mut())
1134 {
1135 *d = s.to_string()
1136 }
1137 }
1138 }
1139
1140 pub fn cast_to<D: Datum>(&self) -> TractResult<Cow<'_, Tensor>> {
1142 self.cast_to_dt(D::datum_type())
1143 }
1144
1145 #[allow(clippy::redundant_closure_call)]
1147 pub fn cast_to_dt(&self, dst_dt: DatumType) -> TractResult<Cow<'_, Tensor>> {
1148 unsafe {
1149 if self.dt == dst_dt {
1150 return Ok(Cow::Borrowed(self));
1151 }
1152 if self.dt == TDim::datum_type() && (dst_dt.is_integer() || dst_dt.is_float()) {
1153 let slice = self.as_slice_unchecked::<TDim>();
1154 let mut ints = Self::uninitialized::<i64>(&self.shape)?;
1155 let ints_slice = ints.as_slice_mut_unchecked::<i64>();
1156 for i in 0..self.len() {
1157 ints_slice[i] = slice[i].to_i64()?;
1158 }
1159 return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1160 }
1161 if self.dt == bool::datum_type()
1162 && (dst_dt.is_integer() || dst_dt.is_float() || dst_dt == TDim::datum_type())
1163 {
1164 let slice = self.as_slice_unchecked::<bool>();
1165 let mut ints = Self::uninitialized::<i8>(&self.shape)?;
1166 let ints_slice = ints.as_slice_mut_unchecked::<i8>();
1167 for i in 0..self.len() {
1168 ints_slice[i] = slice[i] as usize as i8;
1169 }
1170 return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1171 }
1172 let mut result = Self::uninitialized_dt(dst_dt, &self.shape)?;
1173 if self.dt == DatumType::String {
1174 dispatch_numbers!(Self::cast_from_string(dst_dt)(self, &mut result))?;
1175 return Ok(Cow::Owned(result));
1176 }
1177 if dst_dt == DatumType::String {
1178 dispatch_datum!(Self::cast_to_string(self.dt)(self, &mut result));
1179 return Ok(Cow::Owned(result));
1180 }
1181 macro_rules! n {
1182 ($source:ty) => {
1183 if <$source>::datum_type() == self.datum_type() {
1184 match dst_dt {
1185 DatumType::I8 => self.natural_cast::<$source, i8>(&mut result),
1186 DatumType::I16 => self.natural_cast::<$source, i16>(&mut result),
1187 DatumType::I32 => self.natural_cast::<$source, i32>(&mut result),
1188 DatumType::I64 => self.natural_cast::<$source, i64>(&mut result),
1189 DatumType::U8 => self.natural_cast::<$source, u8>(&mut result),
1190 DatumType::U16 => self.natural_cast::<$source, u16>(&mut result),
1191 DatumType::U32 => self.natural_cast::<$source, u32>(&mut result),
1192 DatumType::U64 => self.natural_cast::<$source, u64>(&mut result),
1193 DatumType::F16 => self.natural_cast::<$source, f16>(&mut result),
1194 DatumType::F32 => self.natural_cast::<$source, f32>(&mut result),
1195 DatumType::F64 => self.natural_cast::<$source, f64>(&mut result),
1196 DatumType::TDim => {
1197 let ints = self.cast_to::<i32>()?;
1198 let slice = ints.as_slice_unchecked::<i32>();
1199 let result = result.as_slice_mut_unchecked::<TDim>();
1200 for i in 0..self.len() {
1201 result[i] = slice[i].into();
1202 }
1203 }
1204 DatumType::Bool => self.cast_number_to_bool::<$source>(&mut result),
1205 _ => todo!(),
1206 }
1207 return Ok(Cow::Owned(result));
1208 };
1209 };
1210 }
1211 if !dst_dt.is_quantized() && !self.datum_type().is_quantized() {
1213 n!(u8);
1214 n!(u16);
1215 n!(u32);
1216 n!(u64);
1217 n!(i8);
1218 n!(i16);
1219 n!(i32);
1220 n!(i64);
1221 n!(f16);
1222 n!(f32);
1223 n!(f64);
1224 } else {
1225 let (s_zp, s_scale) = self.datum_type().zp_scale();
1226 let (d_zp, d_scale) = dst_dt.zp_scale();
1227 if self.datum_type().is_quantized() && dst_dt.is_float() {
1228 macro_rules! q_to_fp {
1229 ($source:ty, $dest:ty) => {
1230 if <$source>::datum_type().unquantized()
1231 == self.datum_type().unquantized()
1232 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1233 {
1234 self.as_slice_unchecked::<$source>()
1235 .iter()
1236 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1237 .for_each(|(&s, d)| {
1238 *d = (s as $dest - s_zp as $dest) * s_scale as $dest;
1239 });
1240 return Ok(Cow::Owned(result));
1241 }
1242 };
1243 }
1244 q_to_fp!(i8, f64);
1245 q_to_fp!(i8, f32);
1246 q_to_fp!(u8, f64);
1247 q_to_fp!(u8, f32);
1248 }
1249 macro_rules! q8_to_q8 {
1251 ($typ:ty) => {
1252 if dst_dt.unquantized() == <$typ>::datum_type() {
1253 self.as_slice_unchecked::<$typ>()
1254 .iter()
1255 .zip(result.as_slice_mut_unchecked::<$typ>().iter_mut())
1256 .for_each(|(&s, d)| {
1257 *d = (d_zp as i32
1258 + scale_by(s as i32 - s_zp as i32, s_scale / d_scale))
1259 .clamp_cast()
1260 });
1261 return Ok(Cow::Owned(result));
1262 }
1263 };
1264 }
1265
1266 macro_rules! q_via_f32 {
1267 ($source:ty, $dest:ty, $round:expr) => {
1268 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1269 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1270 {
1271 self.as_slice_unchecked::<$source>()
1272 .iter()
1273 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1274 .for_each(|(&s, d)| {
1275 let s_float = (s as f32 - s_zp as f32) * s_scale as f32;
1276 let d_float = s_float as f32 / d_scale as f32 + d_zp as f32;
1277 *d = $round(d_float);
1278 });
1279 return Ok(Cow::Owned(result));
1280 }
1281 };
1282 }
1283
1284 macro_rules! q_n {
1285 (clamp $source:ty, $dest:ty) => {{
1286 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1287 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1288 {
1289 self.as_slice_unchecked::<$source>()
1290 .iter()
1291 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1292 .for_each(|(&s, d)| {
1293 *d = s.clamp_cast();
1294 });
1295 return Ok(Cow::Owned(result));
1296 }
1297 }};
1298 ($source:ty, $dest:ty) => {{
1299 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1300 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1301 {
1302 self.as_slice_unchecked::<$source>()
1303 .iter()
1304 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1305 .for_each(|(&s, d)| {
1306 *d = s as $dest;
1307 });
1308 return Ok(Cow::Owned(result));
1309 }
1310 }};
1311 }
1312
1313 if dst_dt.unquantized() == self.datum_type().unquantized()
1314 && dst_dt.is_quantized()
1315 && self.datum_type().is_quantized()
1316 {
1317 q8_to_q8!(i8);
1318 q8_to_q8!(u8);
1319 }
1320
1321 q_via_f32!(f32, i8, |f| round_ties_to_even(f).clamp_cast());
1322 q_via_f32!(f32, u8, |f| round_ties_to_even(f).clamp_cast());
1323 q_via_f32!(f32, i32, |f| round_ties_to_even(f).clamp_cast());
1324 q_via_f32!(i8, f32, |f| f);
1325 q_via_f32!(u8, f32, |f| f);
1326 q_via_f32!(i32, f32, |f| f);
1327
1328 if dst_dt.is_quantized() && self.datum_type().is_quantized() {
1329 q_via_f32!(u8, i8, |f| round_ties_to_even(f).clamp_cast());
1330 q_via_f32!(i8, u8, |f| round_ties_to_even(f).clamp_cast());
1331 q_via_f32!(i32, u8, |f| round_ties_to_even(f).clamp_cast());
1332 q_via_f32!(i32, i8, |f| round_ties_to_even(f).clamp_cast());
1333 q_via_f32!(u8, i32, |f| round_ties_to_even(f).clamp_cast());
1334 q_via_f32!(i8, i32, |f| round_ties_to_even(f).clamp_cast());
1335
1336 q_via_f32!(i8, i8, |f| round_ties_to_even(f).clamp_cast());
1338 q_via_f32!(u8, u8, |f| round_ties_to_even(f).clamp_cast());
1339 }
1340
1341 q_n!(i8, i32);
1342 q_n!(i8, u32);
1343 q_n!(u8, i32);
1344 q_n!(u8, u32);
1345 q_n!(clamp i32, i8);
1346 q_n!(clamp i32, u8);
1347 q_n!(clamp u32, i8);
1348 q_n!(clamp u32, u8);
1349 q_n!(i8, i8);
1350 q_n!(u8, u8);
1351 q_n!(i32, i32);
1352 q_n!(u32, u32);
1353 }
1354
1355 bail!("Unsupported cast from {:?} to {:?}", self.dt, dst_dt)
1356 }
1357 }
1358
1359 pub fn cast_to_scalar<D: Datum + Copy>(&self) -> TractResult<D> {
1361 let casted = self.cast_to::<D>()?;
1362 casted.try_as_dense()?.to_scalar::<D>().copied()
1363 }
1364
1365 pub fn nth(&self, nth: usize) -> TractResult<Tensor> {
1367 if nth >= self.len() {
1368 bail!(
1369 "nth called with {}th element on a tensor of len {} ({:?}",
1370 nth,
1371 self.len(),
1372 self
1373 );
1374 }
1375 unsafe fn nth_t<T: Datum>(me: &Tensor, nth: usize, output: &mut Tensor) {
1376 unsafe {
1377 let value = me.as_slice_unchecked::<T>()[nth].clone();
1378 output.as_slice_mut_unchecked::<T>()[0] = value;
1379 }
1380 }
1381 unsafe {
1382 let mut output = Tensor::uninitialized_dt(self.datum_type(), &[])?;
1383 dispatch_datum_by_size!(nth_t(self.datum_type())(self, nth, &mut output));
1384 Ok(output)
1385 }
1386 }
1387
1388 fn eq_dt(&self, other: &Tensor) -> TractResult<bool> {
1390 unsafe fn eq_t<D: Datum>(me: &Tensor, other: &Tensor) -> TractResult<bool> {
1391 unsafe {
1392 if D::datum_type().is_float() {
1393 return dispatch_floatlike!(float_eq_t(D::datum_type())(me, other));
1394 }
1395 Ok(izip!(me.as_slice_unchecked::<D>(), other.as_slice_unchecked::<D>())
1396 .all(|(a, b)| a == b))
1397 }
1398 }
1399
1400 unsafe fn float_eq_t<D: Datum + Float>(me: &Tensor, other: &Tensor) -> TractResult<bool> {
1401 unsafe {
1402 Ok(izip!(me.as_slice_unchecked::<D>(), other.as_slice_unchecked::<D>())
1403 .all(|(a, b)| (a.is_nan() && b.is_nan()) || a == b))
1404 }
1405 }
1406
1407 unsafe {
1408 Ok(self.datum_type() == other.datum_type()
1409 && self.shape() == other.shape()
1410 && dispatch_datum!(eq_t(self.dt)(self, other))?)
1411 }
1412 }
1413
1414 fn from_datum<T: Datum>(mut it: ArrayD<T>) -> Tensor {
1415 unsafe {
1416 let mut t = Self::uninitialized::<T>(it.shape()).unwrap();
1417 if let Some(slice) = it.as_slice_mut() {
1418 if t.datum_type().is_copy() {
1419 std::ptr::copy_nonoverlapping(
1420 slice.as_ptr() as *const i8,
1421 t.as_ptr_mut_unchecked(),
1422 t.dense_storage().layout().size(),
1423 );
1424 } else {
1425 t.as_slice_mut_unchecked::<T>()
1426 .iter_mut()
1427 .zip(slice.iter_mut())
1428 .for_each(|(t, s)| *t = std::mem::take(s));
1429 }
1430 return t;
1431 }
1432 if it.strides().iter().all(|&s| s > 0) && it.as_slice_memory_order().is_some() {
1433 let mut len_and_strides: TVec<(usize, usize)> = tvec!();
1434 for (len, stride) in itertools::izip!(it.shape(), it.strides(), t.strides())
1435 .sorted_by_key(|(_, src, _)| *src)
1436 .map(|(l, _, dst)| (*l as isize, *dst))
1437 {
1438 if !len_and_strides.is_empty()
1439 && len_and_strides.last().unwrap().1 * len_and_strides.last().unwrap().0
1440 == stride as usize
1441 {
1442 len_and_strides.last_mut().unwrap().0 *= len as usize;
1443 } else {
1444 len_and_strides.push((len as usize, stride as usize));
1445 }
1446 }
1447 len_and_strides.reverse();
1448 crate::scatter::scatter_contig_data(
1449 it.as_ptr(),
1450 t.as_ptr_mut_unchecked(),
1451 &len_and_strides,
1452 );
1453 return t;
1454 }
1455 t.as_slice_mut_unchecked().iter_mut().zip(it).for_each(|(t, a)| *t = a);
1457 t
1458 }
1459 }
1460
1461 pub fn deep_clone(&self) -> Tensor {
1462 unsafe {
1463 let mut tensor = Tensor::uninitialized_dt(self.datum_type(), self.shape()).unwrap();
1464 if self.len() > 0 {
1465 if self.dt.is_copy() {
1466 self.dense_storage().as_ptr().copy_to_nonoverlapping(
1467 tensor.as_bytes_mut().as_mut_ptr(),
1468 self.dense_storage().layout().size(),
1469 )
1470 } else if self.dt == DatumType::String {
1471 tensor
1472 .as_slice_mut_unchecked::<String>()
1473 .clone_from_slice(self.as_slice_unchecked());
1474 } else if self.dt == DatumType::Blob {
1475 tensor
1476 .as_slice_mut_unchecked::<Blob>()
1477 .clone_from_slice(self.as_slice_unchecked());
1478 } else if self.dt == DatumType::Opaque {
1479 tensor
1480 .as_slice_mut_unchecked::<Opaque>()
1481 .clone_from_slice(self.as_slice_unchecked());
1482 } else if self.dt == DatumType::TDim {
1483 tensor
1484 .as_slice_mut_unchecked::<TDim>()
1485 .clone_from_slice(self.as_slice_unchecked());
1486 }
1487 }
1488 tensor
1489 }
1490 }
1491
1492 pub fn slice(&self, axis: usize, start: usize, end: usize) -> TractResult<Tensor> {
1493 if axis >= self.rank() {
1494 bail!("Can not slice at axis {} tensor {:?}", axis, self);
1495 }
1496 if start > self.shape[axis] || end > self.shape[axis] || start >= end {
1497 bail!("Invalid slicing range {start}..{end} on axis {axis} for {self:?}");
1498 }
1499 fn slice_t<T: Datum>(
1500 t: &Tensor,
1501 axis: usize,
1502 start: usize,
1503 end: usize,
1504 ) -> TractResult<Tensor> {
1505 Ok(t.to_dense_array_view::<T>()?
1506 .slice_axis(ndarray::Axis(axis), (start..end).into())
1507 .into_owned()
1508 .into_tensor())
1509 }
1510 dispatch_datum!(slice_t(self.datum_type())(self, axis, start, end))
1511 }
1512
1513 #[inline]
1514 pub fn view(&self) -> view::TensorView<'_> {
1515 unsafe { view::TensorView::view(self) }
1516 }
1517
1518 #[inline]
1519 pub fn view_at_prefix(&self, prefix: &[usize]) -> TractResult<view::TensorView<'_>> {
1520 view::TensorView::at_prefix(self, prefix)
1521 }
1522
1523 #[inline]
1524 pub fn view_offsetting(&self, coords: &[usize]) -> TractResult<view::TensorView<'_>> {
1525 view::TensorView::offsetting(self, coords)
1526 }
1527
1528 #[inline]
1529 pub unsafe fn view_offsetting_unchecked(&self, coords: &[usize]) -> view::TensorView<'_> {
1530 unsafe { view::TensorView::offsetting_unchecked(self, coords) }
1531 }
1532
1533 #[inline]
1534 pub fn view_mut(&mut self) -> view::TensorView<'_> {
1535 unsafe { view::TensorView::view(self) }
1536 }
1537
1538 #[inline]
1539 pub fn view_at_prefix_mut(&mut self, prefix: &[usize]) -> TractResult<view::TensorView<'_>> {
1540 view::TensorView::at_prefix(self, prefix)
1541 }
1542
1543 #[inline]
1544 pub fn view_offsetting_mut(&mut self, coords: &[usize]) -> TractResult<view::TensorView<'_>> {
1545 view::TensorView::offsetting(self, coords)
1546 }
1547
1548 pub fn offset_u8_as_i8(self: &Arc<Self>) -> Arc<Self> {
1550 let mut t = if let DatumType::U8 = self.dt.unquantized() {
1551 self.try_as_dense()
1552 .unwrap()
1553 .to_array_view::<u8>()
1554 .unwrap()
1555 .mapv(|v| v.wrapping_sub(128) as i8)
1556 .into_tensor()
1557 } else {
1558 return self.clone();
1559 };
1560
1561 if let DatumType::QU8(qp) = self.dt {
1562 if let QParams::ZpScale { zero_point, scale } = qp {
1563 t.dt = DatumType::QI8(QParams::ZpScale { zero_point: zero_point - 128, scale });
1564 } else {
1565 t.dt = DatumType::QI8(qp);
1566 }
1567 }
1568
1569 t.into_arc_tensor()
1570 }
1571
1572 pub fn offset_i8_as_u8(self: &Arc<Self>) -> Arc<Self> {
1574 let mut t = if let DatumType::I8 = self.dt.unquantized() {
1575 self.try_as_dense()
1576 .unwrap()
1577 .to_array_view::<i8>()
1578 .unwrap()
1579 .mapv(|v| (v as u8).wrapping_add(128))
1580 .into_tensor()
1581 } else {
1582 return self.clone();
1583 };
1584
1585 if let DatumType::QI8(qp) = self.dt {
1586 if let QParams::ZpScale { zero_point, scale } = qp {
1587 t.dt = DatumType::QU8(QParams::ZpScale { zero_point: zero_point + 128, scale });
1588 } else {
1589 t.dt = DatumType::QU8(qp);
1590 }
1591 }
1592 t.into_arc_tensor()
1593 }
1594
1595 pub fn to_aligned_default(&self) -> TractResult<Self> {
1596 if self.dt.is_copy() {
1597 unsafe {
1598 let mut t = Self::uninitialized_dt(self.dt, &self.shape)?;
1599 t.as_bytes_mut().copy_from_slice(self.as_bytes());
1600 Ok(t)
1601 }
1602 } else {
1603 let mut t = Self::zero_dt(self.dt, &self.shape)?;
1604 if self.dt == String::datum_type() {
1605 t.try_as_dense_mut()?
1606 .as_slice_mut::<String>()?
1607 .clone_from_slice(self.try_as_dense()?.as_slice()?);
1608 } else if self.dt == Blob::datum_type() {
1609 t.try_as_dense_mut()?
1610 .as_slice_mut::<Blob>()?
1611 .clone_from_slice(self.try_as_dense()?.as_slice()?);
1612 } else if self.dt == TDim::datum_type() {
1613 t.try_as_dense_mut()?
1614 .as_slice_mut::<TDim>()?
1615 .clone_from_slice(self.try_as_dense()?.as_slice()?);
1616 }
1617 Ok(t)
1618 }
1619 }
1620
1621 pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1622 let mut strides = tvec!();
1623 compute_natural_stride_to(&mut strides, shape);
1624 strides
1625 }
1626
1627 pub fn into_blob(mut self) -> TractResult<Blob> {
1628 ensure!(self.dt.is_copy());
1629 let storage =
1630 std::mem::replace(&mut self.storage, StorageKind::Dense(DenseStorage::default()));
1631 Ok(storage.into_dense().context("Storage is not dense")?.into_blob())
1632 }
1633}
1634
1635impl PartialEq for Tensor {
1636 fn eq(&self, other: &Tensor) -> bool {
1637 if self.dt != other.dt || self.shape != other.shape {
1638 return false;
1639 }
1640 self.eq_dt(other).unwrap_or(false)
1641 }
1642}
1643
1644impl Eq for Tensor {}
1645
1646impl fmt::Debug for Tensor {
1647 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1648 let content = self.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
1649 write!(formatter, "{content}")
1650 }
1651}
1652
1653#[cfg(feature = "complex")]
1654pub fn reinterpret_inner_dim_as_complex(mut t: Tensor) -> TractResult<Tensor> {
1655 ensure!(
1656 t.shape().last() == Some(&2),
1657 "The last dimension in the tensor shape {:?} must be 2",
1658 t.shape()
1659 );
1660 unsafe {
1661 t.shape.pop();
1662 t.set_datum_type(t.datum_type().complexify()?);
1663 t.update_strides_and_len();
1664 Ok(t)
1665 }
1666}
1667
1668#[cfg(feature = "complex")]
1669pub fn reinterpret_complex_as_inner_dim(mut t: Tensor) -> TractResult<Tensor> {
1670 unsafe {
1671 t.shape.push(2);
1672 t.set_datum_type(t.datum_type().decomplexify()?);
1673 t.update_strides_and_len();
1674 Ok(t)
1675 }
1676}
1677
1678pub fn clip_range_bounds(len: usize, range: impl std::ops::RangeBounds<usize>) -> Range<usize> {
1679 use std::ops::Bound;
1680 let start = match range.start_bound() {
1681 Bound::Included(ix) => *ix,
1682 Bound::Excluded(ix) => ix + 1,
1683 Bound::Unbounded => 0,
1684 };
1685 let end = match range.end_bound() {
1686 Bound::Included(ix) => *ix + 1,
1687 Bound::Excluded(ix) => *ix,
1688 Bound::Unbounded => len,
1689 };
1690 start..end
1691}
1692
1693pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1694 let mut strides = tvec!();
1695 compute_natural_stride_to(&mut strides, shape);
1696 strides
1697}
1698
1699fn compute_natural_stride_to(strides: &mut TVec<isize>, shape: &[usize]) {
1700 match shape.len() {
1701 0 => (),
1702 1 => strides.push(1),
1703 2 => strides.extend_from_slice(&[shape[1] as isize, 1]),
1704 3 => strides.extend_from_slice(&[(shape[1] * shape[2]) as isize, shape[2] as _, 1]),
1705 4 => strides.extend_from_slice(&[
1706 (shape[1] * shape[2] * shape[3]) as isize,
1707 (shape[2] * shape[3]) as _,
1708 shape[3] as _,
1709 1,
1710 ]),
1711 _ => {
1712 strides.push(1);
1713 for dim in shape.as_ref().iter().skip(1).rev() {
1714 let previous = *strides.last().unwrap();
1715 strides.push(previous * *dim as isize)
1716 }
1717 strides.reverse();
1718 }
1719 }
1720}
1721
1722impl<D: ::ndarray::Dimension, T: Datum> From<Array<T, D>> for Tensor {
1723 fn from(it: Array<T, D>) -> Tensor {
1724 Tensor::from_datum(it.into_dyn())
1725 }
1726}
1727
1728pub trait IntoTensor: Sized {
1730 fn into_tensor(self) -> Tensor;
1734}
1735
1736pub trait IntoArcTensor: Sized {
1738 fn into_arc_tensor(self) -> Arc<Tensor>;
1742}
1743
1744impl<D: ::ndarray::Dimension, T: Datum> IntoTensor for Array<T, D> {
1745 fn into_tensor(self) -> Tensor {
1746 Tensor::from(self)
1747 }
1748}
1749
1750impl<D: ::ndarray::Dimension, T: Datum> IntoArcTensor for Array<T, D> {
1751 fn into_arc_tensor(self) -> Arc<Tensor> {
1752 Arc::new(Tensor::from(self))
1753 }
1754}
1755
1756impl IntoTensor for Tensor {
1757 fn into_tensor(self) -> Tensor {
1758 self
1759 }
1760}
1761
1762impl IntoTensor for Arc<Tensor> {
1763 fn into_tensor(self) -> Tensor {
1764 Arc::try_unwrap(self).unwrap_or_else(|t| (*t).clone())
1765 }
1766}
1767
1768impl IntoArcTensor for Tensor {
1769 fn into_arc_tensor(self) -> Arc<Tensor> {
1770 Arc::new(self)
1771 }
1772}
1773
1774impl IntoArcTensor for Arc<Tensor> {
1775 fn into_arc_tensor(self) -> Arc<Tensor> {
1776 self
1777 }
1778}
1779
1780#[cfg(test)]
1781mod tests {
1782 use crate::dim::SymbolScope;
1783 use crate::prelude::tensor1;
1784
1785 use super::*;
1786 use litteral::tensor0;
1787 use proptest::collection::vec;
1788 use proptest::prelude::*;
1789
1790 #[derive(Debug)]
1791 struct PermuteAxisProblem {
1792 shape: Vec<usize>,
1793 permutation: Vec<usize>,
1794 }
1795
1796 impl Arbitrary for PermuteAxisProblem {
1797 type Strategy = BoxedStrategy<PermuteAxisProblem>;
1798 type Parameters = ();
1799
1800 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1801 (0..8usize)
1802 .prop_flat_map(|rank| {
1803 let permute: Vec<usize> = (0..rank).collect();
1804 (proptest::collection::vec(1..5usize, rank), Just(permute).prop_shuffle())
1805 })
1806 .prop_map(|(shape, permutation)| PermuteAxisProblem { shape, permutation })
1807 .boxed()
1808 }
1809 }
1810
1811 impl PermuteAxisProblem {
1812 fn input(&self) -> ArrayD<i32> {
1813 let mut i = 0;
1814 ArrayD::from_shape_simple_fn(&*self.shape, || {
1815 i += 1;
1816 i
1817 })
1818 .permuted_axes(&*self.permutation)
1819 }
1820
1821 fn reference(&self) -> Tensor {
1822 let values: Vec<i32> = self.input().iter().copied().collect();
1823 let shape = self.permutation.iter().map(|ix| self.shape[*ix]).collect::<TVec<usize>>();
1824 super::litteral::tensor1(&values).into_shape(&shape).unwrap()
1825 }
1826
1827 fn tract(&self) -> Tensor {
1828 Tensor::from(self.input())
1829 }
1830
1831 fn check(&self) -> proptest::test_runner::TestCaseResult {
1832 prop_assert_eq!(self.tract(), self.reference());
1833 Ok(())
1834 }
1835 }
1836
1837 proptest::proptest! {
1838 #[test]
1839 fn prop(pb: PermuteAxisProblem) {
1840 pb.check().unwrap();
1841 }
1842 }
1843
1844 #[test]
1845 fn t_1_2() {
1846 PermuteAxisProblem { shape: vec![2, 1], permutation: vec![1, 0] }.check().unwrap();
1847 }
1848
1849 #[test]
1850 fn t_2_2() {
1851 PermuteAxisProblem { shape: vec![2, 2], permutation: vec![1, 0] }.check().unwrap();
1852 }
1853
1854 #[derive(Debug)]
1855 struct BroadcastVecToShape {
1856 vec: Vec<f32>,
1857 axis: usize,
1858 shape: TVec<usize>,
1859 }
1860
1861 impl BroadcastVecToShape {
1862 fn check(&self) -> proptest::test_runner::TestCaseResult {
1863 let input = tensor1(&self.vec);
1864 let mut intermediate = tvec![1usize; self.shape.len()];
1865 intermediate[self.axis] = self.vec.len();
1866 let reference = input
1867 .clone()
1868 .into_shape(&intermediate)
1869 .unwrap()
1870 .broadcast_to_shape(&self.shape)
1871 .unwrap();
1872 prop_assert_eq!(
1873 reference,
1874 input.broadcast_vector_to_shape(&self.shape, self.axis).unwrap()
1875 );
1876 Ok(())
1877 }
1878 }
1879
1880 impl Arbitrary for BroadcastVecToShape {
1881 type Strategy = BoxedStrategy<BroadcastVecToShape>;
1882 type Parameters = ();
1883
1884 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1885 vec(0usize..5, 0usize..4)
1886 .prop_flat_map(|shape| {
1887 (vec(-10f32..10f32, 0usize..5), Just(shape.clone()), 0..shape.len() + 1)
1888 })
1889 .prop_map(|(vec, mut shape, axis)| {
1890 shape.insert(axis, vec.len());
1891 BroadcastVecToShape { vec, shape: shape.into(), axis }
1892 })
1893 .boxed()
1894 }
1895 }
1896
1897 proptest::proptest! {
1898 #[test]
1899 fn broadcast_vector_to_shape_prop(pb: BroadcastVecToShape) {
1900 pb.check().unwrap()
1901 }
1902 }
1903
1904 #[test]
1905 #[cfg(feature = "complex")]
1906 fn test_reinterpret_inner_dim_as_complex() -> TractResult<()> {
1907 let input = crate::internal::tensor2(&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]]);
1908 let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1909 let expected = crate::internal::tensor1(&[
1910 Complex::new(1.0f32, 2.0),
1911 Complex::new(3.0, 4.0),
1912 Complex::new(5.0, 6.0),
1913 ]);
1914 assert_eq!(expected, cplx_input);
1915 Ok(())
1916 }
1917
1918 #[test]
1919 #[cfg(feature = "complex")]
1920 fn test_reinterpret_inner_dim_as_complex_2() -> TractResult<()> {
1921 let input =
1922 crate::internal::tensor3(&[[[1i32, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]);
1923 let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1924 let expected = crate::internal::tensor2(&[
1925 [Complex::new(1i32, 2), Complex::new(1, 2)],
1926 [Complex::new(3, 4), Complex::new(3, 4)],
1927 [Complex::new(5, 6), Complex::new(5, 6)],
1928 ]);
1929 assert_eq!(expected, cplx_input);
1930 Ok(())
1931 }
1932
1933 #[test]
1934 fn clone_tdim_tensor() {
1935 let symbols = SymbolScope::default();
1936 let a = symbols.sym("a");
1937 let t = tensor0(TDim::from(a));
1938 let _ = t.clone();
1939 }
1940}