1pub mod realization;
2
3pub use realization::Realization;
4
5use std::sync::atomic::{AtomicBool, Ordering};
6
7use crate::config::{SimConfig, SweepMode};
8use crate::geometry::Lattice;
9use crate::statistics::{
10 sokal_tau, AutocorrAccum, ClusterSnapshot, ClusterStats, Diagnostics, EquilDiagnosticAccum,
11 OverlapAccum, Statistics, SweepResult,
12};
13use crate::{clusters, mcmc, spins};
14use rayon::prelude::*;
15use validator::Validate;
16
17#[allow(clippy::too_many_arguments)]
28pub fn run_sweep_loop(
29 lattice: &Lattice,
30 real: &mut Realization,
31 n_replicas: usize,
32 n_temps: usize,
33 config: &SimConfig,
34 interrupted: &AtomicBool,
35 on_sweep: &(dyn Fn() + Sync),
36 realization_idx: usize,
37) -> Result<SweepResult, String> {
38 config.validate().map_err(|e| format!("{e}"))?;
39
40 let n_spins = lattice.n_spins;
41 let n_systems = n_replicas * n_temps;
42 let n_sweeps = config.n_sweeps;
43 let warmup_sweeps = config.warmup_sweeps;
44
45 let n_modes = config.overlap_cluster.as_ref().map_or(0, |h| h.modes.len());
46
47 if let Some(ref oc_cfg) = config.overlap_cluster {
48 let max_gs = oc_cfg.max_group_size();
49 if n_replicas < max_gs {
50 return Err(format!(
51 "overlap cluster requires n_replicas >= max group_size ({n_replicas} < {max_gs})"
52 ));
53 }
54 }
55
56 let n_pairs = n_replicas / 2;
57
58 let mut fk_csd_accum: Vec<Vec<u64>> = (0..n_temps).map(|_| vec![0u64; n_spins + 1]).collect();
59 let mut sw_csd_buf: Vec<Vec<u64>> = (0..n_systems).map(|_| vec![0u64; n_spins + 1]).collect();
60
61 let mut overlap_csd_accum: Vec<Vec<Vec<u64>>> = (0..n_modes)
62 .map(|_| (0..n_temps).map(|_| vec![0u64; n_spins + 1]).collect())
63 .collect();
64 let mut overlap_csd_buf: Vec<Vec<u64>> = (0..n_temps * n_pairs)
65 .map(|_| vec![0u64; n_spins + 1])
66 .collect();
67
68 let collect_top = config
69 .overlap_cluster
70 .as_ref()
71 .is_some_and(|h| h.collect_stats)
72 && n_pairs > 0;
73
74 let mut top4_accum: Vec<Vec<[f64; 4]>> =
75 (0..n_modes).map(|_| vec![[0.0; 4]; n_temps]).collect();
76 let mut top4_n: Vec<usize> = vec![0; n_modes];
77 let mut top4_buf: Vec<[u32; 4]> = if collect_top {
78 vec![[0u32; 4]; n_temps * n_pairs]
79 } else {
80 vec![]
81 };
82
83 let mut overlap_call_count: usize = 0;
84
85 let snapshot_interval = if realization_idx == 0 {
86 config
87 .overlap_cluster
88 .as_ref()
89 .and_then(|oc| oc.snapshot_interval)
90 } else {
91 None
92 };
93 let n_pair_slots = n_temps * n_pairs;
94 let mut snap_buf: Vec<Vec<u32>> = if snapshot_interval.is_some() {
95 (0..n_pair_slots)
96 .map(|_| Vec::with_capacity(n_spins))
97 .collect()
98 } else {
99 vec![]
100 };
101 let mut blue_snap_buf: Vec<Vec<u32>> = if snapshot_interval.is_some() {
102 (0..n_pair_slots)
103 .map(|_| Vec::with_capacity(n_spins))
104 .collect()
105 } else {
106 vec![]
107 };
108 let mut spin_snap_buf: Vec<Vec<[Vec<i8>; 2]>> = if snapshot_interval.is_some() {
109 (0..n_pair_slots).map(|_| Vec::new()).collect()
110 } else {
111 vec![]
112 };
113 let mut sid_snap_buf: Vec<Vec<[usize; 2]>> = if snapshot_interval.is_some() {
114 (0..n_pair_slots).map(|_| Vec::new()).collect()
115 } else {
116 vec![]
117 };
118 let mut cluster_snapshots: Vec<ClusterSnapshot> = Vec::new();
119
120 let mut mags_stat = Statistics::new(n_temps, 1);
121 let mut mags2_stat = Statistics::new(n_temps, 1);
122 let mut mags4_stat = Statistics::new(n_temps, 1);
123 let mut energies_stat = Statistics::new(n_temps, 1);
124 let mut energies2_stat = Statistics::new(n_temps, 2);
125 let n_measurement_sweeps = n_sweeps.saturating_sub(warmup_sweeps);
126 let ac_max_lag = config
127 .autocorrelation_max_lag
128 .map(|k| k.min(n_measurement_sweeps / 4).max(1));
129 let mut m2_accum = ac_max_lag.map(|k| AutocorrAccum::new(k, n_temps));
130 let mut q2_accum = if ac_max_lag.is_some() && n_pairs > 0 {
131 ac_max_lag.map(|k| AutocorrAccum::new(k, n_temps))
132 } else {
133 None
134 };
135 let collect_ac = ac_max_lag.is_some();
136 let collect_q2_ac = q2_accum.is_some();
137 let mut m2_ac_buf = if collect_ac {
138 vec![0.0f64; n_temps]
139 } else {
140 vec![]
141 };
142
143 let equil_diag = config.equilibration_diagnostic;
144 let mut equil_accum = if equil_diag {
145 Some(EquilDiagnosticAccum::new(n_temps, n_sweeps))
146 } else {
147 None
148 };
149 let mut diag_e_buf = if equil_diag {
150 vec![0.0f32; n_temps]
151 } else {
152 vec![]
153 };
154
155 let mut ov_accum = OverlapAccum::new(
156 n_temps,
157 n_spins,
158 n_pairs,
159 lattice.n_neighbors,
160 equil_diag,
161 collect_q2_ac,
162 );
163
164 let mut mags_buf = vec![0.0f32; n_temps];
165 let mut mags2_buf = vec![0.0f32; n_temps];
166 let mut mags4_buf = vec![0.0f32; n_temps];
167 let mut energies_buf = vec![0.0f32; n_temps];
168
169 for sweep_id in 0..n_sweeps {
170 if interrupted.load(Ordering::Relaxed) {
171 return Err("interrupted".to_string());
172 }
173 on_sweep();
174 let record = sweep_id >= warmup_sweeps;
175
176 match config.sweep_mode {
177 SweepMode::Metropolis => mcmc::sweep::metropolis_sweep(
178 lattice,
179 &mut real.spins,
180 &real.couplings,
181 &real.temperatures,
182 &real.system_ids,
183 &mut real.rngs,
184 config.sequential,
185 ),
186 SweepMode::Gibbs => mcmc::sweep::gibbs_sweep(
187 lattice,
188 &mut real.spins,
189 &real.couplings,
190 &real.temperatures,
191 &real.system_ids,
192 &mut real.rngs,
193 config.sequential,
194 ),
195 }
196
197 let do_cluster = config
198 .cluster_update
199 .as_ref()
200 .is_some_and(|c| sweep_id % c.interval == 0);
201
202 if do_cluster {
203 let cluster_cfg = config.cluster_update.as_ref().unwrap();
204 let wolff = cluster_cfg.mode == crate::config::ClusterMode::Wolff;
205 let csd_out = if cluster_cfg.collect_stats && record {
206 for buf in sw_csd_buf.iter_mut() {
207 buf.fill(0);
208 }
209 Some(sw_csd_buf.as_mut_slice())
210 } else {
211 None
212 };
213
214 clusters::fk_update(
215 lattice,
216 &mut real.spins,
217 &real.couplings,
218 &real.temperatures,
219 &real.system_ids,
220 &mut real.rngs,
221 wolff,
222 csd_out,
223 config.sequential,
224 );
225
226 if cluster_cfg.collect_stats && record {
227 for (slot, buf) in sw_csd_buf.iter().enumerate() {
228 let accum = &mut fk_csd_accum[slot % n_temps];
229 for (a, &b) in accum.iter_mut().zip(buf.iter()) {
230 *a += b;
231 }
232 }
233 }
234 }
235
236 let pt_this_sweep = config
237 .pt_interval
238 .is_some_and(|interval| sweep_id % interval == 0);
239
240 if record || pt_this_sweep || equil_diag {
241 (real.energies, _) = spins::energy::compute_energies(
242 lattice,
243 &real.spins,
244 &real.couplings,
245 n_systems,
246 false,
247 );
248 }
249
250 if equil_diag {
251 diag_e_buf.fill(0.0);
252 #[allow(clippy::needless_range_loop)]
253 for r in 0..n_replicas {
254 let offset = r * n_temps;
255 for t in 0..n_temps {
256 let system_id = real.system_ids[offset + t];
257 diag_e_buf[t] += real.energies[system_id];
258 }
259 }
260 let inv = 1.0 / n_replicas as f32;
261 for v in diag_e_buf.iter_mut() {
262 *v *= inv;
263 }
264 }
265
266 if (equil_diag || record) && n_pairs > 0 {
267 ov_accum.collect(lattice, &real.spins, &real.system_ids, record);
268 }
269
270 if equil_diag {
271 if n_pairs > 0 {
272 equil_accum
273 .as_mut()
274 .unwrap()
275 .push(&diag_e_buf, &ov_accum.diag_ql_buf);
276 } else {
277 let zeros = vec![0.0f32; n_temps];
278 equil_accum.as_mut().unwrap().push(&diag_e_buf, &zeros);
279 }
280 }
281
282 if record {
283 for t in 0..n_temps {
284 mags_buf[t] = 0.0;
285 mags2_buf[t] = 0.0;
286 mags4_buf[t] = 0.0;
287 energies_buf[t] = 0.0;
288 }
289
290 if collect_ac {
291 m2_ac_buf.fill(0.0);
292 }
293
294 for r in 0..n_replicas {
295 let offset = r * n_temps;
296 for t in 0..n_temps {
297 let system_id = real.system_ids[offset + t];
298 let spin_base = system_id * n_spins;
299 let mut sum = 0i64;
300 for j in 0..n_spins {
301 sum += real.spins[spin_base + j] as i64;
302 }
303 let mag = sum as f32 / n_spins as f32;
304 let m2 = mag * mag;
305 mags_buf[t] = mag;
306 mags2_buf[t] = m2;
307 mags4_buf[t] = m2 * m2;
308 energies_buf[t] = real.energies[system_id];
309 }
310
311 if collect_ac {
312 for t in 0..n_temps {
313 m2_ac_buf[t] += mags2_buf[t] as f64;
314 }
315 }
316
317 mags_stat.update(&mags_buf);
318 mags2_stat.update(&mags2_buf);
319 mags4_stat.update(&mags4_buf);
320 energies_stat.update(&energies_buf);
321 energies2_stat.update(&energies_buf);
322 }
323
324 if let Some(ref mut acc) = m2_accum {
325 let inv = 1.0 / n_replicas as f64;
326 for v in m2_ac_buf.iter_mut() {
327 *v *= inv;
328 }
329 acc.push(&m2_ac_buf);
330 }
331
332 if let Some(ref mut acc) = q2_accum {
333 let inv = 1.0 / n_pairs as f64;
334 for v in ov_accum.q2_ac_buf.iter_mut() {
335 *v *= inv;
336 }
337 acc.push(&ov_accum.q2_ac_buf);
338 }
339 }
340
341 if let Some(ref oc_cfg) = config.overlap_cluster {
342 if sweep_id % oc_cfg.interval == 0 {
343 let mode_idx = overlap_call_count % n_modes;
344 let mode = &oc_cfg.modes[mode_idx];
345
346 let ov_csd_out = if oc_cfg.collect_stats && record {
347 for buf in overlap_csd_buf.iter_mut() {
348 buf.fill(0);
349 }
350 Some(overlap_csd_buf.as_mut_slice())
351 } else {
352 None
353 };
354
355 let top4_out = if collect_top && record {
356 for slot in top4_buf.iter_mut() {
357 *slot = [0u32; 4];
358 }
359 Some(top4_buf.as_mut_slice())
360 } else {
361 None
362 };
363
364 let take_snapshot =
365 snapshot_interval.is_some_and(|si| sweep_id % si == 0) && record;
366
367 let is_cmr = matches!(mode, crate::config::OverlapClusterBuildMode::Cmr);
368
369 let snap = if take_snapshot {
370 for buf in spin_snap_buf.iter_mut() {
371 buf.clear();
372 }
373 for buf in sid_snap_buf.iter_mut() {
374 buf.clear();
375 }
376 Some(snap_buf.as_mut_slice())
377 } else {
378 None
379 };
380 let blue_snap = if take_snapshot && is_cmr {
381 Some(blue_snap_buf.as_mut_slice())
382 } else {
383 None
384 };
385 let spin_snap = if take_snapshot {
386 Some(spin_snap_buf.as_mut_slice())
387 } else {
388 None
389 };
390 let sid_snap = if take_snapshot {
391 Some(sid_snap_buf.as_mut_slice())
392 } else {
393 None
394 };
395
396 clusters::overlap_update(
397 lattice,
398 &mut real.spins,
399 &real.couplings,
400 &real.temperatures,
401 &real.system_ids,
402 n_replicas,
403 n_temps,
404 &mut real.pair_rngs,
405 mode,
406 oc_cfg.cluster_mode,
407 ov_csd_out,
408 top4_out,
409 config.sequential,
410 snap,
411 blue_snap,
412 spin_snap,
413 sid_snap,
414 );
415
416 if take_snapshot {
417 let ids: Vec<Vec<u32>> = (0..n_temps)
418 .map(|t| snap_buf[t * n_pairs].clone())
419 .collect();
420 let blue = if is_cmr {
421 Some(
422 (0..n_temps)
423 .map(|t| blue_snap_buf[t * n_pairs].clone())
424 .collect(),
425 )
426 } else {
427 None
428 };
429 let spins: Vec<[Vec<i8>; 2]> = (0..n_temps)
430 .map(|t| {
431 spin_snap_buf[t * n_pairs]
432 .first()
433 .cloned()
434 .unwrap_or_else(|| [vec![], vec![]])
435 })
436 .collect();
437 let sids: Vec<[usize; 2]> = (0..n_temps)
438 .map(|t| sid_snap_buf[t * n_pairs].first().copied().unwrap_or([0, 0]))
439 .collect();
440 cluster_snapshots.push(ClusterSnapshot {
441 sweep_id,
442 mode_idx,
443 cluster_ids: ids,
444 blue_ids: blue,
445 spins,
446 system_ids: sids,
447 });
448 }
449
450 if oc_cfg.collect_stats && record {
451 for (slot, buf) in overlap_csd_buf.iter().enumerate() {
452 let accum = &mut overlap_csd_accum[mode_idx][slot / n_pairs];
453 for (a, &b) in accum.iter_mut().zip(buf.iter()) {
454 *a += b;
455 }
456 }
457 }
458
459 if collect_top && record {
460 for t in 0..n_temps {
461 for p in 0..n_pairs {
462 let raw = top4_buf[t * n_pairs + p];
463 for (k, &v) in raw.iter().enumerate() {
464 top4_accum[mode_idx][t][k] += v as f64 / n_spins as f64;
465 }
466 }
467 }
468 top4_n[mode_idx] += 1;
469 }
470
471 overlap_call_count += 1;
472 }
473 }
474
475 if pt_this_sweep {
476 if config.overlap_cluster.is_some() {
477 (real.energies, _) = spins::energy::compute_energies(
478 lattice,
479 &real.spins,
480 &real.couplings,
481 n_systems,
482 false,
483 );
484 }
485 for r in 0..n_replicas {
486 let offset = r * n_temps;
487 let sid_slice = &mut real.system_ids[offset..offset + n_temps];
488 let temp_slice = &real.temperatures[offset..offset + n_temps];
489 mcmc::tempering::parallel_tempering(
490 &real.energies,
491 temp_slice,
492 sid_slice,
493 n_spins,
494 &mut real.rngs[offset],
495 );
496 }
497 }
498 }
499
500 let top_cluster_sizes: Vec<Vec<[f64; 4]>> = if collect_top {
501 top4_accum
502 .iter()
503 .zip(top4_n.iter())
504 .map(|(mode_accum, &count)| {
505 if count == 0 {
506 return vec![];
507 }
508 let denom = (count * n_pairs) as f64;
509 mode_accum
510 .iter()
511 .map(|arr| {
512 [
513 arr[0] / denom,
514 arr[1] / denom,
515 arr[2] / denom,
516 arr[3] / denom,
517 ]
518 })
519 .collect()
520 })
521 .collect()
522 } else {
523 vec![]
524 };
525
526 let mags2_tau = m2_accum
527 .as_ref()
528 .map(|acc| acc.finish().iter().map(|g| sokal_tau(g)).collect())
529 .unwrap_or_default();
530 let overlap2_tau = q2_accum
531 .as_ref()
532 .map(|acc| acc.finish().iter().map(|g| sokal_tau(g)).collect())
533 .unwrap_or_default();
534
535 let equil_checkpoints = equil_accum.map(|acc| acc.finish()).unwrap_or_default();
536
537 Ok(SweepResult {
538 mags: mags_stat.average(),
539 mags2: mags2_stat.average(),
540 mags4: mags4_stat.average(),
541 energies: energies_stat.average(),
542 energies2: energies2_stat.average(),
543 overlap_stats: ov_accum.finish(),
544 cluster_stats: ClusterStats {
545 fk_csd: fk_csd_accum,
546 overlap_csd: overlap_csd_accum,
547 top_cluster_sizes,
548 },
549 diagnostics: Diagnostics {
550 mags2_tau,
551 overlap2_tau,
552 equil_checkpoints,
553 },
554 cluster_snapshots,
555 })
556}
557
558pub fn run_sweep_parallel(
564 lattice: &Lattice,
565 realizations: &mut [Realization],
566 n_replicas: usize,
567 n_temps: usize,
568 config: &SimConfig,
569 interrupted: &AtomicBool,
570 on_sweep: &(dyn Fn() + Sync),
571) -> Result<SweepResult, String> {
572 if realizations.len() == 1 {
573 return run_sweep_loop(
574 lattice,
575 &mut realizations[0],
576 n_replicas,
577 n_temps,
578 config,
579 interrupted,
580 on_sweep,
581 0,
582 );
583 }
584
585 let results: Vec<Result<SweepResult, String>> = realizations
586 .par_iter_mut()
587 .enumerate()
588 .map(|(idx, real)| {
589 run_sweep_loop(
590 lattice,
591 real,
592 n_replicas,
593 n_temps,
594 config,
595 interrupted,
596 on_sweep,
597 idx,
598 )
599 })
600 .collect();
601
602 let mut results: Vec<SweepResult> = results.into_iter().collect::<Result<Vec<_>, _>>()?;
603 let snapshots = std::mem::take(&mut results[0].cluster_snapshots);
604 let mut agg = SweepResult::aggregate(&results);
605 agg.cluster_snapshots = snapshots;
606 Ok(agg)
607}