1use crate::agent::Agent;
17use crate::messaging::{BruteForceMessages, SpatialMessages2D, SpatialMessages3D};
18use crate::store::AgentStore;
19use crate::types::AgentId;
20
21#[derive(Debug, Clone)]
23pub struct TwoPhaseResult {
24 pub output_us: u128,
26 pub finalize_us: u128,
28 pub input_us: u128,
30 pub message_count: usize,
32 pub agent_count: usize,
34}
35
36pub fn two_phase_brute_force<A, S, M, FOut, FIn>(
45 store: &S,
46 ids: &[AgentId],
47 messages: &mut BruteForceMessages<M>,
48 mut output_fn: FOut,
49 mut input_fn: FIn,
50) -> TwoPhaseResult
51where
52 A: Agent,
53 M: Clone,
54 S: AgentStore<A>,
55 FOut: FnMut(&A, &mut BruteForceMessages<M>),
56 FIn: FnMut(&mut A, &[M]),
57{
58 let t_output = std::time::Instant::now();
60 for &id in ids {
61 if let Some(agent) = store.get(id) {
62 output_fn(&*agent, messages);
63 }
64 }
65 let output_us = t_output.elapsed().as_micros();
66
67 let t_finalize = std::time::Instant::now();
69 messages.finalize();
70 let finalize_us = t_finalize.elapsed().as_micros();
71 let message_count = messages.len();
72
73 let t_input = std::time::Instant::now();
75 let all_msgs = messages.read_all();
76 for &id in ids {
77 if let Some(mut agent) = store.get_mut(id) {
78 input_fn(&mut *agent, all_msgs);
79 }
80 }
81 let input_us = t_input.elapsed().as_micros();
82
83 messages.clear();
84
85 TwoPhaseResult {
86 output_us,
87 finalize_us,
88 input_us,
89 message_count,
90 agent_count: ids.len(),
91 }
92}
93
94pub fn two_phase_spatial_2d<A, S, M, FOut, FIn>(
100 store: &S,
101 ids: &[AgentId],
102 messages: &mut SpatialMessages2D<M>,
103 mut output_fn: FOut,
104 mut input_fn: FIn,
105) -> TwoPhaseResult
106where
107 A: Agent,
108 M: Clone,
109 S: AgentStore<A>,
110 FOut: FnMut(&A, &mut SpatialMessages2D<M>),
111 FIn: FnMut(&mut A, &SpatialMessages2D<M>),
112{
113 let t_output = std::time::Instant::now();
115 for &id in ids {
116 if let Some(agent) = store.get(id) {
117 output_fn(&*agent, messages);
118 }
119 }
120 let output_us = t_output.elapsed().as_micros();
121
122 let t_finalize = std::time::Instant::now();
124 messages.finalize();
125 let finalize_us = t_finalize.elapsed().as_micros();
126 let message_count = messages.len();
127
128 let t_input = std::time::Instant::now();
130 for &id in ids {
131 if let Some(mut agent) = store.get_mut(id) {
132 input_fn(&mut *agent, messages);
133 }
134 }
135 let input_us = t_input.elapsed().as_micros();
136
137 messages.clear();
138
139 TwoPhaseResult {
140 output_us,
141 finalize_us,
142 input_us,
143 message_count,
144 agent_count: ids.len(),
145 }
146}
147
148pub fn two_phase_spatial_3d<A, S, M, FOut, FIn>(
152 store: &S,
153 ids: &[AgentId],
154 messages: &mut SpatialMessages3D<M>,
155 mut output_fn: FOut,
156 mut input_fn: FIn,
157) -> TwoPhaseResult
158where
159 A: Agent,
160 M: Clone,
161 S: AgentStore<A>,
162 FOut: FnMut(&A, &mut SpatialMessages3D<M>),
163 FIn: FnMut(&mut A, &SpatialMessages3D<M>),
164{
165 let t_output = std::time::Instant::now();
167 for &id in ids {
168 if let Some(agent) = store.get(id) {
169 output_fn(&*agent, messages);
170 }
171 }
172 let output_us = t_output.elapsed().as_micros();
173
174 let t_finalize = std::time::Instant::now();
176 messages.finalize();
177 let finalize_us = t_finalize.elapsed().as_micros();
178 let message_count = messages.len();
179
180 let t_input = std::time::Instant::now();
182 for &id in ids {
183 if let Some(mut agent) = store.get_mut(id) {
184 input_fn(&mut *agent, messages);
185 }
186 }
187 let input_us = t_input.elapsed().as_micros();
188
189 messages.clear();
190
191 TwoPhaseResult {
192 output_us,
193 finalize_us,
194 input_us,
195 message_count,
196 agent_count: ids.len(),
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::prelude::*;
204
205 #[derive(Debug, Clone)]
206 struct Boid {
207 id: AgentId,
208 x: f32,
209 y: f32,
210 vx: f32,
211 vy: f32,
212 }
213
214 impl Agent for Boid {
215 fn id(&self) -> AgentId {
216 self.id
217 }
218 }
219
220 #[derive(Debug, Clone)]
221 struct BoidMessage {
222 id: AgentId,
223 x: f32,
224 y: f32,
225 #[allow(dead_code)]
226 vx: f32,
227 #[allow(dead_code)]
228 vy: f32,
229 }
230
231 #[test]
232 fn two_phase_brute_force_boids() {
233 let mut store = HashMapStore::new();
234 for i in 1..=100 {
235 store.insert(Boid {
236 id: i,
237 x: (i as f32) * 0.01,
238 y: (i as f32) * 0.01,
239 vx: 0.001,
240 vy: 0.001,
241 });
242 }
243
244 let ids: Vec<AgentId> = store.iter_ids();
245 let mut messages = BruteForceMessages::with_capacity(100);
246
247 let result = two_phase_brute_force(
248 &store,
249 &ids,
250 &mut messages,
251 |agent: &Boid, msgs| {
252 msgs.output(BoidMessage {
253 id: agent.id,
254 x: agent.x,
255 y: agent.y,
256 vx: agent.vx,
257 vy: agent.vy,
258 });
259 },
260 |agent: &mut Boid, all_msgs| {
261 let mut sum_x = 0.0;
263 let mut sum_y = 0.0;
264 let mut count = 0;
265 for msg in all_msgs {
266 if msg.id != agent.id {
267 sum_x += msg.x;
268 sum_y += msg.y;
269 count += 1;
270 }
271 }
272 if count > 0 {
273 let avg_x = sum_x / count as f32;
274 let avg_y = sum_y / count as f32;
275 agent.vx += (avg_x - agent.x) * 0.01;
276 agent.vy += (avg_y - agent.y) * 0.01;
277 }
278 agent.x += agent.vx;
279 agent.y += agent.vy;
280 },
281 );
282
283 assert_eq!(result.message_count, 100);
284 assert_eq!(result.agent_count, 100);
285 }
286
287 #[test]
288 fn two_phase_spatial_2d_boids() {
289 let mut store = HashMapStore::new();
290 for i in 1..=50 {
291 store.insert(Boid {
292 id: i,
293 x: (i as f32) * 0.1,
294 y: (i as f32) * 0.1,
295 vx: 0.001,
296 vy: 0.001,
297 });
298 }
299
300 let ids: Vec<AgentId> = store.iter_ids();
301 let mut messages = SpatialMessages2D::new(1.0).unwrap();
302
303 let result = two_phase_spatial_2d(
304 &store,
305 &ids,
306 &mut messages,
307 |agent: &Boid, msgs| {
308 msgs.output(
309 BoidMessage {
310 id: agent.id,
311 x: agent.x,
312 y: agent.y,
313 vx: agent.vx,
314 vy: agent.vy,
315 },
316 agent.x,
317 agent.y,
318 );
319 },
320 |agent: &mut Boid, msgs| {
321 let nearby: Vec<_> = msgs.read_nearby(agent.x, agent.y, 1.0).collect();
323 let mut sum_x = 0.0;
324 let mut count = 0;
325 for (msg, _dist_sq) in &nearby {
326 if msg.id != agent.id {
327 sum_x += msg.x;
328 count += 1;
329 }
330 }
331 if count > 0 {
332 agent.vx += (sum_x / count as f32 - agent.x) * 0.01;
333 }
334 agent.x += agent.vx;
335 agent.y += agent.vy;
336 },
337 );
338
339 assert_eq!(result.message_count, 50);
340 assert_eq!(result.agent_count, 50);
341 }
342}