tract_gpu/tensor/
arena_view.rs1use num_traits::AsPrimitive;
2use std::ffi::c_void;
3use std::fmt::Display;
4use tract_core::internal::*;
5
6use crate::device::{DeviceBuffer, get_context};
7use crate::utils::check_strides_validity;
8
9use super::OwnedDeviceTensor;
10
11#[derive(Debug, Clone, Hash)]
12pub struct DeviceArenaView {
13 pub(crate) arena: Arc<Box<dyn OwnedDeviceTensor>>,
14 pub(crate) dt: DatumType,
15 pub(crate) len: usize,
16 pub(crate) shape: TVec<usize>,
17 pub(crate) strides: TVec<isize>,
18 pub(crate) offset_bytes: usize,
19 pub(crate) opaque_fact: Option<Box<dyn OpaqueFact>>,
20}
21
22impl DeviceArenaView {
23 #[inline]
24 pub fn shape(&self) -> &[usize] {
25 self.shape.as_slice()
26 }
27
28 #[inline]
30 pub fn datum_type(&self) -> DatumType {
31 self.dt
32 }
33
34 #[inline]
35 pub fn strides(&self) -> &[isize] {
36 self.strides.as_slice()
37 }
38
39 pub fn device_buffer(&self) -> &dyn DeviceBuffer {
41 self.arena.device_buffer()
42 }
43
44 pub fn device_buffer_ptr(&self) -> *const c_void {
45 self.arena.device_buffer().ptr()
46 }
47
48 pub fn buffer_offset<I: Copy + 'static>(&self) -> I
50 where
51 usize: AsPrimitive<I>,
52 {
53 self.offset_bytes.as_()
54 }
55
56 pub fn opaque_fact(&self) -> Option<&dyn OpaqueFact> {
57 self.opaque_fact.as_deref()
58 }
59
60 #[inline]
62 #[allow(clippy::len_without_is_empty)]
63 pub fn len(&self) -> usize {
64 self.len
65 }
66
67 pub fn as_bytes(&self) -> Vec<u8> {
68 let len = if let Some(of) = &self.opaque_fact {
69 of.mem_size().as_i64().unwrap() as usize
70 } else {
71 self.len() * self.dt.size_of()
72 };
73 self.arena.get_bytes_slice(self.offset_bytes, len)
74 }
75
76 pub fn reshaped(&self, shape: impl Into<TVec<usize>>) -> TractResult<Self> {
78 ensure!(self.opaque_fact.is_none(), "Can't reshape opaque tensor");
79 let shape = shape.into();
80 if self.len() != shape.iter().product::<usize>() {
81 bail!("Invalid reshape {:?} to {:?}", self.shape(), shape);
82 }
83 if shape.as_slice() != self.shape() {
84 Ok(Self {
85 arena: Arc::clone(&self.arena),
86 dt: self.dt,
87 len: self.len,
88 strides: Tensor::natural_strides(&shape),
89 shape,
90 offset_bytes: self.offset_bytes,
91 opaque_fact: None,
92 })
93 } else {
94 Ok(self.clone())
95 }
96 }
97
98 pub fn restrided(&self, strides: impl Into<TVec<isize>>) -> TractResult<Self> {
99 ensure!(self.opaque_fact.is_none(), "Can't restride opaque tensor");
100 let strides = strides.into();
101 check_strides_validity(self.shape().into(), strides.clone())?;
102
103 if strides.as_slice() != self.strides() {
104 Ok(Self {
105 arena: Arc::clone(&self.arena),
106 dt: self.dt,
107 len: self.len,
108 strides,
109 shape: self.shape.clone(),
110 offset_bytes: self.offset_bytes,
111 opaque_fact: None,
112 })
113 } else {
114 Ok(self.clone())
115 }
116 }
117
118 pub fn to_host(&self) -> TractResult<Tensor> {
119 get_context()?.synchronize()?;
120 let content = self.as_bytes();
121 unsafe {
122 if self.dt == DatumType::Opaque {
123 ensure!(self.len == 1, "Expected scalar Opaque");
124 Ok(tensor0(Opaque(Arc::new(BlobWithFact {
125 fact: self
126 .opaque_fact
127 .clone()
128 .context("Expected Opaque Fact for Opaque ArenaView")?,
129 value: Arc::new(Blob::from_bytes(&content)?),
130 }))))
131 } else {
132 Tensor::from_raw_dt(self.dt, &self.shape, &content)
133 }
134 }
135 }
136}
137
138impl Display for DeviceArenaView {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 let content = self
141 .clone()
142 .to_host()
143 .unwrap()
144 .dump(false)
145 .unwrap_or_else(|e| format!("Error : {e:?}"));
146 write!(f, "DeviceArenaView: {{ {content} }}")
147 }
148}