1use crate::blob::Blob;
3use crate::datum::{round_ties_to_even, scale_by, ClampCast, Datum, DatumType, QParams};
4use crate::dim::TDim;
5use crate::internal::*;
6use crate::opaque::Opaque;
7use crate::TVec;
8use half::f16;
9use itertools::Itertools;
10use ndarray::prelude::*;
11#[cfg(feature = "complex")]
12use num_complex::Complex;
13use num_traits::Zero;
14use std::borrow::Cow;
15use std::fmt;
16use std::hash::Hash;
17use std::ops::Range;
18use std::sync::Arc;
19
20pub mod litteral;
21pub mod view;
22
23#[derive(Copy, Clone, Default, Debug)]
24pub enum Approximation {
25 Exact,
26 #[default]
27 Close,
28 Approximate,
29 VeryApproximate,
30 SuperApproximate,
31 UltraApproximate,
32 Custom(f32, f32, f32),
33}
34
35impl PartialEq for Approximation {
36 fn eq(&self, other: &Self) -> bool {
37 use Approximation::Custom;
38 if let (Custom(aa, ar, ao), Custom(ba, br, bo)) = (self, other) {
39 aa == ba && ar == br && bo == ao
40 } else {
41 std::mem::discriminant(self) == std::mem::discriminant(other)
42 }
43 }
44}
45
46impl Eq for Approximation {}
47
48impl From<bool> for Approximation {
49 fn from(b: bool) -> Self {
50 if b {
51 Self::Approximate
52 } else {
53 Self::Exact
54 }
55 }
56}
57
58impl Approximation {
59 fn atol_rtol_outliers(&self, dt: &DatumType) -> (f64, f64, f64) {
60 use Approximation::*;
61 match (self, dt) {
62 (Exact, _) => (0.0, 0.0, 0.0),
63 (Close, DatumType::F16) => (1e-3, 1e-3, 0.0),
64 (Approximate, DatumType::F16) => (1e-3, 5e-3, 0.0),
65 (Approximate, qp) if qp.is_quantized() => (qp.zp_scale().1 as f64, 0., 0.0),
66 (Close, _) => (1e-7, 1e-7, 0.0),
67 (Approximate, _) => (1e-4, 5e-4, 0.0),
68 (VeryApproximate, _) => (5e-2, 1e-2, 0.0),
69 (SuperApproximate, _) => (0.1, 0.05, 0.0001),
70 (UltraApproximate, _) => (0.2, 0.1, 0.0005),
71 (Custom(atol, rtol, out), _) => (*atol as _, *rtol as _, *out as _),
72 }
73 }
74}
75
76#[derive(Eq)]
78pub struct Tensor {
79 dt: DatumType,
80 shape: TVec<usize>,
81 strides: TVec<isize>,
82 len: usize,
83 data: Blob,
84}
85
86unsafe impl Send for Tensor {}
87unsafe impl Sync for Tensor {}
88
89impl Hash for Tensor {
90 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
91 use DatumType::*;
92 self.dt.hash(state);
93 self.shape.hash(state);
94 self.data.layout().align().hash(state);
95 unsafe {
96 match self.dt {
97 Bool => self.as_slice_unchecked::<bool>().hash(state),
98 I8 => self.as_slice_unchecked::<i8>().hash(state),
99 I16 => self.as_slice_unchecked::<i16>().hash(state),
100 I32 => self.as_slice_unchecked::<i32>().hash(state),
101 I64 => self.as_slice_unchecked::<i64>().hash(state),
102 U8 => self.as_slice_unchecked::<u8>().hash(state),
103 U16 => self.as_slice_unchecked::<u16>().hash(state),
104 U32 => self.as_slice_unchecked::<u32>().hash(state),
105 U64 => self.as_slice_unchecked::<u64>().hash(state),
106 F16 => self.as_slice_unchecked::<i16>().hash(state),
107 F32 => self.as_slice_unchecked::<i32>().hash(state),
108 F64 => self.as_slice_unchecked::<i64>().hash(state),
109 TDim => self.as_slice_unchecked::<crate::dim::TDim>().hash(state),
110 String => self.as_slice_unchecked::<std::string::String>().hash(state),
111 Blob => self.as_slice_unchecked::<crate::blob::Blob>().hash(state),
112 Opaque => self.as_slice_unchecked::<crate::opaque::Opaque>().hash(state),
113 QI8(_) => self.as_slice_unchecked::<i8>().hash(state),
114 QU8(_) => self.as_slice_unchecked::<u8>().hash(state),
115 QI32(_) => self.as_slice_unchecked::<i32>().hash(state),
116 #[cfg(feature = "complex")]
117 ComplexI16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
118 #[cfg(feature = "complex")]
119 ComplexI32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
120 #[cfg(feature = "complex")]
121 ComplexI64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
122 #[cfg(feature = "complex")]
123 ComplexF16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
124 #[cfg(feature = "complex")]
125 ComplexF32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
126 #[cfg(feature = "complex")]
127 ComplexF64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
128 }
129 }
130 }
131}
132
133impl Clone for Tensor {
134 fn clone(&self) -> Tensor {
135 self.deep_clone()
136 }
137}
138
139impl Default for Tensor {
140 fn default() -> Tensor {
141 litteral::tensor0(0f32)
142 }
143}
144
145impl Drop for Tensor {
146 fn drop(&mut self) {
147 macro_rules! drop_in_place {
148 ($t: ty) => {
149 if self.dt == <$t>::datum_type() {
150 unsafe {
151 self.as_slice_mut::<$t>()
152 .unwrap()
153 .iter_mut()
154 .for_each(|s| std::ptr::drop_in_place(s as *mut $t));
155 }
156 }
157 };
158 }
159 drop_in_place!(Blob);
160 drop_in_place!(String);
161 drop_in_place!(TDim);
162 drop_in_place!(Opaque);
163 }
164}
165
166#[allow(unreachable_code)]
167pub fn vector_size() -> usize {
168 #[cfg(target_arch = "x86_64")]
169 {
170 return if is_x86_feature_detected!("avx512f") { 512 / 8 } else { 256 / 8 };
171 }
172 128 / 8
173}
174
175impl Tensor {
176 #[allow(unreachable_code, unexpected_cfgs)]
177 #[inline]
178 pub fn default_alignment(dt: DatumType, shape: &[usize]) -> usize {
179 if shape.len() == 0 {
180 dt.alignment()
181 } else {
182 vector_size()
183 }
184 }
185
186 #[inline]
188 pub unsafe fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<Tensor> {
189 Self::uninitialized_dt(T::datum_type(), shape)
190 }
191
192 #[inline]
194 pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
195 Self::uninitialized_aligned_dt(dt, shape, Self::default_alignment(dt, shape))
196 }
197
198 #[inline]
200 pub unsafe fn uninitialized_aligned<T: Datum>(
201 shape: &[usize],
202 alignment: usize,
203 ) -> TractResult<Tensor> {
204 Self::uninitialized_aligned_dt(T::datum_type(), shape, alignment)
205 }
206
207 pub unsafe fn uninitialized_aligned_dt(
209 dt: DatumType,
210 shape: &[usize],
211 alignment: usize,
212 ) -> TractResult<Tensor> {
213 let bytes = shape.iter().cloned().product::<usize>() * dt.size_of();
214 let data = Blob::new_for_size_and_align(bytes, alignment);
215 let mut tensor = Tensor { strides: tvec!(), dt, shape: shape.into(), data, len: 0 };
216 if tensor.shape.len() == 0 {
217 tensor.len = 1;
218 } else {
219 tensor.update_strides_and_len();
220 }
221 if !tensor.data.is_empty() {
222 if dt == String::datum_type() || dt == Blob::datum_type() {
223 tensor.data.fill(0);
225 } else if dt == TDim::datum_type() {
226 tensor
227 .as_slice_mut_unchecked::<TDim>()
228 .iter_mut()
229 .for_each(|dim| std::ptr::write(dim, TDim::zero()))
230 } else if dt == Opaque::datum_type() {
231 tensor.as_slice_mut_unchecked::<Opaque>().iter_mut().for_each(|p| {
232 std::ptr::write(p, Opaque::default());
233 });
234 } else if cfg!(debug_assertions) {
235 assert!(dt.is_copy());
236 if dt == DatumType::F32 {
237 tensor.fill_t(f32::NAN).unwrap();
238 } else {
239 tensor.as_bytes_mut().iter_mut().for_each(|x| *x = (-1i8) as u8);
241 }
242 }
243 }
244 Ok(tensor)
245 }
246
247 pub fn stack_tensors(
248 axis: usize,
249 tensors: &[impl std::borrow::Borrow<Tensor>],
250 ) -> TractResult<Tensor> {
251 ensure!(tensors.len() > 0);
252 let rank = tensors[0].borrow().rank();
253 ensure!(axis < rank);
254 ensure!(tensors.iter().all(|t| t.borrow().rank() == rank));
255 let dt = tensors[0].borrow().datum_type();
256 ensure!(tensors.iter().all(|t| t.borrow().datum_type() == dt));
257 let mut shape: TVec<usize> = tensors[0].borrow().shape().into();
258 for ax in 0..rank {
259 if ax != axis {
260 ensure!(tensors.iter().all(|t| t.borrow().shape()[ax] == shape[ax]));
261 }
262 }
263 shape[axis] = tensors.iter().map(|v| v.borrow().shape()[axis]).sum();
264 unsafe {
265 let mut result = Tensor::uninitialized_dt(dt, &shape)?;
266 if dt.is_copy() && shape[..axis].iter().all(|d| *d == 1) {
267 let mut offset = 0isize;
268 for v in tensors {
269 let v = v.borrow();
270 let len = v.data.len();
271 std::ptr::copy_nonoverlapping(
272 v.data.as_ptr(),
273 result.data.as_mut_ptr().offset(offset),
274 len,
275 );
276 offset += len as isize;
277 }
278 } else {
279 let mut offset = 0;
280 for t in tensors {
281 let t = t.borrow();
282 let len = t.shape()[axis];
283 result.assign_slice_from_resolved(offset..offset + len, t, 0..len, axis);
284 offset += len;
285 }
286 }
287
288 Ok(result)
289 }
290 }
291
292 pub fn clear<T: Datum + num_traits::Zero + Clone>(&mut self) -> TractResult<()> {
293 self.fill_t(T::zero())
294 }
295
296 pub fn zero<T: Datum + num_traits::Zero>(shape: &[usize]) -> TractResult<Tensor> {
297 unsafe {
298 let mut t = Tensor::uninitialized::<T>(shape)?;
299 t.clear::<T>()?;
300 Ok(t)
301 }
302 }
303
304 pub fn zero_scalar<T: Datum + num_traits::Zero>() -> TractResult<Tensor> {
305 Tensor::zero::<T>(&[])
306 }
307
308 pub fn zero_scalar_dt(dt: DatumType) -> TractResult<Tensor> {
309 Tensor::zero_dt(dt, &[])
310 }
311
312 pub fn zero_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
313 Tensor::zero_aligned_dt(dt, shape, Self::default_alignment(dt, shape))
314 }
315
316 pub fn fill_t<T: Datum + Clone>(&mut self, value: T) -> TractResult<()> {
317 self.as_slice_mut::<T>()?.iter_mut().for_each(|item| *item = value.clone());
318 Ok(())
319 }
320
321 pub fn zero_aligned_dt(
322 dt: DatumType,
323 shape: &[usize],
324 alignment: usize,
325 ) -> TractResult<Tensor> {
326 if shape.iter().product::<usize>() == 0 {
327 unsafe { return Tensor::uninitialized_dt(dt, shape) };
328 }
329 if dt.is_quantized() {
330 unsafe {
331 let mut t = Tensor::uninitialized_dt(dt, shape)?;
332 let zp = dt.zp_scale().0;
333 match dt.unquantized() {
334 DatumType::I8 => {
335 t.as_slice_mut::<i8>()?.iter_mut().for_each(|item| *item = zp as _)
336 }
337 DatumType::U8 => {
338 t.as_slice_mut::<u8>()?.iter_mut().for_each(|item| *item = zp as _)
339 }
340 DatumType::I32 => {
341 t.as_slice_mut::<i32>()?.iter_mut().for_each(|item| *item = zp as _)
342 }
343 _ => unreachable!(),
344 }
345 Ok(t)
346 }
347 } else {
348 dispatch_zerolike!(Self::zero_aligned(dt)(shape, alignment))
349 }
350 }
351
352 pub fn zero_aligned<T: Datum + num_traits::Zero>(
353 shape: &[usize],
354 alignment: usize,
355 ) -> TractResult<Tensor> {
356 unsafe {
357 let mut tensor = Self::uninitialized_aligned::<T>(shape, alignment)?;
358 tensor.clear::<T>()?;
359 Ok(tensor)
360 }
361 }
362
363 pub fn from_shape<T: Datum + Copy>(shape: &[usize], data: &[T]) -> TractResult<Tensor> {
366 let dt = T::datum_type();
367 Self::from_shape_align(shape, data, dt.alignment())
368 }
369
370 pub fn from_shape_align<T: Datum + Copy>(
373 shape: &[usize],
374 data: &[T],
375 align: usize,
376 ) -> TractResult<Tensor> {
377 ensure!(
378 data.len() == shape.iter().product::<usize>(),
379 "Shape product must be equal to data length"
380 );
381 unsafe {
382 let bytes = std::slice::from_raw_parts(
383 data.as_ptr() as *const u8,
384 data.len() * T::datum_type().size_of(),
385 );
386 let dt = T::datum_type();
387 Self::from_raw_dt_align(dt, shape, bytes, align)
388 }
389 }
390
391 pub unsafe fn from_raw<T: Datum>(shape: &[usize], content: &[u8]) -> TractResult<Tensor> {
395 Tensor::from_raw_dt(T::datum_type(), shape, content)
396 }
397
398 pub unsafe fn from_raw_aligned<T: Datum>(
399 shape: &[usize],
400 content: &[u8],
401 align: usize,
402 ) -> TractResult<Tensor> {
403 Tensor::from_raw_dt_align(T::datum_type(), shape, content, align)
404 }
405
406 pub unsafe fn from_raw_dt(
407 dt: DatumType,
408 shape: &[usize],
409 content: &[u8],
410 ) -> TractResult<Tensor> {
411 Self::from_raw_dt_align(dt, shape, content, dt.alignment())
412 }
413
414 pub unsafe fn from_raw_dt_align(
415 dt: DatumType,
416 shape: &[usize],
417 content: &[u8],
418 align: usize,
419 ) -> TractResult<Tensor> {
420 let mut tensor = Tensor::uninitialized_aligned_dt(dt, shape, align)?;
421 tensor.as_bytes_mut().copy_from_slice(content);
422 Ok(tensor)
423 }
424
425 pub unsafe fn from_slice_align<T: Datum>(content: &[T], align: usize) -> TractResult<Tensor> {
426 let bytes = if content.len() == 0 {
427 &[]
428 } else {
429 std::slice::from_raw_parts(
430 content.as_ptr() as *const u8,
431 content.len() * T::datum_type().size_of(),
432 )
433 };
434 Self::from_raw_dt_align(T::datum_type(), &[content.len()], bytes, align)
435 }
436
437 #[inline]
439 pub fn rank(&self) -> usize {
440 self.shape.len()
441 }
442
443 #[inline]
445 pub fn shape(&self) -> &[usize] {
446 &self.shape
447 }
448
449 #[inline]
451 #[allow(clippy::len_without_is_empty)]
452 pub fn len(&self) -> usize {
453 self.len
454 }
455
456 #[inline]
458 #[allow(clippy::len_without_is_empty)]
459 pub fn volume(&self) -> usize {
460 self.len
461 }
462
463 #[inline]
465 pub fn strides(&self) -> &[isize] {
466 &self.strides
467 }
468
469 fn update_strides_and_len(&mut self) {
470 self.strides.clear();
471 if self.shape.len() == 0 {
472 self.len = 1;
473 return;
474 }
475 compute_natural_stride_to(&mut self.strides, &self.shape);
476 self.len = unsafe { *self.strides.get_unchecked(0) as usize * self.shape.get_unchecked(0) };
477 }
478
479 pub unsafe fn set_shape_unchecked(&mut self, shape: &[usize]) {
481 if shape != &*self.shape {
482 self.shape.clear();
483 self.shape.extend_from_slice(shape);
484 self.update_strides_and_len();
485 }
486 }
487
488 pub unsafe fn set_geometry_unchecked(&mut self, shape: &[usize], strides: &[isize]) {
490 self.shape.clear();
491 self.shape.extend_from_slice(shape);
492 self.strides.clear();
493 self.strides.extend_from_slice(strides);
494 }
495
496 pub fn set_shape(&mut self, shape: &[usize]) -> TractResult<()> {
498 if self.len() != shape.iter().product::<usize>() {
499 bail!("Invalid reshape {:?} to {:?}", self.shape, shape);
500 }
501 unsafe { self.set_shape_unchecked(shape) }
502 Ok(())
503 }
504
505 pub fn permute_axes(self, axes: &[usize]) -> TractResult<Tensor> {
506 ensure!(axes.iter().duplicates().next().is_none());
507 ensure!(axes.iter().all(|a| *a < self.rank()));
508 unsafe {
509 #[inline]
510 unsafe fn permute<T: Datum>(axes: &[usize], input: Tensor) -> Tensor {
511 input.into_array_unchecked::<T>().permuted_axes(axes).into_tensor()
512 }
513 let dt = self.datum_type();
514 let mut t = dispatch_datum_by_size!(permute(self.datum_type())(axes, self));
515 t.set_datum_type(dt);
516 Ok(t)
517 }
518 }
519
520 pub fn move_axis(self, from: usize, to: usize) -> TractResult<Tensor> {
521 let mut permutation: Vec<usize> = (0..self.rank()).collect();
522 permutation.remove(from);
523 permutation.insert(to, from);
524 self.permute_axes(&permutation)
525 }
526
527 pub fn collapse_axis_with_next(mut self, axis: usize) -> Tensor {
528 let removed = self.shape.remove(axis + 1);
529 self.shape[axis] *= removed;
530 self.update_strides_and_len();
531 self
532 }
533
534 pub fn split_axis(mut self, axis: usize, outer_dim: usize) -> TractResult<Tensor> {
535 if self.shape[axis] % outer_dim != 0 {
536 bail!(
537 "Invalid axis split, shape is {:?}, axis split at {}, outer {}",
538 self.shape,
539 axis,
540 outer_dim
541 );
542 }
543 self.shape.insert(axis + 1, self.shape[axis] / outer_dim);
544 self.shape[axis] = outer_dim;
545 self.update_strides_and_len();
546 Ok(self)
547 }
548
549 pub fn into_shape(mut self, shape: &[usize]) -> TractResult<Tensor> {
551 self.set_shape(shape)?;
552 Ok(self)
553 }
554
555 pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
556 self.shape.insert(axis, 1);
557 self.strides.insert(axis, self.strides.get(axis).copied().unwrap_or(1));
558 Ok(())
559 }
560
561 pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
562 ensure!(self.shape[axis] == 1, "Remove a non-1 axis: axis {} in {:?}", axis, self);
563 self.shape.remove(axis);
564 self.strides.remove(axis);
565 Ok(())
566 }
567
568 pub fn broadcast_into_rank(mut self, rank: usize) -> TractResult<Tensor> {
569 self.broadcast_to_rank(rank)?;
570 self.update_strides_and_len();
571 Ok(self)
572 }
573
574 pub fn broadcast_to_rank(&mut self, rank: usize) -> TractResult<()> {
575 if rank < self.rank() {
576 bail!("Can only broadcast to higher rank")
577 }
578 while self.shape.len() < rank {
579 self.shape.insert(0, 1)
580 }
581 self.update_strides_and_len();
582 Ok(())
583 }
584
585 pub fn broadcast_scalar_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
586 if self.rank() > 0 {
587 bail!("broadcast_scalar_to_shape called on {:?}, which is not a salar", self);
588 }
589 unsafe fn make<T: Datum>(src: &Tensor, dst: &mut Tensor) {
590 let value: &T = src.to_scalar_unchecked::<T>();
591 dst.as_slice_mut_unchecked::<T>().iter_mut().for_each(|item| *item = value.clone());
592 }
593 unsafe {
594 let mut t = Tensor::uninitialized_dt(self.datum_type(), shape)?;
595 dispatch_datum_by_size!(make(self.datum_type())(self, &mut t));
596 Ok(t)
597 }
598 }
599
600 fn broadcast_to_shape_t<T: Datum>(&self, shape: &[usize]) -> TractResult<Tensor> {
601 unsafe {
602 let view = self.to_array_view_unchecked::<T>();
603 let mut output = view
604 .broadcast(shape)
605 .with_context(|| format!("Broadcasting {view:?} to {shape:?}"))?
606 .into_owned()
607 .into_tensor();
608 output.set_datum_type(self.datum_type());
609 Ok(output)
610 }
611 }
612
613 pub fn broadcast_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
614 dispatch_datum!(Self::broadcast_to_shape_t(self.dt)(self, shape))
615 }
616
617 pub fn broadcast_vector_to_shape(&self, shape: &[usize], axis: usize) -> TractResult<Tensor> {
618 ensure!(self.rank() == 1);
619 ensure!(shape[axis] == self.len());
620 if !self.datum_type().is_copy() {
621 let mut vec_shape = vec![1; shape.len()];
622 vec_shape[axis] = self.len();
623 return self.clone().into_shape(&vec_shape)?.broadcast_to_shape(shape);
624 }
625 unsafe {
626 let mut output = Tensor::uninitialized_dt(self.datum_type(), shape)?;
627 if output.len() == 0 {
628 return Ok(output);
629 }
630 let inner_len = shape[axis + 1..].iter().product::<usize>();
631
632 unsafe fn splat<T>(input: &Tensor, output: &mut Tensor, inner_len: usize)
633 where
634 T: Datum + Copy,
635 {
636 for ix in 0..input.len() {
637 let value: T = input.as_slice_unchecked()[ix];
638 output.as_slice_mut_unchecked::<T>()[ix * inner_len..(ix + 1) * inner_len]
639 .iter_mut()
640 .for_each(|item| *item = value);
641 }
642 }
643 dispatch_copy_by_size!(splat(self.datum_type())(&self, &mut output, inner_len));
644
645 let outer_len = shape[0..axis].iter().product::<usize>();
646 let repeat_bytes_len = inner_len * self.as_bytes().len();
647 let bytes = output.as_bytes_mut();
648 for ix in 1..outer_len {
649 bytes.copy_within(0..repeat_bytes_len, ix * repeat_bytes_len);
650 }
651
652 Ok(output)
653 }
654 }
655
656 fn clip_range_bounds(
657 &self,
658 axis: usize,
659 range: impl std::ops::RangeBounds<usize>,
660 ) -> Range<usize> {
661 use std::ops::Bound;
662 let start = match range.start_bound() {
663 Bound::Included(ix) => *ix,
664 Bound::Excluded(ix) => ix + 1,
665 Bound::Unbounded => 0,
666 };
667 let end = match range.end_bound() {
668 Bound::Included(ix) => *ix + 1,
669 Bound::Excluded(ix) => *ix,
670 Bound::Unbounded => self.shape()[axis],
671 };
672 start..end
673 }
674
675 pub fn assign_slice(
676 &mut self,
677 range: impl std::ops::RangeBounds<usize>,
678 src: &Tensor,
679 src_range: impl std::ops::RangeBounds<usize>,
680 axis: usize,
681 ) -> TractResult<()> {
682 let range = self.clip_range_bounds(axis, range);
683 let src_range = src.clip_range_bounds(axis, src_range);
684 ensure!(
685 src.datum_type() == self.datum_type(),
686 "Attempt to assign into {:?} from {:?}, datum type mismatch",
687 self.datum_type(),
688 src.datum_type()
689 );
690 ensure!(
691 src_range.len() == range.len(),
692 "Attempt to assign a range of {:?} from a range of {:?}",
693 range,
694 src_range,
695 );
696 ensure!(
697 self.rank() == src.rank()
698 && itertools::izip!(0.., self.shape(), src.shape())
699 .all(|(ix, dst, src)| ix == axis || src == dst),
700 "Attempt to assign a {}-axis range of {:?} from a range of {:?}",
701 axis,
702 self,
703 src
704 );
705 ensure!(
706 src_range.end <= src.shape()[axis],
707 "Assigning from invalid slice (axis {}, {:?}) of {:?}",
708 axis,
709 src_range,
710 src
711 );
712 ensure!(
713 range.end <= self.shape()[axis],
714 "Assigning to invalid slice (axis {}, {:?}) of {:?}",
715 axis,
716 range,
717 self
718 );
719 unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
720 Ok(())
721 }
722
723 pub unsafe fn assign_slice_unchecked(
724 &mut self,
725 range: impl std::ops::RangeBounds<usize>,
726 src: &Tensor,
727 src_range: impl std::ops::RangeBounds<usize>,
728 axis: usize,
729 ) {
730 let range = self.clip_range_bounds(axis, range);
731 let src_range = src.clip_range_bounds(axis, src_range);
732 self.assign_slice_from_resolved(range, src, src_range, axis);
733 }
734
735 unsafe fn assign_slice_from_resolved(
736 &mut self,
737 range: std::ops::Range<usize>,
738 src: &Tensor,
739 src_range: std::ops::Range<usize>,
740 axis: usize,
741 ) {
742 use ndarray::Slice;
743 unsafe fn assign_slice_t<T: Datum>(
744 to: &mut Tensor,
745 to_range: Range<usize>,
746 from: &Tensor,
747 from_range: Range<usize>,
748 axis: usize,
749 ) {
750 to.to_array_view_mut_unchecked::<T>()
751 .slice_axis_mut(Axis(axis), Slice::from(to_range))
752 .assign(
753 &from
754 .to_array_view_unchecked::<T>()
755 .slice_axis(Axis(axis), Slice::from(from_range)),
756 )
757 }
758 if self.datum_type().is_copy() && self.shape[..axis].iter().all(|d| *d == 1) {
759 let stride = self.strides[axis] as usize * self.datum_type().size_of();
760 let dst_start = (stride * range.start) as isize;
761 let src_start = (stride * src_range.start) as isize;
762 let len = stride * range.len();
763 if len > 0 {
764 if self.data.as_ptr() != src.data.as_ptr() {
765 std::ptr::copy_nonoverlapping(
766 src.data.as_ptr().offset(src_start),
767 self.data.as_mut_ptr().offset(dst_start),
768 len,
769 );
770 } else {
771 std::ptr::copy(
772 src.data.as_ptr().offset(src_start),
773 self.data.as_mut_ptr().offset(dst_start),
774 len,
775 );
776 }
777 }
778 } else {
779 dispatch_datum!(assign_slice_t(self.datum_type())(self, range, src, src_range, axis));
780 }
781 }
782
783 #[inline]
785 pub fn datum_type(&self) -> DatumType {
786 self.dt
787 }
788
789 #[inline]
791 pub unsafe fn set_datum_type(&mut self, dt: DatumType) {
792 self.dt = dt
793 }
794
795 pub fn dump(&self, force_full: bool) -> TractResult<String> {
799 unsafe fn dump_t<D: Datum>(tensor: &Tensor, n: usize) -> String {
800 if let Some(qp) = tensor.datum_type().qparams() {
801 let integers = tensor.cast_to::<i32>().unwrap();
802 integers.as_slice_unchecked::<i32>()[0..n]
803 .iter()
804 .map(|x| format!("[{}]({})", x, qp.dq(*x)))
805 .join(", ")
806 } else {
807 tensor.as_slice_unchecked::<D>()[0..n].iter().join(", ")
808 }
809 }
810 unsafe {
811 let trunc = self.len() > 12 && !force_full;
812 let data = dispatch_datum!(dump_t(self.datum_type())(
813 self,
814 if trunc { 12 } else { self.len() }
815 ));
816 Ok(format!(
817 "{},{:?} {}{}",
818 self.shape.iter().join(","),
819 self.dt,
820 data,
821 if trunc { "..." } else { "" }
822 ))
823 }
824 }
825
826 pub fn close_enough(
828 &self,
829 other: &Self,
830 approx: impl Into<Approximation> + std::fmt::Debug,
831 ) -> TractResult<()> {
832 let approx = approx.into();
833 if self.shape() != other.shape() {
834 bail!("Shape mismatch {:?} != {:?}", self.shape(), other.shape())
835 }
836 let (atol, rtol, outliers) = approx.atol_rtol_outliers(&self.datum_type());
837 let ma = self.cast_to::<f32>()?;
838 let ma = ma.to_array_view::<f32>()?;
839 let mb = other.cast_to::<f32>()?;
840 let mb = mb.to_array_view::<f32>()?;
841 let mut first_outlier = None;
842 let mut outliers_count = 0;
843 ndarray::indices_of(&ma).into_iter().for_each(|indices| {
844 let a = ma[&indices];
845 let b = mb[&indices];
846 if !((a.is_nan() && b.is_nan())
847 || (a.is_infinite() && b.is_infinite() && a.signum() == b.signum())
848 || (a - b).abs() <= atol as f32 + rtol as f32 * b.abs())
849 {
850 if outliers_count == 0 {
851 first_outlier = Some(indices.as_array_view().to_vec());
852 }
853 outliers_count += 1;
854 }
855 });
856 if self.volume() > 0 && outliers_count as f64 / self.volume() as f64 > outliers {
857 let indices = first_outlier.unwrap();
858 let a = ma[&*indices];
859 let b = mb[&*indices];
860 bail!(
861 "Mismatch. First outlier: {:?} for {:?}) at {:?} {} != {}. Outliers: {} / {} = {:0.5} > {:0.5}.",
862 approx,
863 self.datum_type(),
864 indices,
865 a,
866 b,
867 outliers_count,
868 self.volume(),
869 outliers_count as f64 / self.volume() as f64,
870 outliers
871 );
872 }
873 Ok(())
874 }
875
876 pub fn into_array<D: Datum>(self) -> TractResult<ArrayD<D>> {
878 Ok(self.to_array_view::<D>()?.to_owned())
879 }
880
881 pub unsafe fn into_array_unchecked<D: Datum>(self) -> ArrayD<D> {
883 self.to_array_view_unchecked::<D>().to_owned()
884 }
885
886 fn check_for_access<D: Datum>(&self) -> TractResult<()> {
887 ensure!(
888 self.datum_type().unquantized() == D::datum_type().unquantized(),
889 "Tensor datum type error: tensor is {:?}, accessed as {:?}",
890 self.datum_type(),
891 D::datum_type(),
892 );
893 Ok(())
894 }
895
896 pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<D>> {
898 self.check_for_access::<D>()?;
899 unsafe { Ok(self.to_array_view_unchecked()) }
900 }
901
902 pub fn to_array_view_mut<D: Datum>(&mut self) -> TractResult<ArrayViewMutD<D>> {
904 self.check_for_access::<D>()?;
905 unsafe { Ok(self.to_array_view_mut_unchecked()) }
906 }
907
908 pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<D> {
910 if self.len() != 0 {
911 ArrayViewD::from_shape_ptr(&*self.shape, self.data.as_ptr() as *const D)
912 } else {
913 ArrayViewD::from_shape(&*self.shape, &[]).unwrap()
914 }
915 }
916
917 pub unsafe fn to_array_view_mut_unchecked<D: Datum>(&mut self) -> ArrayViewMutD<D> {
919 if self.len() != 0 {
920 ArrayViewMutD::from_shape_ptr(&*self.shape, self.data.as_mut_ptr() as *mut D)
921 } else {
922 ArrayViewMutD::from_shape(&*self.shape, &mut []).unwrap()
923 }
924 }
925
926 pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
928 self.check_for_access::<D>()?;
929 Ok(self.data.as_ptr() as *const D)
930 }
931
932 pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
934 self.data.as_ptr() as *const D
935 }
936
937 pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
939 self.data.as_mut_ptr() as *mut D
940 }
941
942 pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
944 self.as_ptr::<D>().map(|p| p as *mut D)
945 }
946
947 pub fn as_slice<D: Datum>(&self) -> TractResult<&[D]> {
949 let ptr: *const D = self.as_ptr()?;
950 if self.data.len() == 0 {
951 Ok(&[])
952 } else {
953 unsafe { Ok(std::slice::from_raw_parts::<D>(ptr, self.len())) }
954 }
955 }
956
957 pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
959 let ptr: *mut D = self.as_ptr_mut()?;
960 if self.data.len() == 0 {
961 Ok(&mut [])
962 } else {
963 unsafe { Ok(std::slice::from_raw_parts_mut::<D>(ptr, self.len())) }
964 }
965 }
966
967 pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &[D] {
969 if self.data.len() == 0 {
970 &[]
971 } else {
972 std::slice::from_raw_parts::<D>(self.as_ptr_unchecked(), self.len())
973 }
974 }
975
976 pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
978 if self.data.len() == 0 {
979 &mut []
980 } else {
981 std::slice::from_raw_parts_mut::<D>(self.as_ptr_mut_unchecked(), self.len())
982 }
983 }
984
985 pub fn to_scalar<D: Datum>(&self) -> TractResult<&D> {
987 self.check_for_access::<D>()?;
988 if self.len() == 0 {
989 bail!("to_scalar called on empty tensor ({:?})", self)
990 }
991 if self.len() > 1 {
992 bail!("to_scalar called on a tensor with multiple values ({:?})", self)
993 }
994 unsafe { Ok(self.to_scalar_unchecked()) }
995 }
996
997 pub fn to_scalar_tensor(&self) -> TractResult<Tensor> {
999 fn to_scalar_tensor_t<D: Datum>(t: &Tensor) -> TractResult<Tensor> {
1000 Ok(litteral::tensor0(t.to_scalar::<D>()?.clone()))
1001 }
1002 dispatch_datum!(to_scalar_tensor_t(self.datum_type())(self))
1003 }
1004
1005 pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &D {
1007 &*(self.data.as_ptr() as *const D)
1008 }
1009
1010 pub fn to_scalar_mut<D: Datum>(&mut self) -> TractResult<&mut D> {
1012 self.check_for_access::<D>()?;
1013 if self.len() == 0 {
1014 bail!("to_scalar_mut called on empty tensor ({:?})", self)
1015 }
1016 if self.len() > 1 {
1017 bail!("to_scalar called on a tensor with multiple values ({:?})", self)
1018 }
1019 unsafe { Ok(self.to_scalar_mut_unchecked()) }
1020 }
1021
1022 pub unsafe fn to_scalar_mut_unchecked<D: Datum>(&mut self) -> &mut D {
1024 &mut *(self.data.as_mut_ptr() as *mut D)
1025 }
1026
1027 pub fn as_bytes(&self) -> &[u8] {
1028 self.data.as_bytes()
1029 }
1030
1031 pub fn as_bytes_mut(&mut self) -> &mut [u8] {
1032 self.data.as_bytes_mut()
1033 }
1034
1035 unsafe fn is_uniform_t<T: Datum>(&self) -> bool {
1036 let slice = self.as_slice_unchecked::<T>();
1037 slice[1..].iter().all(|x| x == &slice[0])
1038 }
1039
1040 pub fn is_uniform(&self) -> bool {
1041 if self.len() <= 1 {
1042 return true;
1043 }
1044 unsafe { dispatch_datum!(Tensor::is_uniform_t(self.datum_type())(self)) }
1045 }
1046
1047 unsafe fn as_uniform_t<T: Datum>(&self) -> Tensor {
1048 let v: T = self.as_slice_unchecked::<T>()[0].clone();
1049 litteral::tensor0(v)
1050 }
1051
1052 pub fn as_uniform(&self) -> Option<Tensor> {
1053 if self.len() >= 1 && self.is_uniform() {
1054 unsafe {
1055 let mut t = dispatch_datum!(Tensor::as_uniform_t(self.datum_type())(self));
1056 t.set_datum_type(self.datum_type());
1057 Some(t)
1058 }
1059 } else {
1060 None
1061 }
1062 }
1063
1064 pub fn is_all_zero(&self) -> TractResult<bool> {
1065 Ok(self.len() == 0 || self.as_uniform().map(|t| t.is_zero().unwrap()).unwrap_or(false))
1066 }
1067
1068 pub fn is_zero(&self) -> TractResult<bool> {
1069 Ok(self == &Tensor::zero_scalar_dt(self.dt)?)
1070 }
1071
1072 unsafe fn natural_cast<
1073 Source: Datum + num_traits::AsPrimitive<Target>,
1074 Target: Datum + Copy,
1075 >(
1076 &self,
1077 other: &mut Tensor,
1078 ) {
1079 self.as_slice_unchecked::<Source>()
1080 .iter()
1081 .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1082 .for_each(|(s, d)| *d = s.as_());
1083 }
1084
1085 unsafe fn cast_number_to_bool<Source: Datum + num_traits::Zero>(&self, other: &mut Tensor) {
1086 self.as_slice_unchecked::<Source>()
1087 .iter()
1088 .zip(other.as_slice_mut_unchecked::<bool>().iter_mut())
1089 .for_each(|(s, d)| *d = !s.is_zero());
1090 }
1091
1092 unsafe fn cast_from_string<Target: Datum + core::str::FromStr>(
1093 &self,
1094 other: &mut Tensor,
1095 ) -> TractResult<()> {
1096 for (s, d) in self
1097 .as_slice_unchecked::<String>()
1098 .iter()
1099 .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1100 {
1101 *d = s
1102 .parse()
1103 .map_err(|_| format_err!("Can not parse as {:?}", Target::datum_type()))?;
1104 }
1105 Ok(())
1106 }
1107
1108 unsafe fn cast_to_string<Source: Datum>(&self, other: &mut Tensor) {
1109 for (s, d) in self
1110 .as_slice_unchecked::<Source>()
1111 .iter()
1112 .zip(other.as_slice_mut_unchecked::<String>().iter_mut())
1113 {
1114 *d = s.to_string()
1115 }
1116 }
1117
1118 pub fn cast_to<D: Datum>(&self) -> TractResult<Cow<Tensor>> {
1120 self.cast_to_dt(D::datum_type())
1121 }
1122
1123 #[allow(clippy::redundant_closure_call)]
1125 pub fn cast_to_dt(&self, dst_dt: DatumType) -> TractResult<Cow<Tensor>> {
1126 unsafe {
1127 if self.dt == dst_dt {
1128 return Ok(Cow::Borrowed(self));
1129 }
1130 if self.dt == TDim::datum_type() && (dst_dt.is_integer() || dst_dt.is_float()) {
1131 let slice = self.as_slice_unchecked::<TDim>();
1132 let mut ints = Self::uninitialized::<i64>(&self.shape)?;
1133 let ints_slice = ints.as_slice_mut_unchecked::<i64>();
1134 for i in 0..self.len() {
1135 ints_slice[i] = slice[i].to_i64()?;
1136 }
1137 return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1138 }
1139 if self.dt == bool::datum_type()
1140 && (dst_dt.is_integer() || dst_dt.is_float() || dst_dt == TDim::datum_type())
1141 {
1142 let slice = self.as_slice_unchecked::<bool>();
1143 let mut ints = Self::uninitialized::<i8>(&self.shape)?;
1144 let ints_slice = ints.as_slice_mut_unchecked::<i8>();
1145 for i in 0..self.len() {
1146 ints_slice[i] = slice[i] as usize as i8;
1147 }
1148 return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1149 }
1150 let mut result = Self::uninitialized_dt(dst_dt, &self.shape)?;
1151 if self.dt == DatumType::String {
1152 dispatch_numbers!(Self::cast_from_string(dst_dt)(self, &mut result))?;
1153 return Ok(Cow::Owned(result));
1154 }
1155 if dst_dt == DatumType::String {
1156 dispatch_datum!(Self::cast_to_string(self.dt)(self, &mut result));
1157 return Ok(Cow::Owned(result));
1158 }
1159 macro_rules! n {
1160 ($source:ty) => {
1161 if <$source>::datum_type() == self.datum_type() {
1162 match dst_dt {
1163 DatumType::I8 => self.natural_cast::<$source, i8>(&mut result),
1164 DatumType::I16 => self.natural_cast::<$source, i16>(&mut result),
1165 DatumType::I32 => self.natural_cast::<$source, i32>(&mut result),
1166 DatumType::I64 => self.natural_cast::<$source, i64>(&mut result),
1167 DatumType::U8 => self.natural_cast::<$source, u8>(&mut result),
1168 DatumType::U16 => self.natural_cast::<$source, u16>(&mut result),
1169 DatumType::U32 => self.natural_cast::<$source, u32>(&mut result),
1170 DatumType::U64 => self.natural_cast::<$source, u64>(&mut result),
1171 DatumType::F16 => self.natural_cast::<$source, f16>(&mut result),
1172 DatumType::F32 => self.natural_cast::<$source, f32>(&mut result),
1173 DatumType::F64 => self.natural_cast::<$source, f64>(&mut result),
1174 DatumType::TDim => {
1175 let ints = self.cast_to::<i32>()?;
1176 let slice = ints.as_slice_unchecked::<i32>();
1177 let result = result.as_slice_mut_unchecked::<TDim>();
1178 for i in 0..self.len() {
1179 result[i] = slice[i].into();
1180 }
1181 }
1182 DatumType::Bool => self.cast_number_to_bool::<$source>(&mut result),
1183 _ => todo!(),
1184 }
1185 return Ok(Cow::Owned(result));
1186 };
1187 };
1188 }
1189 if !dst_dt.is_quantized() && !self.datum_type().is_quantized() {
1191 n!(u8);
1192 n!(u16);
1193 n!(u32);
1194 n!(u64);
1195 n!(i8);
1196 n!(i16);
1197 n!(i32);
1198 n!(i64);
1199 n!(f16);
1200 n!(f32);
1201 n!(f64);
1202 } else {
1203 let (s_zp, s_scale) = self.datum_type().zp_scale();
1204 let (d_zp, d_scale) = dst_dt.zp_scale();
1205 if self.datum_type().is_quantized() && dst_dt.is_float() {
1206 macro_rules! q_to_fp {
1207 ($source:ty, $dest:ty) => {
1208 if <$source>::datum_type().unquantized()
1209 == self.datum_type().unquantized()
1210 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1211 {
1212 self.as_slice_unchecked::<$source>()
1213 .iter()
1214 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1215 .for_each(|(&s, d)| {
1216 *d = (s as $dest - s_zp as $dest) * s_scale as $dest;
1217 });
1218 return Ok(Cow::Owned(result));
1219 }
1220 };
1221 }
1222 q_to_fp!(i8, f64);
1223 q_to_fp!(i8, f32);
1224 q_to_fp!(u8, f64);
1225 q_to_fp!(u8, f32);
1226 }
1227 macro_rules! q8_to_q8 {
1229 ($typ:ty) => {
1230 if dst_dt.unquantized() == <$typ>::datum_type() {
1231 self.as_slice_unchecked::<$typ>()
1232 .iter()
1233 .zip(result.as_slice_mut_unchecked::<$typ>().iter_mut())
1234 .for_each(|(&s, d)| {
1235 *d = (d_zp as i32
1236 + scale_by(s as i32 - s_zp as i32, s_scale / d_scale))
1237 .clamp_cast()
1238 });
1239 return Ok(Cow::Owned(result));
1240 }
1241 };
1242 }
1243
1244 macro_rules! q_via_f32 {
1245 ($source:ty, $dest:ty, $round:expr) => {
1246 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1247 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1248 {
1249 self.as_slice_unchecked::<$source>()
1250 .iter()
1251 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1252 .for_each(|(&s, d)| {
1253 let s_float = (s as f32 - s_zp as f32) * s_scale as f32;
1254 let d_float = s_float as f32 / d_scale as f32 + d_zp as f32;
1255 *d = $round(d_float);
1256 });
1257 return Ok(Cow::Owned(result));
1258 }
1259 };
1260 }
1261
1262 macro_rules! q_n {
1263 (clamp $source:ty, $dest:ty) => {{
1264 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1265 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1266 {
1267 self.as_slice_unchecked::<$source>()
1268 .iter()
1269 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1270 .for_each(|(&s, d)| {
1271 *d = s.clamp_cast();
1272 });
1273 return Ok(Cow::Owned(result));
1274 }
1275 }};
1276 ($source:ty, $dest:ty) => {{
1277 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1278 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1279 {
1280 self.as_slice_unchecked::<$source>()
1281 .iter()
1282 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1283 .for_each(|(&s, d)| {
1284 *d = s as $dest;
1285 });
1286 return Ok(Cow::Owned(result));
1287 }
1288 }};
1289 }
1290
1291 if dst_dt.unquantized() == self.datum_type().unquantized()
1292 && dst_dt.is_quantized()
1293 && self.datum_type().is_quantized()
1294 {
1295 q8_to_q8!(i8);
1296 q8_to_q8!(u8);
1297 }
1298
1299 q_via_f32!(f32, i8, |f| round_ties_to_even(f).clamp_cast());
1300 q_via_f32!(f32, u8, |f| round_ties_to_even(f).clamp_cast());
1301 q_via_f32!(f32, i32, |f| round_ties_to_even(f).clamp_cast());
1302 q_via_f32!(i8, f32, |f| f);
1303 q_via_f32!(u8, f32, |f| f);
1304 q_via_f32!(i32, f32, |f| f);
1305
1306 if dst_dt.is_quantized() && self.datum_type().is_quantized() {
1307 q_via_f32!(u8, i8, |f| round_ties_to_even(f).clamp_cast());
1308 q_via_f32!(i8, u8, |f| round_ties_to_even(f).clamp_cast());
1309 q_via_f32!(i32, u8, |f| round_ties_to_even(f).clamp_cast());
1310 q_via_f32!(i32, i8, |f| round_ties_to_even(f).clamp_cast());
1311 q_via_f32!(u8, i32, |f| round_ties_to_even(f).clamp_cast());
1312 q_via_f32!(i8, i32, |f| round_ties_to_even(f).clamp_cast());
1313
1314 q_via_f32!(i8, i8, |f| round_ties_to_even(f).clamp_cast());
1316 q_via_f32!(u8, u8, |f| round_ties_to_even(f).clamp_cast());
1317 }
1318
1319 q_n!(i8, i32);
1320 q_n!(i8, u32);
1321 q_n!(u8, i32);
1322 q_n!(u8, u32);
1323 q_n!(clamp i32, i8);
1324 q_n!(clamp i32, u8);
1325 q_n!(clamp u32, i8);
1326 q_n!(clamp u32, u8);
1327 q_n!(i8, i8);
1328 q_n!(u8, u8);
1329 q_n!(i32, i32);
1330 q_n!(u32, u32);
1331 }
1332
1333 bail!("Unsupported cast from {:?} to {:?}", self.dt, dst_dt)
1334 }
1335 }
1336
1337 pub fn cast_to_scalar<D: Datum + Copy>(&self) -> TractResult<D> {
1339 let casted = self.cast_to::<D>()?;
1340 casted.to_scalar::<D>().copied()
1341 }
1342
1343 pub fn nth(&self, nth: usize) -> TractResult<Tensor> {
1345 if nth >= self.len() {
1346 bail!(
1347 "nth called with {}th element on a tensor of len {} ({:?}",
1348 nth,
1349 self.len(),
1350 self
1351 );
1352 }
1353 unsafe fn nth_t<T: Datum>(me: &Tensor, nth: usize, output: &mut Tensor) {
1354 let value = me.as_slice_unchecked::<T>()[nth].clone();
1355 output.as_slice_mut_unchecked::<T>()[0] = value;
1356 }
1357 unsafe {
1358 let mut output = Tensor::uninitialized_dt(self.datum_type(), &[])?;
1359 dispatch_datum_by_size!(nth_t(self.datum_type())(self, nth, &mut output));
1360 Ok(output)
1361 }
1362 }
1363
1364 fn eq_dt(&self, other: &Tensor) -> TractResult<bool> {
1366 unsafe fn eq_t<D: Datum>(me: &Tensor, other: &Tensor) -> bool {
1367 me.as_slice_unchecked::<D>() == other.as_slice_unchecked::<D>()
1368 }
1369
1370 unsafe {
1371 Ok(self.datum_type() == other.datum_type()
1372 && self.shape() == other.shape()
1373 && dispatch_datum!(eq_t(self.dt)(self, other)))
1374 }
1375 }
1376
1377 fn from_datum<T: Datum>(mut it: ArrayD<T>) -> Tensor {
1378 unsafe {
1379 let mut t = Self::uninitialized::<T>(it.shape()).unwrap();
1380 if let Some(slice) = it.as_slice_mut() {
1381 if t.datum_type().is_copy() {
1382 std::ptr::copy_nonoverlapping(
1383 slice.as_ptr() as *const i8,
1384 t.as_ptr_mut_unchecked(),
1385 t.data.layout().size(),
1386 );
1387 } else {
1388 t.as_slice_mut_unchecked::<T>()
1389 .iter_mut()
1390 .zip(slice.iter_mut())
1391 .for_each(|(t, s)| *t = std::mem::take(s));
1392 }
1393 return t;
1394 }
1395 if it.strides().iter().all(|&s| s > 0) {
1396 let mut len_and_strides: TVec<(usize, usize)> = tvec!();
1397 for (len, stride) in itertools::izip!(it.shape(), it.strides(), t.strides())
1398 .sorted_by_key(|(_, src, _)| *src)
1399 .map(|(l, _, dst)| (*l as isize, *dst))
1400 {
1401 if !len_and_strides.is_empty()
1402 && len_and_strides.last().unwrap().1 * len_and_strides.last().unwrap().0
1403 == stride as usize
1404 {
1405 len_and_strides.last_mut().unwrap().0 *= len as usize;
1406 } else {
1407 len_and_strides.push((len as usize, stride as usize));
1408 }
1409 }
1410 len_and_strides.reverse();
1411 crate::scatter::scatter_contig_data(
1412 it.as_ptr(),
1413 t.as_ptr_mut_unchecked(),
1414 &len_and_strides,
1415 );
1416 return t;
1417 }
1418 t.as_slice_mut_unchecked().iter_mut().zip(it).for_each(|(t, a)| *t = a);
1420 t
1421 }
1422 }
1423
1424 pub fn deep_clone(&self) -> Tensor {
1425 unsafe {
1426 let mut tensor = Tensor::uninitialized_dt(self.datum_type(), self.shape()).unwrap();
1427 if self.len() > 0 {
1428 if self.dt.is_copy() {
1429 self.data.as_ptr().copy_to_nonoverlapping(
1430 tensor.as_bytes_mut().as_mut_ptr(),
1431 self.data.layout().size(),
1432 )
1433 } else if self.dt == DatumType::String {
1434 tensor
1435 .as_slice_mut_unchecked::<String>()
1436 .clone_from_slice(self.as_slice_unchecked());
1437 } else if self.dt == DatumType::Blob {
1438 tensor
1439 .as_slice_mut_unchecked::<Blob>()
1440 .clone_from_slice(self.as_slice_unchecked());
1441 } else if self.dt == DatumType::Opaque {
1442 tensor
1443 .as_slice_mut_unchecked::<Opaque>()
1444 .clone_from_slice(self.as_slice_unchecked());
1445 } else if self.dt == DatumType::TDim {
1446 tensor
1447 .as_slice_mut_unchecked::<TDim>()
1448 .clone_from_slice(self.as_slice_unchecked());
1449 }
1450 }
1451 tensor
1452 }
1453 }
1454
1455 pub fn slice(&self, axis: usize, start: usize, end: usize) -> TractResult<Tensor> {
1456 if axis >= self.rank() {
1457 bail!("Can not slice at axis {} tensor {:?}", axis, self);
1458 }
1459 if start > self.shape[axis] || end > self.shape[axis] || start >= end {
1460 bail!("Invalid slicing range {start}..{end} on axis {axis} for {self:?}");
1461 }
1462 fn slice_t<T: Datum>(
1463 t: &Tensor,
1464 axis: usize,
1465 start: usize,
1466 end: usize,
1467 ) -> TractResult<Tensor> {
1468 Ok(t.to_array_view::<T>()?
1469 .slice_axis(ndarray::Axis(axis), (start..end).into())
1470 .into_owned()
1471 .into_tensor())
1472 }
1473 dispatch_datum!(slice_t(self.datum_type())(self, axis, start, end))
1474 }
1475
1476 #[inline]
1477 pub fn view(&self) -> view::TensorView {
1478 unsafe { view::TensorView::view(self) }
1479 }
1480
1481 #[inline]
1482 pub fn view_at_prefix(&self, prefix: &[usize]) -> TractResult<view::TensorView> {
1483 view::TensorView::at_prefix(self, prefix)
1484 }
1485
1486 #[inline]
1487 pub fn view_offsetting(&self, coords: &[usize]) -> TractResult<view::TensorView> {
1488 view::TensorView::offsetting(self, coords)
1489 }
1490
1491 #[inline]
1492 pub unsafe fn view_offsetting_unchecked(&self, coords: &[usize]) -> view::TensorView {
1493 view::TensorView::offsetting_unchecked(self, coords)
1494 }
1495
1496 #[inline]
1497 pub fn view_mut(&mut self) -> view::TensorView {
1498 unsafe { view::TensorView::view(self) }
1499 }
1500
1501 #[inline]
1502 pub fn view_at_prefix_mut(&mut self, prefix: &[usize]) -> TractResult<view::TensorView> {
1503 view::TensorView::at_prefix(self, prefix)
1504 }
1505
1506 #[inline]
1507 pub fn view_offsetting_mut(&mut self, coords: &[usize]) -> TractResult<view::TensorView> {
1508 view::TensorView::offsetting(self, coords)
1509 }
1510
1511 pub fn offset_u8_as_i8(self: &Arc<Self>) -> Arc<Self> {
1513 let mut t = if let DatumType::U8 = self.dt.unquantized() {
1514 self.to_array_view::<u8>().unwrap().mapv(|v| v.wrapping_sub(128) as i8).into_tensor()
1515 } else {
1516 return self.clone();
1517 };
1518
1519 if let DatumType::QU8(qp) = self.dt {
1520 if let QParams::ZpScale { zero_point, scale } = qp {
1521 t.dt = DatumType::QI8(QParams::ZpScale { zero_point: zero_point - 128, scale });
1522 } else {
1523 t.dt = DatumType::QI8(qp);
1524 }
1525 }
1526
1527 t.into_arc_tensor()
1528 }
1529
1530 pub fn offset_i8_as_u8(self: &Arc<Self>) -> Arc<Self> {
1532 let mut t = if let DatumType::I8 = self.dt.unquantized() {
1533 self.to_array_view::<i8>().unwrap().mapv(|v| (v as u8).wrapping_add(128)).into_tensor()
1534 } else {
1535 return self.clone();
1536 };
1537
1538 if let DatumType::QI8(qp) = self.dt {
1539 if let QParams::ZpScale { zero_point, scale } = qp {
1540 t.dt = DatumType::QU8(QParams::ZpScale { zero_point: zero_point + 128, scale });
1541 } else {
1542 t.dt = DatumType::QU8(qp);
1543 }
1544 }
1545 t.into_arc_tensor()
1546 }
1547
1548 pub fn to_aligned_default(&self) -> TractResult<Self> {
1549 if self.dt.is_copy() {
1550 unsafe {
1551 let mut t = Self::uninitialized_aligned_dt(
1552 self.dt,
1553 &self.shape,
1554 Self::default_alignment(self.dt, &self.shape),
1555 )?;
1556 t.as_bytes_mut().copy_from_slice(self.as_bytes());
1557 Ok(t)
1558 }
1559 } else {
1560 let mut t = Self::zero_aligned_dt(
1561 self.dt,
1562 &self.shape,
1563 Self::default_alignment(self.dt, &self.shape),
1564 )?;
1565 if self.dt == String::datum_type() {
1566 t.as_slice_mut::<String>()?.clone_from_slice(self.as_slice()?);
1567 } else if self.dt == Blob::datum_type() {
1568 t.as_slice_mut::<Blob>()?.clone_from_slice(self.as_slice()?);
1569 } else if self.dt == TDim::datum_type() {
1570 t.as_slice_mut::<TDim>()?.clone_from_slice(self.as_slice()?);
1571 }
1572 Ok(t)
1573 }
1574 }
1575
1576 pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1577 let mut strides = tvec!();
1578 compute_natural_stride_to(&mut strides, shape);
1579 strides
1580 }
1581
1582 pub fn into_blob(mut self) -> TractResult<Blob> {
1583 ensure!(self.dt.is_copy());
1584 Ok(std::mem::take(&mut self.data))
1585 }
1586}
1587
1588impl PartialEq for Tensor {
1589 fn eq(&self, other: &Tensor) -> bool {
1590 if self.dt != other.dt || self.shape != other.shape {
1591 return false;
1592 }
1593 self.eq_dt(other).unwrap_or(false)
1594 }
1595}
1596
1597impl fmt::Debug for Tensor {
1598 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1599 let content = self.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
1600 write!(formatter, "{content}")
1601 }
1602}
1603
1604#[cfg(feature = "complex")]
1605pub fn reinterpret_inner_dim_as_complex(mut t: Tensor) -> TractResult<Tensor> {
1606 ensure!(
1607 t.shape().last() == Some(&2),
1608 "The last dimension in the tensor shape {:?} must be 2",
1609 t.shape()
1610 );
1611 unsafe {
1612 t.shape.pop();
1613 t.set_datum_type(t.datum_type().complexify()?);
1614 t.update_strides_and_len();
1615 Ok(t)
1616 }
1617}
1618
1619#[cfg(feature = "complex")]
1620pub fn reinterpret_complex_as_inner_dim(mut t: Tensor) -> TractResult<Tensor> {
1621 unsafe {
1622 t.shape.push(2);
1623 t.set_datum_type(t.datum_type().decomplexify()?);
1624 t.update_strides_and_len();
1625 Ok(t)
1626 }
1627}
1628
1629pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1630 let mut strides = tvec!();
1631 compute_natural_stride_to(&mut strides, shape);
1632 strides
1633}
1634
1635fn compute_natural_stride_to(strides: &mut TVec<isize>, shape: &[usize]) {
1636 match shape.len() {
1637 0 => (),
1638 1 => strides.push(1),
1639 2 => strides.extend_from_slice(&[shape[1] as isize, 1]),
1640 3 => strides.extend_from_slice(&[(shape[1] * shape[2]) as isize, shape[2] as _, 1]),
1641 4 => strides.extend_from_slice(&[
1642 (shape[1] * shape[2] * shape[3]) as isize,
1643 (shape[2] * shape[3]) as _,
1644 shape[3] as _,
1645 1,
1646 ]),
1647 _ => {
1648 strides.push(1);
1649 for dim in shape.as_ref().iter().skip(1).rev() {
1650 let previous = *strides.last().unwrap();
1651 strides.push(previous * *dim as isize)
1652 }
1653 strides.reverse();
1654 }
1655 }
1656}
1657
1658impl<D: ::ndarray::Dimension, T: Datum> From<Array<T, D>> for Tensor {
1659 fn from(it: Array<T, D>) -> Tensor {
1660 Tensor::from_datum(it.into_dyn())
1661 }
1662}
1663
1664pub trait IntoTensor: Sized {
1666 fn into_tensor(self) -> Tensor;
1670}
1671
1672pub trait IntoArcTensor: Sized {
1674 fn into_arc_tensor(self) -> Arc<Tensor>;
1678}
1679
1680impl<D: ::ndarray::Dimension, T: Datum> IntoTensor for Array<T, D> {
1681 fn into_tensor(self) -> Tensor {
1682 Tensor::from(self)
1683 }
1684}
1685
1686impl<D: ::ndarray::Dimension, T: Datum> IntoArcTensor for Array<T, D> {
1687 fn into_arc_tensor(self) -> Arc<Tensor> {
1688 Arc::new(Tensor::from(self))
1689 }
1690}
1691
1692impl IntoTensor for Tensor {
1693 fn into_tensor(self) -> Tensor {
1694 self
1695 }
1696}
1697
1698impl IntoTensor for Arc<Tensor> {
1699 fn into_tensor(self) -> Tensor {
1700 Arc::try_unwrap(self).unwrap_or_else(|t| (*t).clone())
1701 }
1702}
1703
1704impl IntoArcTensor for Tensor {
1705 fn into_arc_tensor(self) -> Arc<Tensor> {
1706 Arc::new(self)
1707 }
1708}
1709
1710impl IntoArcTensor for Arc<Tensor> {
1711 fn into_arc_tensor(self) -> Arc<Tensor> {
1712 self
1713 }
1714}
1715
1716#[cfg(test)]
1717mod tests {
1718 use crate::dim::SymbolScope;
1719 use crate::prelude::tensor1;
1720
1721 use super::*;
1722 use litteral::tensor0;
1723 use proptest::collection::vec;
1724 use proptest::prelude::*;
1725
1726 #[derive(Debug)]
1727 struct PermuteAxisProblem {
1728 shape: Vec<usize>,
1729 permutation: Vec<usize>,
1730 }
1731
1732 impl Arbitrary for PermuteAxisProblem {
1733 type Strategy = BoxedStrategy<PermuteAxisProblem>;
1734 type Parameters = ();
1735
1736 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1737 (0..8usize)
1738 .prop_flat_map(|rank| {
1739 let permute: Vec<usize> = (0..rank).collect();
1740 (proptest::collection::vec(1..5usize, rank), Just(permute).prop_shuffle())
1741 })
1742 .prop_map(|(shape, permutation)| PermuteAxisProblem { shape, permutation })
1743 .boxed()
1744 }
1745 }
1746
1747 impl PermuteAxisProblem {
1748 fn input(&self) -> ArrayD<i32> {
1749 let mut i = 0;
1750 ArrayD::from_shape_simple_fn(&*self.shape, || {
1751 i += 1;
1752 i
1753 })
1754 .permuted_axes(&*self.permutation)
1755 }
1756
1757 fn reference(&self) -> Tensor {
1758 let values: Vec<i32> = self.input().iter().copied().collect();
1759 let shape = self.permutation.iter().map(|ix| self.shape[*ix]).collect::<TVec<usize>>();
1760 super::litteral::tensor1(&values).into_shape(&shape).unwrap()
1761 }
1762
1763 fn tract(&self) -> Tensor {
1764 Tensor::from(self.input())
1765 }
1766
1767 fn check(&self) -> proptest::test_runner::TestCaseResult {
1768 prop_assert_eq!(self.tract(), self.reference());
1769 Ok(())
1770 }
1771 }
1772
1773 proptest::proptest! {
1774 #[test]
1775 fn prop(pb: PermuteAxisProblem) {
1776 pb.check().unwrap();
1777 }
1778 }
1779
1780 #[test]
1781 fn t_1_2() {
1782 PermuteAxisProblem { shape: vec![2, 1], permutation: vec![1, 0] }.check().unwrap();
1783 }
1784
1785 #[test]
1786 fn t_2_2() {
1787 PermuteAxisProblem { shape: vec![2, 2], permutation: vec![1, 0] }.check().unwrap();
1788 }
1789
1790 #[derive(Debug)]
1791 struct BroadcastVecToShape {
1792 vec: Vec<f32>,
1793 axis: usize,
1794 shape: TVec<usize>,
1795 }
1796
1797 impl BroadcastVecToShape {
1798 fn check(&self) -> proptest::test_runner::TestCaseResult {
1799 let input = tensor1(&self.vec);
1800 let mut intermediate = tvec![1usize; self.shape.len()];
1801 intermediate[self.axis] = self.vec.len();
1802 let reference = input
1803 .clone()
1804 .into_shape(&intermediate)
1805 .unwrap()
1806 .broadcast_to_shape(&self.shape)
1807 .unwrap();
1808 prop_assert_eq!(
1809 reference,
1810 input.broadcast_vector_to_shape(&self.shape, self.axis).unwrap()
1811 );
1812 Ok(())
1813 }
1814 }
1815
1816 impl Arbitrary for BroadcastVecToShape {
1817 type Strategy = BoxedStrategy<BroadcastVecToShape>;
1818 type Parameters = ();
1819
1820 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1821 vec(0usize..5, 0usize..4)
1822 .prop_flat_map(|shape| {
1823 (vec(-10f32..10f32, 0usize..5), Just(shape.clone()), 0..shape.len() + 1)
1824 })
1825 .prop_map(|(vec, mut shape, axis)| {
1826 shape.insert(axis, vec.len());
1827 BroadcastVecToShape { vec, shape: shape.into(), axis }
1828 })
1829 .boxed()
1830 }
1831 }
1832
1833 proptest::proptest! {
1834 #[test]
1835 fn broadcast_vector_to_shape_prop(pb: BroadcastVecToShape) {
1836 pb.check().unwrap()
1837 }
1838 }
1839
1840 #[test]
1841 #[cfg(feature = "complex")]
1842 fn test_reinterpret_inner_dim_as_complex() -> TractResult<()> {
1843 let input = crate::internal::tensor2(&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]]);
1844 let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1845 let expected = crate::internal::tensor1(&[
1846 Complex::new(1.0f32, 2.0),
1847 Complex::new(3.0, 4.0),
1848 Complex::new(5.0, 6.0),
1849 ]);
1850 assert_eq!(expected, cplx_input);
1851 Ok(())
1852 }
1853
1854 #[test]
1855 #[cfg(feature = "complex")]
1856 fn test_reinterpret_inner_dim_as_complex_2() -> TractResult<()> {
1857 let input =
1858 crate::internal::tensor3(&[[[1i32, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]);
1859 let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1860 let expected = crate::internal::tensor2(&[
1861 [Complex::new(1i32, 2), Complex::new(1, 2)],
1862 [Complex::new(3, 4), Complex::new(3, 4)],
1863 [Complex::new(5, 6), Complex::new(5, 6)],
1864 ]);
1865 assert_eq!(expected, cplx_input);
1866 Ok(())
1867 }
1868
1869 #[test]
1870 fn clone_tdim_tensor() {
1871 let symbols = SymbolScope::default();
1872 let a = symbols.sym("a");
1873 let t = tensor0(TDim::from(a));
1874 let _ = t.clone();
1875 }
1876}