tract_gpu/
session_handler.rs

1use crate::memory::DeviceMemSchema;
2use crate::memory::DeviceMemoryPool;
3use crate::tensor::DeviceTensor;
4use std::borrow::Borrow;
5use tract_core::internal::*;
6
7#[derive(Debug, Clone)]
8pub struct DeviceSessionHandler {
9    pub mem_schema: DeviceMemSchema,
10}
11
12impl DeviceSessionHandler {
13    pub fn from_plan<M, P>(plan: P, memory_hint: &SymbolValues) -> TractResult<Self>
14    where
15        M: Borrow<Graph<TypedFact, Box<dyn TypedOp>>>,
16        P: Borrow<TypedSimplePlan<M>> + Clone,
17    {
18        let mem_schema = DeviceMemSchema::build(
19            plan.borrow().model(),
20            plan.borrow().order_without_consts(),
21            memory_hint,
22        )?;
23        Ok(Self { mem_schema })
24    }
25}
26
27impl SessionStateHandler for DeviceSessionHandler {
28    fn before_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()> {
29        let resolved_mem_schema = self.mem_schema.resolve(&session_state.resolved_symbols)?;
30        let memory_pool = DeviceMemoryPool::from_schema(resolved_mem_schema)?;
31
32        session_state.scratch_extensions.insert(memory_pool);
33        ensure!(session_state.scratch_extensions.get::<DeviceMemoryPool>().is_some());
34        Ok(())
35    }
36
37    fn after_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()> {
38        session_state.scratch_extensions.remove::<DeviceMemoryPool>();
39        Ok(())
40    }
41}
42
43pub fn make_tensor_for_node(
44    session: &SessionState,
45    node_id: usize,
46    dt: DatumType,
47    shape: &[usize],
48) -> TractResult<DeviceTensor> {
49    session
50        .scratch_extensions
51        .get::<DeviceMemoryPool>()
52        .map(|mem| mem.tensor_for_node(node_id, dt, shape))
53        .unwrap_or_else(|| DeviceTensor::uninitialized_dt(dt, shape))
54}