Skip to main content

tract_gpu/memory/
pool.rs

1use crate::device::get_context;
2use crate::memory::DeviceResolvedMemSchema;
3use crate::tensor::DeviceArenaView;
4use crate::tensor::DeviceTensor;
5use crate::tensor::OwnedDeviceTensor;
6
7use tract_core::internal::*;
8
9#[derive(Debug)]
10pub struct DeviceMemoryPool {
11    storage: Arc<Box<dyn OwnedDeviceTensor>>,
12    resolved_schema: DeviceResolvedMemSchema,
13}
14
15impl DeviceMemoryPool {
16    pub fn from_schema(resolved_schema: DeviceResolvedMemSchema) -> TractResult<Self> {
17        Ok(Self {
18            storage: Arc::new(
19                get_context()?
20                    .uninitialized_device_tensor(&[resolved_schema.memory_size], DatumType::U8)?,
21            ),
22            resolved_schema,
23        })
24    }
25
26    pub fn tensor_for_node(
27        &self,
28        node_id: usize,
29        dt: DatumType,
30        shape: &[usize],
31    ) -> TractResult<DeviceTensor> {
32        self.resolved_schema.offsets_by_node[node_id]
33            .as_ref()
34            .map(|offsets| {
35                ensure!(
36                    offsets.len() == 1 && offsets[0].len() == 1,
37                    "'tensor_for_node' is for mono-output nodes only"
38                );
39                Ok(DeviceArenaView {
40                    arena: Arc::clone(&self.storage),
41                    dt,
42                    len: shape.iter().product(),
43                    shape: shape.into(),
44                    strides: Tensor::natural_strides(shape),
45                    offset_bytes: offsets[0][0],
46                    exotic_fact: None,
47                }
48                .into())
49            })
50            .unwrap_or_else(|| DeviceTensor::uninitialized_dt(dt, shape))
51    }
52
53    pub fn scalar_exotic_tensor_for_node(
54        &self,
55        node_id: usize,
56        dt: DatumType,
57        exotic_fact: Box<dyn ExoticFact>,
58    ) -> TractResult<DeviceTensor> {
59        match self.resolved_schema.offsets_by_node[node_id].as_ref() {
60            Some(offsets) => {
61                ensure!(
62                    offsets.len() == 1 && offsets[0].len() == 2,
63                    "'scalar_exotic_tensor_for_node' is for mono-output nodes only"
64                );
65                Ok(DeviceArenaView {
66                    arena: Arc::clone(&self.storage),
67                    dt,
68                    len: 1,
69                    shape: tvec!(),
70                    strides: tvec!(),
71                    offset_bytes: offsets[0][1],
72                    exotic_fact: Some(exotic_fact.clone()),
73                }
74                .into())
75            }
76            None => DeviceTensor::uninitialized_exotic(exotic_fact),
77        }
78    }
79}