Skip to main content

sp1_gpu_tracegen/riscv/
mod.rs

1mod global;
2
3use slop_alloc::mem::CopyError;
4use sp1_core_machine::riscv::RiscvAir;
5use sp1_gpu_cudart::{DeviceMle, TaskScope};
6
7use crate::{CudaTracegenAir, F};
8
9impl CudaTracegenAir<F> for RiscvAir<F> {
10    fn supports_device_main_tracegen(&self) -> bool {
11        match self {
12            Self::Global(chip) => chip.supports_device_main_tracegen(),
13            // Other chips don't have `CudaTracegenAir` implemented yet.
14            _ => false,
15        }
16    }
17
18    async fn generate_trace_device(
19        &self,
20        input: &Self::Record,
21        output: &mut Self::Record,
22        scope: &TaskScope,
23    ) -> Result<DeviceMle<F>, CopyError> {
24        match self {
25            Self::Global(chip) => chip.generate_trace_device(input, output, scope).await,
26            // Other chips don't have `CudaTracegenAir` implemented yet.
27            _ => unimplemented!(),
28        }
29    }
30}