th_rust/simulator/
mcmc.rs1use crate::circuit::Circuit;
2use crate::simulator::Plotter;
3use rand::{Rng, thread_rng};
4use rand::distributions::Uniform;
5
6pub struct MCMCConfig {
7 pub warmup: usize,
8 pub sweeps: usize,
9 pub plot: bool,
10 pub beta: f64,
11}
12
13pub struct MCMC {
14 pub circuit: Circuit,
15 pub config: MCMCConfig,
16 pub activations: Vec<i32>,
17}
18
19impl MCMC {
20 pub fn new(circuit: Circuit) -> MCMC {
21 MCMC {
22 circuit,
23 config: MCMCConfig {
24 warmup: 20,
25 sweeps: 10000,
26 plot: true,
27 beta: 1.0
28 },
29 activations: Vec::new()
30 }
31 }
32
33 pub fn from_config(circuit: Circuit, config: MCMCConfig) -> MCMC {
34 MCMC { circuit, config, activations: Vec::new() }
35 }
36
37 pub fn run(&mut self) {
38 self.activations = vec![0; self.circuit.weight.shape().0];
40 let mut plotter = Plotter::new(1 << (self.activations.len() - 1));
41
42 for sweep in 0..self.config.sweeps {
43
44 let sweep_size = self.activations.len();
46 for _ in 0..sweep_size {
47 for i in 0..sweep_size {
49 let synapse = self.synapse(i);
51 let activation = self.activation(synapse);
52 self.activations[i] = activation;
53 }
54 }
55
56 if sweep > self.config.warmup {
58 let _energy_sample = self.energy();
59
60 plotter.add_sample(&self.activations);
61 }
62 }
63
64 if self.config.plot {
65 plotter.plot("mcmc.png").unwrap();
66 }
67 }
68
69 fn synapse(&self, pbit_index: usize) -> i32 {
70 let mut synapse = 0;
71 let weight = self.circuit.weight.clone();
72 let bias = self.circuit.bias.clone();
73
74 for j in 0..self.activations.len() {
75 if pbit_index != j {
76 synapse += weight[(pbit_index, j)] * self.activations[j];
78 }
79 }
80
81 synapse + bias[pbit_index]
82 }
83
84 fn activation(&self, synapse: i32) -> i32 {
85 let mut rng = thread_rng();
86 let activation_dist = Uniform::new(-1.0, 1.0);
87
88 let mut raw_activ = (self.config.beta * synapse as f64).tanh();
89 raw_activ -= rng.sample(activation_dist);
90
91 if raw_activ > 0.0 {
92 1
93 } else {
94 -1
95 }
96 }
97
98 fn energy(&self) -> f64 {
99 let weight = self.circuit.weight.clone();
100 let bias = self.circuit.bias.clone();
101
102 let mut energy = 0.0;
103 for j in 0..self.activations.len() {
104 for i in 0..j {
105 energy += (weight[(i, j)] * self.activations[i] * self.activations[j]) as f64;
106 }
107 }
108
109 for i in 0..self.activations.len() {
110 energy += (bias[i] * self.activations[i]) as f64;
111 }
112
113 -energy
114 }
115}