sparkl3d_kernels/cuda/
reset_grid.rs

1use crate::gpu_grid::{GpuGrid, GpuGridNode};
2use crate::BlockHeaderId;
3use cuda_std::thread;
4use cuda_std::*;
5use na::vector;
6
7// NOTE: the number of threads must be 4x4x4 (3D) or 4x4 (2D)
8#[cfg_attr(target_os = "cuda", kernel)]
9pub unsafe fn reset_grid(mut next_grid: GpuGrid) {
10    let bid = BlockHeaderId(thread::block_idx_x());
11    #[cfg(feature = "dim2")]
12    let shift = vector![
13        thread::thread_idx_x() as usize,
14        thread::thread_idx_y() as usize
15    ];
16    #[cfg(feature = "dim3")]
17    let shift = vector![
18        thread::thread_idx_x() as usize,
19        thread::thread_idx_y() as usize,
20        thread::thread_idx_z() as usize
21    ];
22
23    let node_id = bid.to_physical().node_id_unchecked(shift);
24    if let Some(cell) = next_grid.get_node_mut(node_id) {
25        *cell = GpuGridNode::default();
26    }
27}
28
29// NOTE: the number of threads must be 4x4x4 (3D) or 4x4 (2D)
30#[cfg_attr(target_os = "cuda", kernel)]
31pub unsafe fn copy_grid_projection_data(prev_grid: GpuGrid, mut next_grid: GpuGrid) {
32    let next_bid = BlockHeaderId(thread::block_idx_x());
33    #[cfg(feature = "dim2")]
34    let shift = vector![
35        thread::thread_idx_x() as usize,
36        thread::thread_idx_y() as usize
37    ];
38    #[cfg(feature = "dim3")]
39    let shift = vector![
40        thread::thread_idx_x() as usize,
41        thread::thread_idx_y() as usize,
42        thread::thread_idx_z() as usize
43    ];
44
45    let next_block = next_grid.active_block_unchecked(next_bid);
46    if let Some(prev_bid) = prev_grid.get_header_block_id(next_block.virtual_id) {
47        let next_node_id = next_bid.to_physical().node_id_unchecked(shift);
48        let prev_node_id = prev_bid.to_physical().node_id_unchecked(shift);
49
50        if let (Some(prev_cell), Some(next_cell)) = (
51            prev_grid.get_node(prev_node_id),
52            next_grid.get_node_mut(next_node_id),
53        ) {
54            next_cell.prev_mass = prev_cell.mass;
55            next_cell.projection_status = prev_cell.projection_status;
56            next_cell.collision_normal = prev_cell.collision_normal;
57            next_cell.projection_scaled_dir = prev_cell.projection_scaled_dir;
58        }
59    }
60}