Skip to main content

sci_form/
conformer.rs

1use crate::distgeom::{
2    calculate_bounds_matrix_opts, check_chiral_centers, check_double_bond_geometry,
3    check_tetrahedral_centers, compute_initial_coords_rdkit, identify_chiral_sets,
4    identify_tetrahedral_centers, pick_rdkit_distances, triangle_smooth_tol, MinstdRand,
5    MAX_MINIMIZED_E_PER_ATOM,
6};
7use crate::forcefield::bounds_ff::minimize_bfgs_rdkit;
8use crate::forcefield::etkdg_3d::{build_etkdg_3d_ff_with_torsions, minimize_etkdg_3d_bfgs};
9use crate::graph::Molecule;
10/// End-to-end 3D conformer generation pipeline.
11///
12/// Algorithm matching RDKit's ETKDG embedPoints() with retry-on-failure:
13///   1. Build distance-bounds matrix → Floyd-Warshall smoothing
14///   2. Identify chiral sets + tetrahedral centers (for validation)
15///   3. Retry loop (up to 10×N attempts):
16///      a. Sample random distances → metric matrix → 3D/4D embedding (4D only if chiral)
17///      b. First minimization: bounds FF (chiral_w=1.0, 4d_w=0.1, basin=5.0), loop until converged
18///      c. Energy/atom check: reject if energy/N ≥ 0.05
19///      d. Tetrahedral center volume check
20///      e. Chiral center sign check
21///      f. Second minimization: bounds FF (chiral_w=0.2, 4d_w=1.0), loop until converged
22///      g. Drop to 3D → ETKDG 3D FF minimization (300 iters, single pass)
23///      h. Planarity check (OOP energy)
24///      i. Double bond geometry check
25use nalgebra::DMatrix;
26
27const BASIN_THRESH: f32 = 5.0;
28const FORCE_TOL: f32 = 1e-3;
29const PLANARITY_TOLERANCE: f32 = 0.7;
30const ERROR_TOL: f64 = 1e-5; // RDKit's ERROR_TOL for energy pre-check
31
32/// Generate a 3D conformer from a SMILES string.
33pub fn generate_3d_conformer_from_smiles(smiles: &str, seed: u64) -> Result<DMatrix<f32>, String> {
34    let mol = Molecule::from_smiles(smiles)?;
35    generate_3d_conformer(&mol, seed)
36}
37
38/// Generate a 3D conformer for an already-parsed `Molecule`.
39///
40/// Implements RDKit's embedPoints() retry-on-failure loop.
41/// Returns the first valid 3D conformer, or an error if all attempts fail.
42pub fn generate_3d_conformer(mol: &Molecule, seed: u64) -> Result<DMatrix<f32>, String> {
43    let csd_torsions = crate::smarts::match_experimental_torsions(mol);
44    generate_3d_conformer_with_torsions(mol, seed, &csd_torsions)
45}
46
47/// Generate multiple conformers with different seeds and return the one
48/// with the lowest ETKDG 3D force field energy (best geometry).
49pub fn generate_3d_conformer_best_of_k(
50    mol: &Molecule,
51    seed: u64,
52    csd_torsions: &[crate::forcefield::etkdg_3d::M6TorsionContrib],
53    num_seeds: usize,
54) -> Result<DMatrix<f32>, String> {
55    if num_seeds <= 1 {
56        return generate_3d_conformer_with_torsions(mol, seed, csd_torsions);
57    }
58
59    let mut best: Option<(DMatrix<f32>, f64)> = None;
60    let mut last_err = String::new();
61
62    // Pre-compute bounds + FF scaffold once (topology-dependent, not seed-dependent)
63    let bounds = {
64        let raw = calculate_bounds_matrix_opts(mol, true);
65        let mut b = raw;
66        if triangle_smooth_tol(&mut b, 0.0) {
67            b
68        } else {
69            let raw2 = calculate_bounds_matrix_opts(mol, false);
70            let mut b2 = raw2.clone();
71            if triangle_smooth_tol(&mut b2, 0.0) {
72                b2
73            } else {
74                let mut b3 = raw2;
75                triangle_smooth_tol(&mut b3, 0.05);
76                b3
77            }
78        }
79    };
80
81    for k in 0..num_seeds {
82        let s = seed.wrapping_add(k as u64 * 1000);
83        match generate_3d_conformer_with_torsions(mol, s, csd_torsions) {
84            Ok(coords) => {
85                // Score using ETKDG 3D energy (topology-dependent, comparable across seeds)
86                let n = mol.graph.node_count();
87                let coords_f64 = coords.map(|v| v as f64);
88                let ff = build_etkdg_3d_ff_with_torsions(mol, &coords_f64, &bounds, csd_torsions);
89                let mut flat = vec![0.0f64; n * 3];
90                for a in 0..n {
91                    flat[a * 3] = coords[(a, 0)] as f64;
92                    flat[a * 3 + 1] = coords[(a, 1)] as f64;
93                    flat[a * 3 + 2] = coords[(a, 2)] as f64;
94                }
95                let energy = crate::forcefield::etkdg_3d::etkdg_3d_energy_f64(&flat, n, mol, &ff);
96                match &best {
97                    Some((_, best_e)) if energy >= *best_e => {}
98                    _ => {
99                        best = Some((coords, energy));
100                    }
101                }
102            }
103            Err(e) => {
104                last_err = e;
105            }
106        }
107    }
108
109    match best {
110        Some((coords, _)) => Ok(coords),
111        None => Err(last_err),
112    }
113}
114
115/// Generate a 3D conformer with optional CSD torsion overrides.
116///
117/// If `csd_torsions` is non-empty, they replace the default torsion terms
118/// in the 3D force field, providing much better torsion angle quality.
119pub fn generate_3d_conformer_with_torsions(
120    mol: &Molecule,
121    seed: u64,
122    csd_torsions: &[crate::forcefield::etkdg_3d::M6TorsionContrib],
123) -> Result<DMatrix<f32>, String> {
124    let n = mol.graph.node_count();
125    if n == 0 {
126        return Err("Empty molecule".to_string());
127    }
128
129    // RDKit-style two-pass bounds: try with set15 + strict smoothing,
130    // fall back to no set15 if triangle smoothing fails.
131    let bounds = {
132        let raw = calculate_bounds_matrix_opts(mol, true);
133        let mut b = raw;
134        if triangle_smooth_tol(&mut b, 0.0) {
135            b
136        } else {
137            #[cfg(test)]
138            eprintln!("  [FALLBACK] strict smoothing failed, retrying without set15");
139            let raw2 = calculate_bounds_matrix_opts(mol, false);
140            let mut b2 = raw2.clone();
141            if triangle_smooth_tol(&mut b2, 0.0) {
142                b2
143            } else {
144                #[cfg(test)]
145                eprintln!("  [FALLBACK] second smoothing also failed, using soft smooth");
146                let mut b3 = raw2;
147                triangle_smooth_tol(&mut b3, 0.05);
148                b3
149            }
150        }
151    };
152    let chiral_sets = identify_chiral_sets(mol);
153    let tetrahedral_centers = identify_tetrahedral_centers(mol);
154
155    let max_iterations = 10 * n;
156    let mut rng = MinstdRand::new(seed as u32);
157
158    // RDKit: 4D embedding only when chiral centers (CW/CCW) are present
159    // Otherwise 3D embedding (no 4th dimension overhead)
160    let use_4d = !chiral_sets.is_empty();
161    let embed_dim = if use_4d { 4 } else { 3 };
162
163    // Track consecutive embedding failures for random coord fallback
164    // For large molecules (100+ atoms), eigendecomposition is O(N³) and fails more often,
165    // so we switch to random coords sooner. For small molecules, we allow more eigen attempts.
166    let mut consecutive_embed_fails = 0u32;
167    let embed_fail_threshold = if n > 100 {
168        (n as u32 / 8).max(10)
169    } else {
170        (n as u32 / 4).max(20)
171    };
172    let mut random_coord_attempts = 0u32;
173    let max_random_coord_attempts = if n > 100 { 50u32 } else { 100u32 };
174    // Scale BFGS restart limit: large molecules converge with fewer restarts
175    let bfgs_restart_limit = if n > 100 { 20 } else { 50 };
176
177    for _iter in 0..max_iterations {
178        // Log attempt number if requested (works in both lib and integration tests)
179        let _log_attempts = std::env::var("LOG_ATTEMPTS").is_ok();
180
181        // Step 1: Generate initial coords
182        let use_random_coords = consecutive_embed_fails >= embed_fail_threshold
183            && random_coord_attempts < max_random_coord_attempts;
184        let (mut coords, basin_thresh) = if use_random_coords {
185            random_coord_attempts += 1;
186            // Random coordinate fallback: RDKit uses boxSizeMult * cube_root(N)
187            let box_size = 2.0 * (n as f64).cbrt().max(2.5);
188            let mut c = DMatrix::from_element(n, embed_dim, 0.0f64);
189            for i in 0..n {
190                for d in 0..embed_dim {
191                    c[(i, d)] = box_size * (rng.next_double() - 0.5);
192                }
193            }
194            // RDKit disables basin threshold for random coords
195            (c, 1e8f64)
196        } else {
197            let dists = pick_rdkit_distances(&mut rng, &bounds);
198            let coords_opt = compute_initial_coords_rdkit(&mut rng, &dists, embed_dim);
199            match coords_opt {
200                Some(c) => {
201                    consecutive_embed_fails = 0;
202                    (c, BASIN_THRESH as f64)
203                }
204                None => {
205                    consecutive_embed_fails += 1;
206                    // If we've exhausted random coord attempts and still failing, give up
207                    if random_coord_attempts >= max_random_coord_attempts {
208                        break;
209                    }
210                    if _log_attempts && consecutive_embed_fails == embed_fail_threshold {
211                        eprintln!(
212                            "  attempt {} → switching to random coords after {} failures",
213                            _iter, embed_fail_threshold
214                        );
215                    } else if _log_attempts {
216                        eprintln!("  attempt {} → embedding failed", _iter);
217                    }
218                    continue;
219                }
220            }
221        };
222
223        // Step 2: First minimization (bounds FF with chiral_w=1.0, 4d_w=0.1)
224        // RDKit: while(needMore) { needMore = field->minimize(400, forceTol); }
225        // Safety limit: max 50 restarts to prevent infinite loops (RDKit typically needs < 10)
226        {
227            let bt = basin_thresh as f32;
228            let initial_energy =
229                compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 0.1, 1.0);
230            if initial_energy > ERROR_TOL {
231                let mut need_more = 1;
232                let mut restarts = 0;
233                while need_more != 0 && restarts < bfgs_restart_limit {
234                    need_more = minimize_bfgs_rdkit(
235                        &mut coords,
236                        &bounds,
237                        &chiral_sets,
238                        400,
239                        FORCE_TOL as f64,
240                        bt,
241                        0.1,
242                        1.0,
243                    );
244                    restarts += 1;
245                }
246            }
247        }
248
249        // Step 3: Energy per atom check
250        let bt = basin_thresh as f32;
251        let energy = compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 0.1, 1.0);
252        if energy / n as f64 >= MAX_MINIMIZED_E_PER_ATOM as f64 {
253            if _log_attempts {
254                eprintln!(
255                    "  attempt {} → energy check failed: {:.6}/atom",
256                    _iter,
257                    energy / n as f64
258                );
259            }
260            continue;
261        }
262
263        // Step 4: Check tetrahedral centers (f64 coords matching RDKit's Point3D)
264        if !check_tetrahedral_centers(&coords, &tetrahedral_centers) {
265            if _log_attempts {
266                eprintln!("  attempt {} → tetrahedral check failed", _iter);
267            }
268            continue;
269        }
270
271        // Step 5: Check chiral center volumes (f64 coords)
272        if !chiral_sets.is_empty() && !check_chiral_centers(&coords, &chiral_sets) {
273            if _log_attempts {
274                eprintln!("  attempt {} → chiral check failed", _iter);
275            }
276            continue;
277        }
278
279        // Step 6: Second minimization (chiral_w=0.2, 4d_w=1.0) — only if 4D embedding
280        // RDKit: while(needMore) { needMore = field2->minimize(200, forceTol); }
281        if use_4d {
282            let energy2 =
283                compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 1.0, 0.2);
284            if energy2 > ERROR_TOL {
285                let mut need_more = 1;
286                let mut restarts = 0;
287                while need_more != 0 && restarts < bfgs_restart_limit {
288                    need_more = minimize_bfgs_rdkit(
289                        &mut coords,
290                        &bounds,
291                        &chiral_sets,
292                        200,
293                        FORCE_TOL as f64,
294                        bt,
295                        1.0,
296                        0.2,
297                    );
298                    restarts += 1;
299                }
300            }
301        }
302
303        // Step 7: Drop to 3D (no-op for 3D embedding)
304        let coords3d = coords.columns(0, 3).into_owned();
305
306        // Step 8: ETKDG 3D FF minimization — single pass of 300 iterations (matching RDKit)
307        // Build FF using f64 coords for distance computation (matching RDKit's Point3D)
308        let ff = build_etkdg_3d_ff_with_torsions(mol, &coords3d, &bounds, csd_torsions);
309        // RDKit: only minimize if energy > ERROR_TOL
310        let e3d = crate::forcefield::etkdg_3d::etkdg_3d_energy_f64(
311            &{
312                let n = mol.graph.node_count();
313                let mut flat = vec![0.0f64; n * 3];
314                for a in 0..n {
315                    flat[a * 3] = coords3d[(a, 0)];
316                    flat[a * 3 + 1] = coords3d[(a, 1)];
317                    flat[a * 3 + 2] = coords3d[(a, 2)];
318                }
319                flat
320            },
321            mol.graph.node_count(),
322            mol,
323            &ff,
324        );
325        let refined = if e3d > ERROR_TOL {
326            minimize_etkdg_3d_bfgs(mol, &coords3d, &ff, 300, FORCE_TOL)
327        } else {
328            coords3d
329        };
330
331        // Step 9: Planarity check — matching RDKit's construct3DImproperForceField
332        // Uses UFF inversion energy + SP angle constraint energy (k=10.0) in f64
333        {
334            let n_improper_atoms = ff.inversion_contribs.len() / 3;
335            let flat_f64: Vec<f64> = {
336                let nr = refined.nrows();
337                let mut flat = vec![0.0f64; nr * 3];
338                for a in 0..nr {
339                    flat[a * 3] = refined[(a, 0)];
340                    flat[a * 3 + 1] = refined[(a, 1)];
341                    flat[a * 3 + 2] = refined[(a, 2)];
342                }
343                flat
344            };
345            let planarity_energy =
346                crate::forcefield::etkdg_3d::planarity_check_energy_f64(&flat_f64, n, &ff);
347            if planarity_energy > n_improper_atoms as f64 * PLANARITY_TOLERANCE as f64 {
348                if _log_attempts {
349                    eprintln!(
350                        "  attempt {} → planarity check failed (energy={:.4} > threshold={:.4})",
351                        _iter,
352                        planarity_energy,
353                        n_improper_atoms as f64 * PLANARITY_TOLERANCE as f64
354                    );
355                }
356                continue;
357            }
358        }
359
360        // Step 10: Double bond geometry check (f64 coords)
361        if !check_double_bond_geometry(mol, &refined) {
362            if _log_attempts {
363                eprintln!("  attempt {} → double bond check failed", _iter);
364            }
365            continue;
366        }
367
368        if _log_attempts {
369            eprintln!("  attempt {} → SUCCESS", _iter);
370        }
371
372        let refined_f32 = refined.map(|v| v as f32);
373        return Ok(refined_f32);
374    }
375
376    Err(format!(
377        "Failed to generate valid conformer after {} iterations",
378        max_iterations
379    ))
380}
381
382/// Compute bounds FF energy in f64, matching RDKit's field->calcEnergy() exactly.
383/// Used for the energy pre-check before minimization.
384pub fn compute_total_bounds_energy_f64(
385    coords: &DMatrix<f64>,
386    bounds: &DMatrix<f64>,
387    chiral_sets: &[crate::forcefield::bounds_ff::ChiralSet],
388    basin_thresh: f32,
389    weight_4d: f32,
390    weight_chiral: f32,
391) -> f64 {
392    let n = coords.nrows();
393    let dim_coords = coords.ncols();
394    let basin_thresh_f64 = basin_thresh as f64;
395    let weight_4d_f64 = weight_4d as f64;
396    let weight_chiral_f64 = weight_chiral as f64;
397
398    let mut energy = 0.0f64;
399    for i in 1..n {
400        for j in 0..i {
401            let ub = bounds[(j, i)];
402            let lb = bounds[(i, j)];
403            if ub - lb > basin_thresh_f64 {
404                continue;
405            }
406            let mut d2 = 0.0f64;
407            for d in 0..dim_coords {
408                let diff = coords[(i, d)] - coords[(j, d)];
409                d2 += diff * diff;
410            }
411            let ub2 = ub * ub;
412            let lb2 = lb * lb;
413            let val = if d2 > ub2 {
414                d2 / ub2 - 1.0
415            } else if d2 < lb2 {
416                2.0 * lb2 / (lb2 + d2) - 1.0
417            } else {
418                0.0
419            };
420            if val > 0.0 {
421                energy += val * val;
422            }
423        }
424    }
425    if !chiral_sets.is_empty() {
426        // Flatten coords to flat array for f64 chiral energy
427        let mut flat = vec![0.0f64; n * dim_coords];
428        for i in 0..n {
429            for d in 0..dim_coords {
430                flat[i * dim_coords + d] = coords[(i, d)];
431            }
432        }
433        energy += weight_chiral_f64
434            * crate::forcefield::bounds_ff::chiral_violation_energy_f64(
435                &flat,
436                dim_coords,
437                chiral_sets,
438            );
439    }
440    if dim_coords == 4 {
441        for i in 0..n {
442            let x4 = coords[(i, 3)];
443            energy += weight_4d_f64 * x4 * x4;
444        }
445    }
446    energy
447}