1use crate::device::get_context;
2use crate::memory::DeviceResolvedMemSchema;
3use crate::tensor::DeviceArenaView;
4use crate::tensor::DeviceTensor;
5use crate::tensor::OwnedDeviceTensor;
6
7use tract_core::internal::*;
8
9#[derive(Debug)]
10pub struct DeviceMemoryPool {
11 storage: Arc<Box<dyn OwnedDeviceTensor>>,
12 resolved_schema: DeviceResolvedMemSchema,
13}
14
15impl DeviceMemoryPool {
16 pub fn from_schema(resolved_schema: DeviceResolvedMemSchema) -> TractResult<Self> {
17 Ok(Self {
18 storage: Arc::new(
19 get_context()?
20 .uninitialized_device_tensor(&[resolved_schema.memory_size], DatumType::U8)?,
21 ),
22 resolved_schema,
23 })
24 }
25
26 pub fn tensor_for_node(
27 &self,
28 node_id: usize,
29 dt: DatumType,
30 shape: &[usize],
31 ) -> TractResult<DeviceTensor> {
32 ensure!(dt != DatumType::Opaque, "Use opaque_tensor for node instead");
33 self.resolved_schema.offsets_by_node[node_id]
34 .as_ref()
35 .map(|offsets| {
36 ensure!(
37 offsets.len() == 1 && offsets[0].len() == 1,
38 "'tensor_for_node' is for mono-output nodes only"
39 );
40 Ok(DeviceArenaView {
41 arena: Arc::clone(&self.storage),
42 dt,
43 len: shape.iter().product(),
44 shape: shape.into(),
45 strides: Tensor::natural_strides(shape),
46 offset_bytes: offsets[0][0],
47 opaque_fact: None,
48 }
49 .into())
50 })
51 .unwrap_or_else(|| DeviceTensor::uninitialized_dt(dt, shape))
52 }
53
54 pub fn scalar_opaque_tensor_for_node(
55 &self,
56 node_id: usize,
57 opaque_fact: Box<dyn OpaqueFact>,
58 ) -> TractResult<DeviceTensor> {
59 match self.resolved_schema.offsets_by_node[node_id].as_ref() {
60 Some(offsets) => {
61 ensure!(
62 offsets.len() == 1 && offsets[0].len() == 2,
63 "'scalar_opaque_tensor_for_node' is for mono-output nodes only"
64 );
65 Ok(DeviceArenaView {
66 arena: Arc::clone(&self.storage),
67 dt: DatumType::Opaque,
68 len: 1,
69 shape: tvec!(),
70 strides: tvec!(),
71 offset_bytes: offsets[0][1],
72 opaque_fact: Some(opaque_fact.clone()),
73 }
74 .into())
75 }
76 None => DeviceTensor::uninitialized_opaque(opaque_fact),
77 }
78 }
79}