Skip to main content

tract_gpu/memory/
pool.rs

1use crate::device::get_context;
2use crate::memory::DeviceResolvedMemSchema;
3use crate::tensor::DeviceArenaView;
4use crate::tensor::DeviceTensor;
5use crate::tensor::OwnedDeviceTensor;
6
7use std::cell::RefCell;
8use std::collections::HashSet;
9use tract_core::internal::*;
10
11#[derive(Debug)]
12pub struct DeviceMemoryPool {
13    storage: Arc<Box<dyn OwnedDeviceTensor>>,
14    resolved_schema: DeviceResolvedMemSchema,
15    node_seen: RefCell<HashSet<usize>>,
16}
17
18impl DeviceMemoryPool {
19    pub fn from_schema(resolved_schema: DeviceResolvedMemSchema) -> TractResult<Self> {
20        Ok(Self {
21            storage: Arc::new(
22                get_context()?
23                    .uninitialized_device_tensor(&[resolved_schema.memory_size], DatumType::U8)?,
24            ),
25            resolved_schema,
26            node_seen: RefCell::new(HashSet::new()),
27        })
28    }
29
30    pub fn tensor_for_node(
31        &self,
32        node_id: usize,
33        dt: DatumType,
34        shape: &[usize],
35    ) -> TractResult<DeviceTensor> {
36        ensure!(
37            !self.node_seen.borrow().contains(&node_id),
38            "Tensor for node {:?} was already requested. Maybe the memory pool was not reset properly.",
39            node_id
40        );
41        self.resolved_schema.offsets_by_node[node_id]
42            .map(|offset| {
43                Ok(DeviceArenaView {
44                    arena: Arc::clone(&self.storage),
45                    dt,
46                    len: shape.iter().product(),
47                    shape: shape.into(),
48                    strides: Tensor::natural_strides(shape),
49                    offset_bytes: offset,
50                }
51                .into())
52            })
53            .unwrap_or_else(|| DeviceTensor::uninitialized_dt(dt, shape))
54    }
55
56    pub fn reset(&self) {
57        self.node_seen.borrow_mut().clear();
58    }
59}