rustsim/compute.rs
1//! Compute backend detection and GPU-accelerated batch stepping.
2//!
3//! The CUDA accelerator works on SoA (Structure-of-Arrays) buffers extracted
4//! from the ABM agent store. The flow is:
5//!
6//! 1. Extract agent fields into flat `Vec<f32>` columns via `SoaExtractable`
7//! 2. Upload columns to GPU device memory
8//! 3. Launch a user-provided PTX kernel that processes all agents in parallel
9//! 4. Download results back to host
10//! 5. Write columns back into the agent store
11//!
12//! When CUDA is unavailable (no `cuda` feature or no device), the same SoA
13//! buffers are processed on CPU via a user-provided closure.
14//!
15//! # Determinism and backend selection
16//!
17//! `cpu_batch_step` is replayable when:
18//! - SoA extraction order is deterministic for the chosen store and workload
19//! - the supplied CPU kernel is itself deterministic
20//!
21//! `auto_batch_step` and `auto_device_step` do **not** guarantee a fixed backend
22//! across machines or runs, because backend selection depends on:
23//! - compile-time `cuda` support
24//! - runtime device availability
25//! - the `RUSTSIM_BACKEND` environment variable
26//! - CUDA failure fallback to CPU
27//!
28//! Exact bitwise equivalence between CPU and CUDA results is not guaranteed.
29//! Floating-point behavior, execution order, and kernel implementation details
30//! may differ across backends.
31//!
32//! # CUDA safety and failure surfaces
33//!
34//! The only `unsafe` operations in this module are CUDA kernel launches via
35//! `cudarc`. Those launches rely on the following invariants:
36//! - `block_size > 0`
37//! - the PTX kernel signature matches the launched argument tuple
38//! - each device buffer points to a valid uploaded SoA column
39//! - the kernel performs bounds checks for `idx < n`
40//! - the kernel does not read or write out of bounds
41//!
42//! Failure surfaces are explicit `Err(String)` results from:
43//! - CUDA device initialization
44//! - PTX load/module lookup
45//! - host-to-device transfer
46//! - invalid launch configuration such as `block_size == 0`
47//! - unsupported SoA arity outside `1..=8`
48//! - kernel launch / synchronization
49//! - device-to-host transfer
50//!
51//! `auto_batch_step` and `auto_device_step` treat those CUDA errors as runtime
52//! fallback triggers and continue on CPU.
53//!
54//! # Persistent Device Store
55//!
56//! For multi-step runs, use [`DeviceSoaStore`](crate::device_store::DeviceSoaStore)
57//! to avoid per-step SoA extraction overhead. This mirrors FlameGPU2's design
58//! where agent data lives on the GPU across steps.
59
60/// Represents the available compute backend.
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum ComputeBackend {
63 Cpu,
64 Cuda,
65}
66
67impl std::fmt::Display for ComputeBackend {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 match self {
70 ComputeBackend::Cpu => write!(f, "CPU"),
71 ComputeBackend::Cuda => write!(f, "CUDA"),
72 }
73 }
74}
75
76/// Probe the system for CUDA availability.
77///
78/// # Backend selection precedence
79///
80/// The result is determined in this fixed order:
81///
82/// 1. **`RUSTSIM_BACKEND` environment variable**, if set to a recognized value:
83/// - `cuda` / `gpu` → force [`ComputeBackend::Cuda`]
84/// - `cpu` → force [`ComputeBackend::Cpu`]
85/// - any other value is ignored and detection continues.
86/// 2. **`cuda` feature enabled** (default): probe
87/// a panic-safe CUDA context probe; on success return
88/// [`ComputeBackend::Cuda`]. Because the umbrella crate enables the
89/// `cuda` feature by default, on a host with a working NVIDIA driver
90/// this step routes to the GPU automatically with no opt-in required;
91/// hosts without a driver or runtime fall through to step 4.
92/// 3. **`cuda` feature disabled** (built with `--no-default-features`):
93/// probe `nvidia-smi --query-gpu=name` on `PATH`; on success return
94/// [`ComputeBackend::Cuda`]. In this mode the Cuda variant is advisory
95/// only — the actual CUDA call sites are feature-gated out, so callers
96/// of `auto_batch_step` / `auto_device_step` will still run on the
97/// CPU.
98/// 4. **Otherwise**: return [`ComputeBackend::Cpu`].
99///
100/// Callers can still force a specific backend per call via `auto_batch_step`
101/// / `auto_device_step` arguments or the [`cpu_batch_step`] / `cuda_batch_step`
102/// entry points. Backend selection is **not** a determinism guarantee — see
103/// [`docs/determinism.md`](https://github.com/rustsim/rustsim/blob/main/docs/determinism.md).
104pub fn detect_backend() -> ComputeBackend {
105 if let Ok(val) = std::env::var("RUSTSIM_BACKEND") {
106 match val.to_lowercase().as_str() {
107 "cuda" | "gpu" => {
108 tracing::info!(backend = "CUDA", "backend override via RUSTSIM_BACKEND");
109 return ComputeBackend::Cuda;
110 }
111 "cpu" => {
112 tracing::info!(backend = "CPU", "backend override via RUSTSIM_BACKEND");
113 return ComputeBackend::Cpu;
114 }
115 _ => {}
116 }
117 }
118
119 #[cfg(feature = "cuda")]
120 {
121 if crate::cuda_context::new_context(0).is_ok() {
122 tracing::info!("CUDA device detected");
123 return ComputeBackend::Cuda;
124 }
125 }
126
127 #[cfg(not(feature = "cuda"))]
128 {
129 match std::process::Command::new("nvidia-smi")
130 .arg("--query-gpu=name")
131 .arg("--format=csv,noheader")
132 .output()
133 {
134 Ok(output) if output.status.success() => {
135 tracing::info!("CUDA device detected via nvidia-smi");
136 return ComputeBackend::Cuda;
137 }
138 _ => {}
139 }
140 }
141
142 tracing::debug!("no CUDA device found, using CPU");
143 ComputeBackend::Cpu
144}
145
146// ---------------------------------------------------------------------------
147// GPU Accelerator
148// ---------------------------------------------------------------------------
149
150use rustsim_core::soa::{self, SoaExtractable, SoaExtractableF64};
151use rustsim_core::store::AgentStore;
152
153/// Result of a GPU (or CPU-fallback) batch step.
154#[derive(Debug)]
155pub struct AccelStepResult {
156 /// Which backend was used.
157 pub backend: ComputeBackend,
158 /// Number of agents processed.
159 pub agent_count: usize,
160 /// Wall-clock time for the kernel / CPU work (excludes extract + write-back).
161 pub kernel_us: u128,
162}
163
164impl AccelStepResult {
165 /// Kernel/runtime duration in milliseconds.
166 pub fn kernel_ms(&self) -> f64 {
167 self.kernel_us as f64 / 1_000.0
168 }
169
170 /// Approximate processed-agent throughput in agents/second.
171 pub fn agents_per_second(&self) -> f64 {
172 if self.kernel_us == 0 {
173 return 0.0;
174 }
175 self.agent_count as f64 / (self.kernel_us as f64 / 1_000_000.0)
176 }
177}
178
179/// CPU-side batch step over SoA columns.
180///
181/// `kernel` receives `(columns, n)` where each `columns[c]` is a mutable
182/// slice of length `n`. The closure should update the columns in place,
183/// operating on all `n` agents.
184pub fn cpu_batch_step<A, S, F>(store: &S, mut kernel: F) -> AccelStepResult
185where
186 A: SoaExtractable,
187 S: AgentStore<A>,
188 F: FnMut(&mut [Vec<f32>], usize),
189{
190 let (ids, mut columns) = soa::extract_soa::<A, S>(store);
191 let n = ids.len();
192
193 let t0 = std::time::Instant::now();
194 kernel(&mut columns, n);
195 let kernel_us = t0.elapsed().as_micros();
196
197 soa::write_back_soa::<A, S>(store, &ids, &columns);
198
199 tracing::debug!(
200 backend = "CPU",
201 agents = n,
202 kernel_us,
203 "cpu_batch_step completed"
204 );
205
206 AccelStepResult {
207 backend: ComputeBackend::Cpu,
208 agent_count: n,
209 kernel_us,
210 }
211}
212
213/// CPU-side batch step over **`f64`** SoA columns.
214///
215/// Parallel to [`cpu_batch_step`] but preserves double precision end-to-end.
216/// Use this when `f32` would introduce unacceptable rounding — e.g.
217/// long-horizon integrators, stiff dynamics, or scientific workloads.
218///
219/// `kernel` receives `(columns, n)` where each `columns[c]` is a mutable
220/// `Vec<f64>` of length `n`.
221pub fn cpu_batch_step_f64<A, S, F>(store: &S, mut kernel: F) -> AccelStepResult
222where
223 A: SoaExtractableF64,
224 S: AgentStore<A>,
225 F: FnMut(&mut [Vec<f64>], usize),
226{
227 let (ids, mut columns) = soa::extract_soa_f64::<A, S>(store);
228 let n = ids.len();
229
230 let t0 = std::time::Instant::now();
231 kernel(&mut columns, n);
232 let kernel_us = t0.elapsed().as_micros();
233
234 soa::write_back_soa_f64::<A, S>(store, &ids, &columns);
235
236 tracing::debug!(
237 backend = "CPU",
238 precision = "f64",
239 agents = n,
240 kernel_us,
241 "cpu_batch_step_f64 completed"
242 );
243
244 AccelStepResult {
245 backend: ComputeBackend::Cpu,
246 agent_count: n,
247 kernel_us,
248 }
249}
250
251/// Parallel CPU batch step over SoA columns, chunked by `chunk_size`.
252///
253/// Available only with the `rayon` feature. Extracts SoA columns from
254/// `store`, invokes `kernel(chunk_start, &mut [&mut [f32]])` in parallel
255/// across aligned chunks of every column, then writes results back.
256///
257/// The kernel sees a mutable sub-slice of every column for the same
258/// `[chunk_start, chunk_start + chunk_len)` index range. Chunks are
259/// disjoint, so kernel invocations do not alias each other.
260///
261/// Use this when the per-agent work is non-trivial. For very small kernels,
262/// the serial [`cpu_batch_step`] may be faster due to lower overhead.
263#[cfg(feature = "rayon")]
264pub fn par_batch_step<A, S, F>(store: &S, chunk_size: usize, kernel: F) -> AccelStepResult
265where
266 A: SoaExtractable,
267 S: AgentStore<A>,
268 F: Fn(usize, &mut [&mut [f32]]) + Send + Sync,
269{
270 let (ids, mut columns) = soa::extract_soa::<A, S>(store);
271 let n = ids.len();
272
273 let t0 = std::time::Instant::now();
274 {
275 let mut slices: Vec<&mut [f32]> = columns.iter_mut().map(|c| c.as_mut_slice()).collect();
276 crate::parallel::par_apply_chunks_multi(&mut slices, chunk_size, kernel);
277 }
278 let kernel_us = t0.elapsed().as_micros();
279
280 soa::write_back_soa::<A, S>(store, &ids, &columns);
281
282 tracing::debug!(
283 backend = "CPU",
284 parallel = true,
285 chunk_size,
286 agents = n,
287 kernel_us,
288 "par_batch_step completed"
289 );
290
291 AccelStepResult {
292 backend: ComputeBackend::Cpu,
293 agent_count: n,
294 kernel_us,
295 }
296}
297
298/// CUDA batch step over SoA columns.
299///
300/// Uploads agent columns to the GPU on a dedicated [`cudarc::driver::CudaStream`],
301/// launches the named kernel from the provided PTX source, then downloads
302/// results back onto the host and writes them into the agent store.
303///
304/// Supports any number of SoA columns — the launch uses the stream-based
305/// [`cudarc::driver::CudaStream::launch_builder`] API and collects
306/// `(col_0, col_1, …, col_{k-1}, n)` as argument slots in order.
307///
308/// Failure surfaces returned as `Err(String)` include:
309/// - invalid `block_size`
310/// - CUDA context initialization
311/// - PTX compile/load or kernel lookup
312/// - host/device transfer failures
313/// - kernel launch or stream synchronization failures
314///
315/// # Safety requirements for the PTX kernel
316/// - the kernel signature must match the launched argument list
317/// - the kernel must bounds-check against `n`
318/// - the kernel must not read or write outside the provided column buffers
319///
320/// # Arguments
321/// - `store` -- the agent store to extract from / write back to
322/// - `ptx_source` -- PTX source string (compile your `.cu` to PTX offline or embed it)
323/// - `module_name` -- name for the loaded module (unused with cudarc 0.19,
324/// kept for source-compatibility with the previous API)
325/// - `kernel_name` -- the `__global__` function name inside the PTX
326/// - `block_size` -- CUDA threads per block (e.g. 256)
327#[cfg(feature = "cuda")]
328pub fn cuda_batch_step<A, S>(
329 store: &S,
330 ptx_source: &str,
331 _module_name: &str,
332 kernel_name: &str,
333 block_size: u32,
334) -> Result<AccelStepResult, String>
335where
336 A: SoaExtractable,
337 S: AgentStore<A>,
338{
339 use cudarc::driver::{LaunchConfig, PushKernelArg};
340
341 if block_size == 0 {
342 return Err("block_size must be positive".to_string());
343 }
344
345 // CudaContext owns the device; streams are scheduled off it.
346 let ctx = crate::cuda_context::new_context(0)?;
347 let stream = ctx.default_stream();
348
349 // Extract SoA
350 let (ids, mut columns) = soa::extract_soa::<A, S>(store);
351 let n = ids.len();
352 if n == 0 {
353 return Ok(AccelStepResult {
354 backend: ComputeBackend::Cuda,
355 agent_count: 0,
356 kernel_us: 0,
357 });
358 }
359
360 // Compile/load PTX module and look up the kernel.
361 let ptx = cudarc::nvrtc::Ptx::from_src(ptx_source);
362 let module = ctx
363 .load_module(ptx)
364 .map_err(|e| format!("PTX load failed: {e}"))?;
365 let func = module
366 .load_function(kernel_name)
367 .map_err(|e| format!("kernel '{kernel_name}' not found: {e}"))?;
368
369 // Upload columns to device on the compute stream.
370 let mut d_columns = Vec::with_capacity(columns.len());
371 for col in &columns {
372 let d_col = stream
373 .clone_htod(col.as_slice())
374 .map_err(|e| format!("htod failed: {e}"))?;
375 d_columns.push(d_col);
376 }
377
378 // Build launch config.
379 let grid_size = n.div_ceil(block_size as usize) as u32;
380 let cfg = LaunchConfig {
381 grid_dim: (grid_size, 1, 1),
382 block_dim: (block_size, 1, 1),
383 shared_mem_bytes: 0,
384 };
385
386 let n_u32 = n as u32;
387
388 let t0 = std::time::Instant::now();
389
390 // SAFETY:
391 // - each `d_columns[i]` is a valid device buffer allocated on `stream`
392 // - `n_u32` is the logical row count passed to the kernel for bounds checking
393 // - callers are responsible for providing a PTX kernel whose signature matches
394 // the argument list built below and whose implementation does not access out
395 // of bounds
396 // - `block_size > 0` has been validated above
397 // - the `launch_builder` / `launch` pair schedules work on `stream` only, and
398 // `stream.synchronize()` is called below before any host-visible read
399 unsafe {
400 let mut builder = stream.launch_builder(&func);
401 for d in d_columns.iter_mut() {
402 builder.arg(d);
403 }
404 builder.arg(&n_u32);
405 builder
406 .launch(cfg)
407 .map_err(|e| format!("kernel launch failed: {e}"))?;
408 }
409
410 stream
411 .synchronize()
412 .map_err(|e| format!("stream sync failed: {e}"))?;
413 let kernel_us = t0.elapsed().as_micros();
414
415 // Download results.
416 for (i, d_col) in d_columns.iter().enumerate() {
417 stream
418 .memcpy_dtoh(d_col, &mut columns[i])
419 .map_err(|e| format!("dtoh failed: {e}"))?;
420 }
421
422 // Write back
423 soa::write_back_soa::<A, S>(store, &ids, &columns);
424
425 Ok(AccelStepResult {
426 backend: ComputeBackend::Cuda,
427 agent_count: n,
428 kernel_us,
429 })
430}
431
432/// CUDA batch step over SoA columns using **pinned host memory and dedicated
433/// non-default CUDA streams** for host/device transfer overlap.
434///
435/// Same contract as [`cuda_batch_step`] but:
436/// - SoA columns are staged through page-locked (pinned) host buffers
437/// allocated via [`cudarc::driver::CudaContext::alloc_pinned`], letting the
438/// driver issue truly asynchronous `memcpy_htod` / `memcpy_dtoh`.
439/// - Host-to-device uploads and device-to-host downloads run on a dedicated
440/// *copy* stream, while the kernel launch runs on a dedicated *compute*
441/// stream; the two are serialized via
442/// [`cudarc::driver::CudaStream::join`]. This is the standard CUDA pattern
443/// for overlapping transfer and compute across successive steps.
444///
445/// A single invocation still runs in-order on the host timeline; the benefit
446/// materializes when multiple kernels are scheduled back-to-back and the
447/// driver is free to overlap the download of step *N* with the upload of
448/// step *N+1*. For persistent device-side data use
449/// [`crate::device_store::DeviceSoaStore::step_cuda_pinned`].
450///
451/// Failure surfaces returned as `Err(String)` are identical to
452/// [`cuda_batch_step`] plus pinned-allocation and stream-creation failures.
453///
454/// # Safety requirements for the PTX kernel
455/// - the kernel signature must match the launched argument list
456/// - the kernel must bounds-check against `n`
457/// - the kernel must not read or write outside the provided column buffers
458///
459/// # Arguments
460/// Same as [`cuda_batch_step`].
461#[cfg(feature = "cuda")]
462pub fn cuda_batch_step_pinned<A, S>(
463 store: &S,
464 ptx_source: &str,
465 _module_name: &str,
466 kernel_name: &str,
467 block_size: u32,
468) -> Result<AccelStepResult, String>
469where
470 A: SoaExtractable,
471 S: AgentStore<A>,
472{
473 use cudarc::driver::{LaunchConfig, PushKernelArg};
474
475 if block_size == 0 {
476 return Err("block_size must be positive".to_string());
477 }
478
479 let ctx = crate::cuda_context::new_context(0)?;
480 let copy_stream = ctx
481 .new_stream()
482 .map_err(|e| format!("copy stream init failed: {e}"))?;
483 let compute_stream = ctx
484 .new_stream()
485 .map_err(|e| format!("compute stream init failed: {e}"))?;
486
487 let (ids, mut columns) = soa::extract_soa::<A, S>(store);
488 let n = ids.len();
489 if n == 0 {
490 return Ok(AccelStepResult {
491 backend: ComputeBackend::Cuda,
492 agent_count: 0,
493 kernel_us: 0,
494 });
495 }
496
497 let ptx = cudarc::nvrtc::Ptx::from_src(ptx_source);
498 let module = ctx
499 .load_module(ptx)
500 .map_err(|e| format!("PTX load failed: {e}"))?;
501 let func = module
502 .load_function(kernel_name)
503 .map_err(|e| format!("kernel '{kernel_name}' not found: {e}"))?;
504
505 // Allocate pinned host staging and fill from extracted SoA columns.
506 let mut pinned: Vec<cudarc::driver::PinnedHostSlice<f32>> = Vec::with_capacity(columns.len());
507 for col in &columns {
508 // SAFETY: `alloc_pinned` returns uninitialized pinned host memory.
509 // We immediately fill the entire slice via `copy_from_slice` before
510 // any read, so no uninitialized byte is ever observed.
511 let mut p = unsafe { ctx.alloc_pinned::<f32>(col.len()) }
512 .map_err(|e| format!("pinned alloc failed: {e}"))?;
513 p.as_mut_slice()
514 .map_err(|e| format!("pinned access failed: {e}"))?
515 .copy_from_slice(col);
516 pinned.push(p);
517 }
518
519 // Host -> device on the copy stream.
520 let mut d_columns: Vec<cudarc::driver::CudaSlice<f32>> = Vec::with_capacity(pinned.len());
521 for p in &pinned {
522 let d = copy_stream
523 .clone_htod(p)
524 .map_err(|e| format!("htod failed: {e}"))?;
525 d_columns.push(d);
526 }
527
528 compute_stream
529 .join(©_stream)
530 .map_err(|e| format!("compute.join(copy) failed: {e}"))?;
531
532 let grid_size = n.div_ceil(block_size as usize) as u32;
533 let cfg = LaunchConfig {
534 grid_dim: (grid_size, 1, 1),
535 block_dim: (block_size, 1, 1),
536 shared_mem_bytes: 0,
537 };
538
539 let n_u32 = n as u32;
540
541 let t0 = std::time::Instant::now();
542
543 // SAFETY:
544 // - each `d_columns[i]` is a valid device buffer allocated on `copy_stream`
545 // and made visible to `compute_stream` via `compute_stream.join(©_stream)`
546 // - `n_u32` is the logical row count passed to the kernel for bounds checking
547 // - callers are responsible for providing a PTX kernel whose signature matches
548 // the argument list built below and whose implementation does not access out
549 // of bounds
550 // - `block_size > 0` has been validated above
551 // - work is scheduled on `compute_stream`; the copy stream waits on it via
552 // `copy_stream.join(&compute_stream)` below before issuing dtoh, and the
553 // copy stream is then synchronized before any host-visible read
554 unsafe {
555 let mut builder = compute_stream.launch_builder(&func);
556 for d in d_columns.iter_mut() {
557 builder.arg(d);
558 }
559 builder.arg(&n_u32);
560 builder
561 .launch(cfg)
562 .map_err(|e| format!("kernel launch failed: {e}"))?;
563 }
564
565 copy_stream
566 .join(&compute_stream)
567 .map_err(|e| format!("copy.join(compute) failed: {e}"))?;
568
569 for (i, d_col) in d_columns.iter().enumerate() {
570 copy_stream
571 .memcpy_dtoh(d_col, &mut pinned[i])
572 .map_err(|e| format!("dtoh failed: {e}"))?;
573 }
574
575 copy_stream
576 .synchronize()
577 .map_err(|e| format!("stream sync failed: {e}"))?;
578 let kernel_us = t0.elapsed().as_micros();
579
580 for (i, p) in pinned.iter().enumerate() {
581 columns[i].copy_from_slice(
582 p.as_slice()
583 .map_err(|e| format!("pinned readback failed: {e}"))?,
584 );
585 }
586
587 soa::write_back_soa::<A, S>(store, &ids, &columns);
588
589 Ok(AccelStepResult {
590 backend: ComputeBackend::Cuda,
591 agent_count: n,
592 kernel_us,
593 })
594}
595
596/// Automatically choose CUDA or CPU for a batch step.
597///
598/// If the `cuda` feature is enabled and a device is found, uses `cuda_batch_step`.
599/// Otherwise falls back to `cpu_batch_step`.
600///
601/// If CUDA is selected but fails at runtime (device init, PTX load, launch,
602/// synchronization, transfer, or invalid CUDA configuration), the function logs
603/// a warning and continues on CPU.
604pub fn auto_batch_step<A, S, F>(
605 store: &S,
606 cpu_kernel: F,
607 #[cfg(feature = "cuda")] ptx_source: &str,
608 #[cfg(feature = "cuda")] module_name: &str,
609 #[cfg(feature = "cuda")] kernel_name: &str,
610 #[cfg(feature = "cuda")] block_size: u32,
611) -> AccelStepResult
612where
613 A: SoaExtractable,
614 S: AgentStore<A>,
615 F: FnMut(&mut [Vec<f32>], usize),
616{
617 #[cfg(feature = "cuda")]
618 {
619 if detect_backend() == ComputeBackend::Cuda {
620 match cuda_batch_step::<A, S>(store, ptx_source, module_name, kernel_name, block_size) {
621 Ok(result) => return result,
622 Err(e) => {
623 tracing::warn!(error = %e, "CUDA batch step failed, falling back to CPU");
624 }
625 }
626 }
627 }
628
629 cpu_batch_step::<A, S, F>(store, cpu_kernel)
630}
631
632/// Step a [`DeviceSoaStore`](crate::device_store::DeviceSoaStore) using CUDA or CPU.
633///
634/// Unlike `auto_batch_step`, this operates on persistent SoA storage,
635/// avoiding the extract/write-back cycle each step.
636///
637/// Returns the kernel time in microseconds and the backend used.
638///
639/// If CUDA is selected but fails at runtime, the function logs a warning and
640/// continues on the CPU path over the persistent SoA buffers.
641pub fn auto_device_step(
642 device: &mut crate::device_store::DeviceSoaStore,
643 mut cpu_kernel: impl FnMut(&mut [Vec<f32>], usize),
644 #[cfg(feature = "cuda")] ptx_source: &str,
645 #[cfg(feature = "cuda")] module_name: &str,
646 #[cfg(feature = "cuda")] kernel_name: &str,
647 #[cfg(feature = "cuda")] block_size: u32,
648) -> AccelStepResult {
649 #[cfg(feature = "cuda")]
650 {
651 if detect_backend() == ComputeBackend::Cuda {
652 match device.step_cuda(ptx_source, module_name, kernel_name, block_size) {
653 Ok(kernel_us) => {
654 return AccelStepResult {
655 backend: ComputeBackend::Cuda,
656 agent_count: device.agent_count(),
657 kernel_us,
658 };
659 }
660 Err(e) => {
661 tracing::warn!(error = %e, "CUDA device step failed, falling back to CPU");
662 }
663 }
664 }
665 }
666
667 let kernel_us = device.step_cpu(&mut cpu_kernel);
668 AccelStepResult {
669 backend: ComputeBackend::Cpu,
670 agent_count: device.agent_count(),
671 kernel_us,
672 }
673}