Skip to main content

spin_sim/simulation/
mod.rs

1pub mod realization;
2
3pub use realization::Realization;
4
5use std::sync::atomic::{AtomicBool, Ordering};
6
7use crate::config::{OverlapClusterBuildMode, SimConfig, SweepMode};
8use crate::geometry::Lattice;
9use crate::statistics::{
10    sokal_tau, AutocorrAccum, ClusterStats, Diagnostics, EquilDiagnosticAccum, Statistics,
11    SweepResult, OVERLAP_HIST_BINS,
12};
13use crate::{clusters, mcmc, spins};
14use rayon::prelude::*;
15use validator::Validate;
16
17/// Run the full Monte Carlo loop (warmup + measurement) for one [`Realization`].
18///
19/// Each sweep consists of:
20/// 1. A full single-spin pass (`sweep_mode`: Metropolis or Gibbs)
21/// 2. An optional cluster update (every `cluster_update.interval` sweeps)
22/// 3. Measurement (after `warmup_sweeps`)
23/// 4. Optional overlap cluster move (every `overlap_cluster.interval` sweeps, requires `n_replicas ≥ 2`)
24/// 5. Optional parallel tempering (every `pt_interval` sweeps)
25///
26/// `on_sweep` is called once per sweep (useful for progress bars).
27pub fn run_sweep_loop(
28    lattice: &Lattice,
29    real: &mut Realization,
30    n_replicas: usize,
31    n_temps: usize,
32    config: &SimConfig,
33    interrupted: &AtomicBool,
34    on_sweep: &(dyn Fn() + Sync),
35) -> Result<SweepResult, String> {
36    config.validate().map_err(|e| format!("{e}"))?;
37
38    let n_spins = lattice.n_spins;
39    let n_systems = n_replicas * n_temps;
40    let n_sweeps = config.n_sweeps;
41    let warmup_sweeps = config.warmup_sweeps;
42
43    let overlap_wolff = config
44        .overlap_cluster
45        .as_ref()
46        .is_some_and(|h| h.cluster_mode == crate::config::ClusterMode::Wolff);
47
48    let (stochastic, restrict_to_negative) =
49        config
50            .overlap_cluster
51            .as_ref()
52            .map_or((false, true), |h| match h.mode {
53                OverlapClusterBuildMode::Houdayer => (false, true),
54                OverlapClusterBuildMode::Jorg => (true, true),
55                OverlapClusterBuildMode::Cmr(_) => (true, false),
56            });
57
58    let group_size = config
59        .overlap_cluster
60        .as_ref()
61        .map_or(2, |h| h.mode.group_size());
62
63    if config.overlap_cluster.is_some() && n_replicas < group_size {
64        return Err(format!(
65            "overlap cluster requires n_replicas >= group_size ({n_replicas} < {group_size})"
66        ));
67    }
68
69    let n_pairs = n_replicas / 2;
70
71    let mut fk_csd_accum: Vec<Vec<u64>> = (0..n_temps).map(|_| vec![0u64; n_spins + 1]).collect();
72    let mut sw_csd_buf: Vec<Vec<u64>> = (0..n_systems).map(|_| vec![0u64; n_spins + 1]).collect();
73
74    let mut overlap_csd_accum: Vec<Vec<u64>> =
75        (0..n_temps).map(|_| vec![0u64; n_spins + 1]).collect();
76    let mut overlap_csd_buf: Vec<Vec<u64>> = (0..n_temps * n_pairs)
77        .map(|_| vec![0u64; n_spins + 1])
78        .collect();
79
80    let collect_top = config
81        .overlap_cluster
82        .as_ref()
83        .is_some_and(|h| h.collect_top_clusters)
84        && n_pairs > 0;
85
86    let mut top4_accum: Vec<[f64; 4]> = vec![[0.0; 4]; n_temps];
87    let mut top4_n: usize = 0;
88    let mut top4_buf: Vec<[u32; 4]> = if collect_top {
89        vec![[0u32; 4]; n_temps * n_pairs]
90    } else {
91        vec![]
92    };
93
94    let mut mags_stat = Statistics::new(n_temps, 1);
95    let mut mags2_stat = Statistics::new(n_temps, 1);
96    let mut mags4_stat = Statistics::new(n_temps, 1);
97    let mut energies_stat = Statistics::new(n_temps, 1);
98    let mut energies2_stat = Statistics::new(n_temps, 2);
99    let mut overlap_stat = Statistics::new(n_temps, 1);
100    let mut overlap2_stat = Statistics::new(n_temps, 1);
101    let mut overlap4_stat = Statistics::new(n_temps, 1);
102
103    let n_measurement_sweeps = n_sweeps.saturating_sub(warmup_sweeps);
104    let ac_max_lag = config
105        .autocorrelation_max_lag
106        .map(|k| k.min(n_measurement_sweeps / 4).max(1));
107    let mut m2_accum = ac_max_lag.map(|k| AutocorrAccum::new(k, n_temps));
108    let mut q2_accum = if ac_max_lag.is_some() && n_pairs > 0 {
109        ac_max_lag.map(|k| AutocorrAccum::new(k, n_temps))
110    } else {
111        None
112    };
113    let collect_ac = ac_max_lag.is_some();
114    let collect_q2_ac = q2_accum.is_some();
115    let mut m2_ac_buf = if collect_ac {
116        vec![0.0f64; n_temps]
117    } else {
118        vec![]
119    };
120    let mut q2_ac_buf = if collect_q2_ac {
121        vec![0.0f64; n_temps]
122    } else {
123        vec![]
124    };
125
126    let equil_diag = config.equilibration_diagnostic;
127    let mut equil_accum = if equil_diag {
128        Some(EquilDiagnosticAccum::new(n_temps, n_sweeps))
129    } else {
130        None
131    };
132    let mut diag_e_buf = if equil_diag {
133        vec![0.0f32; n_temps]
134    } else {
135        vec![]
136    };
137
138    let mut overlap_hist: Vec<Vec<u64>> = if n_pairs > 0 {
139        (0..n_temps)
140            .map(|_| vec![0u64; OVERLAP_HIST_BINS])
141            .collect()
142    } else {
143        vec![]
144    };
145
146    let mut mags_buf = vec![0.0f32; n_temps];
147    let mut mags2_buf = vec![0.0f32; n_temps];
148    let mut mags4_buf = vec![0.0f32; n_temps];
149    let mut energies_buf = vec![0.0f32; n_temps];
150    let mut overlaps_buf = vec![0.0f32; n_temps];
151    let mut overlaps2_buf = vec![0.0f32; n_temps];
152    let mut overlaps4_buf = vec![0.0f32; n_temps];
153
154    for sweep_id in 0..n_sweeps {
155        if interrupted.load(Ordering::Relaxed) {
156            return Err("interrupted".to_string());
157        }
158        on_sweep();
159        let record = sweep_id >= warmup_sweeps;
160
161        match config.sweep_mode {
162            SweepMode::Metropolis => mcmc::sweep::metropolis_sweep(
163                lattice,
164                &mut real.spins,
165                &real.couplings,
166                &real.temperatures,
167                &real.system_ids,
168                &mut real.rngs,
169                config.sequential,
170            ),
171            SweepMode::Gibbs => mcmc::sweep::gibbs_sweep(
172                lattice,
173                &mut real.spins,
174                &real.couplings,
175                &real.temperatures,
176                &real.system_ids,
177                &mut real.rngs,
178                config.sequential,
179            ),
180        }
181
182        let do_cluster = config
183            .cluster_update
184            .as_ref()
185            .is_some_and(|c| sweep_id % c.interval == 0);
186
187        if do_cluster {
188            let cluster_cfg = config.cluster_update.as_ref().unwrap();
189            let wolff = cluster_cfg.mode == crate::config::ClusterMode::Wolff;
190            let csd_out = if cluster_cfg.collect_csd && record {
191                for buf in sw_csd_buf.iter_mut() {
192                    buf.fill(0);
193                }
194                Some(sw_csd_buf.as_mut_slice())
195            } else {
196                None
197            };
198
199            clusters::fk_update(
200                lattice,
201                &mut real.spins,
202                &real.couplings,
203                &real.temperatures,
204                &real.system_ids,
205                &mut real.rngs,
206                wolff,
207                csd_out,
208                config.sequential,
209            );
210
211            if cluster_cfg.collect_csd && record {
212                for (slot, buf) in sw_csd_buf.iter().enumerate() {
213                    let accum = &mut fk_csd_accum[slot % n_temps];
214                    for (a, &b) in accum.iter_mut().zip(buf.iter()) {
215                        *a += b;
216                    }
217                }
218            }
219        }
220
221        let pt_this_sweep = config
222            .pt_interval
223            .is_some_and(|interval| sweep_id % interval == 0);
224
225        if record || pt_this_sweep || equil_diag {
226            (real.energies, _) = spins::energy::compute_energies(
227                lattice,
228                &real.spins,
229                &real.couplings,
230                n_systems,
231                false,
232            );
233        }
234
235        if equil_diag {
236            diag_e_buf.fill(0.0);
237            #[allow(clippy::needless_range_loop)]
238            for r in 0..n_replicas {
239                let offset = r * n_temps;
240                for t in 0..n_temps {
241                    let system_id = real.system_ids[offset + t];
242                    diag_e_buf[t] += real.energies[system_id];
243                }
244            }
245            let inv = 1.0 / n_replicas as f32;
246            for v in diag_e_buf.iter_mut() {
247                *v *= inv;
248            }
249
250            let link_overlaps = if n_pairs > 0 {
251                spins::energy::compute_link_overlaps(
252                    lattice,
253                    &real.spins,
254                    &real.system_ids,
255                    n_replicas,
256                    n_temps,
257                )
258            } else {
259                vec![0.0f32; n_temps]
260            };
261
262            equil_accum
263                .as_mut()
264                .unwrap()
265                .push(&diag_e_buf, &link_overlaps);
266        }
267
268        if record {
269            for t in 0..n_temps {
270                mags_buf[t] = 0.0;
271                mags2_buf[t] = 0.0;
272                mags4_buf[t] = 0.0;
273                energies_buf[t] = 0.0;
274            }
275
276            if collect_ac {
277                m2_ac_buf.fill(0.0);
278            }
279
280            for r in 0..n_replicas {
281                let offset = r * n_temps;
282                for t in 0..n_temps {
283                    let system_id = real.system_ids[offset + t];
284                    let spin_base = system_id * n_spins;
285                    let mut sum = 0i64;
286                    for j in 0..n_spins {
287                        sum += real.spins[spin_base + j] as i64;
288                    }
289                    let mag = sum as f32 / n_spins as f32;
290                    let m2 = mag * mag;
291                    mags_buf[t] = mag;
292                    mags2_buf[t] = m2;
293                    mags4_buf[t] = m2 * m2;
294                    energies_buf[t] = real.energies[system_id];
295                }
296
297                if collect_ac {
298                    for t in 0..n_temps {
299                        m2_ac_buf[t] += mags2_buf[t] as f64;
300                    }
301                }
302
303                mags_stat.update(&mags_buf);
304                mags2_stat.update(&mags2_buf);
305                mags4_stat.update(&mags4_buf);
306                energies_stat.update(&energies_buf);
307                energies2_stat.update(&energies_buf);
308            }
309
310            if let Some(ref mut acc) = m2_accum {
311                let inv = 1.0 / n_replicas as f64;
312                for v in m2_ac_buf.iter_mut() {
313                    *v *= inv;
314                }
315                acc.push(&m2_ac_buf);
316            }
317
318            if collect_q2_ac {
319                q2_ac_buf.fill(0.0);
320            }
321
322            for pair_idx in 0..n_pairs {
323                let r_a = 2 * pair_idx;
324                let r_b = 2 * pair_idx + 1;
325                for t in 0..n_temps {
326                    overlaps_buf[t] = 0.0;
327                    overlaps2_buf[t] = 0.0;
328                    overlaps4_buf[t] = 0.0;
329                }
330
331                for t in 0..n_temps {
332                    let sys_a = real.system_ids[r_a * n_temps + t];
333                    let sys_b = real.system_ids[r_b * n_temps + t];
334                    let base_a = sys_a * n_spins;
335                    let base_b = sys_b * n_spins;
336                    let mut dot = 0i64;
337                    for j in 0..n_spins {
338                        dot += (real.spins[base_a + j] as i64) * (real.spins[base_b + j] as i64);
339                    }
340                    let q = dot as f32 / n_spins as f32;
341                    let q2 = q * q;
342                    overlaps_buf[t] = q;
343                    overlaps2_buf[t] = q2;
344                    overlaps4_buf[t] = q2 * q2;
345                    let bin = (((q + 1.0) * 0.5 * OVERLAP_HIST_BINS as f32) as usize)
346                        .min(OVERLAP_HIST_BINS - 1);
347                    overlap_hist[t][bin] += 1;
348                }
349
350                if collect_q2_ac {
351                    for t in 0..n_temps {
352                        q2_ac_buf[t] += overlaps2_buf[t] as f64;
353                    }
354                }
355
356                overlap_stat.update(&overlaps_buf);
357                overlap2_stat.update(&overlaps2_buf);
358                overlap4_stat.update(&overlaps4_buf);
359            }
360
361            if let Some(ref mut acc) = q2_accum {
362                let inv = 1.0 / n_pairs as f64;
363                for v in q2_ac_buf.iter_mut() {
364                    *v *= inv;
365                }
366                acc.push(&q2_ac_buf);
367            }
368        }
369
370        if let Some(ref oc_cfg) = config.overlap_cluster {
371            if sweep_id % oc_cfg.interval == 0 {
372                let ov_csd_out = if oc_cfg.collect_csd && record {
373                    for buf in overlap_csd_buf.iter_mut() {
374                        buf.fill(0);
375                    }
376                    Some(overlap_csd_buf.as_mut_slice())
377                } else {
378                    None
379                };
380
381                let top4_out = if collect_top && record {
382                    for slot in top4_buf.iter_mut() {
383                        *slot = [0u32; 4];
384                    }
385                    Some(top4_buf.as_mut_slice())
386                } else {
387                    None
388                };
389
390                clusters::overlap_update(
391                    lattice,
392                    &mut real.spins,
393                    &real.couplings,
394                    &real.temperatures,
395                    &real.system_ids,
396                    n_replicas,
397                    n_temps,
398                    &mut real.pair_rngs,
399                    stochastic,
400                    restrict_to_negative,
401                    overlap_wolff,
402                    group_size,
403                    ov_csd_out,
404                    top4_out,
405                    config.sequential,
406                );
407
408                if oc_cfg.collect_csd && record {
409                    for (slot, buf) in overlap_csd_buf.iter().enumerate() {
410                        let accum = &mut overlap_csd_accum[slot / n_pairs];
411                        for (a, &b) in accum.iter_mut().zip(buf.iter()) {
412                            *a += b;
413                        }
414                    }
415                }
416
417                if collect_top && record {
418                    for t in 0..n_temps {
419                        for p in 0..n_pairs {
420                            let raw = top4_buf[t * n_pairs + p];
421                            for (k, &v) in raw.iter().enumerate() {
422                                top4_accum[t][k] += v as f64 / n_spins as f64;
423                            }
424                        }
425                    }
426                    top4_n += 1;
427                }
428            }
429        }
430
431        if pt_this_sweep {
432            if config.overlap_cluster.is_some() {
433                (real.energies, _) = spins::energy::compute_energies(
434                    lattice,
435                    &real.spins,
436                    &real.couplings,
437                    n_systems,
438                    false,
439                );
440            }
441            for r in 0..n_replicas {
442                let offset = r * n_temps;
443                let sid_slice = &mut real.system_ids[offset..offset + n_temps];
444                let temp_slice = &real.temperatures[offset..offset + n_temps];
445                mcmc::tempering::parallel_tempering(
446                    &real.energies,
447                    temp_slice,
448                    sid_slice,
449                    n_spins,
450                    &mut real.rngs[offset],
451                );
452            }
453        }
454    }
455
456    let top_cluster_sizes = if collect_top && top4_n > 0 {
457        let denom = (top4_n * n_pairs) as f64;
458        top4_accum
459            .iter()
460            .map(|arr| {
461                [
462                    arr[0] / denom,
463                    arr[1] / denom,
464                    arr[2] / denom,
465                    arr[3] / denom,
466                ]
467            })
468            .collect()
469    } else {
470        vec![]
471    };
472
473    let mags2_tau = m2_accum
474        .as_ref()
475        .map(|acc| acc.finish().iter().map(|g| sokal_tau(g)).collect())
476        .unwrap_or_default();
477    let overlap2_tau = q2_accum
478        .as_ref()
479        .map(|acc| acc.finish().iter().map(|g| sokal_tau(g)).collect())
480        .unwrap_or_default();
481
482    let equil_checkpoints = equil_accum.map(|acc| acc.finish()).unwrap_or_default();
483
484    Ok(SweepResult {
485        mags: mags_stat.average(),
486        mags2: mags2_stat.average(),
487        mags4: mags4_stat.average(),
488        energies: energies_stat.average(),
489        energies2: energies2_stat.average(),
490        overlap: if n_pairs > 0 {
491            overlap_stat.average()
492        } else {
493            vec![]
494        },
495        overlap2: if n_pairs > 0 {
496            overlap2_stat.average()
497        } else {
498            vec![]
499        },
500        overlap4: if n_pairs > 0 {
501            overlap4_stat.average()
502        } else {
503            vec![]
504        },
505        overlap_histogram: overlap_hist,
506        cluster_stats: ClusterStats {
507            fk_csd: fk_csd_accum,
508            overlap_csd: overlap_csd_accum,
509            top_cluster_sizes,
510        },
511        diagnostics: Diagnostics {
512            mags2_tau,
513            overlap2_tau,
514            equil_checkpoints,
515        },
516    })
517}
518
519/// Run the sweep loop in parallel over multiple disorder realizations.
520///
521/// Each realization is processed by [`run_sweep_loop`], then results are
522/// averaged via [`SweepResult::aggregate`]. For a single realization the
523/// call is made directly, skipping rayon thread-pool overhead.
524pub fn run_sweep_parallel(
525    lattice: &Lattice,
526    realizations: &mut [Realization],
527    n_replicas: usize,
528    n_temps: usize,
529    config: &SimConfig,
530    interrupted: &AtomicBool,
531    on_sweep: &(dyn Fn() + Sync),
532) -> Result<SweepResult, String> {
533    if realizations.len() == 1 {
534        return run_sweep_loop(
535            lattice,
536            &mut realizations[0],
537            n_replicas,
538            n_temps,
539            config,
540            interrupted,
541            on_sweep,
542        );
543    }
544
545    let results: Vec<Result<SweepResult, String>> = realizations
546        .par_iter_mut()
547        .map(|real| {
548            run_sweep_loop(
549                lattice,
550                real,
551                n_replicas,
552                n_temps,
553                config,
554                interrupted,
555                on_sweep,
556            )
557        })
558        .collect();
559
560    let results: Vec<SweepResult> = results.into_iter().collect::<Result<Vec<_>, _>>()?;
561    Ok(SweepResult::aggregate(&results))
562}