Skip to main content

tract_gpu/memory/
pool.rs

1use crate::device::get_context;
2use crate::memory::DeviceResolvedMemSchema;
3use crate::tensor::DeviceArenaView;
4use crate::tensor::DeviceTensor;
5use crate::tensor::OwnedDeviceTensor;
6
7use tract_core::internal::*;
8
9#[derive(Debug)]
10pub struct DeviceMemoryPool {
11    storage: Arc<Box<dyn OwnedDeviceTensor>>,
12    resolved_schema: DeviceResolvedMemSchema,
13}
14
15impl DeviceMemoryPool {
16    pub fn from_schema(resolved_schema: DeviceResolvedMemSchema) -> TractResult<Self> {
17        Ok(Self {
18            storage: Arc::new(
19                get_context()?
20                    .uninitialized_device_tensor(&[resolved_schema.memory_size], DatumType::U8)?,
21            ),
22            resolved_schema,
23        })
24    }
25
26    pub fn tensor_for_node(
27        &self,
28        node_id: usize,
29        dt: DatumType,
30        shape: &[usize],
31    ) -> TractResult<DeviceTensor> {
32        ensure!(dt != DatumType::Opaque, "Use opaque_tensor for node instead");
33        self.resolved_schema.offsets_by_node[node_id]
34            .as_ref()
35            .map(|offsets| {
36                ensure!(
37                    offsets.len() == 1 && offsets[0].len() == 1,
38                    "'tensor_for_node' is for mono-output nodes only"
39                );
40                Ok(DeviceArenaView {
41                    arena: Arc::clone(&self.storage),
42                    dt,
43                    len: shape.iter().product(),
44                    shape: shape.into(),
45                    strides: Tensor::natural_strides(shape),
46                    offset_bytes: offsets[0][0],
47                    opaque_fact: None,
48                }
49                .into())
50            })
51            .unwrap_or_else(|| DeviceTensor::uninitialized_dt(dt, shape))
52    }
53
54    pub fn scalar_opaque_tensor_for_node(
55        &self,
56        node_id: usize,
57        opaque_fact: Box<dyn OpaqueFact>,
58    ) -> TractResult<DeviceTensor> {
59        match self.resolved_schema.offsets_by_node[node_id].as_ref() {
60            Some(offsets) => {
61                ensure!(
62                    offsets.len() == 1 && offsets[0].len() == 2,
63                    "'scalar_opaque_tensor_for_node' is for mono-output nodes only"
64                );
65                Ok(DeviceArenaView {
66                    arena: Arc::clone(&self.storage),
67                    dt: DatumType::Opaque,
68                    len: 1,
69                    shape: tvec!(),
70                    strides: tvec!(),
71                    offset_bytes: offsets[0][1],
72                    opaque_fact: Some(opaque_fact.clone()),
73                }
74                .into())
75            }
76            None => DeviceTensor::uninitialized_opaque(opaque_fact),
77        }
78    }
79}