1use crate::device::get_context;
2use crate::memory::DeviceResolvedMemSchema;
3use crate::tensor::DeviceArenaView;
4use crate::tensor::DeviceTensor;
5use crate::tensor::OwnedDeviceTensor;
6
7use tract_core::internal::*;
8
9#[derive(Debug)]
10pub struct DeviceMemoryPool {
11 storage: Arc<Box<dyn OwnedDeviceTensor>>,
12 resolved_schema: DeviceResolvedMemSchema,
13}
14
15impl DeviceMemoryPool {
16 pub fn from_schema(resolved_schema: DeviceResolvedMemSchema) -> TractResult<Self> {
17 Ok(Self {
18 storage: Arc::new(
19 get_context()?
20 .uninitialized_device_tensor(&[resolved_schema.memory_size], DatumType::U8)?,
21 ),
22 resolved_schema,
23 })
24 }
25
26 pub fn tensor_for_node(
27 &self,
28 node_id: usize,
29 dt: DatumType,
30 shape: &[usize],
31 ) -> TractResult<DeviceTensor> {
32 self.resolved_schema.offsets_by_node[node_id]
33 .as_ref()
34 .map(|offsets| {
35 ensure!(
36 offsets.len() == 1 && offsets[0].len() == 1,
37 "'tensor_for_node' is for mono-output nodes only"
38 );
39 Ok(DeviceArenaView {
40 arena: Arc::clone(&self.storage),
41 dt,
42 len: shape.iter().product(),
43 shape: shape.into(),
44 strides: Tensor::natural_strides(shape),
45 offset_bytes: offsets[0][0],
46 exotic_fact: None,
47 }
48 .into())
49 })
50 .unwrap_or_else(|| DeviceTensor::uninitialized_dt(dt, shape))
51 }
52
53 pub fn scalar_exotic_tensor_for_node(
54 &self,
55 node_id: usize,
56 dt: DatumType,
57 exotic_fact: Box<dyn ExoticFact>,
58 ) -> TractResult<DeviceTensor> {
59 match self.resolved_schema.offsets_by_node[node_id].as_ref() {
60 Some(offsets) => {
61 ensure!(
62 offsets.len() == 1 && offsets[0].len() == 2,
63 "'scalar_exotic_tensor_for_node' is for mono-output nodes only"
64 );
65 Ok(DeviceArenaView {
66 arena: Arc::clone(&self.storage),
67 dt,
68 len: 1,
69 shape: tvec!(),
70 strides: tvec!(),
71 offset_bytes: offsets[0][1],
72 exotic_fact: Some(exotic_fact.clone()),
73 }
74 .into())
75 }
76 None => DeviceTensor::uninitialized_exotic(exotic_fact),
77 }
78 }
79}