Skip to main content

sci_form/
dynamics.rs

1use rand::rngs::StdRng;
2use rand::{Rng, SeedableRng};
3use serde::{Deserialize, Serialize};
4
5const AMU_ANGFS2_TO_KCAL_MOL: f64 = 2_390.057_361_533_49;
6const R_GAS_KCAL_MOLK: f64 = 0.001_987_204_258_640_83;
7const EV_TO_KCAL_MOL: f64 = 23.060_5;
8const HARTREE_TO_KCAL_MOL: f64 = 627.509_474_063_1;
9
10/// Force backend for molecular dynamics.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum MdBackend {
13    /// UFF force field (fast, approximate).
14    Uff,
15    /// PM3 semi-empirical (slower, more accurate).
16    Pm3,
17    /// GFN-xTB tight-binding (moderate speed, good for metals).
18    Xtb,
19}
20
21/// Energy backend for NEB path calculations.
22///
23/// Supports all methods that can provide energy and gradients (analytical
24/// or numerical) for NEB image relaxation.
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum NebBackend {
28    /// UFF force field — fast exploratory paths.
29    Uff,
30    /// MMFF94 force field — better organic coverage than UFF.
31    Mmff94,
32    /// PM3 semi-empirical — analytical gradients.
33    Pm3,
34    /// GFN0-xTB tight-binding — analytical gradients.
35    Xtb,
36    /// GFN1-xTB — numerical gradients (energy-only + finite differences).
37    Gfn1,
38    /// GFN2-xTB — numerical gradients (energy-only + finite differences).
39    Gfn2,
40    /// HF-3c minimal-basis Hartree-Fock — numerical gradients (expensive).
41    Hf3c,
42}
43
44impl NebBackend {
45    /// Parse a method string into a `NebBackend`.
46    pub fn from_method(s: &str) -> Result<Self, String> {
47        match s.trim().to_ascii_lowercase().as_str() {
48            "uff" => Ok(Self::Uff),
49            "mmff94" | "mmff" => Ok(Self::Mmff94),
50            "pm3" => Ok(Self::Pm3),
51            "xtb" | "gfn0" | "gfn0-xtb" | "gfn0_xtb" => Ok(Self::Xtb),
52            "gfn1" | "gfn1-xtb" | "gfn1_xtb" => Ok(Self::Gfn1),
53            "gfn2" | "gfn2-xtb" | "gfn2_xtb" => Ok(Self::Gfn2),
54            "hf3c" | "hf-3c" => Ok(Self::Hf3c),
55            other => Err(format!(
56                "Unknown NEB backend '{}'. Expected: uff, mmff94, pm3, xtb, gfn1, gfn2, hf3c",
57                other
58            )),
59        }
60    }
61
62    /// method label for human display.
63    pub fn as_str(self) -> &'static str {
64        match self {
65            Self::Uff => "uff",
66            Self::Mmff94 => "mmff94",
67            Self::Pm3 => "pm3",
68            Self::Xtb => "xtb",
69            Self::Gfn1 => "gfn1",
70            Self::Gfn2 => "gfn2",
71            Self::Hf3c => "hf3c",
72        }
73    }
74}
75
76/// One trajectory frame for molecular-dynamics sampling.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct MdFrame {
79    /// Integration step index.
80    pub step: usize,
81    /// Elapsed simulation time in femtoseconds.
82    pub time_fs: f64,
83    /// Flat xyz coordinates in angstroms.
84    pub coords: Vec<f64>,
85    /// Potential energy from UFF in kcal/mol.
86    pub potential_energy_kcal_mol: f64,
87    /// Kinetic energy in kcal/mol.
88    pub kinetic_energy_kcal_mol: f64,
89    /// Instantaneous temperature estimate in K.
90    pub temperature_k: f64,
91}
92
93/// Full trajectory output for exploratory molecular dynamics.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct MdTrajectory {
96    /// Stored MD frames.
97    pub frames: Vec<MdFrame>,
98    /// Timestep in femtoseconds.
99    pub dt_fs: f64,
100    /// Notes and caveats for interpretation.
101    pub notes: Vec<String>,
102    /// Energy conservation drift: (E_total_final - E_total_initial) / |E_total_initial| * 100%.
103    /// Only meaningful for NVE (no thermostat) simulations.
104    pub energy_drift_percent: Option<f64>,
105}
106
107/// One image (node) on a simplified NEB path.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct NebImage {
110    /// Image index from reactant (0) to product (n-1).
111    pub index: usize,
112    /// Flat xyz coordinates in angstroms.
113    pub coords: Vec<f64>,
114    /// UFF potential energy in kcal/mol.
115    pub potential_energy_kcal_mol: f64,
116}
117
118/// Simplified NEB output for low-cost pathway exploration.
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct NebPathResult {
121    /// Ordered path images from reactant to product.
122    pub images: Vec<NebImage>,
123    /// Notes and caveats.
124    pub notes: Vec<String>,
125}
126
127pub fn atomic_mass_amu(z: u8) -> f64 {
128    match z {
129        1 => 1.008,
130        5 => 10.81,
131        6 => 12.011,
132        7 => 14.007,
133        8 => 15.999,
134        9 => 18.998,
135        14 => 28.085,
136        15 => 30.974,
137        16 => 32.06,
138        17 => 35.45,
139        35 => 79.904,
140        53 => 126.904,
141        26 => 55.845,
142        46 => 106.42,
143        78 => 195.084,
144        _ => 12.0,
145    }
146}
147
148fn sample_standard_normal(rng: &mut StdRng) -> f64 {
149    let u1 = (1.0 - rng.gen::<f64>()).max(1e-12);
150    let u2 = rng.gen::<f64>();
151    (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
152}
153
154fn kinetic_energy_and_temperature(
155    velocities: &[f64],
156    masses_amu: &[f64],
157    n_atoms: usize,
158) -> (f64, f64) {
159    let mut ke = 0.0;
160    for i in 0..n_atoms {
161        let vx = velocities[3 * i];
162        let vy = velocities[3 * i + 1];
163        let vz = velocities[3 * i + 2];
164        let v2 = vx * vx + vy * vy + vz * vz;
165        ke += 0.5 * masses_amu[i] * v2 * AMU_ANGFS2_TO_KCAL_MOL;
166    }
167    let dof = (3 * n_atoms).saturating_sub(6).max(1) as f64;
168    let t = 2.0 * ke / (dof * R_GAS_KCAL_MOLK);
169    (ke, t)
170}
171
172/// Run short exploratory molecular dynamics using Velocity Verlet and optional Berendsen NVT.
173pub fn simulate_velocity_verlet_uff(
174    smiles: &str,
175    coords: &[f64],
176    n_steps: usize,
177    dt_fs: f64,
178    seed: u64,
179    target_temp_and_tau: Option<(f64, f64)>,
180) -> Result<MdTrajectory, String> {
181    if n_steps == 0 {
182        return Err("n_steps must be > 0".to_string());
183    }
184    if dt_fs <= 0.0 {
185        return Err("dt_fs must be > 0".to_string());
186    }
187
188    let mol = crate::graph::Molecule::from_smiles(smiles)?;
189    let n_atoms = mol.graph.node_count();
190    if coords.len() != n_atoms * 3 {
191        return Err(format!(
192            "coords length {} != 3 * atoms {}",
193            coords.len(),
194            n_atoms
195        ));
196    }
197
198    let masses_amu: Vec<f64> = (0..n_atoms)
199        .map(petgraph::graph::NodeIndex::new)
200        .map(|idx| atomic_mass_amu(mol.graph[idx].element))
201        .collect();
202
203    let ff = crate::forcefield::builder::build_uff_force_field(&mol);
204    let mut x = coords.to_vec();
205    let mut grad = vec![0.0f64; n_atoms * 3];
206    let mut potential = ff.compute_system_energy_and_gradients(&x, &mut grad);
207
208    let mut rng = StdRng::seed_from_u64(seed);
209    let mut v = vec![0.0f64; n_atoms * 3];
210
211    if let Some((target_temp_k, _tau_fs)) = target_temp_and_tau {
212        for i in 0..n_atoms {
213            let sigma = ((R_GAS_KCAL_MOLK * target_temp_k)
214                / (masses_amu[i] * AMU_ANGFS2_TO_KCAL_MOL))
215                .sqrt();
216            v[3 * i] = sigma * sample_standard_normal(&mut rng);
217            v[3 * i + 1] = sigma * sample_standard_normal(&mut rng);
218            v[3 * i + 2] = sigma * sample_standard_normal(&mut rng);
219        }
220    }
221
222    let (ke0, t0) = kinetic_energy_and_temperature(&v, &masses_amu, n_atoms);
223    let mut frames = vec![MdFrame {
224        step: 0,
225        time_fs: 0.0,
226        coords: x.clone(),
227        potential_energy_kcal_mol: potential,
228        kinetic_energy_kcal_mol: ke0,
229        temperature_k: t0,
230    }];
231
232    for step in 1..=n_steps {
233        let mut a = vec![0.0f64; n_atoms * 3];
234        for i in 0..n_atoms {
235            let m = masses_amu[i];
236            a[3 * i] = -grad[3 * i] / (m * AMU_ANGFS2_TO_KCAL_MOL);
237            a[3 * i + 1] = -grad[3 * i + 1] / (m * AMU_ANGFS2_TO_KCAL_MOL);
238            a[3 * i + 2] = -grad[3 * i + 2] / (m * AMU_ANGFS2_TO_KCAL_MOL);
239        }
240
241        for i in 0..(n_atoms * 3) {
242            x[i] += v[i] * dt_fs + 0.5 * a[i] * dt_fs * dt_fs;
243        }
244
245        potential = ff.compute_system_energy_and_gradients(&x, &mut grad);
246
247        let mut a_new = vec![0.0f64; n_atoms * 3];
248        for i in 0..n_atoms {
249            let m = masses_amu[i];
250            a_new[3 * i] = -grad[3 * i] / (m * AMU_ANGFS2_TO_KCAL_MOL);
251            a_new[3 * i + 1] = -grad[3 * i + 1] / (m * AMU_ANGFS2_TO_KCAL_MOL);
252            a_new[3 * i + 2] = -grad[3 * i + 2] / (m * AMU_ANGFS2_TO_KCAL_MOL);
253        }
254
255        for i in 0..(n_atoms * 3) {
256            v[i] += 0.5 * (a[i] + a_new[i]) * dt_fs;
257        }
258
259        let (mut ke, mut temp_k) = kinetic_energy_and_temperature(&v, &masses_amu, n_atoms);
260        if let Some((target_temp_k, tau_fs)) = target_temp_and_tau {
261            let tau = tau_fs.max(1e-6);
262            let lambda = (1.0 + (dt_fs / tau) * (target_temp_k / temp_k.max(1e-6) - 1.0)).sqrt();
263            let lambda = lambda.clamp(0.5, 2.0);
264            for vi in &mut v {
265                *vi *= lambda;
266            }
267            let kt = kinetic_energy_and_temperature(&v, &masses_amu, n_atoms);
268            ke = kt.0;
269            temp_k = kt.1;
270        }
271
272        if !x.iter().all(|v| v.is_finite()) || !potential.is_finite() || !ke.is_finite() {
273            return Err(format!(
274                "MD diverged at step {} (non-finite coordinates/energy)",
275                step
276            ));
277        }
278
279        frames.push(MdFrame {
280            step,
281            time_fs: step as f64 * dt_fs,
282            coords: x.clone(),
283            potential_energy_kcal_mol: potential,
284            kinetic_energy_kcal_mol: ke,
285            temperature_k: temp_k,
286        });
287    }
288
289    let mut notes = vec![
290        "Velocity Verlet integration over UFF force-field gradients for short exploratory trajectories."
291            .to_string(),
292    ];
293    if target_temp_and_tau.is_some() {
294        notes.push(
295            "Berendsen thermostat rescaling applied for approximate constant-temperature sampling."
296                .to_string(),
297        );
298    } else {
299        notes.push(
300            "No thermostat applied (NVE-like propagation with current numerical approximations)."
301                .to_string(),
302        );
303    }
304
305    Ok(MdTrajectory {
306        energy_drift_percent: if target_temp_and_tau.is_none() {
307            let e0 = frames[0].potential_energy_kcal_mol + frames[0].kinetic_energy_kcal_mol;
308            let ef = frames
309                .last()
310                .map(|f| f.potential_energy_kcal_mol + f.kinetic_energy_kcal_mol)
311                .unwrap_or(e0);
312            if e0.abs() > 1e-10 {
313                Some((ef - e0) / e0.abs() * 100.0)
314            } else {
315                None
316            }
317        } else {
318            None
319        },
320        frames,
321        dt_fs,
322        notes,
323    })
324}
325
326/// Build a simplified NEB-like path using linear interpolation and spring-coupled relaxation.
327pub fn compute_simplified_neb_path(
328    smiles: &str,
329    start_coords: &[f64],
330    end_coords: &[f64],
331    n_images: usize,
332    n_iter: usize,
333    spring_k: f64,
334    step_size: f64,
335) -> Result<NebPathResult, String> {
336    if n_images < 2 {
337        return Err("n_images must be >= 2".to_string());
338    }
339    if n_iter == 0 {
340        return Err("n_iter must be > 0".to_string());
341    }
342    if step_size <= 0.0 {
343        return Err("step_size must be > 0".to_string());
344    }
345
346    let mol = crate::graph::Molecule::from_smiles(smiles)?;
347    let n_atoms = mol.graph.node_count();
348    let n_xyz = n_atoms * 3;
349    if start_coords.len() != n_xyz || end_coords.len() != n_xyz {
350        return Err(format!(
351            "start/end coords must each have length {} (3 * n_atoms)",
352            n_xyz
353        ));
354    }
355
356    let ff = crate::forcefield::builder::build_uff_force_field(&mol);
357
358    let mut images = vec![vec![0.0f64; n_xyz]; n_images];
359    for (img_idx, img) in images.iter_mut().enumerate() {
360        let t = img_idx as f64 / (n_images - 1) as f64;
361        for k in 0..n_xyz {
362            img[k] = (1.0 - t) * start_coords[k] + t * end_coords[k];
363        }
364    }
365
366    for _ in 0..n_iter {
367        let prev = images.clone();
368
369        #[cfg(feature = "parallel")]
370        {
371            use rayon::prelude::*;
372            let updated: Vec<(usize, Vec<f64>)> = (1..(n_images - 1))
373                .into_par_iter()
374                .map(|i| {
375                    let ff_local = crate::forcefield::builder::build_uff_force_field(&mol);
376                    let mut grad = vec![0.0f64; n_xyz];
377                    let _ = ff_local.compute_system_energy_and_gradients(&prev[i], &mut grad);
378                    let mut new_img = prev[i].clone();
379                    for k in 0..n_xyz {
380                        let spring_force =
381                            spring_k * (prev[i + 1][k] - 2.0 * prev[i][k] + prev[i - 1][k]);
382                        let total_force = -grad[k] + spring_force;
383                        new_img[k] = prev[i][k] + step_size * total_force;
384                    }
385                    (i, new_img)
386                })
387                .collect();
388            for (i, img) in updated {
389                images[i] = img;
390            }
391        }
392
393        #[cfg(not(feature = "parallel"))]
394        {
395            for i in 1..(n_images - 1) {
396                let mut grad = vec![0.0f64; n_xyz];
397                let _ = ff.compute_system_energy_and_gradients(&prev[i], &mut grad);
398                for k in 0..n_xyz {
399                    let spring_force =
400                        spring_k * (prev[i + 1][k] - 2.0 * prev[i][k] + prev[i - 1][k]);
401                    let total_force = -grad[k] + spring_force;
402                    images[i][k] = prev[i][k] + step_size * total_force;
403                }
404            }
405        }
406    }
407
408    let mut out_images = Vec::with_capacity(n_images);
409    for (i, coords) in images.into_iter().enumerate() {
410        let mut grad = vec![0.0f64; n_xyz];
411        let e = ff.compute_system_energy_and_gradients(&coords, &mut grad);
412        out_images.push(NebImage {
413            index: i,
414            coords,
415            potential_energy_kcal_mol: e,
416        });
417    }
418
419    Ok(NebPathResult {
420        images: out_images,
421        notes: vec![
422            "Simplified NEB: linear interpolation + spring-coupled UFF gradient relaxation on internal images."
423                .to_string(),
424            "This is a low-cost exploratory path tool and not a full climbing-image / tangent-projected NEB implementation."
425                .to_string(),
426        ],
427    })
428}
429
430// ─── Configurable NEB backend dispatch ──────────────────────────────────────
431
432/// Compute energy (kcal/mol) for a single point using a backend that only
433/// exposes energy (no analytical gradients). Used by the numerical gradient
434/// fallback for GFN1, GFN2, and HF-3c.
435fn neb_point_energy_kcal(
436    backend: NebBackend,
437    elements: &[u8],
438    coords: &[f64],
439) -> Result<f64, String> {
440    let positions: Vec<[f64; 3]> = coords.chunks(3).map(|c| [c[0], c[1], c[2]]).collect();
441    match backend {
442        NebBackend::Gfn1 => {
443            let r = crate::xtb::gfn1::solve_gfn1(elements, &positions)?;
444            Ok(r.total_energy * EV_TO_KCAL_MOL)
445        }
446        NebBackend::Gfn2 => {
447            let r = crate::xtb::gfn2::solve_gfn2(elements, &positions)?;
448            Ok(r.total_energy * EV_TO_KCAL_MOL)
449        }
450        NebBackend::Hf3c => {
451            let r = crate::hf::solve_hf3c(elements, &positions, &crate::hf::HfConfig::default())?;
452            Ok(r.energy * HARTREE_TO_KCAL_MOL)
453        }
454        _ => unreachable!("neb_point_energy_kcal only for energy-only backends"),
455    }
456}
457
458/// Compute energy and numerical gradient (central finite differences) for
459/// backends without analytical gradients.
460fn neb_numerical_gradient(
461    backend: NebBackend,
462    elements: &[u8],
463    coords: &[f64],
464    grad: &mut [f64],
465) -> Result<f64, String> {
466    let delta = 1e-5; // Å
467    let e0 = neb_point_energy_kcal(backend, elements, coords)?;
468    let mut displaced = coords.to_vec();
469    for i in 0..coords.len() {
470        displaced[i] = coords[i] + delta;
471        let e_plus = neb_point_energy_kcal(backend, elements, &displaced)?;
472        displaced[i] = coords[i] - delta;
473        let e_minus = neb_point_energy_kcal(backend, elements, &displaced)?;
474        displaced[i] = coords[i]; // restore
475        grad[i] = (e_plus - e_minus) / (2.0 * delta);
476    }
477    Ok(e0)
478}
479
480/// Compute energy (kcal/mol) and gradients (kcal/mol/Å) for a NEB image
481/// using the specified backend.
482pub fn neb_energy_and_gradient(
483    backend: NebBackend,
484    _smiles: &str,
485    elements: &[u8],
486    mol: &crate::graph::Molecule,
487    coords: &[f64],
488    grad: &mut [f64],
489) -> Result<f64, String> {
490    match backend {
491        NebBackend::Uff => {
492            let ff = crate::forcefield::builder::build_uff_force_field(mol);
493            Ok(ff.compute_system_energy_and_gradients(coords, grad))
494        }
495        NebBackend::Mmff94 => {
496            let bonds: Vec<(usize, usize, u8)> = mol
497                .graph
498                .edge_indices()
499                .map(|e| {
500                    let (a, b) = mol.graph.edge_endpoints(e).unwrap();
501                    let order = match mol.graph[e].order {
502                        crate::graph::BondOrder::Single => 1u8,
503                        crate::graph::BondOrder::Double => 2,
504                        crate::graph::BondOrder::Triple => 3,
505                        crate::graph::BondOrder::Aromatic => 2,
506                        crate::graph::BondOrder::Unknown => 1,
507                    };
508                    (a.index(), b.index(), order)
509                })
510                .collect();
511            let terms = crate::forcefield::mmff94::Mmff94Builder::build(elements, &bonds);
512            let (energy, g) =
513                crate::forcefield::mmff94::Mmff94Builder::total_energy(&terms, coords);
514            grad[..g.len()].copy_from_slice(&g);
515            Ok(energy)
516        }
517        NebBackend::Pm3 => {
518            let positions: Vec<[f64; 3]> = coords.chunks(3).map(|c| [c[0], c[1], c[2]]).collect();
519            let r = crate::pm3::gradients::compute_pm3_gradient(elements, &positions)?;
520            let energy_kcal = r.energy * EV_TO_KCAL_MOL;
521            for (a, g) in r.gradients.iter().enumerate() {
522                for d in 0..3 {
523                    grad[a * 3 + d] = g[d] * EV_TO_KCAL_MOL;
524                }
525            }
526            Ok(energy_kcal)
527        }
528        NebBackend::Xtb => {
529            let positions: Vec<[f64; 3]> = coords.chunks(3).map(|c| [c[0], c[1], c[2]]).collect();
530            let r = crate::xtb::gradients::compute_xtb_gradient(elements, &positions)?;
531            let energy_kcal = r.energy * EV_TO_KCAL_MOL;
532            for (a, g) in r.gradients.iter().enumerate() {
533                for d in 0..3 {
534                    grad[a * 3 + d] = g[d] * EV_TO_KCAL_MOL;
535                }
536            }
537            Ok(energy_kcal)
538        }
539        NebBackend::Gfn1 | NebBackend::Gfn2 | NebBackend::Hf3c => {
540            neb_numerical_gradient(backend, elements, coords, grad)
541        }
542    }
543}
544
545/// Build a simplified NEB path with a configurable energy backend.
546///
547/// This is the multi-method version of [`compute_simplified_neb_path`].
548/// Supply `method` as one of: `"uff"`, `"mmff94"`, `"pm3"`, `"xtb"`,
549/// `"gfn1"`, `"gfn2"`, `"hf3c"`.
550pub fn compute_simplified_neb_path_configurable(
551    smiles: &str,
552    start_coords: &[f64],
553    end_coords: &[f64],
554    n_images: usize,
555    n_iter: usize,
556    spring_k: f64,
557    step_size: f64,
558    method: &str,
559) -> Result<NebPathResult, String> {
560    let backend = NebBackend::from_method(method)?;
561    if n_images < 2 {
562        return Err("n_images must be >= 2".to_string());
563    }
564    if n_iter == 0 {
565        return Err("n_iter must be > 0".to_string());
566    }
567    if step_size <= 0.0 {
568        return Err("step_size must be > 0".to_string());
569    }
570
571    let mol = crate::graph::Molecule::from_smiles(smiles)?;
572    let n_atoms = mol.graph.node_count();
573    let n_xyz = n_atoms * 3;
574    if start_coords.len() != n_xyz || end_coords.len() != n_xyz {
575        return Err(format!(
576            "start/end coords must each have length {} (3 * n_atoms)",
577            n_xyz
578        ));
579    }
580
581    let elements: Vec<u8> = (0..n_atoms)
582        .map(|i| mol.graph[petgraph::graph::NodeIndex::new(i)].element)
583        .collect();
584
585    // Linear interpolation
586    let mut images = vec![vec![0.0f64; n_xyz]; n_images];
587    for (img_idx, img) in images.iter_mut().enumerate() {
588        let t = img_idx as f64 / (n_images - 1) as f64;
589        for k in 0..n_xyz {
590            img[k] = (1.0 - t) * start_coords[k] + t * end_coords[k];
591        }
592    }
593
594    // Spring-coupled NEB relaxation
595    for _ in 0..n_iter {
596        let prev = images.clone();
597        for i in 1..(n_images - 1) {
598            let mut grad = vec![0.0f64; n_xyz];
599            let _ = neb_energy_and_gradient(backend, smiles, &elements, &mol, &prev[i], &mut grad)?;
600            for k in 0..n_xyz {
601                let spring_force = spring_k * (prev[i + 1][k] - 2.0 * prev[i][k] + prev[i - 1][k]);
602                let total_force = -grad[k] + spring_force;
603                images[i][k] = prev[i][k] + step_size * total_force;
604            }
605        }
606    }
607
608    // Evaluate final energies
609    let mut out_images = Vec::with_capacity(n_images);
610    for (i, coords) in images.into_iter().enumerate() {
611        let mut grad = vec![0.0f64; n_xyz];
612        let e = neb_energy_and_gradient(backend, smiles, &elements, &mol, &coords, &mut grad)?;
613        out_images.push(NebImage {
614            index: i,
615            coords,
616            potential_energy_kcal_mol: e,
617        });
618    }
619
620    Ok(NebPathResult {
621        images: out_images,
622        notes: vec![
623            format!(
624                "Simplified NEB ({}) with spring-coupled gradient relaxation on {} internal images.",
625                backend.as_str(),
626                n_images.saturating_sub(2),
627            ),
628            "Low-cost exploratory path; not a full climbing-image / tangent-projected NEB."
629                .to_string(),
630        ],
631    })
632}
633
634/// Compute energy-only for any NEB backend (used for single-point comparisons).
635///
636/// Returns energy in kcal/mol.
637pub fn neb_backend_energy_kcal(method: &str, smiles: &str, coords: &[f64]) -> Result<f64, String> {
638    let backend = NebBackend::from_method(method)?;
639    let mol = crate::graph::Molecule::from_smiles(smiles)?;
640    let n_atoms = mol.graph.node_count();
641    let n_xyz = n_atoms * 3;
642    if coords.len() != n_xyz {
643        return Err(format!(
644            "coords length {} != 3 * atoms {}",
645            coords.len(),
646            n_atoms
647        ));
648    }
649    let elements: Vec<u8> = (0..n_atoms)
650        .map(|i| mol.graph[petgraph::graph::NodeIndex::new(i)].element)
651        .collect();
652    let mut grad = vec![0.0f64; n_xyz];
653    neb_energy_and_gradient(backend, smiles, &elements, &mol, coords, &mut grad)
654}
655
656/// Compute energy and return both energy (kcal/mol) and flat gradient (kcal/mol/Å).
657pub fn neb_backend_energy_and_gradient(
658    method: &str,
659    smiles: &str,
660    coords: &[f64],
661) -> Result<(f64, Vec<f64>), String> {
662    let backend = NebBackend::from_method(method)?;
663    let mol = crate::graph::Molecule::from_smiles(smiles)?;
664    let n_atoms = mol.graph.node_count();
665    let n_xyz = n_atoms * 3;
666    if coords.len() != n_xyz {
667        return Err(format!(
668            "coords length {} != 3 * atoms {}",
669            coords.len(),
670            n_atoms
671        ));
672    }
673    let elements: Vec<u8> = (0..n_atoms)
674        .map(|i| mol.graph[petgraph::graph::NodeIndex::new(i)].element)
675        .collect();
676    let mut grad = vec![0.0f64; n_xyz];
677    let energy = neb_energy_and_gradient(backend, smiles, &elements, &mol, coords, &mut grad)?;
678    Ok((energy, grad))
679}
680
681/// Compute energy and gradients using the specified backend.
682///
683/// Returns (energy_kcal_mol, gradients) or an error.
684pub fn compute_backend_energy_and_gradients(
685    backend: MdBackend,
686    smiles: &str,
687    elements: &[u8],
688    coords: &[f64],
689    grad: &mut [f64],
690) -> Result<f64, String> {
691    match backend {
692        MdBackend::Uff => {
693            let mol = crate::graph::Molecule::from_smiles(smiles)?;
694            let ff = crate::forcefield::builder::build_uff_force_field(&mol);
695            let e = ff.compute_system_energy_and_gradients(coords, grad);
696            Ok(e)
697        }
698        MdBackend::Pm3 => {
699            let positions: Vec<[f64; 3]> = coords.chunks(3).map(|c| [c[0], c[1], c[2]]).collect();
700            let grad_result = crate::pm3::gradients::compute_pm3_gradient(elements, &positions)?;
701            // PM3 gradient returns eV/Å; convert to kcal/mol/Å
702            let energy_kcal = grad_result.energy * 23.0605;
703            for (a, g) in grad_result.gradients.iter().enumerate() {
704                for d in 0..3 {
705                    grad[a * 3 + d] = g[d] * 23.0605;
706                }
707            }
708            Ok(energy_kcal)
709        }
710        MdBackend::Xtb => {
711            let positions: Vec<[f64; 3]> = coords.chunks(3).map(|c| [c[0], c[1], c[2]]).collect();
712            let grad_result = crate::xtb::gradients::compute_xtb_gradient(elements, &positions)?;
713            let energy_kcal = grad_result.energy * 23.0605;
714            for (a, g) in grad_result.gradients.iter().enumerate() {
715                for d in 0..3 {
716                    grad[a * 3 + d] = g[d] * 23.0605;
717                }
718            }
719            Ok(energy_kcal)
720        }
721    }
722}
723
724/// Run molecular dynamics using Velocity Verlet with a configurable force backend.
725///
726/// Supports UFF, PM3, and xTB backends. Note that PM3/xTB use numerical gradients
727/// and are significantly slower than UFF.
728pub fn simulate_velocity_verlet(
729    smiles: &str,
730    coords: &[f64],
731    elements: &[u8],
732    n_steps: usize,
733    dt_fs: f64,
734    seed: u64,
735    target_temp_and_tau: Option<(f64, f64)>,
736    backend: MdBackend,
737) -> Result<MdTrajectory, String> {
738    if n_steps == 0 {
739        return Err("n_steps must be > 0".to_string());
740    }
741    let n_atoms = elements.len();
742    if coords.len() != n_atoms * 3 {
743        return Err(format!(
744            "coords length {} != 3*atoms {}",
745            coords.len(),
746            n_atoms
747        ));
748    }
749
750    let masses_amu: Vec<f64> = elements.iter().map(|&z| atomic_mass_amu(z)).collect();
751
752    let mut x = coords.to_vec();
753    let mut grad = vec![0.0f64; n_atoms * 3];
754    let mut potential =
755        compute_backend_energy_and_gradients(backend, smiles, elements, &x, &mut grad)?;
756
757    let mut rng = StdRng::seed_from_u64(seed);
758    let mut v = vec![0.0f64; n_atoms * 3];
759
760    if let Some((target_temp_k, _)) = target_temp_and_tau {
761        for i in 0..n_atoms {
762            let sigma = ((R_GAS_KCAL_MOLK * target_temp_k)
763                / (masses_amu[i] * AMU_ANGFS2_TO_KCAL_MOL))
764                .sqrt();
765            v[3 * i] = sigma * sample_standard_normal(&mut rng);
766            v[3 * i + 1] = sigma * sample_standard_normal(&mut rng);
767            v[3 * i + 2] = sigma * sample_standard_normal(&mut rng);
768        }
769    }
770
771    let (ke0, t0) = kinetic_energy_and_temperature(&v, &masses_amu, n_atoms);
772    let mut frames = vec![MdFrame {
773        step: 0,
774        time_fs: 0.0,
775        coords: x.clone(),
776        potential_energy_kcal_mol: potential,
777        kinetic_energy_kcal_mol: ke0,
778        temperature_k: t0,
779    }];
780
781    for step in 1..=n_steps {
782        let mut a = vec![0.0f64; n_atoms * 3];
783        for i in 0..n_atoms {
784            let m = masses_amu[i];
785            for k in 0..3 {
786                a[3 * i + k] = -grad[3 * i + k] / (m * AMU_ANGFS2_TO_KCAL_MOL);
787            }
788        }
789
790        for i in 0..(n_atoms * 3) {
791            x[i] += v[i] * dt_fs + 0.5 * a[i] * dt_fs * dt_fs;
792        }
793
794        potential = compute_backend_energy_and_gradients(backend, smiles, elements, &x, &mut grad)?;
795
796        let mut a_new = vec![0.0f64; n_atoms * 3];
797        for i in 0..n_atoms {
798            let m = masses_amu[i];
799            for k in 0..3 {
800                a_new[3 * i + k] = -grad[3 * i + k] / (m * AMU_ANGFS2_TO_KCAL_MOL);
801            }
802        }
803
804        for i in 0..(n_atoms * 3) {
805            v[i] += 0.5 * (a[i] + a_new[i]) * dt_fs;
806        }
807
808        let (mut ke, mut temp_k) = kinetic_energy_and_temperature(&v, &masses_amu, n_atoms);
809        if let Some((target_temp_k, tau_fs)) = target_temp_and_tau {
810            let tau = tau_fs.max(1e-6);
811            let lambda = (1.0 + (dt_fs / tau) * (target_temp_k / temp_k.max(1e-6) - 1.0)).sqrt();
812            let lambda = lambda.clamp(0.5, 2.0);
813            for vi in &mut v {
814                *vi *= lambda;
815            }
816            let kt = kinetic_energy_and_temperature(&v, &masses_amu, n_atoms);
817            ke = kt.0;
818            temp_k = kt.1;
819        }
820
821        if !x.iter().all(|v| v.is_finite()) || !potential.is_finite() {
822            return Err(format!("MD diverged at step {}", step));
823        }
824
825        frames.push(MdFrame {
826            step,
827            time_fs: step as f64 * dt_fs,
828            coords: x.clone(),
829            potential_energy_kcal_mol: potential,
830            kinetic_energy_kcal_mol: ke,
831            temperature_k: temp_k,
832        });
833    }
834
835    let energy_drift_percent = if target_temp_and_tau.is_none() {
836        let e0 = frames[0].potential_energy_kcal_mol + frames[0].kinetic_energy_kcal_mol;
837        let ef = frames
838            .last()
839            .map(|f| f.potential_energy_kcal_mol + f.kinetic_energy_kcal_mol)
840            .unwrap_or(e0);
841        if e0.abs() > 1e-10 {
842            Some((ef - e0) / e0.abs() * 100.0)
843        } else {
844            None
845        }
846    } else {
847        None
848    };
849
850    Ok(MdTrajectory {
851        frames,
852        dt_fs,
853        notes: vec![format!("Velocity Verlet with {:?} backend.", backend)],
854        energy_drift_percent,
855    })
856}
857
858/// Nosé-Hoover chain thermostat for rigorous NVT sampling.
859///
860/// Implements a chain of `chain_length` thermostats coupled to velocities.
861/// Produces canonical (NVT) ensemble with correct fluctuations.
862pub fn simulate_nose_hoover(
863    smiles: &str,
864    coords: &[f64],
865    elements: &[u8],
866    n_steps: usize,
867    dt_fs: f64,
868    target_temp_k: f64,
869    thermostat_mass: f64,
870    chain_length: usize,
871    seed: u64,
872    backend: MdBackend,
873) -> Result<MdTrajectory, String> {
874    if n_steps == 0 {
875        return Err("n_steps must be > 0".to_string());
876    }
877    let n_atoms = elements.len();
878    if coords.len() != n_atoms * 3 {
879        return Err("coords length mismatch".to_string());
880    }
881
882    let masses_amu: Vec<f64> = elements.iter().map(|&z| atomic_mass_amu(z)).collect();
883    let dof = (3 * n_atoms).saturating_sub(6).max(1) as f64;
884    let target_ke = 0.5 * dof * R_GAS_KCAL_MOLK * target_temp_k;
885
886    let mut x = coords.to_vec();
887    let mut grad = vec![0.0f64; n_atoms * 3];
888    let mut potential =
889        compute_backend_energy_and_gradients(backend, smiles, elements, &x, &mut grad)?;
890
891    let mut rng = StdRng::seed_from_u64(seed);
892    let mut v = vec![0.0f64; n_atoms * 3];
893    for i in 0..n_atoms {
894        let sigma =
895            ((R_GAS_KCAL_MOLK * target_temp_k) / (masses_amu[i] * AMU_ANGFS2_TO_KCAL_MOL)).sqrt();
896        v[3 * i] = sigma * sample_standard_normal(&mut rng);
897        v[3 * i + 1] = sigma * sample_standard_normal(&mut rng);
898        v[3 * i + 2] = sigma * sample_standard_normal(&mut rng);
899    }
900
901    // Nosé-Hoover chain variables
902    let chain_len = chain_length.max(1);
903    let q = vec![thermostat_mass; chain_len]; // thermostat masses
904    let mut xi = vec![0.0f64; chain_len]; // thermostat positions
905    let mut v_xi = vec![0.0f64; chain_len]; // thermostat velocities
906
907    let (ke0, t0) = kinetic_energy_and_temperature(&v, &masses_amu, n_atoms);
908    let mut frames = vec![MdFrame {
909        step: 0,
910        time_fs: 0.0,
911        coords: x.clone(),
912        potential_energy_kcal_mol: potential,
913        kinetic_energy_kcal_mol: ke0,
914        temperature_k: t0,
915    }];
916
917    let dt2 = dt_fs * 0.5;
918    let dt4 = dt_fs * 0.25;
919
920    for step in 1..=n_steps {
921        let (ke, _) = kinetic_energy_and_temperature(&v, &masses_amu, n_atoms);
922
923        // Nosé-Hoover chain: propagate thermostat (Yoshida-Suzuki-like)
924        // Force on first thermostat
925        let mut g_xi = vec![0.0f64; chain_len];
926        g_xi[0] = (2.0 * ke - 2.0 * target_ke) / q[0];
927        for j in 1..chain_len {
928            g_xi[j] =
929                (q[j - 1] * v_xi[j - 1] * v_xi[j - 1] - R_GAS_KCAL_MOLK * target_temp_k) / q[j];
930        }
931
932        // Update thermostat velocities (outside-in)
933        if chain_len > 1 {
934            v_xi[chain_len - 1] += g_xi[chain_len - 1] * dt4;
935        }
936        for j in (0..chain_len.saturating_sub(1)).rev() {
937            let exp_factor = (-v_xi[j + 1] * dt4).exp();
938            v_xi[j] = v_xi[j] * exp_factor + g_xi[j] * dt4;
939        }
940
941        // Scale velocities
942        let scale = (-v_xi[0] * dt2).exp();
943        for vi in v.iter_mut() {
944            *vi *= scale;
945        }
946
947        // Update thermostat positions
948        for j in 0..chain_len {
949            xi[j] += v_xi[j] * dt2;
950        }
951
952        // Recompute KE after scaling
953        let (ke_scaled, _) = kinetic_energy_and_temperature(&v, &masses_amu, n_atoms);
954        g_xi[0] = (2.0 * ke_scaled - 2.0 * target_ke) / q[0];
955
956        // Update thermostat velocities again (inside-out)
957        for j in 0..chain_len.saturating_sub(1) {
958            let exp_factor = (-v_xi[j + 1] * dt4).exp();
959            v_xi[j] = v_xi[j] * exp_factor + g_xi[j] * dt4;
960        }
961        if chain_len > 1 {
962            // Recompute force on last thermostat
963            g_xi[chain_len - 1] = (q[chain_len - 2] * v_xi[chain_len - 2] * v_xi[chain_len - 2]
964                - R_GAS_KCAL_MOLK * target_temp_k)
965                / q[chain_len - 1];
966            v_xi[chain_len - 1] += g_xi[chain_len - 1] * dt4;
967        }
968
969        // Velocity Verlet: half-step velocity
970        for i in 0..n_atoms {
971            let m = masses_amu[i];
972            for k in 0..3 {
973                v[3 * i + k] -= 0.5 * dt_fs * grad[3 * i + k] / (m * AMU_ANGFS2_TO_KCAL_MOL);
974            }
975        }
976
977        // Position update
978        for i in 0..(n_atoms * 3) {
979            x[i] += v[i] * dt_fs;
980        }
981
982        // New forces
983        potential = compute_backend_energy_and_gradients(backend, smiles, elements, &x, &mut grad)?;
984
985        // Velocity Verlet: second half-step velocity
986        for i in 0..n_atoms {
987            let m = masses_amu[i];
988            for k in 0..3 {
989                v[3 * i + k] -= 0.5 * dt_fs * grad[3 * i + k] / (m * AMU_ANGFS2_TO_KCAL_MOL);
990            }
991        }
992
993        let (ke_final, temp_k) = kinetic_energy_and_temperature(&v, &masses_amu, n_atoms);
994
995        if !x.iter().all(|v| v.is_finite()) || !potential.is_finite() {
996            return Err(format!("NH-MD diverged at step {}", step));
997        }
998
999        frames.push(MdFrame {
1000            step,
1001            time_fs: step as f64 * dt_fs,
1002            coords: x.clone(),
1003            potential_energy_kcal_mol: potential,
1004            kinetic_energy_kcal_mol: ke_final,
1005            temperature_k: temp_k,
1006        });
1007    }
1008
1009    // Temperature analysis
1010    let temps: Vec<f64> = frames
1011        .iter()
1012        .skip(frames.len() / 5)
1013        .map(|f| f.temperature_k)
1014        .collect();
1015    let avg_temp = temps.iter().sum::<f64>() / temps.len() as f64;
1016    let temp_variance =
1017        temps.iter().map(|t| (t - avg_temp).powi(2)).sum::<f64>() / temps.len() as f64;
1018
1019    Ok(MdTrajectory {
1020        frames,
1021        dt_fs,
1022        notes: vec![
1023            format!(
1024                "Nosé-Hoover chain thermostat (chain={}) with {:?} backend.",
1025                chain_len, backend
1026            ),
1027            format!(
1028                "Target T = {:.1} K, avg T = {:.1} K, σ(T) = {:.1} K",
1029                target_temp_k,
1030                avg_temp,
1031                temp_variance.sqrt()
1032            ),
1033        ],
1034        energy_drift_percent: None,
1035    })
1036}
1037
1038/// Export an MD trajectory to XYZ format string.
1039///
1040/// Each frame becomes a block with atom count, comment line (step/energy/temp),
1041/// and atom coordinates.
1042pub fn trajectory_to_xyz(trajectory: &MdTrajectory, elements: &[u8]) -> String {
1043    let n_atoms = elements.len();
1044    let mut output = String::new();
1045
1046    for frame in &trajectory.frames {
1047        output.push_str(&format!("{}\n", n_atoms));
1048        output.push_str(&format!(
1049            "step={} t={:.1}fs E_pot={:.4} E_kin={:.4} T={:.1}K\n",
1050            frame.step,
1051            frame.time_fs,
1052            frame.potential_energy_kcal_mol,
1053            frame.kinetic_energy_kcal_mol,
1054            frame.temperature_k,
1055        ));
1056        for i in 0..n_atoms {
1057            let sym = element_symbol(elements[i]);
1058            output.push_str(&format!(
1059                "{} {:.6} {:.6} {:.6}\n",
1060                sym,
1061                frame.coords[3 * i],
1062                frame.coords[3 * i + 1],
1063                frame.coords[3 * i + 2],
1064            ));
1065        }
1066    }
1067    output
1068}
1069
1070fn element_symbol(z: u8) -> &'static str {
1071    match z {
1072        1 => "H",
1073        2 => "He",
1074        3 => "Li",
1075        4 => "Be",
1076        5 => "B",
1077        6 => "C",
1078        7 => "N",
1079        8 => "O",
1080        9 => "F",
1081        10 => "Ne",
1082        11 => "Na",
1083        12 => "Mg",
1084        13 => "Al",
1085        14 => "Si",
1086        15 => "P",
1087        16 => "S",
1088        17 => "Cl",
1089        18 => "Ar",
1090        19 => "K",
1091        20 => "Ca",
1092        22 => "Ti",
1093        24 => "Cr",
1094        25 => "Mn",
1095        26 => "Fe",
1096        27 => "Co",
1097        28 => "Ni",
1098        29 => "Cu",
1099        30 => "Zn",
1100        35 => "Br",
1101        44 => "Ru",
1102        46 => "Pd",
1103        47 => "Ag",
1104        53 => "I",
1105        78 => "Pt",
1106        79 => "Au",
1107        _ => "X",
1108    }
1109}
1110
1111// ═══════════════════════════════════════════════════════════════════════════
1112//  Reaction dynamics: SMILES → embed → complex → NEB → full frame path
1113// ═══════════════════════════════════════════════════════════════════════════
1114
1115/// Configuration for a full reaction dynamics computation.
1116#[derive(Debug, Clone, Serialize, Deserialize)]
1117pub struct ReactionDynamicsConfig {
1118    /// Number of NEB images for the reactive region.
1119    pub n_neb_images: usize,
1120    /// NEB spring-coupled relaxation iterations.
1121    pub neb_iterations: usize,
1122    /// NEB spring constant (kcal/mol/Ų).
1123    pub spring_k: f64,
1124    /// NEB optimisation step size (Å).
1125    pub step_size: f64,
1126    /// Number of approach frames (molecules approaching).
1127    pub n_approach_frames: usize,
1128    /// Number of departure frames (products separating).
1129    pub n_departure_frames: usize,
1130    /// Far distance (Å) at start/end of approach/departure.
1131    pub far_distance: f64,
1132    /// Target distance between reactive atoms in the complexes (Å).
1133    pub reactive_distance: f64,
1134    /// Random seed for conformer generation.
1135    pub seed: u64,
1136}
1137
1138impl Default for ReactionDynamicsConfig {
1139    fn default() -> Self {
1140        Self {
1141            n_neb_images: 30,
1142            neb_iterations: 100,
1143            spring_k: 0.1,
1144            step_size: 0.01,
1145            n_approach_frames: 15,
1146            n_departure_frames: 15,
1147            far_distance: 8.0,
1148            reactive_distance: 2.0,
1149            seed: 42,
1150        }
1151    }
1152}
1153
1154/// A single frame along the reaction coordinate.
1155#[derive(Debug, Clone, Serialize, Deserialize)]
1156pub struct ReactionDynamicsFrame {
1157    /// Frame index (0-based).
1158    pub index: usize,
1159    /// Flat xyz coordinates `[x0,y0,z0, x1,y1,z1, ...]` in Å.
1160    pub coords: Vec<f64>,
1161    /// Energy at this frame (kcal/mol).
1162    pub energy_kcal_mol: f64,
1163    /// Phase label: `"approach"`, `"reaction"`, or `"departure"`.
1164    pub phase: String,
1165}
1166
1167/// Full result of a reaction dynamics computation.
1168#[derive(Debug, Clone, Serialize, Deserialize)]
1169pub struct ReactionDynamicsResult {
1170    /// All frames ordered along the reaction coordinate.
1171    pub frames: Vec<ReactionDynamicsFrame>,
1172    /// Atomic numbers for every atom in each frame (same for all frames).
1173    pub elements: Vec<u8>,
1174    /// Frame index of the transition state (highest energy).
1175    pub ts_frame_index: usize,
1176    /// Activation energy (kcal/mol) = E(TS) − E(first frame).
1177    pub activation_energy_kcal_mol: f64,
1178    /// Reaction energy (kcal/mol) = E(last frame) − E(first frame).
1179    pub reaction_energy_kcal_mol: f64,
1180    /// Method used for energy/gradient evaluation.
1181    pub method: String,
1182    /// Number of atoms per frame.
1183    pub n_atoms: usize,
1184    /// Informational notes.
1185    pub notes: Vec<String>,
1186}
1187
1188/// Centre-of-mass of a flat coordinate array.
1189fn com_flat(coords: &[f64]) -> [f64; 3] {
1190    let n = coords.len() / 3;
1191    if n == 0 {
1192        return [0.0; 3];
1193    }
1194    let mut cx = 0.0;
1195    let mut cy = 0.0;
1196    let mut cz = 0.0;
1197    for i in 0..n {
1198        cx += coords[i * 3];
1199        cy += coords[i * 3 + 1];
1200        cz += coords[i * 3 + 2];
1201    }
1202    let nf = n as f64;
1203    [cx / nf, cy / nf, cz / nf]
1204}
1205
1206/// Translate coords so COM is at origin.
1207fn centre_at_origin(coords: &mut [f64]) {
1208    let [cx, cy, cz] = com_flat(coords);
1209    for i in (0..coords.len()).step_by(3) {
1210        coords[i] -= cx;
1211        coords[i + 1] -= cy;
1212        coords[i + 2] -= cz;
1213    }
1214}
1215
1216/// Rodrigues rotation: rotate flat coords so direction `from` aligns with `to`.
1217fn rotate_to_align(coords: &mut [f64], from: [f64; 3], to: [f64; 3]) {
1218    let fl = (from[0] * from[0] + from[1] * from[1] + from[2] * from[2]).sqrt();
1219    let tl = (to[0] * to[0] + to[1] * to[1] + to[2] * to[2]).sqrt();
1220    if fl < 1e-10 || tl < 1e-10 {
1221        return;
1222    }
1223    let fx = from[0] / fl;
1224    let fy = from[1] / fl;
1225    let fz = from[2] / fl;
1226    let tx = to[0] / tl;
1227    let ty = to[1] / tl;
1228    let tz = to[2] / tl;
1229
1230    let kx = fy * tz - fz * ty;
1231    let ky = fz * tx - fx * tz;
1232    let kz = fx * ty - fy * tx;
1233    let sin_a = (kx * kx + ky * ky + kz * kz).sqrt();
1234    let cos_a = fx * tx + fy * ty + fz * tz;
1235
1236    if sin_a < 1e-10 {
1237        if cos_a > 0.0 {
1238            return; // already aligned
1239        }
1240        // Anti-aligned: reflect through a perpendicular plane
1241        for i in (0..coords.len()).step_by(3) {
1242            coords[i] = -coords[i];
1243        }
1244        return;
1245    }
1246
1247    let nkx = kx / sin_a;
1248    let nky = ky / sin_a;
1249    let nkz = kz / sin_a;
1250    let c = cos_a;
1251    let s = sin_a;
1252    let t1 = 1.0 - c;
1253
1254    for i in (0..coords.len()).step_by(3) {
1255        let x = coords[i];
1256        let y = coords[i + 1];
1257        let z = coords[i + 2];
1258        let dot = nkx * x + nky * y + nkz * z;
1259        coords[i] = x * c + (nky * z - nkz * y) * s + nkx * dot * t1;
1260        coords[i + 1] = y * c + (nkz * x - nkx * z) * s + nky * dot * t1;
1261        coords[i + 2] = z * c + (nkx * y - nky * x) * s + nkz * dot * t1;
1262    }
1263}
1264
1265/// Build a reactive complex from separately-embedded molecule conformers.
1266///
1267/// Each molecule is centred at its own COM, then the first two are oriented so
1268/// their closest-pair atoms face each other at `reactive_dist` Å apart.
1269/// Returns `(flat_coords, elements)` with the overall COM at origin.
1270fn build_reaction_complex(
1271    conformers: &[crate::ConformerResult],
1272    reactive_dist: f64,
1273) -> (Vec<f64>, Vec<u8>) {
1274    if conformers.is_empty() {
1275        return (vec![], vec![]);
1276    }
1277
1278    // Centre each molecule at its own COM
1279    let mut mols: Vec<(Vec<f64>, Vec<u8>)> = conformers
1280        .iter()
1281        .map(|c| {
1282            let mut coords = c.coords.clone();
1283            centre_at_origin(&mut coords);
1284            (coords, c.elements.clone())
1285        })
1286        .collect();
1287
1288    if mols.len() == 1 {
1289        let (mut coords, elems) = mols.remove(0);
1290        centre_at_origin(&mut coords);
1291        return (coords, elems);
1292    }
1293
1294    // Find closest inter-molecular pair between mol 0 and mol 1
1295    let n0 = mols[0].1.len();
1296    let n1 = mols[1].1.len();
1297    let mut best_i = 0usize;
1298    let mut best_j = 0usize;
1299    let mut best_d2 = f64::INFINITY;
1300    // Use a reference offset so we rank "face-to-face" proximity
1301    let ref_off = 4.0;
1302    for i in 0..n0 {
1303        let xi = mols[0].0[i * 3];
1304        let yi = mols[0].0[i * 3 + 1];
1305        let zi = mols[0].0[i * 3 + 2];
1306        for j in 0..n1 {
1307            let dx = mols[1].0[j * 3] + ref_off - xi;
1308            let dy = mols[1].0[j * 3 + 1] - yi;
1309            let dz = mols[1].0[j * 3 + 2] - zi;
1310            let d2 = dx * dx + dy * dy + dz * dz;
1311            if d2 < best_d2 {
1312                best_d2 = d2;
1313                best_i = i;
1314                best_j = j;
1315            }
1316        }
1317    }
1318
1319    // Orient mol 0 reactive atom → +X
1320    {
1321        let rx = mols[0].0[best_i * 3];
1322        let ry = mols[0].0[best_i * 3 + 1];
1323        let rz = mols[0].0[best_i * 3 + 2];
1324        let rl = (rx * rx + ry * ry + rz * rz).sqrt();
1325        if rl > 0.05 {
1326            rotate_to_align(&mut mols[0].0, [rx / rl, ry / rl, rz / rl], [1.0, 0.0, 0.0]);
1327        }
1328    }
1329
1330    // Orient mol 1 reactive atom → −X
1331    {
1332        let rx = mols[1].0[best_j * 3];
1333        let ry = mols[1].0[best_j * 3 + 1];
1334        let rz = mols[1].0[best_j * 3 + 2];
1335        let rl = (rx * rx + ry * ry + rz * rz).sqrt();
1336        if rl > 0.05 {
1337            rotate_to_align(
1338                &mut mols[1].0,
1339                [rx / rl, ry / rl, rz / rl],
1340                [-1.0, 0.0, 0.0],
1341            );
1342        }
1343    }
1344
1345    // Place mol 1 so reactive atoms are `reactive_dist` apart along X
1346    let ra_x = mols[0].0[best_i * 3];
1347    let rb_x = mols[1].0[best_j * 3];
1348    let offset_x = ra_x - rb_x + reactive_dist;
1349
1350    // Assemble combined coords
1351    let total_atoms: usize = mols.iter().map(|(_, e)| e.len()).sum();
1352    let mut all_coords = Vec::with_capacity(total_atoms * 3);
1353    let mut all_elements = Vec::with_capacity(total_atoms);
1354
1355    // Mol 0 — no offset
1356    all_coords.extend_from_slice(&mols[0].0);
1357    all_elements.extend_from_slice(&mols[0].1);
1358
1359    // Mol 1 — shifted along X
1360    for k in 0..n1 {
1361        all_coords.push(mols[1].0[k * 3] + offset_x);
1362        all_coords.push(mols[1].0[k * 3 + 1]);
1363        all_coords.push(mols[1].0[k * 3 + 2]);
1364    }
1365    all_elements.extend_from_slice(&mols[1].1);
1366
1367    // Spectator molecules (m ≥ 2) — stacked further along +X
1368    let mut extra = offset_x + 4.0;
1369    for mol in mols.iter().skip(2) {
1370        for k in 0..mol.1.len() {
1371            all_coords.push(mol.0[k * 3] + extra);
1372            all_coords.push(mol.0[k * 3 + 1]);
1373            all_coords.push(mol.0[k * 3 + 2]);
1374        }
1375        all_elements.extend_from_slice(&mol.1);
1376        extra += 4.0;
1377    }
1378
1379    centre_at_origin(&mut all_coords);
1380    (all_coords, all_elements)
1381}
1382
1383/// Greedy atom mapping between two complexes by element and distance.
1384///
1385/// Returns `mapping[i]` = index in the product complex that corresponds
1386/// to reactant atom `i`. Matches atoms of the same element, choosing the
1387/// closest unmatched pair greedily.
1388fn map_atoms_greedy(
1389    r_elements: &[u8],
1390    r_coords: &[f64],
1391    p_elements: &[u8],
1392    p_coords: &[f64],
1393) -> Result<Vec<usize>, String> {
1394    let n = r_elements.len();
1395    if n != p_elements.len() {
1396        return Err(format!(
1397            "Atom count mismatch: {} reactant atoms vs {} product atoms",
1398            n,
1399            p_elements.len(),
1400        ));
1401    }
1402
1403    // Group indices by element
1404    let mut r_by_elem: std::collections::HashMap<u8, Vec<usize>> = std::collections::HashMap::new();
1405    let mut p_by_elem: std::collections::HashMap<u8, Vec<usize>> = std::collections::HashMap::new();
1406    for i in 0..n {
1407        r_by_elem.entry(r_elements[i]).or_default().push(i);
1408        p_by_elem.entry(p_elements[i]).or_default().push(i);
1409    }
1410
1411    let mut mapping = vec![0usize; n];
1412    let mut used_p = vec![false; n];
1413
1414    for (&elem, r_indices) in &r_by_elem {
1415        let p_indices = p_by_elem
1416            .get(&elem)
1417            .ok_or_else(|| format!("Element Z={} present in reactants but not products", elem,))?;
1418        if r_indices.len() != p_indices.len() {
1419            return Err(format!(
1420                "Element Z={}: {} in reactants vs {} in products",
1421                elem,
1422                r_indices.len(),
1423                p_indices.len(),
1424            ));
1425        }
1426
1427        // Build distance pairs and sort by distance
1428        let mut pairs: Vec<(usize, usize, f64)> = Vec::new();
1429        for &ri in r_indices {
1430            for &pi in p_indices {
1431                let dx = r_coords[ri * 3] - p_coords[pi * 3];
1432                let dy = r_coords[ri * 3 + 1] - p_coords[pi * 3 + 1];
1433                let dz = r_coords[ri * 3 + 2] - p_coords[pi * 3 + 2];
1434                pairs.push((ri, pi, dx * dx + dy * dy + dz * dz));
1435            }
1436        }
1437        pairs.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
1438
1439        let mut used_r = vec![false; n];
1440        for (ri, pi, _) in pairs {
1441            if used_r[ri] || used_p[pi] {
1442                continue;
1443            }
1444            mapping[ri] = pi;
1445            used_r[ri] = true;
1446            used_p[pi] = true;
1447        }
1448    }
1449
1450    Ok(mapping)
1451}
1452
1453/// Reorder product coords into reactant atom order using `mapping`.
1454fn reorder_coords(coords: &[f64], mapping: &[usize]) -> Vec<f64> {
1455    let mut out = vec![0.0; coords.len()];
1456    for (i, &j) in mapping.iter().enumerate() {
1457        out[i * 3] = coords[j * 3];
1458        out[i * 3 + 1] = coords[j * 3 + 1];
1459        out[i * 3 + 2] = coords[j * 3 + 2];
1460    }
1461    out
1462}
1463
1464/// Rigid-body slide: move each molecule's atoms toward/away from the global COM.
1465///
1466/// `mol_ranges[m] = (start_atom, end_atom)` (exclusive end) per molecule.
1467/// `alpha` = 0 → atoms at `far_dist` from global COM; `alpha` = 1 → at complex position.
1468fn slide_molecules(
1469    complex_coords: &[f64],
1470    mol_ranges: &[(usize, usize)],
1471    alpha: f64,
1472    far_dist: f64,
1473) -> Vec<f64> {
1474    let mut out = complex_coords.to_vec();
1475    let [gcx, gcy, gcz] = com_flat(complex_coords);
1476
1477    for &(start, end) in mol_ranges {
1478        let n = end - start;
1479        if n == 0 {
1480            continue;
1481        }
1482        let mut mx = 0.0;
1483        let mut my = 0.0;
1484        let mut mz = 0.0;
1485        for a in start..end {
1486            mx += complex_coords[a * 3];
1487            my += complex_coords[a * 3 + 1];
1488            mz += complex_coords[a * 3 + 2];
1489        }
1490        let nf = n as f64;
1491        mx /= nf;
1492        my /= nf;
1493        mz /= nf;
1494
1495        let dx = mx - gcx;
1496        let dy = my - gcy;
1497        let dz = mz - gcz;
1498        let d = (dx * dx + dy * dy + dz * dz).sqrt();
1499        if d < 0.01 {
1500            continue;
1501        }
1502        let target = d + (far_dist - d) * (1.0 - alpha);
1503        let shift = target - d;
1504        let nx = dx / d;
1505        let ny = dy / d;
1506        let nz = dz / d;
1507        for a in start..end {
1508            out[a * 3] += nx * shift;
1509            out[a * 3 + 1] += ny * shift;
1510            out[a * 3 + 2] += nz * shift;
1511        }
1512    }
1513    out
1514}
1515
1516/// Build a reactant complex guided by the product geometry.
1517///
1518/// For each reactant fragment, positions its embedded 3D geometry at the
1519/// centre-of-mass of the corresponding atoms in the **reordered product**
1520/// (product coords in reactant atom order).  Each fragment is Kabsch-aligned
1521/// to its product target positions and then translated to the product-derived
1522/// COM.  This ensures that the approach direction is chemically meaningful:
1523/// e.g. for H₂ + C₂H₂ → C₂H₄, H₂ approaches from beside the C≡C bond
1524/// (where the H atoms end up in ethylene), not end-on.
1525fn build_product_guided_reactant_complex(
1526    r_confs: &[crate::ConformerResult],
1527    p_reordered_coords: &[f64],
1528) -> Vec<f64> {
1529    let n_total: usize = r_confs.iter().map(|c| c.num_atoms).sum();
1530    let mut all_coords = vec![0.0f64; n_total * 3];
1531    let mut atom_off = 0usize;
1532
1533    for conf in r_confs {
1534        let n = conf.num_atoms;
1535        // Product-side positions for this fragment's atoms
1536        let p_frag: Vec<f64> = (atom_off..atom_off + n)
1537            .flat_map(|a| {
1538                [
1539                    p_reordered_coords[a * 3],
1540                    p_reordered_coords[a * 3 + 1],
1541                    p_reordered_coords[a * 3 + 2],
1542                ]
1543            })
1544            .collect();
1545        let p_com = com_flat(&p_frag);
1546
1547        // Reactant fragment's own geometry, centred at origin
1548        let mut r_frag = conf.coords.clone();
1549        centre_at_origin(&mut r_frag);
1550
1551        // Kabsch-align reactant fragment onto the product fragment positions
1552        // (both centred at their own COMs) to get the best rotation.
1553        if n >= 2 {
1554            let aligned = crate::alignment::kabsch::align_coordinates(&r_frag, &p_frag);
1555            // The aligned_coords are translated onto reference centroid,
1556            // so we must re-centre them and use the product COM.
1557            let ac = com_flat(&aligned.aligned_coords);
1558            for a in 0..n {
1559                all_coords[(atom_off + a) * 3] = aligned.aligned_coords[a * 3] - ac[0] + p_com[0];
1560                all_coords[(atom_off + a) * 3 + 1] =
1561                    aligned.aligned_coords[a * 3 + 1] - ac[1] + p_com[1];
1562                all_coords[(atom_off + a) * 3 + 2] =
1563                    aligned.aligned_coords[a * 3 + 2] - ac[2] + p_com[2];
1564            }
1565        } else {
1566            // Single atom — just place at product position
1567            for a in 0..n {
1568                all_coords[(atom_off + a) * 3] = r_frag[a * 3] + p_com[0];
1569                all_coords[(atom_off + a) * 3 + 1] = r_frag[a * 3 + 1] + p_com[1];
1570                all_coords[(atom_off + a) * 3 + 2] = r_frag[a * 3 + 2] + p_com[2];
1571            }
1572        }
1573
1574        atom_off += n;
1575    }
1576
1577    centre_at_origin(&mut all_coords);
1578    all_coords
1579}
1580
1581/// Compute a full reaction dynamics path: embed reactants + products,
1582/// build oriented complexes, run NEB for the reactive region, and generate
1583/// approach/departure frames — all energies computed in Rust with the
1584/// chosen quantum-chemistry method.
1585///
1586/// The reactant complex orientation is **derived from the product geometry**:
1587/// each reactant fragment is positioned where its atoms end up in the product,
1588/// ensuring a chemically meaningful approach direction.  For example, in
1589/// H₂ + C₂H₂ → C₂H₄ the H₂ approaches side-on to the π system rather than
1590/// end-on along the C≡C axis.
1591///
1592/// # Arguments
1593///
1594/// * `reactant_smiles` — SMILES of each reactant fragment (e.g. `["CC(=O)O", "N"]`).
1595/// * `product_smiles`  — SMILES of each product fragment (e.g. `["CC(=O)N", "O"]`).
1596/// * `method`          — NEB backend: `"uff"`, `"mmff94"`, `"pm3"`, `"xtb"`, `"gfn1"`, `"gfn2"`, `"hf3c"`.
1597/// * `config`          — [`ReactionDynamicsConfig`] with NEB/path parameters.
1598///
1599/// Returns [`ReactionDynamicsResult`] with all frames, energies, and TS info.
1600pub fn compute_reaction_dynamics(
1601    reactant_smiles: &[&str],
1602    product_smiles: &[&str],
1603    method: &str,
1604    config: &ReactionDynamicsConfig,
1605) -> Result<ReactionDynamicsResult, String> {
1606    if reactant_smiles.is_empty() {
1607        return Err("At least one reactant SMILES is required".into());
1608    }
1609    if product_smiles.is_empty() {
1610        return Err("At least one product SMILES is required".into());
1611    }
1612
1613    let backend = NebBackend::from_method(method)?;
1614
1615    // ── 1. Embed all fragments ─────────────────────────────────────────
1616    let r_confs: Vec<crate::ConformerResult> = reactant_smiles
1617        .iter()
1618        .map(|s| crate::embed(s, config.seed))
1619        .collect();
1620    for (i, c) in r_confs.iter().enumerate() {
1621        if let Some(ref e) = c.error {
1622            return Err(format!(
1623                "Failed to embed reactant '{}': {}",
1624                reactant_smiles[i], e
1625            ));
1626        }
1627    }
1628
1629    let p_confs: Vec<crate::ConformerResult> = product_smiles
1630        .iter()
1631        .map(|s| crate::embed(s, config.seed))
1632        .collect();
1633    for (i, c) in p_confs.iter().enumerate() {
1634        if let Some(ref e) = c.error {
1635            return Err(format!(
1636                "Failed to embed product '{}': {}",
1637                product_smiles[i], e
1638            ));
1639        }
1640    }
1641
1642    // Validate atom conservation
1643    let r_total: usize = r_confs.iter().map(|c| c.num_atoms).sum();
1644    let p_total: usize = p_confs.iter().map(|c| c.num_atoms).sum();
1645    if r_total != p_total {
1646        return Err(format!(
1647            "Atom count mismatch: {} atoms in reactants vs {} in products — \
1648             atoms must be conserved",
1649            r_total, p_total,
1650        ));
1651    }
1652
1653    // ── 2. Build product complex and collect elements ──────────────────
1654    let (p_coords, p_elements) = build_reaction_complex(&p_confs, config.reactive_distance);
1655
1656    // Reactant elements in concatenation order
1657    let r_elements: Vec<u8> = r_confs
1658        .iter()
1659        .flat_map(|c| c.elements.iter().copied())
1660        .collect();
1661
1662    // ── 3. Atom mapping: greedy by element + distance ──────────────────
1663    //  We need a temporary reactant complex just for the distance mapping.
1664    //  Use product-agnostic placement first, then rebuild with guidance.
1665    let (r_coords_tmp, _) = build_reaction_complex(&r_confs, config.reactive_distance);
1666    let mapping = map_atoms_greedy(&r_elements, &r_coords_tmp, &p_elements, &p_coords)?;
1667    let p_reordered = reorder_coords(&p_coords, &mapping);
1668
1669    // ── 4. Build reactant complex guided by product geometry ───────────
1670    let r_coords = build_product_guided_reactant_complex(&r_confs, &p_reordered);
1671
1672    // ── 5. Combined SMILES (dot-separated reactants) for NEB topology ──
1673    let combined_smiles = reactant_smiles.join(".");
1674
1675    // ── 6. NEB between product-guided reactant complex and product ─────
1676    let neb = compute_simplified_neb_path_configurable(
1677        &combined_smiles,
1678        &r_coords,
1679        &p_reordered,
1680        config.n_neb_images,
1681        config.neb_iterations,
1682        config.spring_k,
1683        config.step_size,
1684        method,
1685    )?;
1686
1687    // ── 7. Compute molecule ranges for approach sliding ────────────────
1688    let r_mol_ranges: Vec<(usize, usize)> = {
1689        let mut ranges = Vec::new();
1690        let mut off = 0usize;
1691        for c in &r_confs {
1692            ranges.push((off, off + c.num_atoms));
1693            off += c.num_atoms;
1694        }
1695        ranges
1696    };
1697
1698    // ── 8. Build approach frames with energies ─────────────────────────
1699    let mut frames = Vec::new();
1700    let mol = crate::graph::Molecule::from_smiles(&combined_smiles)?;
1701    let n_xyz = r_total * 3;
1702
1703    let na = config.n_approach_frames;
1704    for i in 0..na {
1705        let alpha = if na > 1 {
1706            i as f64 / (na - 1) as f64
1707        } else {
1708            1.0
1709        };
1710        let coords = slide_molecules(&r_coords, &r_mol_ranges, alpha, config.far_distance);
1711        let mut grad = vec![0.0; n_xyz];
1712        let energy = neb_energy_and_gradient(
1713            backend,
1714            &combined_smiles,
1715            &r_elements,
1716            &mol,
1717            &coords,
1718            &mut grad,
1719        )
1720        .unwrap_or(0.0);
1721        frames.push(ReactionDynamicsFrame {
1722            index: frames.len(),
1723            coords,
1724            energy_kcal_mol: energy,
1725            phase: "approach".into(),
1726        });
1727    }
1728
1729    // ── 9. Reaction frames from NEB ────────────────────────────────────
1730    for img in &neb.images {
1731        frames.push(ReactionDynamicsFrame {
1732            index: frames.len(),
1733            coords: img.coords.clone(),
1734            energy_kcal_mol: img.potential_energy_kcal_mol,
1735            phase: "reaction".into(),
1736        });
1737    }
1738
1739    // ── 10. Departure frames with energies ─────────────────────────────
1740    //   If all products form a single molecule (e.g. C2H2 + H2 → C2H4),
1741    //   the departure phase holds the product geometry steady — no sliding.
1742    //   When there are multiple product molecules, slide them apart.
1743    let nd = config.n_departure_frames;
1744    let single_product = p_confs.len() == 1;
1745
1746    // Product molecule ranges (after reordering into reactant atom order)
1747    let p_mol_ranges: Vec<(usize, usize)> = if single_product {
1748        vec![(0, r_total)]
1749    } else {
1750        let p_mol_assign: Vec<usize> = {
1751            let mut assign = vec![0usize; p_total];
1752            let mut off = 0usize;
1753            for (m, c) in p_confs.iter().enumerate() {
1754                for a in off..off + c.num_atoms {
1755                    assign[a] = m;
1756                }
1757                off += c.num_atoms;
1758            }
1759            mapping.iter().map(|&pi| assign[pi]).collect()
1760        };
1761        let n_mols = p_confs.len();
1762        let mut ranges = Vec::with_capacity(n_mols);
1763        for m in 0..n_mols {
1764            let atoms: Vec<usize> = p_mol_assign
1765                .iter()
1766                .enumerate()
1767                .filter(|(_, &a)| a == m)
1768                .map(|(i, _)| i)
1769                .collect();
1770            if let (Some(&lo), Some(&hi)) = (atoms.first(), atoms.last()) {
1771                ranges.push((lo, hi + 1));
1772            }
1773        }
1774        ranges
1775    };
1776
1777    for i in 0..nd {
1778        let alpha = if nd > 1 {
1779            1.0 - i as f64 / (nd - 1) as f64
1780        } else {
1781            0.0
1782        };
1783        let coords = if single_product {
1784            // Single product — no separation; hold the product geometry
1785            p_reordered.clone()
1786        } else {
1787            slide_molecules(&p_reordered, &p_mol_ranges, alpha, config.far_distance)
1788        };
1789        let mut grad = vec![0.0; n_xyz];
1790        let energy = neb_energy_and_gradient(
1791            backend,
1792            &combined_smiles,
1793            &r_elements,
1794            &mol,
1795            &coords,
1796            &mut grad,
1797        )
1798        .unwrap_or(0.0);
1799        frames.push(ReactionDynamicsFrame {
1800            index: frames.len(),
1801            coords,
1802            energy_kcal_mol: energy,
1803            phase: "departure".into(),
1804        });
1805    }
1806
1807    // ── 11. Find TS and compute energetics ─────────────────────────────
1808    let ts_idx = frames
1809        .iter()
1810        .enumerate()
1811        .max_by(|(_, a), (_, b)| {
1812            a.energy_kcal_mol
1813                .partial_cmp(&b.energy_kcal_mol)
1814                .unwrap_or(std::cmp::Ordering::Equal)
1815        })
1816        .map(|(i, _)| i)
1817        .unwrap_or(0);
1818
1819    let e_first = frames.first().map(|f| f.energy_kcal_mol).unwrap_or(0.0);
1820    let e_last = frames.last().map(|f| f.energy_kcal_mol).unwrap_or(0.0);
1821    let e_ts = frames[ts_idx].energy_kcal_mol;
1822
1823    let mut notes = neb.notes;
1824    notes.push(format!(
1825        "Full reaction path: {} approach + {} NEB + {} departure = {} total frames.",
1826        na,
1827        neb.images.len(),
1828        nd,
1829        frames.len(),
1830    ));
1831    notes.push(
1832        "Reactant complex oriented using product geometry for chemically meaningful approach."
1833            .to_string(),
1834    );
1835
1836    Ok(ReactionDynamicsResult {
1837        frames,
1838        elements: r_elements,
1839        ts_frame_index: ts_idx,
1840        activation_energy_kcal_mol: e_ts - e_first,
1841        reaction_energy_kcal_mol: e_last - e_first,
1842        method: method.to_string(),
1843        n_atoms: r_total,
1844        notes,
1845    })
1846}