1pub mod realization;
2
3pub use realization::Realization;
4
5use std::sync::atomic::{AtomicBool, Ordering};
6
7use crate::config::{OverlapClusterBuildMode, SimConfig, SweepMode};
8use crate::geometry::Lattice;
9use crate::statistics::{
10 sokal_tau, AutocorrAccum, ClusterStats, Diagnostics, EquilDiagnosticAccum, Statistics,
11 SweepResult, OVERLAP_HIST_BINS,
12};
13use crate::{clusters, mcmc, spins};
14use rayon::prelude::*;
15use validator::Validate;
16
17pub fn run_sweep_loop(
28 lattice: &Lattice,
29 real: &mut Realization,
30 n_replicas: usize,
31 n_temps: usize,
32 config: &SimConfig,
33 interrupted: &AtomicBool,
34 on_sweep: &(dyn Fn() + Sync),
35) -> Result<SweepResult, String> {
36 config.validate().map_err(|e| format!("{e}"))?;
37
38 let n_spins = lattice.n_spins;
39 let n_systems = n_replicas * n_temps;
40 let n_sweeps = config.n_sweeps;
41 let warmup_sweeps = config.warmup_sweeps;
42
43 let overlap_wolff = config
44 .overlap_cluster
45 .as_ref()
46 .is_some_and(|h| h.cluster_mode == crate::config::ClusterMode::Wolff);
47
48 let (stochastic, restrict_to_negative) =
49 config
50 .overlap_cluster
51 .as_ref()
52 .map_or((false, true), |h| match h.mode {
53 OverlapClusterBuildMode::Houdayer => (false, true),
54 OverlapClusterBuildMode::Jorg => (true, true),
55 OverlapClusterBuildMode::Cmr(_) => (true, false),
56 });
57
58 let group_size = config
59 .overlap_cluster
60 .as_ref()
61 .map_or(2, |h| h.mode.group_size());
62
63 if config.overlap_cluster.is_some() && n_replicas < group_size {
64 return Err(format!(
65 "overlap cluster requires n_replicas >= group_size ({n_replicas} < {group_size})"
66 ));
67 }
68
69 let n_pairs = n_replicas / 2;
70
71 let mut fk_csd_accum: Vec<Vec<u64>> = (0..n_temps).map(|_| vec![0u64; n_spins + 1]).collect();
72 let mut sw_csd_buf: Vec<Vec<u64>> = (0..n_systems).map(|_| vec![0u64; n_spins + 1]).collect();
73
74 let mut overlap_csd_accum: Vec<Vec<u64>> =
75 (0..n_temps).map(|_| vec![0u64; n_spins + 1]).collect();
76 let mut overlap_csd_buf: Vec<Vec<u64>> = (0..n_temps * n_pairs)
77 .map(|_| vec![0u64; n_spins + 1])
78 .collect();
79
80 let collect_top = config
81 .overlap_cluster
82 .as_ref()
83 .is_some_and(|h| h.collect_top_clusters)
84 && n_pairs > 0;
85
86 let mut top4_accum: Vec<[f64; 4]> = vec![[0.0; 4]; n_temps];
87 let mut top4_n: usize = 0;
88 let mut top4_buf: Vec<[u32; 4]> = if collect_top {
89 vec![[0u32; 4]; n_temps * n_pairs]
90 } else {
91 vec![]
92 };
93
94 let mut mags_stat = Statistics::new(n_temps, 1);
95 let mut mags2_stat = Statistics::new(n_temps, 1);
96 let mut mags4_stat = Statistics::new(n_temps, 1);
97 let mut energies_stat = Statistics::new(n_temps, 1);
98 let mut energies2_stat = Statistics::new(n_temps, 2);
99 let mut overlap_stat = Statistics::new(n_temps, 1);
100 let mut overlap2_stat = Statistics::new(n_temps, 1);
101 let mut overlap4_stat = Statistics::new(n_temps, 1);
102
103 let n_measurement_sweeps = n_sweeps.saturating_sub(warmup_sweeps);
104 let ac_max_lag = config
105 .autocorrelation_max_lag
106 .map(|k| k.min(n_measurement_sweeps / 4).max(1));
107 let mut m2_accum = ac_max_lag.map(|k| AutocorrAccum::new(k, n_temps));
108 let mut q2_accum = if ac_max_lag.is_some() && n_pairs > 0 {
109 ac_max_lag.map(|k| AutocorrAccum::new(k, n_temps))
110 } else {
111 None
112 };
113 let collect_ac = ac_max_lag.is_some();
114 let collect_q2_ac = q2_accum.is_some();
115 let mut m2_ac_buf = if collect_ac {
116 vec![0.0f64; n_temps]
117 } else {
118 vec![]
119 };
120 let mut q2_ac_buf = if collect_q2_ac {
121 vec![0.0f64; n_temps]
122 } else {
123 vec![]
124 };
125
126 let equil_diag = config.equilibration_diagnostic;
127 let mut equil_accum = if equil_diag {
128 Some(EquilDiagnosticAccum::new(n_temps, n_sweeps))
129 } else {
130 None
131 };
132 let mut diag_e_buf = if equil_diag {
133 vec![0.0f32; n_temps]
134 } else {
135 vec![]
136 };
137
138 let mut overlap_hist: Vec<Vec<u64>> = if n_pairs > 0 {
139 (0..n_temps)
140 .map(|_| vec![0u64; OVERLAP_HIST_BINS])
141 .collect()
142 } else {
143 vec![]
144 };
145
146 let mut mags_buf = vec![0.0f32; n_temps];
147 let mut mags2_buf = vec![0.0f32; n_temps];
148 let mut mags4_buf = vec![0.0f32; n_temps];
149 let mut energies_buf = vec![0.0f32; n_temps];
150 let mut overlaps_buf = vec![0.0f32; n_temps];
151 let mut overlaps2_buf = vec![0.0f32; n_temps];
152 let mut overlaps4_buf = vec![0.0f32; n_temps];
153
154 for sweep_id in 0..n_sweeps {
155 if interrupted.load(Ordering::Relaxed) {
156 return Err("interrupted".to_string());
157 }
158 on_sweep();
159 let record = sweep_id >= warmup_sweeps;
160
161 match config.sweep_mode {
162 SweepMode::Metropolis => mcmc::sweep::metropolis_sweep(
163 lattice,
164 &mut real.spins,
165 &real.couplings,
166 &real.temperatures,
167 &real.system_ids,
168 &mut real.rngs,
169 config.sequential,
170 ),
171 SweepMode::Gibbs => mcmc::sweep::gibbs_sweep(
172 lattice,
173 &mut real.spins,
174 &real.couplings,
175 &real.temperatures,
176 &real.system_ids,
177 &mut real.rngs,
178 config.sequential,
179 ),
180 }
181
182 let do_cluster = config
183 .cluster_update
184 .as_ref()
185 .is_some_and(|c| sweep_id % c.interval == 0);
186
187 if do_cluster {
188 let cluster_cfg = config.cluster_update.as_ref().unwrap();
189 let wolff = cluster_cfg.mode == crate::config::ClusterMode::Wolff;
190 let csd_out = if cluster_cfg.collect_csd && record {
191 for buf in sw_csd_buf.iter_mut() {
192 buf.fill(0);
193 }
194 Some(sw_csd_buf.as_mut_slice())
195 } else {
196 None
197 };
198
199 clusters::fk_update(
200 lattice,
201 &mut real.spins,
202 &real.couplings,
203 &real.temperatures,
204 &real.system_ids,
205 &mut real.rngs,
206 wolff,
207 csd_out,
208 config.sequential,
209 );
210
211 if cluster_cfg.collect_csd && record {
212 for (slot, buf) in sw_csd_buf.iter().enumerate() {
213 let accum = &mut fk_csd_accum[slot % n_temps];
214 for (a, &b) in accum.iter_mut().zip(buf.iter()) {
215 *a += b;
216 }
217 }
218 }
219 }
220
221 let pt_this_sweep = config
222 .pt_interval
223 .is_some_and(|interval| sweep_id % interval == 0);
224
225 if record || pt_this_sweep || equil_diag {
226 (real.energies, _) = spins::energy::compute_energies(
227 lattice,
228 &real.spins,
229 &real.couplings,
230 n_systems,
231 false,
232 );
233 }
234
235 if equil_diag {
236 diag_e_buf.fill(0.0);
237 #[allow(clippy::needless_range_loop)]
238 for r in 0..n_replicas {
239 let offset = r * n_temps;
240 for t in 0..n_temps {
241 let system_id = real.system_ids[offset + t];
242 diag_e_buf[t] += real.energies[system_id];
243 }
244 }
245 let inv = 1.0 / n_replicas as f32;
246 for v in diag_e_buf.iter_mut() {
247 *v *= inv;
248 }
249
250 let link_overlaps = if n_pairs > 0 {
251 spins::energy::compute_link_overlaps(
252 lattice,
253 &real.spins,
254 &real.system_ids,
255 n_replicas,
256 n_temps,
257 )
258 } else {
259 vec![0.0f32; n_temps]
260 };
261
262 equil_accum
263 .as_mut()
264 .unwrap()
265 .push(&diag_e_buf, &link_overlaps);
266 }
267
268 if record {
269 for t in 0..n_temps {
270 mags_buf[t] = 0.0;
271 mags2_buf[t] = 0.0;
272 mags4_buf[t] = 0.0;
273 energies_buf[t] = 0.0;
274 }
275
276 if collect_ac {
277 m2_ac_buf.fill(0.0);
278 }
279
280 for r in 0..n_replicas {
281 let offset = r * n_temps;
282 for t in 0..n_temps {
283 let system_id = real.system_ids[offset + t];
284 let spin_base = system_id * n_spins;
285 let mut sum = 0i64;
286 for j in 0..n_spins {
287 sum += real.spins[spin_base + j] as i64;
288 }
289 let mag = sum as f32 / n_spins as f32;
290 let m2 = mag * mag;
291 mags_buf[t] = mag;
292 mags2_buf[t] = m2;
293 mags4_buf[t] = m2 * m2;
294 energies_buf[t] = real.energies[system_id];
295 }
296
297 if collect_ac {
298 for t in 0..n_temps {
299 m2_ac_buf[t] += mags2_buf[t] as f64;
300 }
301 }
302
303 mags_stat.update(&mags_buf);
304 mags2_stat.update(&mags2_buf);
305 mags4_stat.update(&mags4_buf);
306 energies_stat.update(&energies_buf);
307 energies2_stat.update(&energies_buf);
308 }
309
310 if let Some(ref mut acc) = m2_accum {
311 let inv = 1.0 / n_replicas as f64;
312 for v in m2_ac_buf.iter_mut() {
313 *v *= inv;
314 }
315 acc.push(&m2_ac_buf);
316 }
317
318 if collect_q2_ac {
319 q2_ac_buf.fill(0.0);
320 }
321
322 for pair_idx in 0..n_pairs {
323 let r_a = 2 * pair_idx;
324 let r_b = 2 * pair_idx + 1;
325 for t in 0..n_temps {
326 overlaps_buf[t] = 0.0;
327 overlaps2_buf[t] = 0.0;
328 overlaps4_buf[t] = 0.0;
329 }
330
331 for t in 0..n_temps {
332 let sys_a = real.system_ids[r_a * n_temps + t];
333 let sys_b = real.system_ids[r_b * n_temps + t];
334 let base_a = sys_a * n_spins;
335 let base_b = sys_b * n_spins;
336 let mut dot = 0i64;
337 for j in 0..n_spins {
338 dot += (real.spins[base_a + j] as i64) * (real.spins[base_b + j] as i64);
339 }
340 let q = dot as f32 / n_spins as f32;
341 let q2 = q * q;
342 overlaps_buf[t] = q;
343 overlaps2_buf[t] = q2;
344 overlaps4_buf[t] = q2 * q2;
345 let bin = (((q + 1.0) * 0.5 * OVERLAP_HIST_BINS as f32) as usize)
346 .min(OVERLAP_HIST_BINS - 1);
347 overlap_hist[t][bin] += 1;
348 }
349
350 if collect_q2_ac {
351 for t in 0..n_temps {
352 q2_ac_buf[t] += overlaps2_buf[t] as f64;
353 }
354 }
355
356 overlap_stat.update(&overlaps_buf);
357 overlap2_stat.update(&overlaps2_buf);
358 overlap4_stat.update(&overlaps4_buf);
359 }
360
361 if let Some(ref mut acc) = q2_accum {
362 let inv = 1.0 / n_pairs as f64;
363 for v in q2_ac_buf.iter_mut() {
364 *v *= inv;
365 }
366 acc.push(&q2_ac_buf);
367 }
368 }
369
370 if let Some(ref oc_cfg) = config.overlap_cluster {
371 if sweep_id % oc_cfg.interval == 0 {
372 let ov_csd_out = if oc_cfg.collect_csd && record {
373 for buf in overlap_csd_buf.iter_mut() {
374 buf.fill(0);
375 }
376 Some(overlap_csd_buf.as_mut_slice())
377 } else {
378 None
379 };
380
381 let top4_out = if collect_top && record {
382 for slot in top4_buf.iter_mut() {
383 *slot = [0u32; 4];
384 }
385 Some(top4_buf.as_mut_slice())
386 } else {
387 None
388 };
389
390 clusters::overlap_update(
391 lattice,
392 &mut real.spins,
393 &real.couplings,
394 &real.temperatures,
395 &real.system_ids,
396 n_replicas,
397 n_temps,
398 &mut real.pair_rngs,
399 stochastic,
400 restrict_to_negative,
401 overlap_wolff,
402 group_size,
403 ov_csd_out,
404 top4_out,
405 config.sequential,
406 );
407
408 if oc_cfg.collect_csd && record {
409 for (slot, buf) in overlap_csd_buf.iter().enumerate() {
410 let accum = &mut overlap_csd_accum[slot / n_pairs];
411 for (a, &b) in accum.iter_mut().zip(buf.iter()) {
412 *a += b;
413 }
414 }
415 }
416
417 if collect_top && record {
418 for t in 0..n_temps {
419 for p in 0..n_pairs {
420 let raw = top4_buf[t * n_pairs + p];
421 for (k, &v) in raw.iter().enumerate() {
422 top4_accum[t][k] += v as f64 / n_spins as f64;
423 }
424 }
425 }
426 top4_n += 1;
427 }
428 }
429 }
430
431 if pt_this_sweep {
432 if config.overlap_cluster.is_some() {
433 (real.energies, _) = spins::energy::compute_energies(
434 lattice,
435 &real.spins,
436 &real.couplings,
437 n_systems,
438 false,
439 );
440 }
441 for r in 0..n_replicas {
442 let offset = r * n_temps;
443 let sid_slice = &mut real.system_ids[offset..offset + n_temps];
444 let temp_slice = &real.temperatures[offset..offset + n_temps];
445 mcmc::tempering::parallel_tempering(
446 &real.energies,
447 temp_slice,
448 sid_slice,
449 n_spins,
450 &mut real.rngs[offset],
451 );
452 }
453 }
454 }
455
456 let top_cluster_sizes = if collect_top && top4_n > 0 {
457 let denom = (top4_n * n_pairs) as f64;
458 top4_accum
459 .iter()
460 .map(|arr| {
461 [
462 arr[0] / denom,
463 arr[1] / denom,
464 arr[2] / denom,
465 arr[3] / denom,
466 ]
467 })
468 .collect()
469 } else {
470 vec![]
471 };
472
473 let mags2_tau = m2_accum
474 .as_ref()
475 .map(|acc| acc.finish().iter().map(|g| sokal_tau(g)).collect())
476 .unwrap_or_default();
477 let overlap2_tau = q2_accum
478 .as_ref()
479 .map(|acc| acc.finish().iter().map(|g| sokal_tau(g)).collect())
480 .unwrap_or_default();
481
482 let equil_checkpoints = equil_accum.map(|acc| acc.finish()).unwrap_or_default();
483
484 Ok(SweepResult {
485 mags: mags_stat.average(),
486 mags2: mags2_stat.average(),
487 mags4: mags4_stat.average(),
488 energies: energies_stat.average(),
489 energies2: energies2_stat.average(),
490 overlap: if n_pairs > 0 {
491 overlap_stat.average()
492 } else {
493 vec![]
494 },
495 overlap2: if n_pairs > 0 {
496 overlap2_stat.average()
497 } else {
498 vec![]
499 },
500 overlap4: if n_pairs > 0 {
501 overlap4_stat.average()
502 } else {
503 vec![]
504 },
505 overlap_histogram: overlap_hist,
506 cluster_stats: ClusterStats {
507 fk_csd: fk_csd_accum,
508 overlap_csd: overlap_csd_accum,
509 top_cluster_sizes,
510 },
511 diagnostics: Diagnostics {
512 mags2_tau,
513 overlap2_tau,
514 equil_checkpoints,
515 },
516 })
517}
518
519pub fn run_sweep_parallel(
525 lattice: &Lattice,
526 realizations: &mut [Realization],
527 n_replicas: usize,
528 n_temps: usize,
529 config: &SimConfig,
530 interrupted: &AtomicBool,
531 on_sweep: &(dyn Fn() + Sync),
532) -> Result<SweepResult, String> {
533 if realizations.len() == 1 {
534 return run_sweep_loop(
535 lattice,
536 &mut realizations[0],
537 n_replicas,
538 n_temps,
539 config,
540 interrupted,
541 on_sweep,
542 );
543 }
544
545 let results: Vec<Result<SweepResult, String>> = realizations
546 .par_iter_mut()
547 .map(|real| {
548 run_sweep_loop(
549 lattice,
550 real,
551 n_replicas,
552 n_temps,
553 config,
554 interrupted,
555 on_sweep,
556 )
557 })
558 .collect();
559
560 let results: Vec<SweepResult> = results.into_iter().collect::<Result<Vec<_>, _>>()?;
561 Ok(SweepResult::aggregate(&results))
562}