sci_form/ani/
aev_params.rs1use serde::{Deserialize, Serialize};
7
8pub const ANI_ELEMENTS: [u8; 7] = [1, 6, 7, 8, 9, 16, 17]; pub const N_SPECIES: usize = ANI_ELEMENTS.len();
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AevParams {
17 pub radial_cutoff: f64,
19 pub angular_cutoff: f64,
21 pub radial_eta: Vec<f64>,
23 pub radial_rs: Vec<f64>,
25 pub angular_eta: Vec<f64>,
27 pub angular_rs: Vec<f64>,
29 pub angular_zeta: Vec<f64>,
31 pub angular_theta_s: Vec<f64>,
33}
34
35impl AevParams {
36 pub fn radial_length(&self) -> usize {
38 self.radial_eta.len() * self.radial_rs.len()
39 }
40
41 pub fn angular_length(&self) -> usize {
43 self.angular_eta.len()
44 * self.angular_rs.len()
45 * self.angular_zeta.len()
46 * self.angular_theta_s.len()
47 }
48
49 pub fn total_aev_length(&self) -> usize {
51 let n_rad = N_SPECIES * self.radial_length();
54 let n_ang = N_SPECIES * (N_SPECIES + 1) / 2 * self.angular_length();
55 n_rad + n_ang
56 }
57}
58
59pub fn species_index(z: u8) -> Option<usize> {
61 ANI_ELEMENTS.iter().position(|&e| e == z)
62}
63
64pub fn default_ani2x_params() -> AevParams {
69 use std::f64::consts::PI;
70
71 let radial_eta = vec![19.7; 8];
73 let radial_rs: Vec<f64> = (0..8).map(|i| 0.8 + 0.5625 * i as f64).collect();
74
75 let angular_eta = vec![12.5; 4];
77 let angular_rs: Vec<f64> = (0..4).map(|i| 0.8 + 0.95 * i as f64).collect();
78 let angular_zeta = vec![14.1; 1];
79 let angular_theta_s: Vec<f64> = (0..8).map(|i| PI * i as f64 / 8.0).collect();
80
81 AevParams {
82 radial_cutoff: 5.2,
83 angular_cutoff: 3.5,
84 radial_eta,
85 radial_rs,
86 angular_eta,
87 angular_rs,
88 angular_zeta,
89 angular_theta_s,
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 #[test]
98 fn test_species_index() {
99 assert_eq!(species_index(1), Some(0)); assert_eq!(species_index(6), Some(1)); assert_eq!(species_index(8), Some(3)); assert_eq!(species_index(26), None); }
104
105 #[test]
106 fn test_aev_dimensions() {
107 let params = default_ani2x_params();
108 assert!(params.radial_length() > 0);
109 assert!(params.angular_length() > 0);
110 assert!(params.total_aev_length() > 0);
111 }
112}