tract_gpu/tensor/
arena_view.rs1use num_traits::AsPrimitive;
2use std::ffi::c_void;
3use std::fmt::Display;
4use tract_core::internal::*;
5
6use crate::device::DeviceBuffer;
7use crate::utils::check_strides_validity;
8
9use super::OwnedDeviceTensor;
10
11#[derive(Debug, Clone, Hash)]
12pub struct DeviceArenaView {
13 pub(crate) arena: Arc<Box<dyn OwnedDeviceTensor>>,
14 pub(crate) dt: DatumType,
15 pub(crate) len: usize,
16 pub(crate) shape: TVec<usize>,
17 pub(crate) strides: TVec<isize>,
18 pub(crate) offset_bytes: usize,
19}
20
21impl DeviceArenaView {
22 #[inline]
23 pub fn shape(&self) -> &[usize] {
24 self.shape.as_slice()
25 }
26
27 #[inline]
29 pub fn datum_type(&self) -> DatumType {
30 self.dt
31 }
32
33 #[inline]
34 pub fn strides(&self) -> &[isize] {
35 self.strides.as_slice()
36 }
37
38 pub fn device_buffer(&self) -> &dyn DeviceBuffer {
40 self.arena.device_buffer()
41 }
42
43 pub fn device_buffer_ptr(&self) -> *const c_void {
44 self.arena.device_buffer().ptr()
45 }
46
47 pub fn buffer_offset<I: Copy + 'static>(&self) -> I
49 where
50 usize: AsPrimitive<I>,
51 {
52 self.offset_bytes.as_()
53 }
54
55 #[inline]
57 #[allow(clippy::len_without_is_empty)]
58 pub fn len(&self) -> usize {
59 self.len
60 }
61
62 pub fn as_bytes(&self) -> &[u8] {
63 &self.arena.as_arc_tensor().unwrap().as_bytes()
64 [self.offset_bytes..self.offset_bytes + self.len() * self.dt.size_of()]
65 }
66
67 #[inline]
68 pub fn view(&self) -> TensorView<'_> {
69 unsafe {
70 TensorView::from_bytes(
71 self.arena.as_arc_tensor().unwrap(),
72 self.offset_bytes as _,
73 self.shape.as_slice(),
74 self.strides.as_slice(),
75 )
76 }
77 }
78
79 pub fn reshaped(&self, shape: impl Into<TVec<usize>>) -> TractResult<Self> {
81 let shape = shape.into();
82 if self.len() != shape.iter().product::<usize>() {
83 bail!("Invalid reshape {:?} to {:?}", self.shape(), shape);
84 }
85 if shape.as_slice() != self.shape() {
86 Ok(Self {
87 arena: Arc::clone(&self.arena),
88 dt: self.dt,
89 len: self.len,
90 strides: Tensor::natural_strides(&shape),
91 shape,
92 offset_bytes: self.offset_bytes,
93 })
94 } else {
95 Ok(self.clone())
96 }
97 }
98
99 pub fn restrided(&self, strides: impl Into<TVec<isize>>) -> TractResult<Self> {
100 let strides = strides.into();
101 check_strides_validity(self.shape().into(), strides.clone())?;
102
103 if strides.as_slice() != self.strides() {
104 Ok(Self {
105 arena: Arc::clone(&self.arena),
106 dt: self.dt,
107 len: self.len,
108 strides,
109 shape: self.shape.clone(),
110 offset_bytes: self.offset_bytes,
111 })
112 } else {
113 Ok(self.clone())
114 }
115 }
116}
117
118impl Display for DeviceArenaView {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 let content =
121 self.clone().into_tensor().dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
122 write!(f, "DeviceArenaView: {{ {content} }}")
123 }
124}
125
126impl IntoTensor for DeviceArenaView {
127 fn into_tensor(self) -> Tensor {
128 unsafe {
129 Tensor::from_raw_dt(self.dt, &self.shape, self.as_bytes())
130 .expect("Could not transform a DeviceArenaView to tensor")
131 }
132 }
133}