1pub mod realization;
2
3pub use realization::Realization;
4
5use std::sync::atomic::{AtomicBool, Ordering};
6
7use crate::config::{SimConfig, SweepMode};
8use crate::geometry::Lattice;
9use crate::statistics::{
10 sokal_tau, AutocorrAccum, ClusterStats, Diagnostics, EquilDiagnosticAccum, OverlapAccum,
11 Statistics, SweepResult,
12};
13use crate::{clusters, mcmc, spins};
14use rayon::prelude::*;
15use validator::Validate;
16
17pub fn run_sweep_loop(
28 lattice: &Lattice,
29 real: &mut Realization,
30 n_replicas: usize,
31 n_temps: usize,
32 config: &SimConfig,
33 interrupted: &AtomicBool,
34 on_sweep: &(dyn Fn() + Sync),
35) -> Result<SweepResult, String> {
36 config.validate().map_err(|e| format!("{e}"))?;
37
38 let n_spins = lattice.n_spins;
39 let n_systems = n_replicas * n_temps;
40 let n_sweeps = config.n_sweeps;
41 let warmup_sweeps = config.warmup_sweeps;
42
43 let n_modes = config.overlap_cluster.as_ref().map_or(0, |h| h.modes.len());
44
45 if let Some(ref oc_cfg) = config.overlap_cluster {
46 let max_gs = oc_cfg.max_group_size();
47 if n_replicas < max_gs {
48 return Err(format!(
49 "overlap cluster requires n_replicas >= max group_size ({n_replicas} < {max_gs})"
50 ));
51 }
52 }
53
54 let n_pairs = n_replicas / 2;
55
56 let mut fk_csd_accum: Vec<Vec<u64>> = (0..n_temps).map(|_| vec![0u64; n_spins + 1]).collect();
57 let mut sw_csd_buf: Vec<Vec<u64>> = (0..n_systems).map(|_| vec![0u64; n_spins + 1]).collect();
58
59 let mut overlap_csd_accum: Vec<Vec<Vec<u64>>> = (0..n_modes)
60 .map(|_| (0..n_temps).map(|_| vec![0u64; n_spins + 1]).collect())
61 .collect();
62 let mut overlap_csd_buf: Vec<Vec<u64>> = (0..n_temps * n_pairs)
63 .map(|_| vec![0u64; n_spins + 1])
64 .collect();
65
66 let collect_top = config
67 .overlap_cluster
68 .as_ref()
69 .is_some_and(|h| h.collect_stats)
70 && n_pairs > 0;
71
72 let mut top4_accum: Vec<Vec<[f64; 4]>> =
73 (0..n_modes).map(|_| vec![[0.0; 4]; n_temps]).collect();
74 let mut top4_n: Vec<usize> = vec![0; n_modes];
75 let mut top4_buf: Vec<[u32; 4]> = if collect_top {
76 vec![[0u32; 4]; n_temps * n_pairs]
77 } else {
78 vec![]
79 };
80
81 let mut overlap_call_count: usize = 0;
82
83 let mut mags_stat = Statistics::new(n_temps, 1);
84 let mut mags2_stat = Statistics::new(n_temps, 1);
85 let mut mags4_stat = Statistics::new(n_temps, 1);
86 let mut energies_stat = Statistics::new(n_temps, 1);
87 let mut energies2_stat = Statistics::new(n_temps, 2);
88 let n_measurement_sweeps = n_sweeps.saturating_sub(warmup_sweeps);
89 let ac_max_lag = config
90 .autocorrelation_max_lag
91 .map(|k| k.min(n_measurement_sweeps / 4).max(1));
92 let mut m2_accum = ac_max_lag.map(|k| AutocorrAccum::new(k, n_temps));
93 let mut q2_accum = if ac_max_lag.is_some() && n_pairs > 0 {
94 ac_max_lag.map(|k| AutocorrAccum::new(k, n_temps))
95 } else {
96 None
97 };
98 let collect_ac = ac_max_lag.is_some();
99 let collect_q2_ac = q2_accum.is_some();
100 let mut m2_ac_buf = if collect_ac {
101 vec![0.0f64; n_temps]
102 } else {
103 vec![]
104 };
105
106 let equil_diag = config.equilibration_diagnostic;
107 let mut equil_accum = if equil_diag {
108 Some(EquilDiagnosticAccum::new(n_temps, n_sweeps))
109 } else {
110 None
111 };
112 let mut diag_e_buf = if equil_diag {
113 vec![0.0f32; n_temps]
114 } else {
115 vec![]
116 };
117
118 let mut ov_accum = OverlapAccum::new(
119 n_temps,
120 n_spins,
121 n_pairs,
122 lattice.n_neighbors,
123 equil_diag,
124 collect_q2_ac,
125 );
126
127 let mut mags_buf = vec![0.0f32; n_temps];
128 let mut mags2_buf = vec![0.0f32; n_temps];
129 let mut mags4_buf = vec![0.0f32; n_temps];
130 let mut energies_buf = vec![0.0f32; n_temps];
131
132 for sweep_id in 0..n_sweeps {
133 if interrupted.load(Ordering::Relaxed) {
134 return Err("interrupted".to_string());
135 }
136 on_sweep();
137 let record = sweep_id >= warmup_sweeps;
138
139 match config.sweep_mode {
140 SweepMode::Metropolis => mcmc::sweep::metropolis_sweep(
141 lattice,
142 &mut real.spins,
143 &real.couplings,
144 &real.temperatures,
145 &real.system_ids,
146 &mut real.rngs,
147 config.sequential,
148 ),
149 SweepMode::Gibbs => mcmc::sweep::gibbs_sweep(
150 lattice,
151 &mut real.spins,
152 &real.couplings,
153 &real.temperatures,
154 &real.system_ids,
155 &mut real.rngs,
156 config.sequential,
157 ),
158 }
159
160 let do_cluster = config
161 .cluster_update
162 .as_ref()
163 .is_some_and(|c| sweep_id % c.interval == 0);
164
165 if do_cluster {
166 let cluster_cfg = config.cluster_update.as_ref().unwrap();
167 let wolff = cluster_cfg.mode == crate::config::ClusterMode::Wolff;
168 let csd_out = if cluster_cfg.collect_stats && record {
169 for buf in sw_csd_buf.iter_mut() {
170 buf.fill(0);
171 }
172 Some(sw_csd_buf.as_mut_slice())
173 } else {
174 None
175 };
176
177 clusters::fk_update(
178 lattice,
179 &mut real.spins,
180 &real.couplings,
181 &real.temperatures,
182 &real.system_ids,
183 &mut real.rngs,
184 wolff,
185 csd_out,
186 config.sequential,
187 );
188
189 if cluster_cfg.collect_stats && record {
190 for (slot, buf) in sw_csd_buf.iter().enumerate() {
191 let accum = &mut fk_csd_accum[slot % n_temps];
192 for (a, &b) in accum.iter_mut().zip(buf.iter()) {
193 *a += b;
194 }
195 }
196 }
197 }
198
199 let pt_this_sweep = config
200 .pt_interval
201 .is_some_and(|interval| sweep_id % interval == 0);
202
203 if record || pt_this_sweep || equil_diag {
204 (real.energies, _) = spins::energy::compute_energies(
205 lattice,
206 &real.spins,
207 &real.couplings,
208 n_systems,
209 false,
210 );
211 }
212
213 if equil_diag {
214 diag_e_buf.fill(0.0);
215 #[allow(clippy::needless_range_loop)]
216 for r in 0..n_replicas {
217 let offset = r * n_temps;
218 for t in 0..n_temps {
219 let system_id = real.system_ids[offset + t];
220 diag_e_buf[t] += real.energies[system_id];
221 }
222 }
223 let inv = 1.0 / n_replicas as f32;
224 for v in diag_e_buf.iter_mut() {
225 *v *= inv;
226 }
227 }
228
229 if (equil_diag || record) && n_pairs > 0 {
230 ov_accum.collect(lattice, &real.spins, &real.system_ids, record);
231 }
232
233 if equil_diag {
234 if n_pairs > 0 {
235 equil_accum
236 .as_mut()
237 .unwrap()
238 .push(&diag_e_buf, &ov_accum.diag_ql_buf);
239 } else {
240 let zeros = vec![0.0f32; n_temps];
241 equil_accum.as_mut().unwrap().push(&diag_e_buf, &zeros);
242 }
243 }
244
245 if record {
246 for t in 0..n_temps {
247 mags_buf[t] = 0.0;
248 mags2_buf[t] = 0.0;
249 mags4_buf[t] = 0.0;
250 energies_buf[t] = 0.0;
251 }
252
253 if collect_ac {
254 m2_ac_buf.fill(0.0);
255 }
256
257 for r in 0..n_replicas {
258 let offset = r * n_temps;
259 for t in 0..n_temps {
260 let system_id = real.system_ids[offset + t];
261 let spin_base = system_id * n_spins;
262 let mut sum = 0i64;
263 for j in 0..n_spins {
264 sum += real.spins[spin_base + j] as i64;
265 }
266 let mag = sum as f32 / n_spins as f32;
267 let m2 = mag * mag;
268 mags_buf[t] = mag;
269 mags2_buf[t] = m2;
270 mags4_buf[t] = m2 * m2;
271 energies_buf[t] = real.energies[system_id];
272 }
273
274 if collect_ac {
275 for t in 0..n_temps {
276 m2_ac_buf[t] += mags2_buf[t] as f64;
277 }
278 }
279
280 mags_stat.update(&mags_buf);
281 mags2_stat.update(&mags2_buf);
282 mags4_stat.update(&mags4_buf);
283 energies_stat.update(&energies_buf);
284 energies2_stat.update(&energies_buf);
285 }
286
287 if let Some(ref mut acc) = m2_accum {
288 let inv = 1.0 / n_replicas as f64;
289 for v in m2_ac_buf.iter_mut() {
290 *v *= inv;
291 }
292 acc.push(&m2_ac_buf);
293 }
294
295 if let Some(ref mut acc) = q2_accum {
296 let inv = 1.0 / n_pairs as f64;
297 for v in ov_accum.q2_ac_buf.iter_mut() {
298 *v *= inv;
299 }
300 acc.push(&ov_accum.q2_ac_buf);
301 }
302 }
303
304 if let Some(ref oc_cfg) = config.overlap_cluster {
305 if sweep_id % oc_cfg.interval == 0 {
306 let mode_idx = overlap_call_count % n_modes;
307 let mode = &oc_cfg.modes[mode_idx];
308
309 let ov_csd_out = if oc_cfg.collect_stats && record {
310 for buf in overlap_csd_buf.iter_mut() {
311 buf.fill(0);
312 }
313 Some(overlap_csd_buf.as_mut_slice())
314 } else {
315 None
316 };
317
318 let top4_out = if collect_top && record {
319 for slot in top4_buf.iter_mut() {
320 *slot = [0u32; 4];
321 }
322 Some(top4_buf.as_mut_slice())
323 } else {
324 None
325 };
326
327 clusters::overlap_update(
328 lattice,
329 &mut real.spins,
330 &real.couplings,
331 &real.temperatures,
332 &real.system_ids,
333 n_replicas,
334 n_temps,
335 &mut real.pair_rngs,
336 mode,
337 oc_cfg.cluster_mode,
338 ov_csd_out,
339 top4_out,
340 config.sequential,
341 );
342
343 if oc_cfg.collect_stats && record {
344 for (slot, buf) in overlap_csd_buf.iter().enumerate() {
345 let accum = &mut overlap_csd_accum[mode_idx][slot / n_pairs];
346 for (a, &b) in accum.iter_mut().zip(buf.iter()) {
347 *a += b;
348 }
349 }
350 }
351
352 if collect_top && record {
353 for t in 0..n_temps {
354 for p in 0..n_pairs {
355 let raw = top4_buf[t * n_pairs + p];
356 for (k, &v) in raw.iter().enumerate() {
357 top4_accum[mode_idx][t][k] += v as f64 / n_spins as f64;
358 }
359 }
360 }
361 top4_n[mode_idx] += 1;
362 }
363
364 overlap_call_count += 1;
365 }
366 }
367
368 if pt_this_sweep {
369 if config.overlap_cluster.is_some() {
370 (real.energies, _) = spins::energy::compute_energies(
371 lattice,
372 &real.spins,
373 &real.couplings,
374 n_systems,
375 false,
376 );
377 }
378 for r in 0..n_replicas {
379 let offset = r * n_temps;
380 let sid_slice = &mut real.system_ids[offset..offset + n_temps];
381 let temp_slice = &real.temperatures[offset..offset + n_temps];
382 mcmc::tempering::parallel_tempering(
383 &real.energies,
384 temp_slice,
385 sid_slice,
386 n_spins,
387 &mut real.rngs[offset],
388 );
389 }
390 }
391 }
392
393 let top_cluster_sizes: Vec<Vec<[f64; 4]>> = if collect_top {
394 top4_accum
395 .iter()
396 .zip(top4_n.iter())
397 .map(|(mode_accum, &count)| {
398 if count == 0 {
399 return vec![];
400 }
401 let denom = (count * n_pairs) as f64;
402 mode_accum
403 .iter()
404 .map(|arr| {
405 [
406 arr[0] / denom,
407 arr[1] / denom,
408 arr[2] / denom,
409 arr[3] / denom,
410 ]
411 })
412 .collect()
413 })
414 .collect()
415 } else {
416 vec![]
417 };
418
419 let mags2_tau = m2_accum
420 .as_ref()
421 .map(|acc| acc.finish().iter().map(|g| sokal_tau(g)).collect())
422 .unwrap_or_default();
423 let overlap2_tau = q2_accum
424 .as_ref()
425 .map(|acc| acc.finish().iter().map(|g| sokal_tau(g)).collect())
426 .unwrap_or_default();
427
428 let equil_checkpoints = equil_accum.map(|acc| acc.finish()).unwrap_or_default();
429
430 Ok(SweepResult {
431 mags: mags_stat.average(),
432 mags2: mags2_stat.average(),
433 mags4: mags4_stat.average(),
434 energies: energies_stat.average(),
435 energies2: energies2_stat.average(),
436 overlap_stats: ov_accum.finish(),
437 cluster_stats: ClusterStats {
438 fk_csd: fk_csd_accum,
439 overlap_csd: overlap_csd_accum,
440 top_cluster_sizes,
441 },
442 diagnostics: Diagnostics {
443 mags2_tau,
444 overlap2_tau,
445 equil_checkpoints,
446 },
447 })
448}
449
450pub fn run_sweep_parallel(
456 lattice: &Lattice,
457 realizations: &mut [Realization],
458 n_replicas: usize,
459 n_temps: usize,
460 config: &SimConfig,
461 interrupted: &AtomicBool,
462 on_sweep: &(dyn Fn() + Sync),
463) -> Result<SweepResult, String> {
464 if realizations.len() == 1 {
465 return run_sweep_loop(
466 lattice,
467 &mut realizations[0],
468 n_replicas,
469 n_temps,
470 config,
471 interrupted,
472 on_sweep,
473 );
474 }
475
476 let results: Vec<Result<SweepResult, String>> = realizations
477 .par_iter_mut()
478 .map(|real| {
479 run_sweep_loop(
480 lattice,
481 real,
482 n_replicas,
483 n_temps,
484 config,
485 interrupted,
486 on_sweep,
487 )
488 })
489 .collect();
490
491 let results: Vec<SweepResult> = results.into_iter().collect::<Result<Vec<_>, _>>()?;
492 Ok(SweepResult::aggregate(&results))
493}