Skip to main content

sci_form/
conformer.rs

1use crate::distgeom::{
2    calculate_bounds_matrix_opts, check_chiral_centers, check_double_bond_geometry,
3    check_tetrahedral_centers, compute_initial_coords_rdkit, identify_chiral_sets,
4    identify_tetrahedral_centers, pick_rdkit_distances, triangle_smooth_tol, MinstdRand,
5    MAX_MINIMIZED_E_PER_ATOM,
6};
7use crate::forcefield::bounds_ff::minimize_bfgs_rdkit;
8use crate::forcefield::etkdg_3d::{build_etkdg_3d_ff_with_torsions, minimize_etkdg_3d_bfgs};
9use crate::graph::Molecule;
10/// End-to-end 3D conformer generation pipeline.
11///
12/// Algorithm matching RDKit's ETKDG embedPoints() with retry-on-failure:
13///   1. Build distance-bounds matrix → Floyd-Warshall smoothing
14///   2. Identify chiral sets + tetrahedral centers (for validation)
15///   3. Retry loop (up to 10×N attempts):
16///      a. Sample random distances → metric matrix → 3D/4D embedding (4D only if chiral)
17///      b. First minimization: bounds FF (chiral_w=1.0, 4d_w=0.1, basin=5.0), loop until converged
18///      c. Energy/atom check: reject if energy/N ≥ 0.05
19///      d. Tetrahedral center volume check
20///      e. Chiral center sign check
21///      f. Second minimization: bounds FF (chiral_w=0.2, 4d_w=1.0), loop until converged
22///      g. Drop to 3D → ETKDG 3D FF minimization (300 iters, single pass)
23///      h. Planarity check (OOP energy)
24///      i. Double bond geometry check
25use nalgebra::DMatrix;
26
27const BASIN_THRESH: f32 = 5.0;
28const FORCE_TOL: f32 = 1e-3;
29const PLANARITY_TOLERANCE: f32 = 0.7;
30const ERROR_TOL: f64 = 1e-5; // RDKit's ERROR_TOL for energy pre-check
31
32/// Generate a 3D conformer from a SMILES string.
33pub fn generate_3d_conformer_from_smiles(smiles: &str, seed: u64) -> Result<DMatrix<f32>, String> {
34    let mol = Molecule::from_smiles(smiles)?;
35    generate_3d_conformer(&mol, seed)
36}
37
38/// Generate a 3D conformer for an already-parsed `Molecule`.
39///
40/// Implements RDKit's embedPoints() retry-on-failure loop.
41/// Returns the first valid 3D conformer, or an error if all attempts fail.
42pub fn generate_3d_conformer(mol: &Molecule, seed: u64) -> Result<DMatrix<f32>, String> {
43    let csd_torsions = crate::smarts::match_experimental_torsions(mol);
44    generate_3d_conformer_with_torsions(mol, seed, &csd_torsions)
45}
46
47/// Generate multiple conformers with different seeds and return the one
48/// with the lowest ETKDG 3D force field energy (best geometry).
49pub fn generate_3d_conformer_best_of_k(
50    mol: &Molecule,
51    seed: u64,
52    csd_torsions: &[crate::forcefield::etkdg_3d::M6TorsionContrib],
53    num_seeds: usize,
54) -> Result<DMatrix<f32>, String> {
55    if num_seeds <= 1 {
56        return generate_3d_conformer_with_torsions(mol, seed, csd_torsions);
57    }
58
59    let mut best: Option<(DMatrix<f32>, f64)> = None;
60    let mut last_err = String::new();
61
62    // Pre-compute bounds + FF scaffold once (topology-dependent, not seed-dependent)
63    let bounds = {
64        let raw = calculate_bounds_matrix_opts(mol, true);
65        let mut b = raw;
66        if triangle_smooth_tol(&mut b, 0.0) {
67            b
68        } else {
69            let raw2 = calculate_bounds_matrix_opts(mol, false);
70            let mut b2 = raw2.clone();
71            if triangle_smooth_tol(&mut b2, 0.0) {
72                b2
73            } else {
74                let mut b3 = raw2;
75                triangle_smooth_tol(&mut b3, 0.05);
76                b3
77            }
78        }
79    };
80
81    for k in 0..num_seeds {
82        let s = seed.wrapping_add(k as u64 * 1000);
83        match generate_3d_conformer_with_torsions(mol, s, csd_torsions) {
84            Ok(coords) => {
85                // Score using ETKDG 3D energy (topology-dependent, comparable across seeds)
86                let n = mol.graph.node_count();
87                let coords_f64 = coords.map(|v| v as f64);
88                let ff = build_etkdg_3d_ff_with_torsions(mol, &coords_f64, &bounds, csd_torsions);
89                let mut flat = vec![0.0f64; n * 3];
90                for a in 0..n {
91                    flat[a * 3] = coords[(a, 0)] as f64;
92                    flat[a * 3 + 1] = coords[(a, 1)] as f64;
93                    flat[a * 3 + 2] = coords[(a, 2)] as f64;
94                }
95                let energy = crate::forcefield::etkdg_3d::etkdg_3d_energy_f64(&flat, n, mol, &ff);
96                match &best {
97                    Some((_, best_e)) if energy >= *best_e => {}
98                    _ => {
99                        best = Some((coords, energy));
100                    }
101                }
102            }
103            Err(e) => {
104                last_err = e;
105            }
106        }
107    }
108
109    match best {
110        Some((coords, _)) => Ok(coords),
111        None => Err(last_err),
112    }
113}
114
115/// Generate a 3D conformer with optional CSD torsion overrides.
116///
117/// If `csd_torsions` is non-empty, they replace the default torsion terms
118/// in the 3D force field, providing much better torsion angle quality.
119pub fn generate_3d_conformer_with_torsions(
120    mol: &Molecule,
121    seed: u64,
122    csd_torsions: &[crate::forcefield::etkdg_3d::M6TorsionContrib],
123) -> Result<DMatrix<f32>, String> {
124    let n = mol.graph.node_count();
125    if n == 0 {
126        return Err("Empty molecule".to_string());
127    }
128
129    // RDKit-style two-pass bounds: try with set15 + strict smoothing,
130    // fall back to no set15 if triangle smoothing fails.
131    let bounds = {
132        let raw = calculate_bounds_matrix_opts(mol, true);
133        let mut b = raw;
134        if triangle_smooth_tol(&mut b, 0.0) {
135            b
136        } else {
137            #[cfg(test)]
138            eprintln!("  [FALLBACK] strict smoothing failed, retrying without set15");
139            let raw2 = calculate_bounds_matrix_opts(mol, false);
140            let mut b2 = raw2.clone();
141            if triangle_smooth_tol(&mut b2, 0.0) {
142                b2
143            } else {
144                #[cfg(test)]
145                eprintln!("  [FALLBACK] second smoothing also failed, using soft smooth");
146                let mut b3 = raw2;
147                triangle_smooth_tol(&mut b3, 0.05);
148                b3
149            }
150        }
151    };
152    let chiral_sets = identify_chiral_sets(mol);
153    let tetrahedral_centers = identify_tetrahedral_centers(mol);
154
155    let max_iterations = 10 * n;
156    let mut rng = MinstdRand::new(seed as u32);
157
158    // RDKit: 4D embedding only when chiral centers (CW/CCW) are present
159    // Otherwise 3D embedding (no 4th dimension overhead)
160    let use_4d = !chiral_sets.is_empty();
161    let embed_dim = if use_4d { 4 } else { 3 };
162
163    // Track consecutive embedding failures for random coord fallback
164    let mut consecutive_embed_fails = 0u32;
165    let embed_fail_threshold = (n as u32 / 4).max(20); // Switch to random coords after N/4 consecutive fails
166    let mut random_coord_attempts = 0u32;
167    let max_random_coord_attempts = 100u32; // Cap random coord attempts to avoid infinite loops
168
169    for _iter in 0..max_iterations {
170        // Log attempt number if requested (works in both lib and integration tests)
171        let _log_attempts = std::env::var("LOG_ATTEMPTS").is_ok();
172
173        // Step 1: Generate initial coords
174        let use_random_coords = consecutive_embed_fails >= embed_fail_threshold
175            && random_coord_attempts < max_random_coord_attempts;
176        let (mut coords, basin_thresh) = if use_random_coords {
177            random_coord_attempts += 1;
178            // Random coordinate fallback: place atoms in [-5, 5] box
179            // Matching RDKit's useRandomCoords mode with boxSizeMult=2.0
180            let box_size = 10.0f64; // 5.0 * 2.0
181            let mut c = DMatrix::from_element(n, embed_dim, 0.0f64);
182            for i in 0..n {
183                for d in 0..embed_dim {
184                    c[(i, d)] = box_size * (rng.next_double() - 0.5);
185                }
186            }
187            // RDKit disables basin threshold for random coords
188            (c, 1e8f64)
189        } else {
190            let dists = pick_rdkit_distances(&mut rng, &bounds);
191            let coords_opt = compute_initial_coords_rdkit(&mut rng, &dists, embed_dim);
192            match coords_opt {
193                Some(c) => {
194                    consecutive_embed_fails = 0;
195                    (c, BASIN_THRESH as f64)
196                }
197                None => {
198                    consecutive_embed_fails += 1;
199                    // If we've exhausted random coord attempts and still failing, give up
200                    if random_coord_attempts >= max_random_coord_attempts {
201                        break;
202                    }
203                    if _log_attempts && consecutive_embed_fails == embed_fail_threshold {
204                        eprintln!(
205                            "  attempt {} → switching to random coords after {} failures",
206                            _iter, embed_fail_threshold
207                        );
208                    } else if _log_attempts {
209                        eprintln!("  attempt {} → embedding failed", _iter);
210                    }
211                    continue;
212                }
213            }
214        };
215
216        // Step 2: First minimization (bounds FF with chiral_w=1.0, 4d_w=0.1)
217        // RDKit: while(needMore) { needMore = field->minimize(400, forceTol); }
218        // Safety limit: max 50 restarts to prevent infinite loops (RDKit typically needs < 10)
219        {
220            let bt = basin_thresh as f32;
221            let initial_energy =
222                compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 0.1, 1.0);
223            if initial_energy > ERROR_TOL {
224                let mut need_more = 1;
225                let mut restarts = 0;
226                while need_more != 0 && restarts < 50 {
227                    need_more = minimize_bfgs_rdkit(
228                        &mut coords,
229                        &bounds,
230                        &chiral_sets,
231                        400,
232                        FORCE_TOL as f64,
233                        bt,
234                        0.1,
235                        1.0,
236                    );
237                    restarts += 1;
238                }
239            }
240        }
241
242        // Step 3: Energy per atom check
243        let bt = basin_thresh as f32;
244        let energy = compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 0.1, 1.0);
245        if energy / n as f64 >= MAX_MINIMIZED_E_PER_ATOM as f64 {
246            if _log_attempts {
247                eprintln!(
248                    "  attempt {} → energy check failed: {:.6}/atom",
249                    _iter,
250                    energy / n as f64
251                );
252            }
253            continue;
254        }
255
256        // Step 4: Check tetrahedral centers (f64 coords matching RDKit's Point3D)
257        if !check_tetrahedral_centers(&coords, &tetrahedral_centers) {
258            if _log_attempts {
259                eprintln!("  attempt {} → tetrahedral check failed", _iter);
260            }
261            continue;
262        }
263
264        // Step 5: Check chiral center volumes (f64 coords)
265        if !chiral_sets.is_empty() && !check_chiral_centers(&coords, &chiral_sets) {
266            if _log_attempts {
267                eprintln!("  attempt {} → chiral check failed", _iter);
268            }
269            continue;
270        }
271
272        // Step 6: Second minimization (chiral_w=0.2, 4d_w=1.0) — only if 4D embedding
273        // RDKit: while(needMore) { needMore = field2->minimize(200, forceTol); }
274        if use_4d {
275            let energy2 =
276                compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 1.0, 0.2);
277            if energy2 > ERROR_TOL {
278                let mut need_more = 1;
279                let mut restarts = 0;
280                while need_more != 0 && restarts < 50 {
281                    need_more = minimize_bfgs_rdkit(
282                        &mut coords,
283                        &bounds,
284                        &chiral_sets,
285                        200,
286                        FORCE_TOL as f64,
287                        bt,
288                        1.0,
289                        0.2,
290                    );
291                    restarts += 1;
292                }
293            }
294        }
295
296        // Step 7: Drop to 3D (no-op for 3D embedding)
297        let coords3d = coords.columns(0, 3).into_owned();
298
299        // Step 8: ETKDG 3D FF minimization — single pass of 300 iterations (matching RDKit)
300        // Build FF using f64 coords for distance computation (matching RDKit's Point3D)
301        let ff = build_etkdg_3d_ff_with_torsions(mol, &coords3d, &bounds, csd_torsions);
302        // RDKit: only minimize if energy > ERROR_TOL
303        let e3d = crate::forcefield::etkdg_3d::etkdg_3d_energy_f64(
304            &{
305                let n = mol.graph.node_count();
306                let mut flat = vec![0.0f64; n * 3];
307                for a in 0..n {
308                    flat[a * 3] = coords3d[(a, 0)];
309                    flat[a * 3 + 1] = coords3d[(a, 1)];
310                    flat[a * 3 + 2] = coords3d[(a, 2)];
311                }
312                flat
313            },
314            mol.graph.node_count(),
315            mol,
316            &ff,
317        );
318        let refined = if e3d > ERROR_TOL {
319            minimize_etkdg_3d_bfgs(mol, &coords3d, &ff, 300, FORCE_TOL)
320        } else {
321            coords3d
322        };
323
324        // Step 9: Planarity check — matching RDKit's construct3DImproperForceField
325        // Uses UFF inversion energy + SP angle constraint energy (k=10.0) in f64
326        {
327            let n_improper_atoms = ff.inversion_contribs.len() / 3;
328            let flat_f64: Vec<f64> = {
329                let nr = refined.nrows();
330                let mut flat = vec![0.0f64; nr * 3];
331                for a in 0..nr {
332                    flat[a * 3] = refined[(a, 0)];
333                    flat[a * 3 + 1] = refined[(a, 1)];
334                    flat[a * 3 + 2] = refined[(a, 2)];
335                }
336                flat
337            };
338            let planarity_energy =
339                crate::forcefield::etkdg_3d::planarity_check_energy_f64(&flat_f64, n, &ff);
340            if planarity_energy > n_improper_atoms as f64 * PLANARITY_TOLERANCE as f64 {
341                if _log_attempts {
342                    eprintln!(
343                        "  attempt {} → planarity check failed (energy={:.4} > threshold={:.4})",
344                        _iter,
345                        planarity_energy,
346                        n_improper_atoms as f64 * PLANARITY_TOLERANCE as f64
347                    );
348                }
349                continue;
350            }
351        }
352
353        // Step 10: Double bond geometry check (f64 coords)
354        if !check_double_bond_geometry(mol, &refined) {
355            if _log_attempts {
356                eprintln!("  attempt {} → double bond check failed", _iter);
357            }
358            continue;
359        }
360
361        if _log_attempts {
362            eprintln!("  attempt {} → SUCCESS", _iter);
363        }
364
365        let refined_f32 = refined.map(|v| v as f32);
366        return Ok(refined_f32);
367    }
368
369    Err(format!(
370        "Failed to generate valid conformer after {} iterations",
371        max_iterations
372    ))
373}
374
375/// Compute bounds FF energy in f64, matching RDKit's field->calcEnergy() exactly.
376/// Used for the energy pre-check before minimization.
377pub fn compute_total_bounds_energy_f64(
378    coords: &DMatrix<f64>,
379    bounds: &DMatrix<f64>,
380    chiral_sets: &[crate::forcefield::bounds_ff::ChiralSet],
381    basin_thresh: f32,
382    weight_4d: f32,
383    weight_chiral: f32,
384) -> f64 {
385    let n = coords.nrows();
386    let dim_coords = coords.ncols();
387    let basin_thresh_f64 = basin_thresh as f64;
388    let weight_4d_f64 = weight_4d as f64;
389    let weight_chiral_f64 = weight_chiral as f64;
390
391    let mut energy = 0.0f64;
392    for i in 1..n {
393        for j in 0..i {
394            let ub = bounds[(j, i)];
395            let lb = bounds[(i, j)];
396            if ub - lb > basin_thresh_f64 {
397                continue;
398            }
399            let mut d2 = 0.0f64;
400            for d in 0..dim_coords {
401                let diff = coords[(i, d)] - coords[(j, d)];
402                d2 += diff * diff;
403            }
404            let ub2 = ub * ub;
405            let lb2 = lb * lb;
406            let val = if d2 > ub2 {
407                d2 / ub2 - 1.0
408            } else if d2 < lb2 {
409                2.0 * lb2 / (lb2 + d2) - 1.0
410            } else {
411                0.0
412            };
413            if val > 0.0 {
414                energy += val * val;
415            }
416        }
417    }
418    if !chiral_sets.is_empty() {
419        // Flatten coords to flat array for f64 chiral energy
420        let mut flat = vec![0.0f64; n * dim_coords];
421        for i in 0..n {
422            for d in 0..dim_coords {
423                flat[i * dim_coords + d] = coords[(i, d)];
424            }
425        }
426        energy += weight_chiral_f64
427            * crate::forcefield::bounds_ff::chiral_violation_energy_f64(
428                &flat,
429                dim_coords,
430                chiral_sets,
431            );
432    }
433    if dim_coords == 4 {
434        for i in 0..n {
435            let x4 = coords[(i, 3)];
436            energy += weight_4d_f64 * x4 * x4;
437        }
438    }
439    energy
440}