Skip to main content

sci_form/ani/
api.rs

1//! Public API for ANI machine-learning potentials.
2//!
3//! Provides the top-level functions `compute_ani()` for energy evaluation
4//! and force computation from molecular geometries.
5
6use super::aev::compute_aevs;
7use super::aev_params::{default_ani2x_params, species_index};
8use super::gradients::compute_forces;
9use super::neighbor::CellList;
10use super::nn::FeedForwardNet;
11use nalgebra::DVector;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15/// Configuration for ANI computation.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct AniConfig {
18    /// Cutoff radius (Å). Default: 5.2.
19    pub cutoff: f64,
20    /// Whether to compute forces.
21    pub compute_forces: bool,
22    /// Whether to return the computed AEVs for validation.
23    pub output_aevs: bool,
24}
25
26impl Default for AniConfig {
27    fn default() -> Self {
28        AniConfig {
29            cutoff: 5.2,
30            compute_forces: true,
31            output_aevs: false,
32        }
33    }
34}
35
36/// Result of an ANI energy/force calculation.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct AniResult {
39    /// Total potential energy (Hartree).
40    pub energy: f64,
41    /// Atomic forces [x,y,z] per atom (Hartree/Å). Empty if not requested.
42    pub forces: Vec<[f64; 3]>,
43    /// Atomic species (atomic numbers).
44    pub species: Vec<u8>,
45    /// Per-atom energy contributions.
46    pub atomic_energies: Vec<f64>,
47    /// Optional AEVs for each atom.
48    pub aevs: Option<Vec<Vec<f64>>>,
49}
50
51/// Compute ANI energy (and optionally forces) for a molecular geometry.
52///
53/// `elements`: atomic numbers.
54/// `positions`: [x,y,z] coordinates in Ångström.
55/// `config`: computation parameters.
56/// `models`: pre-loaded element→network map (from `weights::load_weights`).
57pub fn compute_ani(
58    elements: &[u8],
59    positions: &[[f64; 3]],
60    config: &AniConfig,
61    models: &HashMap<u8, FeedForwardNet>,
62) -> Result<AniResult, String> {
63    if elements.len() != positions.len() {
64        return Err(format!(
65            "elements ({}) and positions ({}) length mismatch",
66            elements.len(),
67            positions.len()
68        ));
69    }
70
71    // Validate species support
72    for &z in elements {
73        if species_index(z).is_none() {
74            return Err(format!("Unsupported element Z={z} for ANI potential"));
75        }
76        if !models.contains_key(&z) {
77            return Err(format!("No model weights for element Z={z}"));
78        }
79    }
80
81    let params = default_ani2x_params();
82
83    // Build neighbor list
84    let cell_list = CellList::new(positions, config.cutoff);
85    let neighbors = cell_list.find_neighbors(positions);
86
87    // Compute AEVs
88    let aevs = compute_aevs(elements, positions, &neighbors, &params);
89
90    // Neural network inference: per-atom energies
91    let mut atomic_energies = Vec::with_capacity(elements.len());
92    for (i, &z) in elements.iter().enumerate() {
93        let net = &models[&z];
94        let input = DVector::from_vec(aevs[i].clone());
95        let e_atom = net.forward(&input);
96        atomic_energies.push(e_atom);
97    }
98
99    let energy: f64 = atomic_energies.iter().sum();
100
101    // Forces
102    let forces = if config.compute_forces {
103        compute_forces(elements, positions, &neighbors, &params, models)
104    } else {
105        Vec::new()
106    };
107
108    Ok(AniResult {
109        energy,
110        forces,
111        species: elements.to_vec(),
112        atomic_energies,
113        aevs: if config.output_aevs { Some(aevs) } else { None },
114    })
115}
116
117/// Compute ANI energy using internally-generated test weights.
118///
119/// This is for testing and demonstration only — not physically meaningful.
120pub fn compute_ani_test(elements: &[u8], positions: &[[f64; 3]]) -> Result<AniResult, String> {
121    let params = default_ani2x_params();
122    let aev_len = params.total_aev_length();
123
124    let mut models = HashMap::new();
125    for &z in elements {
126        models
127            .entry(z)
128            .or_insert_with(|| super::weights::make_test_model(aev_len));
129    }
130
131    compute_ani(elements, positions, &AniConfig::default(), &models)
132}
133
134/// Batch-compute ANI energies for multiple molecules in parallel.
135#[cfg(feature = "parallel")]
136pub fn compute_ani_batch(
137    molecules: &[(&[u8], &[[f64; 3]])],
138    config: &AniConfig,
139    models: &HashMap<u8, FeedForwardNet>,
140) -> Vec<Result<AniResult, String>> {
141    use rayon::prelude::*;
142    molecules
143        .par_iter()
144        .map(|(els, pos)| compute_ani(els, pos, config, models))
145        .collect()
146}
147
148/// Batch-compute ANI energies for multiple molecules sequentially.
149#[cfg(not(feature = "parallel"))]
150pub fn compute_ani_batch(
151    molecules: &[(&[u8], &[[f64; 3]])],
152    config: &AniConfig,
153    models: &HashMap<u8, FeedForwardNet>,
154) -> Vec<Result<AniResult, String>> {
155    molecules
156        .iter()
157        .map(|(els, pos)| compute_ani(els, pos, config, models))
158        .collect()
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn test_api_water() {
167        let elements = [8u8, 1, 1];
168        let positions = [
169            [0.0, 0.0, 0.117],
170            [0.0, 0.757, -0.469],
171            [0.0, -0.757, -0.469],
172        ];
173        let result = compute_ani_test(&elements, &positions).unwrap();
174        assert_eq!(result.species, vec![8, 1, 1]);
175        assert_eq!(result.atomic_energies.len(), 3);
176        assert_eq!(result.forces.len(), 3);
177        assert!(result.aevs.is_none());
178        assert!(result.energy.is_finite());
179    }
180
181    #[test]
182    fn test_unsupported_element() {
183        let elements = [26u8]; // Fe not in ANI
184        let positions = [[0.0, 0.0, 0.0]];
185        let result = compute_ani_test(&elements, &positions);
186        assert!(result.is_err());
187    }
188}