1use crate::cuda::atomic::{AtomicAdd, AtomicInt};
2use crate::cuda::{DefaultParticleUpdater, ParticleUpdater};
3use crate::gpu_grid::{GpuGrid, GpuGridProjectionStatus};
4use crate::{
5    BlockVirtualId, GpuCollider, GpuColliderSet, GpuParticleModel, NBH_SHIFTS, NBH_SHIFTS_SHARED,
6    NUM_CELL_PER_BLOCK,
7};
8use cuda_std::thread;
9use cuda_std::*;
10use nalgebra::vector;
11use sparkl_core::math::{Kernel, Matrix, Real, Vector};
12use sparkl_core::prelude::{
13    DamageModel, ParticlePhase, ParticlePosition, ParticleStatus, ParticleVelocity, ParticleVolume,
14};
15
16#[cfg(feature = "dim2")]
17const NUM_SHARED_CELLS: usize = (4 * 4) * (2 * 2);
18#[cfg(feature = "dim3")]
19const NUM_SHARED_CELLS: usize = (4 * 4 * 4) * (2 * 2 * 2);
20const FREE: u32 = u32::MAX;
21
22struct GridGatherData {
23    prev_mass: Real,
24    mass: Real,
25    momentum: Vector<Real>,
26    velocity: Vector<Real>,
27    psi_mass: Real,
28    psi_momentum: Real,
29    psi_velocity: Real,
30    projection_scaled_dir: Vector<Real>,
31    projection_status: GpuGridProjectionStatus,
32    lock: u32,
40}
41
42pub struct InterpolatedParticleData {
43    pub velocity: Vector<Real>,
44    pub velocity_gradient: Matrix<Real>,
45    pub psi_pos_momentum: Real,
46    pub velocity_gradient_det: Real,
47    pub projection_scaled_dir: Vector<Real>,
48    pub projection_status: GpuGridProjectionStatus,
49}
50
51impl Default for InterpolatedParticleData {
52    fn default() -> Self {
53        Self {
54            velocity: na::zero(),
55            velocity_gradient: na::zero(),
56            psi_pos_momentum: na::zero(),
57            velocity_gradient_det: na::zero(),
58            projection_scaled_dir: na::zero(),
59            projection_status: GpuGridProjectionStatus::NotComputed,
60        }
61    }
62}
63
64#[cfg_attr(target_os = "cuda", kernel)]
65pub unsafe fn g2p2g(
66    dt: Real,
67    colliders_ptr: *const GpuCollider,
68    num_colliders: usize,
69    particles_status: *mut ParticleStatus,
70    particles_pos: *mut ParticlePosition,
71    particles_vel: *mut ParticleVelocity,
72    particles_volume: *mut ParticleVolume,
73    particles_phase: *mut ParticlePhase,
74    sorted_particle_ids: *const u32,
75    models: *mut GpuParticleModel,
76    curr_grid: GpuGrid,
77    next_grid: GpuGrid,
78    damage_model: DamageModel,
79    halo: bool,
80) {
81    g2p2g_generic(
82        dt,
83        colliders_ptr,
84        num_colliders,
85        particles_status,
86        particles_pos,
87        particles_vel,
88        particles_volume,
89        particles_phase,
90        sorted_particle_ids,
91        curr_grid,
92        next_grid,
93        damage_model,
94        halo,
95        DefaultParticleUpdater { models },
96    )
97}
98
99pub unsafe fn g2p2g_generic(
101    dt: Real,
102    colliders_ptr: *const GpuCollider,
103    num_colliders: usize,
104    particles_status: *mut ParticleStatus,
105    particles_pos: *mut ParticlePosition,
106    particles_vel: *mut ParticleVelocity,
107    particles_volume: *mut ParticleVolume,
108    particles_phase: *mut ParticlePhase,
109    sorted_particle_ids: *const u32,
110    curr_grid: GpuGrid,
111    mut next_grid: GpuGrid,
112    damage_model: DamageModel,
113    halo: bool,
114    particle_updater: impl ParticleUpdater,
115) {
116    let shared_nodes = shared_array![GridGatherData; NUM_SHARED_CELLS];
117
118    let bid = thread::block_idx_x();
119    let tid = thread::thread_idx_x();
120
121    let dispatch2active = if halo {
122        next_grid.dispatch_halo_block_to_active_block
123    } else {
124        next_grid.dispatch_block_to_active_block
125    };
126
127    let collider_set = GpuColliderSet {
128        ptr: colliders_ptr,
129        len: num_colliders,
130    };
131
132    let dispatch_block_to_active_block = *dispatch2active.as_ptr().add(bid as usize);
133    let active_block =
134        *next_grid.active_block_unchecked(dispatch_block_to_active_block.active_block_id);
135
136    transfer_global_blocks_to_shared_memory(shared_nodes, &curr_grid, active_block.virtual_id);
137
138    thread::sync_threads();
140
141    if dispatch_block_to_active_block.first_particle + tid
142        < active_block.first_particle + active_block.num_particles
143    {
144        let particle_id = *sorted_particle_ids
145            .add((dispatch_block_to_active_block.first_particle + tid) as usize);
146        let mut particle_status_i = *particles_status.add(particle_id as usize);
147        let mut particle_pos_i = *particles_pos.add(particle_id as usize);
148        let mut particle_vel_i = *particles_vel.add(particle_id as usize);
149        let mut particle_volume_i = *particles_volume.add(particle_id as usize);
150        let mut particle_phase_i = *particles_phase.add(particle_id as usize);
151
152        particle_g2p2g(
153            dt,
154            particle_id,
155            &collider_set,
156            &mut particle_status_i,
157            &mut particle_pos_i,
158            &mut particle_vel_i,
159            &mut particle_volume_i,
160            &mut particle_phase_i,
161            shared_nodes,
162            next_grid.cell_width(),
163            damage_model,
164            particle_updater,
165        );
166
167        *particles_status.add(particle_id as usize) = particle_status_i;
168        *particles_pos.add(particle_id as usize) = particle_pos_i;
169        *particles_vel.add(particle_id as usize) = particle_vel_i;
170        *particles_volume.add(particle_id as usize) = particle_volume_i;
171        *particles_phase.add(particle_id as usize) = particle_phase_i;
172    }
173
174    thread::sync_threads();
176    transfer_shared_blocks_to_grid(shared_nodes, &mut next_grid, active_block.virtual_id);
177}
178
179unsafe fn particle_g2p2g(
180    dt: Real,
181    particle_id: u32,
182    colliders: &GpuColliderSet,
183    particle_status: &mut ParticleStatus,
184    particle_pos: &mut ParticlePosition,
185    particle_vel: &mut ParticleVelocity,
186    particle_volume: &mut ParticleVolume,
187    particle_phase: &mut ParticlePhase,
188    shared_nodes: *mut GridGatherData,
189    cell_width: Real,
190    _damage_model: DamageModel,
191    particle_updater: impl ParticleUpdater,
192) {
193    let tid = thread::thread_idx_x();
194    let inv_d = Kernel::inv_d(cell_width);
195
196    let ref_elt_pos_minus_particle_pos = particle_pos.dir_to_associated_grid_node(cell_width);
197
198    let mut interpolated_data = InterpolatedParticleData::default();
200
201    let w = Kernel::precompute_weights(ref_elt_pos_minus_particle_pos, cell_width);
202
203    let assoc_cell_before_integration = particle_pos.point.map(|e| (e / cell_width).round());
204    let assoc_cell_index_in_block =
205        particle_pos.associated_cell_index_in_block_off_by_two(cell_width);
206    #[cfg(feature = "dim2")]
207    let packed_cell_index_in_block =
208        (assoc_cell_index_in_block.x + 1) + (assoc_cell_index_in_block.y + 1) * 8;
209    #[cfg(feature = "dim3")]
210    let packed_cell_index_in_block = (assoc_cell_index_in_block.x + 1)
211        + (assoc_cell_index_in_block.y + 1) * 8
212        + (assoc_cell_index_in_block.z + 1) * 8 * 8;
213
214    let midcell_mass = {
215        let midcell = &*shared_nodes
216            .add(packed_cell_index_in_block as usize + NBH_SHIFTS_SHARED.last().unwrap());
217        midcell.prev_mass
218    };
219
220    let artificial_pressure_stiffness = particle_updater.artificial_pressure_stiffness();
221    let mut artificial_pressure_force = Vector::zeros();
222
223    for (shift, packed_shift) in NBH_SHIFTS.iter().zip(NBH_SHIFTS_SHARED.iter()) {
224        let dpt = ref_elt_pos_minus_particle_pos + shift.cast::<Real>() * cell_width;
225        #[cfg(feature = "dim2")]
226        let weight = w[0][shift.x] * w[1][shift.y];
227        #[cfg(feature = "dim3")]
228        let weight = w[0][shift.x] * w[1][shift.y] * w[2][shift.z];
229
230        let cell = &*shared_nodes.add(packed_cell_index_in_block as usize + packed_shift);
231        interpolated_data.velocity += weight * cell.velocity;
232        interpolated_data.velocity_gradient += (weight * inv_d) * cell.velocity * dpt.transpose();
233        interpolated_data.psi_pos_momentum += weight * cell.psi_velocity;
234        interpolated_data.velocity_gradient_det += weight * cell.velocity.dot(&dpt) * inv_d;
235
236        if artificial_pressure_stiffness != 0.0
239            && !particle_status.is_static
240            && cell.projection_status.is_outside()
241        {
242            artificial_pressure_force +=
243                weight * (midcell_mass - cell.prev_mass) * dpt * artificial_pressure_stiffness;
244        }
245    }
246
247    {
248        let shift = NBH_SHIFTS[NBH_SHIFTS.len() - 1];
249        let packed_shift = NBH_SHIFTS_SHARED[NBH_SHIFTS_SHARED.len() - 1];
250        let dpt = ref_elt_pos_minus_particle_pos + shift.cast::<Real>() * cell_width;
251        let cell = &*shared_nodes.add(packed_cell_index_in_block as usize + packed_shift);
252
253        let proj_norm = cell.projection_scaled_dir.norm();
254
255        if proj_norm > 1.0e-5 {
256            let normal = cell.projection_scaled_dir / proj_norm;
257            interpolated_data.projection_scaled_dir =
258                cell.projection_scaled_dir - normal * dpt.dot(&normal);
259
260            if interpolated_data.projection_scaled_dir.dot(&normal) < 0.0 {
261                interpolated_data.projection_status = cell.projection_status.flip();
262            } else {
263                interpolated_data.projection_status = cell.projection_status;
264            }
265        }
266    }
267
268    if let Some((stress, force)) = particle_updater.update_particle_and_compute_kirchhoff_stress(
269        dt,
270        cell_width,
271        colliders,
272        particle_id,
273        particle_status,
274        particle_pos,
275        particle_vel,
276        particle_volume,
277        particle_phase,
278        &mut interpolated_data,
279    ) {
280        let inv_d = Kernel::inv_d(cell_width);
281        let ref_elt_pos_minus_particle_pos = particle_pos.dir_to_associated_grid_node(cell_width);
282        let w = Kernel::precompute_weights(ref_elt_pos_minus_particle_pos, cell_width);
283
284        let affine = particle_volume.mass * interpolated_data.velocity_gradient
285            - (particle_volume.volume0 * inv_d * dt) * stress;
286        let momentum =
287            particle_volume.mass * particle_vel.vector + (force + artificial_pressure_force) * dt;
288
289        let psi_mass = if particle_phase.phase > 0.0 && !particle_status.failed {
290            particle_volume.mass
291        } else {
292            0.0
293        };
294
295        let psi_pos_momentum = psi_mass * particle_phase.psi_pos;
296
297        let assoc_cell_after_integration = particle_pos.point.map(|e| (e / cell_width).round());
298        let particle_cell_movement =
299            (assoc_cell_after_integration - assoc_cell_before_integration).map(|e| e as i64);
300
301        #[cfg(feature = "dim2")]
302        let packed_cell_index_in_block = (packed_cell_index_in_block as i64
303            + (particle_cell_movement.x)
304            + (particle_cell_movement.y) * 8) as u32;
305        #[cfg(feature = "dim3")]
306        let packed_cell_index_in_block = (packed_cell_index_in_block as i64
307            + (particle_cell_movement.x)
308            + (particle_cell_movement.y) * 8
309            + (particle_cell_movement.z) * 8 * 8) as u32;
310
311        for (shift, packed_shift) in NBH_SHIFTS.iter().zip(NBH_SHIFTS_SHARED.iter()) {
312            let dpt = ref_elt_pos_minus_particle_pos + shift.cast::<Real>() * cell_width;
313            #[cfg(feature = "dim2")]
314            let weight = w[0][shift.x] * w[1][shift.y];
315            #[cfg(feature = "dim3")]
316            let weight = w[0][shift.x] * w[1][shift.y] * w[2][shift.z];
317
318            let added_mass = weight * particle_volume.mass;
319            let added_momentum = weight * (affine * dpt + momentum);
320            let added_psi_momentum = weight * psi_pos_momentum;
321            let added_psi_mass = weight * psi_mass;
322
323            let cell = &mut *shared_nodes.add(packed_cell_index_in_block as usize + packed_shift);
328
329            loop {
330                let old = cell.lock.shared_atomic_exch_acq(tid);
331                if old == FREE {
332                    cell.mass += added_mass;
333                    cell.momentum += added_momentum;
334                    cell.psi_momentum += added_psi_momentum;
335                    cell.psi_mass += added_psi_mass;
336                    cell.lock.shared_atomic_exch_rel(FREE);
337                    break;
338                }
339            }
340
341            }
350    }
351}
352
353unsafe fn transfer_global_blocks_to_shared_memory(
354    shared_nodes: *mut GridGatherData,
355    curr_grid: &GpuGrid,
356    active_block_vid: BlockVirtualId,
357) {
358    let tid = thread::thread_idx_x();
359    let blk_sz = thread::block_dim_x();
360
361    let num_cell_per_thread = NUM_SHARED_CELLS / blk_sz as usize;
365    let first_transfer_cell_id = tid as u64 * num_cell_per_thread as u64;
369    let octant = first_transfer_cell_id / NUM_CELL_PER_BLOCK;
373    assert!(octant < 8);
374
375    #[cfg(feature = "dim2")]
376    let octant = vector![(octant & 0b0001) >> 0, (octant & 0b0010) >> 1];
377    #[cfg(feature = "dim3")]
378    let octant = vector![
379        (octant & 0b0001) >> 0,
380        (octant & 0b0010) >> 1,
381        (octant & 0b0100) >> 2
382    ];
383
384    let base_block_pos_int = active_block_vid.unpack();
385    let octant_hid = curr_grid.get_header_block_id(BlockVirtualId::pack(
386        base_block_pos_int + octant.cast::<usize>(),
387    ));
388
389    let num_cell_per_thread = NUM_SHARED_CELLS / blk_sz as usize;
393    let first_cell_in_octant = (tid as u64 * num_cell_per_thread as u64) % NUM_CELL_PER_BLOCK;
396
397    if let Some(octant_hid) = octant_hid {
398        for id in first_cell_in_octant..first_cell_in_octant + num_cell_per_thread as u64 {
399            #[cfg(feature = "dim2")]
402            let shift_in_octant = vector![id & 0b0011, (id >> 2) & 0b0011];
403            #[cfg(feature = "dim3")]
404            let shift_in_octant = vector![id & 0b0011, (id >> 2) & 0b0011, id >> 4];
405
406            let shift_in_shared = octant * 4 + shift_in_octant;
408            #[cfg(feature = "dim2")]
410            let id_in_shared = shift_in_shared.x + shift_in_shared.y * 8;
411            #[cfg(feature = "dim3")]
412            let id_in_shared =
413                shift_in_shared.x + shift_in_shared.y * 8 + shift_in_shared.z * 8 * 8;
414            let id_in_global = octant_hid
416                .to_physical()
417                .node_id_unchecked(shift_in_octant.cast::<usize>());
418
419            let shared_node = &mut *shared_nodes.add(id_in_shared as usize);
420
421            if let Some(global_node) = curr_grid.get_node(id_in_global) {
422                shared_node.velocity = global_node.momentum_velocity;
423                shared_node.psi_velocity = global_node.psi_momentum_velocity;
424                shared_node.projection_scaled_dir = global_node.projection_scaled_dir;
425                shared_node.projection_status = global_node.projection_status;
426                shared_node.prev_mass = global_node.prev_mass;
427            } else {
428                shared_node.velocity = na::zero();
429                shared_node.psi_velocity = na::zero();
430                shared_node.projection_scaled_dir = na::zero();
431                shared_node.projection_status = GpuGridProjectionStatus::NotComputed;
432                shared_node.prev_mass = 0.0;
433            }
434
435            shared_node.psi_momentum = 0.0;
436            shared_node.psi_mass = 0.0;
437            shared_node.momentum.fill(0.0);
438            shared_node.mass = 0.0;
439            shared_node.lock = FREE;
440        }
441    } else {
442        for id in first_cell_in_octant..first_cell_in_octant + num_cell_per_thread as u64 {
445            #[cfg(feature = "dim2")]
448            let shift_in_octant = vector![id & 0b0011, (id >> 2) & 0b0011];
449            #[cfg(feature = "dim3")]
450            let shift_in_octant = vector![id & 0b0011, (id >> 2) & 0b0011, id >> 4];
451            let shift_in_shared = octant * 4 + shift_in_octant;
453            #[cfg(feature = "dim2")]
455            let id_in_shared = shift_in_shared.x + shift_in_shared.y * 8;
456            #[cfg(feature = "dim3")]
457            let id_in_shared =
458                shift_in_shared.x + shift_in_shared.y * 8 + shift_in_shared.z * 8 * 8;
459
460            let shared_node = &mut *shared_nodes.add(id_in_shared as usize);
461
462            shared_node.velocity = na::zero();
463            shared_node.psi_velocity = na::zero();
464            shared_node.psi_momentum = 0.0;
465            shared_node.psi_mass = 0.0;
466            shared_node.momentum.fill(0.0);
467            shared_node.mass = 0.0;
468            shared_node.prev_mass = 0.0;
469            shared_node.projection_scaled_dir = na::zero();
470            shared_node.projection_status = GpuGridProjectionStatus::NotComputed;
471            shared_node.lock = FREE;
472        }
473    }
474}
475
476unsafe fn transfer_shared_blocks_to_grid(
477    shared_nodes: *const GridGatherData,
478    next_grid: &mut GpuGrid,
479    active_block_vid: BlockVirtualId,
480) {
481    let tid = thread::thread_idx_x();
482    let blk_sz = thread::block_dim_x();
483
484    let num_cell_per_thread = NUM_SHARED_CELLS / blk_sz as usize;
488    let first_transfer_cell_id = tid as u64 * num_cell_per_thread as u64;
492    let octant = first_transfer_cell_id / NUM_CELL_PER_BLOCK;
496    assert!(octant < 8);
497    #[cfg(feature = "dim2")]
498    let octant = vector![(octant & 0b0001) >> 0, (octant & 0b0010) >> 1];
499    #[cfg(feature = "dim3")]
500    let octant = vector![
501        (octant & 0b0001) >> 0,
502        (octant & 0b0010) >> 1,
503        (octant & 0b0100) >> 2
504    ];
505
506    let base_block_pos_int = active_block_vid.unpack();
507    let octant_hid = next_grid
508        .get_header_block_id(BlockVirtualId::pack(
509            base_block_pos_int + octant.cast::<usize>(),
510        ))
511        .unwrap();
512
513    let num_cell_per_thread = NUM_SHARED_CELLS / blk_sz as usize;
517    let first_cell_in_octant = (tid as u64 * num_cell_per_thread as u64) % NUM_CELL_PER_BLOCK;
520
521    for id in first_cell_in_octant..first_cell_in_octant + num_cell_per_thread as u64 {
522        #[cfg(feature = "dim2")]
525        let shift_in_octant = vector![id & 0b0011, (id >> 2) & 0b0011];
526        #[cfg(feature = "dim3")]
527        let shift_in_octant = vector![id & 0b0011, (id >> 2) & 0b0011, id >> 4];
528        let shift_in_shared = octant * 4 + shift_in_octant;
530        #[cfg(feature = "dim2")]
532        let id_in_shared = shift_in_shared.x + shift_in_shared.y * 8;
533        #[cfg(feature = "dim3")]
534        let id_in_shared = shift_in_shared.x + shift_in_shared.y * 8 + shift_in_shared.z * 8 * 8;
535        let id_in_global = octant_hid
537            .to_physical()
538            .node_id_unchecked(shift_in_octant.cast::<usize>());
539
540        let shared_node = &*shared_nodes.add(id_in_shared as usize);
541
542        if let Some(global_node) = next_grid.get_node_mut(id_in_global) {
543            global_node.mass.global_red_add(shared_node.mass);
544            global_node
545                .momentum_velocity
546                .global_red_add(shared_node.momentum);
547            global_node
548                .psi_momentum_velocity
549                .global_red_add(shared_node.psi_momentum);
550            global_node.psi_mass.global_red_add(shared_node.psi_mass);
551        }
552    }
553}