1mod context;
2pub mod kernels;
3pub mod ops;
4mod rewrite_rules;
5mod tensor;
6mod transform;
7pub mod utils;
8
9pub use context::CUDA_STREAM;
10use tract_core::internal::*;
11use tract_core::transform::ModelTransform;
12pub use transform::CudaTransform;
13
14use crate::utils::ensure_cuda_runtime_dependencies;
15const Q40_ROW_PADDING: usize = 512;
16
17#[derive(Debug)]
18struct CudaRuntime;
19
20impl Runtime for CudaRuntime {
21 fn name(&self) -> StaticName {
22 "cuda".into()
23 }
24
25 fn prepare_with_options(
26 &self,
27 mut model: TypedModel,
28 options: &RunOptions,
29 ) -> TractResult<Box<dyn Runnable>> {
30 ensure_cuda_runtime_dependencies("cuda runtime supported dependencies not found.")?;
31 CudaTransform.transform(&mut model)?;
32 model.optimize()?;
33
34 let options = RunOptions { skip_order_opt_ram: true, ..options.clone() };
35
36 let mut runnable = TypedSimplePlan::build(model, &options)?;
37 if let Some(hints) = options.memory_sizing_hints {
38 let session_handler =
39 tract_gpu::session_handler::DeviceSessionHandler::from_plan(&runnable, &hints)
40 .context("While sizing memory arena. Missing hint ?")?;
41 runnable = runnable.with_session_handler(session_handler);
42 }
43
44 Ok(Box::new(Arc::new(runnable)))
45 }
46}
47
48register_runtime!(CudaRuntime = CudaRuntime);