Skip to main content

tract_gpu/tensor/
owned.rs

1use downcast_rs::{Downcast, impl_downcast};
2use dyn_clone::DynClone;
3use dyn_eq::DynEq;
4use std::fmt::Debug;
5use tract_core::dyn_clone;
6use tract_core::internal::*;
7
8use crate::device::DeviceBuffer;
9
10use super::DeviceTensor;
11
12#[allow(clippy::len_without_is_empty)]
13pub trait OwnedDeviceTensor: Downcast + DynClone + Send + Sync + Debug + DynEq {
14    fn datum_type(&self) -> DatumType;
15
16    fn shape(&self) -> &[usize];
17
18    fn strides(&self) -> &[isize];
19
20    #[inline]
21    fn len(&self) -> usize {
22        self.shape().iter().product()
23    }
24
25    fn reshaped(&self, shape: TVec<usize>) -> TractResult<DeviceTensor>;
26    fn restrided(&self, shape: TVec<isize>) -> TractResult<DeviceTensor>;
27
28    fn exotic_fact(&self) -> Option<&dyn ExoticFact>;
29    fn get_bytes_slice(&self, offset: usize, len: usize) -> Vec<u8>;
30    fn device_buffer(&self) -> &dyn DeviceBuffer;
31    fn to_host(&self) -> TractResult<Arc<Tensor>>;
32}
33
34impl_downcast!(OwnedDeviceTensor);
35dyn_hash::hash_trait_object!(OwnedDeviceTensor);
36dyn_clone::clone_trait_object!(OwnedDeviceTensor);
37dyn_eq::eq_trait_object!(OwnedDeviceTensor);