Skip to main content

sci_form/ani/
api.rs

1//! Public API for ANI machine-learning potentials.
2//!
3//! Provides the top-level functions `compute_ani()` for energy evaluation
4//! and force computation from molecular geometries.
5
6use super::aev::compute_aevs;
7use super::aev_params::{default_ani2x_params, species_index};
8use super::gradients::compute_forces;
9use super::neighbor::CellList;
10use super::nn::FeedForwardNet;
11use nalgebra::DVector;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15/// Configuration for ANI computation.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct AniConfig {
18    /// Cutoff radius (Å). Default: 5.2.
19    pub cutoff: f64,
20    /// Whether to compute forces.
21    pub compute_forces: bool,
22    /// Whether to return the computed AEVs for validation.
23    pub output_aevs: bool,
24}
25
26impl Default for AniConfig {
27    fn default() -> Self {
28        AniConfig {
29            cutoff: 5.2,
30            compute_forces: true,
31            output_aevs: false,
32        }
33    }
34}
35
36/// Result of an ANI energy/force calculation.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct AniResult {
39    /// Total potential energy (Hartree).
40    pub energy: f64,
41    /// Atomic forces [x,y,z] per atom (Hartree/Å). Empty if not requested.
42    pub forces: Vec<[f64; 3]>,
43    /// Atomic species (atomic numbers).
44    pub species: Vec<u8>,
45    /// Per-atom energy contributions.
46    pub atomic_energies: Vec<f64>,
47    /// Optional AEVs for each atom.
48    pub aevs: Option<Vec<Vec<f64>>>,
49}
50
51/// Compute ANI energy (and optionally forces) for a molecular geometry.
52///
53/// `elements`: atomic numbers.
54/// `positions`: [x,y,z] coordinates in Ångström.
55/// `config`: computation parameters.
56/// `models`: pre-loaded element→network map (from `weights::load_weights`).
57pub fn compute_ani(
58    elements: &[u8],
59    positions: &[[f64; 3]],
60    config: &AniConfig,
61    models: &HashMap<u8, FeedForwardNet>,
62) -> Result<AniResult, String> {
63    if elements.len() != positions.len() {
64        return Err(format!(
65            "elements ({}) and positions ({}) length mismatch",
66            elements.len(),
67            positions.len()
68        ));
69    }
70
71    // Validate species support
72    for &z in elements {
73        if species_index(z).is_none() {
74            return Err(format!("Unsupported element Z={z} for ANI potential"));
75        }
76        if !models.contains_key(&z) {
77            return Err(format!("No model weights for element Z={z}"));
78        }
79    }
80
81    let params = default_ani2x_params();
82
83    // Build neighbor list
84    let cell_list = CellList::new(positions, config.cutoff);
85    let neighbors = cell_list.find_neighbors(positions);
86
87    // Compute AEVs
88    let aevs = compute_aevs(elements, positions, &neighbors, &params);
89
90    // Neural network inference: per-atom energies
91    let mut atomic_energies = Vec::with_capacity(elements.len());
92    for (i, &z) in elements.iter().enumerate() {
93        let net = &models[&z];
94        if aevs[i].len() != net.input_dim() {
95            return Err(format!(
96                "AEV dimension {} for atom {} (Z={}) does not match model input dimension {}",
97                aevs[i].len(),
98                i,
99                z,
100                net.input_dim()
101            ));
102        }
103        let input = DVector::from_vec(aevs[i].clone());
104        let e_atom = net.forward(&input);
105        atomic_energies.push(e_atom);
106    }
107
108    let energy: f64 = atomic_energies.iter().sum();
109
110    // Forces
111    let forces = if config.compute_forces {
112        compute_forces(elements, positions, &neighbors, &params, models)
113    } else {
114        Vec::new()
115    };
116
117    Ok(AniResult {
118        energy,
119        forces,
120        species: elements.to_vec(),
121        atomic_energies,
122        aevs: if config.output_aevs { Some(aevs) } else { None },
123    })
124}
125
126/// Compute ANI energy using internally-generated test weights.
127///
128/// This is for testing and demonstration only — not physically meaningful.
129pub fn compute_ani_test(elements: &[u8], positions: &[[f64; 3]]) -> Result<AniResult, String> {
130    let params = default_ani2x_params();
131    let aev_len = params.total_aev_length();
132
133    let mut models = HashMap::new();
134    for &z in elements {
135        models
136            .entry(z)
137            .or_insert_with(|| super::weights::make_test_model(aev_len));
138    }
139
140    compute_ani(elements, positions, &AniConfig::default(), &models)
141}
142
143/// Batch-compute ANI energies for multiple molecules in parallel.
144#[cfg(feature = "parallel")]
145pub fn compute_ani_batch(
146    molecules: &[(&[u8], &[[f64; 3]])],
147    config: &AniConfig,
148    models: &HashMap<u8, FeedForwardNet>,
149) -> Vec<Result<AniResult, String>> {
150    use rayon::prelude::*;
151    molecules
152        .par_iter()
153        .map(|(els, pos)| compute_ani(els, pos, config, models))
154        .collect()
155}
156
157/// Batch-compute ANI energies for multiple molecules sequentially.
158#[cfg(not(feature = "parallel"))]
159pub fn compute_ani_batch(
160    molecules: &[(&[u8], &[[f64; 3]])],
161    config: &AniConfig,
162    models: &HashMap<u8, FeedForwardNet>,
163) -> Vec<Result<AniResult, String>> {
164    molecules
165        .iter()
166        .map(|(els, pos)| compute_ani(els, pos, config, models))
167        .collect()
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn test_api_water() {
176        let elements = [8u8, 1, 1];
177        let positions = [
178            [0.0, 0.0, 0.117],
179            [0.0, 0.757, -0.469],
180            [0.0, -0.757, -0.469],
181        ];
182        let result = compute_ani_test(&elements, &positions).unwrap();
183        assert_eq!(result.species, vec![8, 1, 1]);
184        assert_eq!(result.atomic_energies.len(), 3);
185        assert_eq!(result.forces.len(), 3);
186        assert!(result.aevs.is_none());
187        assert!(result.energy.is_finite());
188    }
189
190    #[test]
191    fn test_unsupported_element() {
192        let elements = [26u8]; // Fe not in ANI
193        let positions = [[0.0, 0.0, 0.0]];
194        let result = compute_ani_test(&elements, &positions);
195        assert!(result.is_err());
196    }
197}