1use super::aev::compute_aevs;
7use super::aev_params::{default_ani2x_params, species_index};
8use super::gradients::compute_forces;
9use super::neighbor::CellList;
10use super::nn::FeedForwardNet;
11use nalgebra::DVector;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct AniConfig {
18 pub cutoff: f64,
20 pub compute_forces: bool,
22 pub output_aevs: bool,
24}
25
26impl Default for AniConfig {
27 fn default() -> Self {
28 AniConfig {
29 cutoff: 5.2,
30 compute_forces: true,
31 output_aevs: false,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct AniResult {
39 pub energy: f64,
41 pub forces: Vec<[f64; 3]>,
43 pub species: Vec<u8>,
45 pub atomic_energies: Vec<f64>,
47 pub aevs: Option<Vec<Vec<f64>>>,
49}
50
51pub fn compute_ani(
58 elements: &[u8],
59 positions: &[[f64; 3]],
60 config: &AniConfig,
61 models: &HashMap<u8, FeedForwardNet>,
62) -> Result<AniResult, String> {
63 if elements.len() != positions.len() {
64 return Err(format!(
65 "elements ({}) and positions ({}) length mismatch",
66 elements.len(),
67 positions.len()
68 ));
69 }
70
71 for &z in elements {
73 if species_index(z).is_none() {
74 return Err(format!("Unsupported element Z={z} for ANI potential"));
75 }
76 if !models.contains_key(&z) {
77 return Err(format!("No model weights for element Z={z}"));
78 }
79 }
80
81 let params = default_ani2x_params();
82
83 let cell_list = CellList::new(positions, config.cutoff);
85 let neighbors = cell_list.find_neighbors(positions);
86
87 let aevs = compute_aevs(elements, positions, &neighbors, ¶ms);
89
90 let mut atomic_energies = Vec::with_capacity(elements.len());
92 for (i, &z) in elements.iter().enumerate() {
93 let net = &models[&z];
94 let input = DVector::from_vec(aevs[i].clone());
95 let e_atom = net.forward(&input);
96 atomic_energies.push(e_atom);
97 }
98
99 let energy: f64 = atomic_energies.iter().sum();
100
101 let forces = if config.compute_forces {
103 compute_forces(elements, positions, &neighbors, ¶ms, models)
104 } else {
105 Vec::new()
106 };
107
108 Ok(AniResult {
109 energy,
110 forces,
111 species: elements.to_vec(),
112 atomic_energies,
113 aevs: if config.output_aevs { Some(aevs) } else { None },
114 })
115}
116
117pub fn compute_ani_test(elements: &[u8], positions: &[[f64; 3]]) -> Result<AniResult, String> {
121 let params = default_ani2x_params();
122 let aev_len = params.total_aev_length();
123
124 let mut models = HashMap::new();
125 for &z in elements {
126 models
127 .entry(z)
128 .or_insert_with(|| super::weights::make_test_model(aev_len));
129 }
130
131 compute_ani(elements, positions, &AniConfig::default(), &models)
132}
133
134#[cfg(feature = "parallel")]
136pub fn compute_ani_batch(
137 molecules: &[(&[u8], &[[f64; 3]])],
138 config: &AniConfig,
139 models: &HashMap<u8, FeedForwardNet>,
140) -> Vec<Result<AniResult, String>> {
141 use rayon::prelude::*;
142 molecules
143 .par_iter()
144 .map(|(els, pos)| compute_ani(els, pos, config, models))
145 .collect()
146}
147
148#[cfg(not(feature = "parallel"))]
150pub fn compute_ani_batch(
151 molecules: &[(&[u8], &[[f64; 3]])],
152 config: &AniConfig,
153 models: &HashMap<u8, FeedForwardNet>,
154) -> Vec<Result<AniResult, String>> {
155 molecules
156 .iter()
157 .map(|(els, pos)| compute_ani(els, pos, config, models))
158 .collect()
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 #[test]
166 fn test_api_water() {
167 let elements = [8u8, 1, 1];
168 let positions = [
169 [0.0, 0.0, 0.117],
170 [0.0, 0.757, -0.469],
171 [0.0, -0.757, -0.469],
172 ];
173 let result = compute_ani_test(&elements, &positions).unwrap();
174 assert_eq!(result.species, vec![8, 1, 1]);
175 assert_eq!(result.atomic_energies.len(), 3);
176 assert_eq!(result.forces.len(), 3);
177 assert!(result.aevs.is_none());
178 assert!(result.energy.is_finite());
179 }
180
181 #[test]
182 fn test_unsupported_element() {
183 let elements = [26u8]; let positions = [[0.0, 0.0, 0.0]];
185 let result = compute_ani_test(&elements, &positions);
186 assert!(result.is_err());
187 }
188}