wham/
histogram.rs

1use std::fmt;
2
3// One histogram
4#[derive(Debug,Clone)]
5pub struct Histogram {
6    // total number of data points stored in the histogram
7    pub num_points: u32,
8
9    // histogram bins
10    pub bins: Vec<f64>
11}
12
13impl Histogram {
14    pub fn new(num_points: u32, bins: Vec<f64>) -> Histogram {
15        Histogram {num_points, bins}
16    }
17}
18
19// a set of histograms
20#[derive(Debug,Clone)]
21pub struct Dataset {
22    // number of histogram windows (number of simulations)
23    pub num_windows: usize,
24
25    // total number of bins
26    pub num_bins: usize,
27
28    // number of bins in each dimension
29    pub dimens_lengths: Vec<usize>,
30
31    // min values of the histogram in each dimension
32    hist_min: Vec<f64>,
33
34    // max values of the histogram in each dimension
35    hist_max: Vec<f64>,
36
37    // width of a bin in unit of its dimension
38    bin_width: Vec<f64>,
39
40    // value of kT
41    pub kT: f64,
42
43    // histogram for each window
44    pub histograms: Vec<Histogram>,
45
46    // flag for cyclic reaction coordinates
47    pub cyclic: bool,
48
49    // locations of biases
50    bias_pos: Vec<f64>,
51
52    // force constants of biases
53    bias_fc: Vec<f64>,
54
55    // bias value cache
56    bias: Vec<f64>,
57
58    // histogram weight
59    pub weights: Vec<f64>,
60}
61
62impl Dataset {
63
64    pub fn new(num_bins: usize, dimens_lengths: Vec<usize>, bin_width: Vec<f64>,
65        hist_min: Vec<f64>, hist_max: Vec<f64>, bias_pos: Vec<f64>,
66        bias_fc: Vec<f64>, kT: f64, histograms: Vec<Histogram>, cyclic: bool) -> Dataset {
67        let num_windows = histograms.len();
68        let bias: Vec<f64> = vec![0.0; num_bins*num_windows];
69        let weights = vec![1.0; num_windows];
70        let mut ds = Dataset{
71            num_windows,
72            num_bins,
73            dimens_lengths,
74            bin_width,
75            hist_min,
76            hist_max,
77            kT,
78            histograms,
79            cyclic,
80            bias_pos,
81            bias_fc,
82            bias,
83            weights
84        };
85        for window in 0..num_windows {
86            for bin in 0..num_bins {
87                let ndx = window * num_bins + bin;
88                ds.bias[ndx] = ds.calc_bias(bin, window);
89            }
90        }
91        ds
92
93    }
94
95    pub fn new_weighted(ds: Dataset, weights: Vec<f64>) -> Dataset {
96        Dataset {
97            weights,
98            ..ds
99        }
100    }
101
102    pub fn get_weighted_bin_count(&self, bin: usize) -> f64 {
103        self.histograms.iter().enumerate().map(|(idx,h)| self.weights[idx]*h.bins[bin]).sum()
104    }
105
106    fn expand_index(&self, bin: usize, lengths: &[usize]) -> Vec<usize> {
107        let mut tmp = bin;
108        let mut idx = vec![0; lengths.len()];
109        for dimen in (1..lengths.len()).rev() {
110            let denom: usize = lengths.iter().take(dimen).product();
111            idx[dimen] = tmp / denom;
112            tmp %= denom;
113        }
114        idx[0] = tmp;
115        idx
116    }
117
118    // get center x value for a bin
119    pub fn get_coords_for_bin(&self, bin: usize) -> Vec<f64> {
120        self.expand_index(bin, &self.dimens_lengths).iter().enumerate().map(|(i, dimen_bin)| {
121            self.hist_min[i] + self.bin_width[i]*(*dimen_bin as f64 + 0.5)
122        }).collect()
123    }
124
125    pub fn get_bias(&self, bin: usize, window: usize) -> f64 {
126        let ndx = window * self.num_bins + bin;
127        self.bias[ndx]
128    }
129
130    // Harmonic bias calculation: bias = 0.5*k(dx)^2
131    // if cyclic is true, lowest and highest bins are assumed to be
132    // neighbors. This returns exp(U/kT) instead of U for better performance.
133    fn calc_bias(&self, bin: usize, window: usize) -> f64 {
134        let dimens = self.dimens_lengths.len();
135        // index of the bias value depends on the window und dimension
136        let bias_ndx: Vec<usize> = (0..dimens)
137            .map(|dimen| { window * dimens + dimen }).collect();
138
139        // find the N coords, force constants and bias coords
140        let coord = self.get_coords_for_bin(bin);
141        let bias_fc: Vec<f64> = bias_ndx.iter().map(|ndx| { self.bias_fc[*ndx] }).collect();
142        let bias_pos: Vec<f64> = bias_ndx.iter().map(|ndx| { self.bias_pos[*ndx] }).collect();
143
144        let mut bias_sum = 0.0;
145        for i in 0..dimens {
146            let mut dist = (coord[i] - bias_pos[i]).abs();
147            if self.cyclic { // periodic conditions
148                let hist_len = self.hist_max[i] - self.hist_min[i];
149                if dist > 0.5 * hist_len {
150                    dist -= hist_len;
151                }
152            }
153            // store exp(U/kT) for better performance
154            bias_sum += 0.5 * bias_fc[i] * dist * dist
155        }
156        (-bias_sum/self.kT).exp()
157    }
158}
159
160impl fmt::Display for Dataset {
161    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
162        let mut datapoints: u32 = 0;
163        for h in &self.histograms {
164            datapoints += h.num_points;
165        }
166        write!(f, "{} windows, {} datapoints", self.num_windows, datapoints)
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use super::super::k_B;
174
175    macro_rules! assert_delta {
176        ($x:expr, $y:expr, $d:expr) => {
177            assert!(($x-$y).abs() < $d, "{} != {}", $x, $y)
178        }
179    }
180
181    fn build_hist() -> Histogram {
182        Histogram::new(
183            22, // num_points
184            vec![1.0, 1.0, 3.0, 5.0, 12.0] // bins
185        )
186    }
187
188    fn build_hist_set() -> Dataset {
189        let h = build_hist();
190        Dataset::new(
191            5, // num bins
192            vec![1],
193            vec![1.0], // bin width
194            vec![0.0], // hist min
195            vec![9.0], // hist max
196            vec![4.5], // x0
197            vec![10.0], // fc
198            300.0*k_B, // kT
199            vec![h], // hists
200            false // cyclic
201        )
202    }
203
204    #[test]
205    fn calc_bias() {
206        let ds = build_hist_set(); // k = 10
207
208        // 3th element -> x=3.5, x0=3.5
209        assert_delta!(0.134_722_337_796, ds.calc_bias(3, 0), 0.000_000_01);
210
211        // 8th element -> x=8.5, x0=3.5
212        assert_delta!(1.0, ds.calc_bias(4,0), 0.000_000_01);
213        
214        // 1st element -> x=0.5, x0=3.5. non-cyclic!
215        assert_delta!(0.0, ds.calc_bias(0,0), 0.000_000_1);
216    }
217
218    #[test]
219    fn calc_biascyclic() {
220        let mut ds = build_hist_set();
221        ds.cyclic = true;
222
223        // 7th element -> x=3.5, x0=3.5
224        assert_delta!(0.134_722_337_796, ds.calc_bias(3, 0), 0.000_000_01);
225
226        // 8th element -> x=4.5, x0=3.5
227        assert_delta!(1.0, ds.calc_bias(4, 0), 0.000_000_01);
228        
229
230        // 1th element -> x=0.5, x0=3.5
231        // cyclic flag makes bin 0 neighboring bin 9, so the distance is actually 2
232        assert_delta!(0.000_000_000_000_011_776_9, ds.calc_bias(0, 0), 0.000_000_01);
233
234        // 2nd element -> x=1.5, x0=3.5
235        assert_delta!(0.000_000_01, ds.calc_bias(1, 0), 0.000_000_01);
236    }
237
238    #[test]
239    fn get_x_for_bin() {
240        let ds = build_hist_set();
241        let expected: Vec<f64> = vec![0,1,2,3,4,5,6,7,8].iter()
242                .map(|x| *x as f64 + 0.5).collect();
243        expected.iter().enumerate().for_each(|(i, exp)| {
244            assert_approx_eq!(exp, &ds.get_coords_for_bin(i)[0]);
245        })
246    }
247
248    #[test]
249    fn get_bin_count() {
250        let ds = Dataset::new(
251            5, // num bins
252            vec![1],
253            vec![1.0, 1.0], // bin width
254            vec![0.0, 0.0], // hist min
255            vec![5.0, 5.0], // hist max
256            vec![7.5, 7.5], // x0
257            vec![10.0, 10.0], // fc
258            300.0*k_B, // kT
259            vec![build_hist(), build_hist()], // hists
260            false // cyclic
261        );
262        assert_delta!(2.0, ds.get_weighted_bin_count(0), 0.000_000_000_1);
263        assert_delta!(2.0, ds.get_weighted_bin_count(1), 0.000_000_000_1);
264        assert_delta!(6.0, ds.get_weighted_bin_count(2), 0.000_000_000_1);
265        assert_delta!(10.0, ds.get_weighted_bin_count(3), 0.000_000_000_1);
266        assert_delta!(24.0, ds.get_weighted_bin_count(4), 0.000_000_000_1);
267    }
268}