1use super::aev::compute_aevs;
7use super::aev_params::{default_ani2x_params, species_index};
8use super::gradients::compute_forces;
9use super::neighbor::CellList;
10use super::nn::FeedForwardNet;
11use nalgebra::DVector;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct AniConfig {
18 pub cutoff: f64,
20 pub compute_forces: bool,
22 pub output_aevs: bool,
24}
25
26impl Default for AniConfig {
27 fn default() -> Self {
28 AniConfig {
29 cutoff: 5.2,
30 compute_forces: true,
31 output_aevs: false,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct AniResult {
39 pub energy: f64,
41 pub forces: Vec<[f64; 3]>,
43 pub species: Vec<u8>,
45 pub atomic_energies: Vec<f64>,
47 pub aevs: Option<Vec<Vec<f64>>>,
49}
50
51pub fn compute_ani(
58 elements: &[u8],
59 positions: &[[f64; 3]],
60 config: &AniConfig,
61 models: &HashMap<u8, FeedForwardNet>,
62) -> Result<AniResult, String> {
63 if elements.len() != positions.len() {
64 return Err(format!(
65 "elements ({}) and positions ({}) length mismatch",
66 elements.len(),
67 positions.len()
68 ));
69 }
70
71 for &z in elements {
73 if species_index(z).is_none() {
74 return Err(format!("Unsupported element Z={z} for ANI potential"));
75 }
76 if !models.contains_key(&z) {
77 return Err(format!("No model weights for element Z={z}"));
78 }
79 }
80
81 let params = default_ani2x_params();
82
83 let cell_list = CellList::new(positions, config.cutoff);
85 let neighbors = cell_list.find_neighbors(positions);
86
87 let aevs = compute_aevs(elements, positions, &neighbors, ¶ms);
89
90 let mut atomic_energies = Vec::with_capacity(elements.len());
92 for (i, &z) in elements.iter().enumerate() {
93 let net = &models[&z];
94 if aevs[i].len() != net.input_dim() {
95 return Err(format!(
96 "AEV dimension {} for atom {} (Z={}) does not match model input dimension {}",
97 aevs[i].len(),
98 i,
99 z,
100 net.input_dim()
101 ));
102 }
103 let input = DVector::from_vec(aevs[i].clone());
104 let e_atom = net.forward(&input);
105 atomic_energies.push(e_atom);
106 }
107
108 let energy: f64 = atomic_energies.iter().sum();
109
110 let forces = if config.compute_forces {
112 compute_forces(elements, positions, &neighbors, ¶ms, models)
113 } else {
114 Vec::new()
115 };
116
117 Ok(AniResult {
118 energy,
119 forces,
120 species: elements.to_vec(),
121 atomic_energies,
122 aevs: if config.output_aevs { Some(aevs) } else { None },
123 })
124}
125
126pub fn compute_ani_test(elements: &[u8], positions: &[[f64; 3]]) -> Result<AniResult, String> {
130 let params = default_ani2x_params();
131 let aev_len = params.total_aev_length();
132
133 let mut models = HashMap::new();
134 for &z in elements {
135 models
136 .entry(z)
137 .or_insert_with(|| super::weights::make_test_model(aev_len));
138 }
139
140 compute_ani(elements, positions, &AniConfig::default(), &models)
141}
142
143#[cfg(feature = "parallel")]
145pub fn compute_ani_batch(
146 molecules: &[(&[u8], &[[f64; 3]])],
147 config: &AniConfig,
148 models: &HashMap<u8, FeedForwardNet>,
149) -> Vec<Result<AniResult, String>> {
150 use rayon::prelude::*;
151 molecules
152 .par_iter()
153 .map(|(els, pos)| compute_ani(els, pos, config, models))
154 .collect()
155}
156
157#[cfg(not(feature = "parallel"))]
159pub fn compute_ani_batch(
160 molecules: &[(&[u8], &[[f64; 3]])],
161 config: &AniConfig,
162 models: &HashMap<u8, FeedForwardNet>,
163) -> Vec<Result<AniResult, String>> {
164 molecules
165 .iter()
166 .map(|(els, pos)| compute_ani(els, pos, config, models))
167 .collect()
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn test_api_water() {
176 let elements = [8u8, 1, 1];
177 let positions = [
178 [0.0, 0.0, 0.117],
179 [0.0, 0.757, -0.469],
180 [0.0, -0.757, -0.469],
181 ];
182 let result = compute_ani_test(&elements, &positions).unwrap();
183 assert_eq!(result.species, vec![8, 1, 1]);
184 assert_eq!(result.atomic_energies.len(), 3);
185 assert_eq!(result.forces.len(), 3);
186 assert!(result.aevs.is_none());
187 assert!(result.energy.is_finite());
188 }
189
190 #[test]
191 fn test_unsupported_element() {
192 let elements = [26u8]; let positions = [[0.0, 0.0, 0.0]];
194 let result = compute_ani_test(&elements, &positions);
195 assert!(result.is_err());
196 }
197}