Skip to main content

tract_gpu/memory/
pool.rs

1use crate::memory::DeviceResolvedMemSchema;
2use crate::tensor::DeviceArenaView;
3use crate::tensor::DeviceTensor;
4use crate::tensor::IntoDevice;
5use crate::tensor::OwnedDeviceTensor;
6
7use std::cell::RefCell;
8use std::collections::HashSet;
9use tract_core::internal::*;
10
11#[derive(Debug)]
12pub struct DeviceMemoryPool {
13    storage: Arc<OwnedDeviceTensor>,
14    resolved_schema: DeviceResolvedMemSchema,
15    node_seen: RefCell<HashSet<usize>>,
16}
17
18impl DeviceMemoryPool {
19    pub fn from_schema(resolved_schema: DeviceResolvedMemSchema) -> TractResult<Self> {
20        let tensor = unsafe {
21            Tensor::uninitialized_dt(DatumType::U8, &[resolved_schema.memory_size]).with_context(
22                || {
23                    format!(
24                        "Error while allocating a tensor of {:?} bytes",
25                        resolved_schema.memory_size
26                    )
27                },
28            )?
29        };
30        let storage = Arc::new(OwnedDeviceTensor::from_tensor(tensor)?);
31
32        Ok(Self { storage, resolved_schema, node_seen: RefCell::new(HashSet::new()) })
33    }
34
35    pub fn tensor_for_node(
36        &self,
37        node_id: usize,
38        dt: DatumType,
39        shape: &[usize],
40    ) -> TractResult<DeviceTensor> {
41        ensure!(!self.node_seen.borrow().contains(&node_id), "Tensor for node {:?} was already requested. Maybe the memory pool was not reset properly.", node_id);
42        self.resolved_schema.offsets_by_node[node_id]
43            .map(|offset| {
44                // self.node_seen.borrow_mut().insert(node_id);
45                Ok(DeviceArenaView {
46                    arena: Arc::clone(&self.storage),
47                    dt,
48                    len: shape.iter().product(),
49                    shape: shape.into(),
50                    strides: Tensor::natural_strides(shape),
51                    offset_bytes: offset,
52                }
53                .into())
54            })
55            .unwrap_or_else(|| unsafe { Tensor::uninitialized_dt(dt, shape)?.into_device() })
56    }
57
58    pub fn reset(&self) {
59        self.node_seen.borrow_mut().clear();
60    }
61}