Skip to main content

rustsim_core/
two_phase.rs

1//! Two-phase stepping utilities inspired by FlameGPU2.
2//!
3//! FlameGPU2's stepping model separates agent computation into distinct
4//! phases via the layer system:
5//!
6//! 1. **Output phase**: agents write messages (no reading)
7//! 2. **Spatial index build**: messages are spatially sorted (PBM construction)
8//! 3. **Input phase**: agents read messages and update state
9//!
10//! This module provides helper functions ([`two_phase_brute_force`],
11//! [`two_phase_spatial_2d`], [`two_phase_spatial_3d`]) that automate
12//! this pattern for use with [`StandardModel`] and spatial messaging.
13//!
14//! [`StandardModel`]: crate::standard::StandardModel
15
16use crate::agent::Agent;
17use crate::messaging::{BruteForceMessages, SpatialMessages2D, SpatialMessages3D};
18use crate::store::AgentStore;
19use crate::types::AgentId;
20
21/// Result of a two-phase step.
22#[derive(Debug, Clone)]
23pub struct TwoPhaseResult {
24    /// Time for the output phase in microseconds.
25    pub output_us: u128,
26    /// Time for the finalize (spatial index build) phase in microseconds.
27    pub finalize_us: u128,
28    /// Time for the input phase in microseconds.
29    pub input_us: u128,
30    /// Number of messages produced.
31    pub message_count: usize,
32    /// Number of agents processed.
33    pub agent_count: usize,
34}
35
36/// Execute a two-phase step using brute-force messaging.
37///
38/// - `output_fn`: called once per agent; should push messages.
39/// - `input_fn`: called once per agent after all messages are available.
40///
41/// This pattern decouples reading from writing, enabling future GPU
42/// parallelization where all agents output concurrently before any
43/// agent reads.
44pub fn two_phase_brute_force<A, S, M, FOut, FIn>(
45    store: &S,
46    ids: &[AgentId],
47    messages: &mut BruteForceMessages<M>,
48    mut output_fn: FOut,
49    mut input_fn: FIn,
50) -> TwoPhaseResult
51where
52    A: Agent,
53    M: Clone,
54    S: AgentStore<A>,
55    FOut: FnMut(&A, &mut BruteForceMessages<M>),
56    FIn: FnMut(&mut A, &[M]),
57{
58    // Phase 1: Output
59    let t_output = std::time::Instant::now();
60    for &id in ids {
61        if let Some(agent) = store.get(id) {
62            output_fn(&*agent, messages);
63        }
64    }
65    let output_us = t_output.elapsed().as_micros();
66
67    // Finalize
68    let t_finalize = std::time::Instant::now();
69    messages.finalize();
70    let finalize_us = t_finalize.elapsed().as_micros();
71    let message_count = messages.len();
72
73    // Phase 2: Input
74    let t_input = std::time::Instant::now();
75    let all_msgs = messages.read_all();
76    for &id in ids {
77        if let Some(mut agent) = store.get_mut(id) {
78            input_fn(&mut *agent, all_msgs);
79        }
80    }
81    let input_us = t_input.elapsed().as_micros();
82
83    messages.clear();
84
85    TwoPhaseResult {
86        output_us,
87        finalize_us,
88        input_us,
89        message_count,
90        agent_count: ids.len(),
91    }
92}
93
94/// Execute a two-phase step using 2D spatial messaging.
95///
96/// - `output_fn`: called once per agent; should push messages with positions.
97/// - `input_fn`: called once per agent after the spatial index is built.
98///   Receives the agent and the spatial message list for neighbor queries.
99pub fn two_phase_spatial_2d<A, S, M, FOut, FIn>(
100    store: &S,
101    ids: &[AgentId],
102    messages: &mut SpatialMessages2D<M>,
103    mut output_fn: FOut,
104    mut input_fn: FIn,
105) -> TwoPhaseResult
106where
107    A: Agent,
108    M: Clone,
109    S: AgentStore<A>,
110    FOut: FnMut(&A, &mut SpatialMessages2D<M>),
111    FIn: FnMut(&mut A, &SpatialMessages2D<M>),
112{
113    // Phase 1: Output
114    let t_output = std::time::Instant::now();
115    for &id in ids {
116        if let Some(agent) = store.get(id) {
117            output_fn(&*agent, messages);
118        }
119    }
120    let output_us = t_output.elapsed().as_micros();
121
122    // Finalize (build PBM-like spatial index)
123    let t_finalize = std::time::Instant::now();
124    messages.finalize();
125    let finalize_us = t_finalize.elapsed().as_micros();
126    let message_count = messages.len();
127
128    // Phase 2: Input
129    let t_input = std::time::Instant::now();
130    for &id in ids {
131        if let Some(mut agent) = store.get_mut(id) {
132            input_fn(&mut *agent, messages);
133        }
134    }
135    let input_us = t_input.elapsed().as_micros();
136
137    messages.clear();
138
139    TwoPhaseResult {
140        output_us,
141        finalize_us,
142        input_us,
143        message_count,
144        agent_count: ids.len(),
145    }
146}
147
148/// Execute a two-phase step using 3D spatial messaging.
149///
150/// Same as [`two_phase_spatial_2d`] but for 3D spatial messages.
151pub fn two_phase_spatial_3d<A, S, M, FOut, FIn>(
152    store: &S,
153    ids: &[AgentId],
154    messages: &mut SpatialMessages3D<M>,
155    mut output_fn: FOut,
156    mut input_fn: FIn,
157) -> TwoPhaseResult
158where
159    A: Agent,
160    M: Clone,
161    S: AgentStore<A>,
162    FOut: FnMut(&A, &mut SpatialMessages3D<M>),
163    FIn: FnMut(&mut A, &SpatialMessages3D<M>),
164{
165    // Phase 1: Output
166    let t_output = std::time::Instant::now();
167    for &id in ids {
168        if let Some(agent) = store.get(id) {
169            output_fn(&*agent, messages);
170        }
171    }
172    let output_us = t_output.elapsed().as_micros();
173
174    // Finalize
175    let t_finalize = std::time::Instant::now();
176    messages.finalize();
177    let finalize_us = t_finalize.elapsed().as_micros();
178    let message_count = messages.len();
179
180    // Phase 2: Input
181    let t_input = std::time::Instant::now();
182    for &id in ids {
183        if let Some(mut agent) = store.get_mut(id) {
184            input_fn(&mut *agent, messages);
185        }
186    }
187    let input_us = t_input.elapsed().as_micros();
188
189    messages.clear();
190
191    TwoPhaseResult {
192        output_us,
193        finalize_us,
194        input_us,
195        message_count,
196        agent_count: ids.len(),
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::prelude::*;
204
205    #[derive(Debug, Clone)]
206    struct Boid {
207        id: AgentId,
208        x: f32,
209        y: f32,
210        vx: f32,
211        vy: f32,
212    }
213
214    impl Agent for Boid {
215        fn id(&self) -> AgentId {
216            self.id
217        }
218    }
219
220    #[derive(Debug, Clone)]
221    struct BoidMessage {
222        id: AgentId,
223        x: f32,
224        y: f32,
225        #[allow(dead_code)]
226        vx: f32,
227        #[allow(dead_code)]
228        vy: f32,
229    }
230
231    #[test]
232    fn two_phase_brute_force_boids() {
233        let mut store = HashMapStore::new();
234        for i in 1..=100 {
235            store.insert(Boid {
236                id: i,
237                x: (i as f32) * 0.01,
238                y: (i as f32) * 0.01,
239                vx: 0.001,
240                vy: 0.001,
241            });
242        }
243
244        let ids: Vec<AgentId> = store.iter_ids();
245        let mut messages = BruteForceMessages::with_capacity(100);
246
247        let result = two_phase_brute_force(
248            &store,
249            &ids,
250            &mut messages,
251            |agent: &Boid, msgs| {
252                msgs.output(BoidMessage {
253                    id: agent.id,
254                    x: agent.x,
255                    y: agent.y,
256                    vx: agent.vx,
257                    vy: agent.vy,
258                });
259            },
260            |agent: &mut Boid, all_msgs| {
261                // Simple cohesion: average position of all others
262                let mut sum_x = 0.0;
263                let mut sum_y = 0.0;
264                let mut count = 0;
265                for msg in all_msgs {
266                    if msg.id != agent.id {
267                        sum_x += msg.x;
268                        sum_y += msg.y;
269                        count += 1;
270                    }
271                }
272                if count > 0 {
273                    let avg_x = sum_x / count as f32;
274                    let avg_y = sum_y / count as f32;
275                    agent.vx += (avg_x - agent.x) * 0.01;
276                    agent.vy += (avg_y - agent.y) * 0.01;
277                }
278                agent.x += agent.vx;
279                agent.y += agent.vy;
280            },
281        );
282
283        assert_eq!(result.message_count, 100);
284        assert_eq!(result.agent_count, 100);
285    }
286
287    #[test]
288    fn two_phase_spatial_2d_boids() {
289        let mut store = HashMapStore::new();
290        for i in 1..=50 {
291            store.insert(Boid {
292                id: i,
293                x: (i as f32) * 0.1,
294                y: (i as f32) * 0.1,
295                vx: 0.001,
296                vy: 0.001,
297            });
298        }
299
300        let ids: Vec<AgentId> = store.iter_ids();
301        let mut messages = SpatialMessages2D::new(1.0).unwrap();
302
303        let result = two_phase_spatial_2d(
304            &store,
305            &ids,
306            &mut messages,
307            |agent: &Boid, msgs| {
308                msgs.output(
309                    BoidMessage {
310                        id: agent.id,
311                        x: agent.x,
312                        y: agent.y,
313                        vx: agent.vx,
314                        vy: agent.vy,
315                    },
316                    agent.x,
317                    agent.y,
318                );
319            },
320            |agent: &mut Boid, msgs| {
321                // Read nearby messages
322                let nearby: Vec<_> = msgs.read_nearby(agent.x, agent.y, 1.0).collect();
323                let mut sum_x = 0.0;
324                let mut count = 0;
325                for (msg, _dist_sq) in &nearby {
326                    if msg.id != agent.id {
327                        sum_x += msg.x;
328                        count += 1;
329                    }
330                }
331                if count > 0 {
332                    agent.vx += (sum_x / count as f32 - agent.x) * 0.01;
333                }
334                agent.x += agent.vx;
335                agent.y += agent.vy;
336            },
337        );
338
339        assert_eq!(result.message_count, 50);
340        assert_eq!(result.agent_count, 50);
341    }
342}