Skip to main content

tract_cuda/
lib.rs

1mod context;
2pub mod kernels;
3pub mod ops;
4mod rewrite_rules;
5mod tensor;
6mod transform;
7pub mod utils;
8
9pub use context::CUDA_STREAM;
10use tract_core::internal::*;
11use tract_core::transform::ModelTransform;
12pub use transform::CudaTransform;
13
14use crate::utils::ensure_cuda_runtime_dependencies;
15const Q40_ROW_PADDING: usize = 512;
16
17#[derive(Debug)]
18struct CudaRuntime;
19
20impl Runtime for CudaRuntime {
21    fn name(&self) -> StaticName {
22        "cuda".into()
23    }
24
25    fn prepare_with_options(
26        &self,
27        mut model: TypedModel,
28        options: &RunOptions,
29    ) -> TractResult<Box<dyn Runnable>> {
30        ensure_cuda_runtime_dependencies("cuda runtime supported dependencies not found.")?;
31        CudaTransform.transform(&mut model)?;
32        model.optimize()?;
33
34        let options = RunOptions { skip_order_opt_ram: true, ..options.clone() };
35
36        let mut runnable = TypedSimplePlan::build(model, &options)?;
37        if let Some(hints) = options.memory_sizing_hints {
38            let session_handler =
39                tract_gpu::session_handler::DeviceSessionHandler::from_plan(&runnable, &hints)
40                    .context("While sizing memory arena. Missing hint ?")?;
41            runnable = runnable.with_session_handler(session_handler);
42        }
43
44        Ok(Box::new(Arc::new(runnable)))
45    }
46}
47
48register_runtime!(CudaRuntime = CudaRuntime);