Skip to main content

spin_sim/simulation/
mod.rs

1pub mod realization;
2
3pub use realization::Realization;
4
5use std::sync::atomic::{AtomicBool, Ordering};
6
7use crate::config::{SimConfig, SweepMode};
8use crate::geometry::Lattice;
9use crate::statistics::{
10    sokal_tau, AutocorrAccum, ClusterStats, Diagnostics, EquilDiagnosticAccum, OverlapAccum,
11    Statistics, SweepResult,
12};
13use crate::{clusters, mcmc, spins};
14use rayon::prelude::*;
15use validator::Validate;
16
17/// Run the full Monte Carlo loop (warmup + measurement) for one [`Realization`].
18///
19/// Each sweep consists of:
20/// 1. A full single-spin pass (`sweep_mode`: Metropolis or Gibbs)
21/// 2. An optional cluster update (every `cluster_update.interval` sweeps)
22/// 3. Measurement (after `warmup_sweeps`)
23/// 4. Optional overlap cluster move (every `overlap_cluster.interval` sweeps, requires `n_replicas ≥ 2`)
24/// 5. Optional parallel tempering (every `pt_interval` sweeps)
25///
26/// `on_sweep` is called once per sweep (useful for progress bars).
27pub fn run_sweep_loop(
28    lattice: &Lattice,
29    real: &mut Realization,
30    n_replicas: usize,
31    n_temps: usize,
32    config: &SimConfig,
33    interrupted: &AtomicBool,
34    on_sweep: &(dyn Fn() + Sync),
35) -> Result<SweepResult, String> {
36    config.validate().map_err(|e| format!("{e}"))?;
37
38    let n_spins = lattice.n_spins;
39    let n_systems = n_replicas * n_temps;
40    let n_sweeps = config.n_sweeps;
41    let warmup_sweeps = config.warmup_sweeps;
42
43    let n_modes = config.overlap_cluster.as_ref().map_or(0, |h| h.modes.len());
44
45    if let Some(ref oc_cfg) = config.overlap_cluster {
46        let max_gs = oc_cfg.max_group_size();
47        if n_replicas < max_gs {
48            return Err(format!(
49                "overlap cluster requires n_replicas >= max group_size ({n_replicas} < {max_gs})"
50            ));
51        }
52    }
53
54    let n_pairs = n_replicas / 2;
55
56    let mut fk_csd_accum: Vec<Vec<u64>> = (0..n_temps).map(|_| vec![0u64; n_spins + 1]).collect();
57    let mut sw_csd_buf: Vec<Vec<u64>> = (0..n_systems).map(|_| vec![0u64; n_spins + 1]).collect();
58
59    let mut overlap_csd_accum: Vec<Vec<Vec<u64>>> = (0..n_modes)
60        .map(|_| (0..n_temps).map(|_| vec![0u64; n_spins + 1]).collect())
61        .collect();
62    let mut overlap_csd_buf: Vec<Vec<u64>> = (0..n_temps * n_pairs)
63        .map(|_| vec![0u64; n_spins + 1])
64        .collect();
65
66    let collect_top = config
67        .overlap_cluster
68        .as_ref()
69        .is_some_and(|h| h.collect_stats)
70        && n_pairs > 0;
71
72    let mut top4_accum: Vec<Vec<[f64; 4]>> =
73        (0..n_modes).map(|_| vec![[0.0; 4]; n_temps]).collect();
74    let mut top4_n: Vec<usize> = vec![0; n_modes];
75    let mut top4_buf: Vec<[u32; 4]> = if collect_top {
76        vec![[0u32; 4]; n_temps * n_pairs]
77    } else {
78        vec![]
79    };
80
81    let mut overlap_call_count: usize = 0;
82
83    let mut mags_stat = Statistics::new(n_temps, 1);
84    let mut mags2_stat = Statistics::new(n_temps, 1);
85    let mut mags4_stat = Statistics::new(n_temps, 1);
86    let mut energies_stat = Statistics::new(n_temps, 1);
87    let mut energies2_stat = Statistics::new(n_temps, 2);
88    let n_measurement_sweeps = n_sweeps.saturating_sub(warmup_sweeps);
89    let ac_max_lag = config
90        .autocorrelation_max_lag
91        .map(|k| k.min(n_measurement_sweeps / 4).max(1));
92    let mut m2_accum = ac_max_lag.map(|k| AutocorrAccum::new(k, n_temps));
93    let mut q2_accum = if ac_max_lag.is_some() && n_pairs > 0 {
94        ac_max_lag.map(|k| AutocorrAccum::new(k, n_temps))
95    } else {
96        None
97    };
98    let collect_ac = ac_max_lag.is_some();
99    let collect_q2_ac = q2_accum.is_some();
100    let mut m2_ac_buf = if collect_ac {
101        vec![0.0f64; n_temps]
102    } else {
103        vec![]
104    };
105
106    let equil_diag = config.equilibration_diagnostic;
107    let mut equil_accum = if equil_diag {
108        Some(EquilDiagnosticAccum::new(n_temps, n_sweeps))
109    } else {
110        None
111    };
112    let mut diag_e_buf = if equil_diag {
113        vec![0.0f32; n_temps]
114    } else {
115        vec![]
116    };
117
118    let mut ov_accum = OverlapAccum::new(
119        n_temps,
120        n_spins,
121        n_pairs,
122        lattice.n_neighbors,
123        equil_diag,
124        collect_q2_ac,
125    );
126
127    let mut mags_buf = vec![0.0f32; n_temps];
128    let mut mags2_buf = vec![0.0f32; n_temps];
129    let mut mags4_buf = vec![0.0f32; n_temps];
130    let mut energies_buf = vec![0.0f32; n_temps];
131
132    for sweep_id in 0..n_sweeps {
133        if interrupted.load(Ordering::Relaxed) {
134            return Err("interrupted".to_string());
135        }
136        on_sweep();
137        let record = sweep_id >= warmup_sweeps;
138
139        match config.sweep_mode {
140            SweepMode::Metropolis => mcmc::sweep::metropolis_sweep(
141                lattice,
142                &mut real.spins,
143                &real.couplings,
144                &real.temperatures,
145                &real.system_ids,
146                &mut real.rngs,
147                config.sequential,
148            ),
149            SweepMode::Gibbs => mcmc::sweep::gibbs_sweep(
150                lattice,
151                &mut real.spins,
152                &real.couplings,
153                &real.temperatures,
154                &real.system_ids,
155                &mut real.rngs,
156                config.sequential,
157            ),
158        }
159
160        let do_cluster = config
161            .cluster_update
162            .as_ref()
163            .is_some_and(|c| sweep_id % c.interval == 0);
164
165        if do_cluster {
166            let cluster_cfg = config.cluster_update.as_ref().unwrap();
167            let wolff = cluster_cfg.mode == crate::config::ClusterMode::Wolff;
168            let csd_out = if cluster_cfg.collect_stats && record {
169                for buf in sw_csd_buf.iter_mut() {
170                    buf.fill(0);
171                }
172                Some(sw_csd_buf.as_mut_slice())
173            } else {
174                None
175            };
176
177            clusters::fk_update(
178                lattice,
179                &mut real.spins,
180                &real.couplings,
181                &real.temperatures,
182                &real.system_ids,
183                &mut real.rngs,
184                wolff,
185                csd_out,
186                config.sequential,
187            );
188
189            if cluster_cfg.collect_stats && record {
190                for (slot, buf) in sw_csd_buf.iter().enumerate() {
191                    let accum = &mut fk_csd_accum[slot % n_temps];
192                    for (a, &b) in accum.iter_mut().zip(buf.iter()) {
193                        *a += b;
194                    }
195                }
196            }
197        }
198
199        let pt_this_sweep = config
200            .pt_interval
201            .is_some_and(|interval| sweep_id % interval == 0);
202
203        if record || pt_this_sweep || equil_diag {
204            (real.energies, _) = spins::energy::compute_energies(
205                lattice,
206                &real.spins,
207                &real.couplings,
208                n_systems,
209                false,
210            );
211        }
212
213        if equil_diag {
214            diag_e_buf.fill(0.0);
215            #[allow(clippy::needless_range_loop)]
216            for r in 0..n_replicas {
217                let offset = r * n_temps;
218                for t in 0..n_temps {
219                    let system_id = real.system_ids[offset + t];
220                    diag_e_buf[t] += real.energies[system_id];
221                }
222            }
223            let inv = 1.0 / n_replicas as f32;
224            for v in diag_e_buf.iter_mut() {
225                *v *= inv;
226            }
227        }
228
229        if (equil_diag || record) && n_pairs > 0 {
230            ov_accum.collect(lattice, &real.spins, &real.system_ids, record);
231        }
232
233        if equil_diag {
234            if n_pairs > 0 {
235                equil_accum
236                    .as_mut()
237                    .unwrap()
238                    .push(&diag_e_buf, &ov_accum.diag_ql_buf);
239            } else {
240                let zeros = vec![0.0f32; n_temps];
241                equil_accum.as_mut().unwrap().push(&diag_e_buf, &zeros);
242            }
243        }
244
245        if record {
246            for t in 0..n_temps {
247                mags_buf[t] = 0.0;
248                mags2_buf[t] = 0.0;
249                mags4_buf[t] = 0.0;
250                energies_buf[t] = 0.0;
251            }
252
253            if collect_ac {
254                m2_ac_buf.fill(0.0);
255            }
256
257            for r in 0..n_replicas {
258                let offset = r * n_temps;
259                for t in 0..n_temps {
260                    let system_id = real.system_ids[offset + t];
261                    let spin_base = system_id * n_spins;
262                    let mut sum = 0i64;
263                    for j in 0..n_spins {
264                        sum += real.spins[spin_base + j] as i64;
265                    }
266                    let mag = sum as f32 / n_spins as f32;
267                    let m2 = mag * mag;
268                    mags_buf[t] = mag;
269                    mags2_buf[t] = m2;
270                    mags4_buf[t] = m2 * m2;
271                    energies_buf[t] = real.energies[system_id];
272                }
273
274                if collect_ac {
275                    for t in 0..n_temps {
276                        m2_ac_buf[t] += mags2_buf[t] as f64;
277                    }
278                }
279
280                mags_stat.update(&mags_buf);
281                mags2_stat.update(&mags2_buf);
282                mags4_stat.update(&mags4_buf);
283                energies_stat.update(&energies_buf);
284                energies2_stat.update(&energies_buf);
285            }
286
287            if let Some(ref mut acc) = m2_accum {
288                let inv = 1.0 / n_replicas as f64;
289                for v in m2_ac_buf.iter_mut() {
290                    *v *= inv;
291                }
292                acc.push(&m2_ac_buf);
293            }
294
295            if let Some(ref mut acc) = q2_accum {
296                let inv = 1.0 / n_pairs as f64;
297                for v in ov_accum.q2_ac_buf.iter_mut() {
298                    *v *= inv;
299                }
300                acc.push(&ov_accum.q2_ac_buf);
301            }
302        }
303
304        if let Some(ref oc_cfg) = config.overlap_cluster {
305            if sweep_id % oc_cfg.interval == 0 {
306                let mode_idx = overlap_call_count % n_modes;
307                let mode = &oc_cfg.modes[mode_idx];
308
309                let ov_csd_out = if oc_cfg.collect_stats && record {
310                    for buf in overlap_csd_buf.iter_mut() {
311                        buf.fill(0);
312                    }
313                    Some(overlap_csd_buf.as_mut_slice())
314                } else {
315                    None
316                };
317
318                let top4_out = if collect_top && record {
319                    for slot in top4_buf.iter_mut() {
320                        *slot = [0u32; 4];
321                    }
322                    Some(top4_buf.as_mut_slice())
323                } else {
324                    None
325                };
326
327                clusters::overlap_update(
328                    lattice,
329                    &mut real.spins,
330                    &real.couplings,
331                    &real.temperatures,
332                    &real.system_ids,
333                    n_replicas,
334                    n_temps,
335                    &mut real.pair_rngs,
336                    mode,
337                    oc_cfg.cluster_mode,
338                    ov_csd_out,
339                    top4_out,
340                    config.sequential,
341                );
342
343                if oc_cfg.collect_stats && record {
344                    for (slot, buf) in overlap_csd_buf.iter().enumerate() {
345                        let accum = &mut overlap_csd_accum[mode_idx][slot / n_pairs];
346                        for (a, &b) in accum.iter_mut().zip(buf.iter()) {
347                            *a += b;
348                        }
349                    }
350                }
351
352                if collect_top && record {
353                    for t in 0..n_temps {
354                        for p in 0..n_pairs {
355                            let raw = top4_buf[t * n_pairs + p];
356                            for (k, &v) in raw.iter().enumerate() {
357                                top4_accum[mode_idx][t][k] += v as f64 / n_spins as f64;
358                            }
359                        }
360                    }
361                    top4_n[mode_idx] += 1;
362                }
363
364                overlap_call_count += 1;
365            }
366        }
367
368        if pt_this_sweep {
369            if config.overlap_cluster.is_some() {
370                (real.energies, _) = spins::energy::compute_energies(
371                    lattice,
372                    &real.spins,
373                    &real.couplings,
374                    n_systems,
375                    false,
376                );
377            }
378            for r in 0..n_replicas {
379                let offset = r * n_temps;
380                let sid_slice = &mut real.system_ids[offset..offset + n_temps];
381                let temp_slice = &real.temperatures[offset..offset + n_temps];
382                mcmc::tempering::parallel_tempering(
383                    &real.energies,
384                    temp_slice,
385                    sid_slice,
386                    n_spins,
387                    &mut real.rngs[offset],
388                );
389            }
390        }
391    }
392
393    let top_cluster_sizes: Vec<Vec<[f64; 4]>> = if collect_top {
394        top4_accum
395            .iter()
396            .zip(top4_n.iter())
397            .map(|(mode_accum, &count)| {
398                if count == 0 {
399                    return vec![];
400                }
401                let denom = (count * n_pairs) as f64;
402                mode_accum
403                    .iter()
404                    .map(|arr| {
405                        [
406                            arr[0] / denom,
407                            arr[1] / denom,
408                            arr[2] / denom,
409                            arr[3] / denom,
410                        ]
411                    })
412                    .collect()
413            })
414            .collect()
415    } else {
416        vec![]
417    };
418
419    let mags2_tau = m2_accum
420        .as_ref()
421        .map(|acc| acc.finish().iter().map(|g| sokal_tau(g)).collect())
422        .unwrap_or_default();
423    let overlap2_tau = q2_accum
424        .as_ref()
425        .map(|acc| acc.finish().iter().map(|g| sokal_tau(g)).collect())
426        .unwrap_or_default();
427
428    let equil_checkpoints = equil_accum.map(|acc| acc.finish()).unwrap_or_default();
429
430    Ok(SweepResult {
431        mags: mags_stat.average(),
432        mags2: mags2_stat.average(),
433        mags4: mags4_stat.average(),
434        energies: energies_stat.average(),
435        energies2: energies2_stat.average(),
436        overlap_stats: ov_accum.finish(),
437        cluster_stats: ClusterStats {
438            fk_csd: fk_csd_accum,
439            overlap_csd: overlap_csd_accum,
440            top_cluster_sizes,
441        },
442        diagnostics: Diagnostics {
443            mags2_tau,
444            overlap2_tau,
445            equil_checkpoints,
446        },
447    })
448}
449
450/// Run the sweep loop in parallel over multiple disorder realizations.
451///
452/// Each realization is processed by [`run_sweep_loop`], then results are
453/// averaged via [`SweepResult::aggregate`]. For a single realization the
454/// call is made directly, skipping rayon thread-pool overhead.
455pub fn run_sweep_parallel(
456    lattice: &Lattice,
457    realizations: &mut [Realization],
458    n_replicas: usize,
459    n_temps: usize,
460    config: &SimConfig,
461    interrupted: &AtomicBool,
462    on_sweep: &(dyn Fn() + Sync),
463) -> Result<SweepResult, String> {
464    if realizations.len() == 1 {
465        return run_sweep_loop(
466            lattice,
467            &mut realizations[0],
468            n_replicas,
469            n_temps,
470            config,
471            interrupted,
472            on_sweep,
473        );
474    }
475
476    let results: Vec<Result<SweepResult, String>> = realizations
477        .par_iter_mut()
478        .map(|real| {
479            run_sweep_loop(
480                lattice,
481                real,
482                n_replicas,
483                n_temps,
484                config,
485                interrupted,
486                on_sweep,
487            )
488        })
489        .collect();
490
491    let results: Vec<SweepResult> = results.into_iter().collect::<Result<Vec<_>, _>>()?;
492    Ok(SweepResult::aggregate(&results))
493}