1use crate::distgeom::{
2 calculate_bounds_matrix_opts, check_chiral_centers, check_double_bond_geometry,
3 check_tetrahedral_centers, compute_initial_coords_rdkit, identify_chiral_sets,
4 identify_tetrahedral_centers, pick_rdkit_distances, triangle_smooth_tol, MinstdRand,
5 MAX_MINIMIZED_E_PER_ATOM,
6};
7use crate::forcefield::bounds_ff::minimize_bfgs_rdkit;
8use crate::forcefield::etkdg_3d::{build_etkdg_3d_ff_with_torsions, minimize_etkdg_3d_bfgs};
9use crate::graph::Molecule;
10use nalgebra::DMatrix;
26
27const BASIN_THRESH: f32 = 5.0;
28const FORCE_TOL: f32 = 1e-3;
29const PLANARITY_TOLERANCE: f32 = 1.0;
30const ERROR_TOL: f64 = 1e-5; #[derive(Debug, Clone)]
34pub struct DistanceRestraint {
35 pub atom_i: usize,
37 pub atom_j: usize,
39 pub target_distance: f64,
41 pub force_constant: f64,
43}
44
45fn apply_restraints_to_bounds(bounds: &mut DMatrix<f64>, restraints: &[DistanceRestraint]) {
47 for r in restraints {
48 let i = r.atom_i;
49 let j = r.atom_j;
50 let (lo, hi) = if i > j { (i, j) } else { (j, i) };
51 let window = (0.1 / r.force_constant.sqrt()).max(0.05);
53 let new_lb = (r.target_distance - window).max(0.0);
54 let new_ub = r.target_distance + window;
55 if new_lb > bounds[(lo, hi)] {
57 bounds[(lo, hi)] = new_lb;
58 }
59 if new_ub < bounds[(hi, lo)] || bounds[(hi, lo)] == 0.0 {
60 bounds[(hi, lo)] = new_ub;
61 }
62 }
63}
64
65fn build_trivial_conformer(bounds: &DMatrix<f64>) -> DMatrix<f32> {
66 let n = bounds.nrows();
67 let mut coords = DMatrix::from_element(n, 3, 0.0f32);
68
69 if n == 2 {
70 let lower = bounds[(1, 0)].max(0.0);
71 let upper = bounds[(0, 1)].max(lower);
72 let distance = if upper > 0.0 {
73 if lower > 0.0 {
74 0.5 * (lower + upper)
75 } else {
76 upper
77 }
78 } else if lower > 0.0 {
79 lower
80 } else {
81 1.0
82 } as f32;
83 let half_distance = 0.5 * distance;
84 coords[(0, 0)] = -half_distance;
85 coords[(1, 0)] = half_distance;
86 }
87
88 coords
89}
90
91pub fn generate_3d_conformer_restrained(
96 mol: &Molecule,
97 seed: u64,
98 restraints: &[DistanceRestraint],
99) -> Result<DMatrix<f32>, String> {
100 let n = mol.graph.node_count();
101 if n == 0 {
102 return Err("Empty molecule".to_string());
103 }
104
105 let csd_torsions = crate::smarts::match_experimental_torsions(mol);
106
107 let bounds = {
109 let raw = calculate_bounds_matrix_opts(mol, true);
110 let mut b = raw;
111 apply_restraints_to_bounds(&mut b, restraints);
112 if triangle_smooth_tol(&mut b, 0.0) {
113 b
114 } else {
115 let raw2 = calculate_bounds_matrix_opts(mol, false);
116 let mut b2 = raw2.clone();
117 apply_restraints_to_bounds(&mut b2, restraints);
118 if triangle_smooth_tol(&mut b2, 0.0) {
119 b2
120 } else {
121 let mut b3 = raw2;
122 apply_restraints_to_bounds(&mut b3, restraints);
123 triangle_smooth_tol(&mut b3, 0.05);
124 b3
125 }
126 }
127 };
128
129 if n <= 2 {
130 return Ok(build_trivial_conformer(&bounds));
131 }
132
133 let chiral_sets = identify_chiral_sets(mol);
135 let tetrahedral_centers = identify_tetrahedral_centers(mol);
136 let use_4d = !chiral_sets.is_empty();
137 let embed_dim = if use_4d { 4 } else { 3 };
138 let max_iterations = 10 * n;
139 let mut rng = MinstdRand::new(seed as u32);
140 let mut consecutive_embed_fails = 0u32;
141 let embed_fail_threshold = if n > 100 {
142 (n as u32 / 8).max(10)
143 } else {
144 (n as u32 / 4).max(20)
145 };
146 let mut random_coord_attempts = 0u32;
147 let max_random_coord_attempts = if n > 100 { 80u32 } else { 150u32 };
148 let bfgs_restart_limit = if n > 100 { 20 } else { 50 };
149 let mut energy_check_failures = 0u32;
150 let energy_relax_threshold = (max_iterations as f64 * 0.3) as u32;
151
152 for _iter in 0..max_iterations {
153 let use_random_coords = consecutive_embed_fails >= embed_fail_threshold
154 && random_coord_attempts < max_random_coord_attempts;
155 let (mut coords, basin_thresh) = if use_random_coords {
156 random_coord_attempts += 1;
157 let box_size = 2.0 * (n as f64).cbrt().max(2.5);
158 let mut c = DMatrix::from_element(n, embed_dim, 0.0f64);
159 for i in 0..n {
160 for d in 0..embed_dim {
161 c[(i, d)] = box_size * (rng.next_double() - 0.5);
162 }
163 }
164 (c, 1e8f64)
165 } else {
166 let dists = pick_rdkit_distances(&mut rng, &bounds);
167 match compute_initial_coords_rdkit(&mut rng, &dists, embed_dim) {
168 Some(c) => {
169 consecutive_embed_fails = 0;
170 (c, BASIN_THRESH as f64)
171 }
172 None => {
173 consecutive_embed_fails += 1;
174 continue;
175 }
176 }
177 };
178
179 {
181 let bt = basin_thresh as f32;
182 let initial_energy =
183 compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 0.1, 1.0);
184 if initial_energy > ERROR_TOL {
185 let mut need_more = 1;
186 let mut restarts = 0;
187 while need_more != 0 && restarts < bfgs_restart_limit {
188 need_more = minimize_bfgs_rdkit(
189 &mut coords,
190 &bounds,
191 &chiral_sets,
192 400,
193 FORCE_TOL as f64,
194 bt,
195 0.1,
196 1.0,
197 );
198 restarts += 1;
199 }
200 }
201 }
202
203 let bt = basin_thresh as f32;
204 let energy = compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 0.1, 1.0);
205 let effective_e_thresh = if energy_check_failures >= energy_relax_threshold {
206 MAX_MINIMIZED_E_PER_ATOM as f64 * 2.5
207 } else {
208 MAX_MINIMIZED_E_PER_ATOM as f64
209 };
210 if energy / n as f64 >= effective_e_thresh {
211 energy_check_failures += 1;
212 continue;
213 }
214 if !check_tetrahedral_centers(&coords, &tetrahedral_centers) {
215 continue;
216 }
217 if !chiral_sets.is_empty() && !check_chiral_centers(&coords, &chiral_sets) {
218 continue;
219 }
220
221 if use_4d {
222 let energy2 =
223 compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 1.0, 0.2);
224 if energy2 > ERROR_TOL {
225 let mut need_more = 1;
226 let mut restarts = 0;
227 while need_more != 0 && restarts < bfgs_restart_limit {
228 need_more = minimize_bfgs_rdkit(
229 &mut coords,
230 &bounds,
231 &chiral_sets,
232 200,
233 FORCE_TOL as f64,
234 bt,
235 1.0,
236 0.2,
237 );
238 restarts += 1;
239 }
240 }
241 }
242
243 let coords3d = coords.columns(0, 3).into_owned();
244 let ff = build_etkdg_3d_ff_with_torsions(mol, &coords3d, &bounds, &csd_torsions);
245 let e3d = crate::forcefield::etkdg_3d::etkdg_3d_energy_f64(
246 &{
247 let mut flat = vec![0.0f64; n * 3];
248 for a in 0..n {
249 flat[a * 3] = coords3d[(a, 0)];
250 flat[a * 3 + 1] = coords3d[(a, 1)];
251 flat[a * 3 + 2] = coords3d[(a, 2)];
252 }
253 flat
254 },
255 n,
256 mol,
257 &ff,
258 );
259 let refined = if e3d > ERROR_TOL {
260 minimize_etkdg_3d_bfgs(mol, &coords3d, &ff, 300, FORCE_TOL)
261 } else {
262 coords3d
263 };
264
265 {
266 let n_improper_atoms = ff.inversion_contribs.len() / 3;
267 let flat_f64: Vec<f64> = {
268 let nr = refined.nrows();
269 let mut flat = vec![0.0f64; nr * 3];
270 for a in 0..nr {
271 flat[a * 3] = refined[(a, 0)];
272 flat[a * 3 + 1] = refined[(a, 1)];
273 flat[a * 3 + 2] = refined[(a, 2)];
274 }
275 flat
276 };
277 let planarity_energy =
278 crate::forcefield::etkdg_3d::planarity_check_energy_f64(&flat_f64, n, &ff);
279 if planarity_energy > n_improper_atoms as f64 * PLANARITY_TOLERANCE as f64 {
280 continue;
281 }
282 }
283
284 if !check_double_bond_geometry(mol, &refined) {
285 continue;
286 }
287 return Ok(refined.map(|v| v as f32));
288 }
289
290 Err(format!(
291 "Failed to generate restrained conformer after {} iterations",
292 max_iterations
293 ))
294}
295
296pub fn generate_3d_conformer_from_smiles(smiles: &str, seed: u64) -> Result<DMatrix<f32>, String> {
298 let mol = Molecule::from_smiles(smiles)?;
299 generate_3d_conformer(&mol, seed)
300}
301
302pub fn generate_3d_conformer(mol: &Molecule, seed: u64) -> Result<DMatrix<f32>, String> {
307 let csd_torsions = crate::smarts::match_experimental_torsions(mol);
308 generate_3d_conformer_with_torsions(mol, seed, &csd_torsions)
309}
310
311pub fn generate_3d_conformer_best_of_k(
314 mol: &Molecule,
315 seed: u64,
316 csd_torsions: &[crate::forcefield::etkdg_3d::M6TorsionContrib],
317 num_seeds: usize,
318) -> Result<DMatrix<f32>, String> {
319 if num_seeds <= 1 {
320 return generate_3d_conformer_with_torsions(mol, seed, csd_torsions);
321 }
322
323 let mut best: Option<(DMatrix<f32>, f64)> = None;
324 let mut last_err = String::new();
325
326 let bounds = {
328 let raw = calculate_bounds_matrix_opts(mol, true);
329 let mut b = raw;
330 if triangle_smooth_tol(&mut b, 0.0) {
331 b
332 } else {
333 let raw2 = calculate_bounds_matrix_opts(mol, false);
334 let mut b2 = raw2.clone();
335 if triangle_smooth_tol(&mut b2, 0.0) {
336 b2
337 } else {
338 let mut b3 = raw2;
339 triangle_smooth_tol(&mut b3, 0.05);
340 b3
341 }
342 }
343 };
344
345 for k in 0..num_seeds {
346 let s = seed.wrapping_add(k as u64 * 1000);
347 match generate_3d_conformer_with_torsions(mol, s, csd_torsions) {
348 Ok(coords) => {
349 let n = mol.graph.node_count();
351 let coords_f64 = coords.map(|v| v as f64);
352 let ff = build_etkdg_3d_ff_with_torsions(mol, &coords_f64, &bounds, csd_torsions);
353 let mut flat = vec![0.0f64; n * 3];
354 for a in 0..n {
355 flat[a * 3] = coords[(a, 0)] as f64;
356 flat[a * 3 + 1] = coords[(a, 1)] as f64;
357 flat[a * 3 + 2] = coords[(a, 2)] as f64;
358 }
359 let energy = crate::forcefield::etkdg_3d::etkdg_3d_energy_f64(&flat, n, mol, &ff);
360 match &best {
361 Some((_, best_e)) if energy >= *best_e => {}
362 _ => {
363 best = Some((coords, energy));
364 }
365 }
366 }
367 Err(e) => {
368 last_err = e;
369 }
370 }
371 }
372
373 match best {
374 Some((coords, _)) => Ok(coords),
375 None => Err(last_err),
376 }
377}
378
379pub fn generate_3d_conformer_with_torsions(
384 mol: &Molecule,
385 seed: u64,
386 csd_torsions: &[crate::forcefield::etkdg_3d::M6TorsionContrib],
387) -> Result<DMatrix<f32>, String> {
388 let n = mol.graph.node_count();
389 if n == 0 {
390 return Err("Empty molecule".to_string());
391 }
392
393 let bounds = {
396 let raw = calculate_bounds_matrix_opts(mol, true);
397 let mut b = raw;
398 if triangle_smooth_tol(&mut b, 0.0) {
399 b
400 } else {
401 #[cfg(test)]
402 eprintln!(" [FALLBACK] strict smoothing failed, retrying without set15");
403 let raw2 = calculate_bounds_matrix_opts(mol, false);
404 let mut b2 = raw2.clone();
405 if triangle_smooth_tol(&mut b2, 0.0) {
406 b2
407 } else {
408 #[cfg(test)]
409 eprintln!(" [FALLBACK] second smoothing also failed, using soft smooth");
410 let mut b3 = raw2;
411 triangle_smooth_tol(&mut b3, 0.05);
412 b3
413 }
414 }
415 };
416
417 if n <= 2 {
418 return Ok(build_trivial_conformer(&bounds));
419 }
420
421 let chiral_sets = identify_chiral_sets(mol);
422 let tetrahedral_centers = identify_tetrahedral_centers(mol);
423
424 let max_iterations = 10 * n;
425 let mut rng = MinstdRand::new(seed as u32);
426
427 let use_4d = !chiral_sets.is_empty();
430 let embed_dim = if use_4d { 4 } else { 3 };
431
432 let mut consecutive_embed_fails = 0u32;
436 let embed_fail_threshold = if n > 100 {
437 (n as u32 / 8).max(10)
438 } else {
439 (n as u32 / 4).max(20)
440 };
441 let mut random_coord_attempts = 0u32;
442 let max_random_coord_attempts = if n > 100 { 80u32 } else { 150u32 };
443 let bfgs_restart_limit = if n > 100 { 20 } else { 50 };
445
446 let mut energy_check_failures = 0u32;
448 let energy_relax_threshold = (max_iterations as f64 * 0.3) as u32;
450
451 for _iter in 0..max_iterations {
452 let _log_attempts = std::env::var("LOG_ATTEMPTS").is_ok();
454
455 let use_random_coords = consecutive_embed_fails >= embed_fail_threshold
457 && random_coord_attempts < max_random_coord_attempts;
458 let (mut coords, basin_thresh) = if use_random_coords {
459 random_coord_attempts += 1;
460 let box_size = 2.0 * (n as f64).cbrt().max(2.5);
462 let mut c = DMatrix::from_element(n, embed_dim, 0.0f64);
463 for i in 0..n {
464 for d in 0..embed_dim {
465 c[(i, d)] = box_size * (rng.next_double() - 0.5);
466 }
467 }
468 (c, 1e8f64)
470 } else {
471 let dists = pick_rdkit_distances(&mut rng, &bounds);
472 let coords_opt = compute_initial_coords_rdkit(&mut rng, &dists, embed_dim);
473 match coords_opt {
474 Some(c) => {
475 consecutive_embed_fails = 0;
476 (c, BASIN_THRESH as f64)
477 }
478 None => {
479 consecutive_embed_fails += 1;
480 if random_coord_attempts >= max_random_coord_attempts {
482 break;
483 }
484 if _log_attempts && consecutive_embed_fails == embed_fail_threshold {
485 eprintln!(
486 " attempt {} → switching to random coords after {} failures",
487 _iter, embed_fail_threshold
488 );
489 } else if _log_attempts {
490 eprintln!(" attempt {} → embedding failed", _iter);
491 }
492 continue;
493 }
494 }
495 };
496
497 {
501 let bt = basin_thresh as f32;
502 let initial_energy =
503 compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 0.1, 1.0);
504 if initial_energy > ERROR_TOL {
505 let mut need_more = 1;
506 let mut restarts = 0;
507 while need_more != 0 && restarts < bfgs_restart_limit {
508 need_more = minimize_bfgs_rdkit(
509 &mut coords,
510 &bounds,
511 &chiral_sets,
512 400,
513 FORCE_TOL as f64,
514 bt,
515 0.1,
516 1.0,
517 );
518 restarts += 1;
519 }
520 }
521 }
522
523 let bt = basin_thresh as f32;
526 let energy = compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 0.1, 1.0);
527 let effective_e_thresh = if energy_check_failures >= energy_relax_threshold {
528 MAX_MINIMIZED_E_PER_ATOM as f64 * 2.5 } else {
530 MAX_MINIMIZED_E_PER_ATOM as f64
531 };
532 if energy / n as f64 >= effective_e_thresh {
533 energy_check_failures += 1;
534 if _log_attempts {
535 eprintln!(
536 " attempt {} → energy check failed: {:.6}/atom",
537 _iter,
538 energy / n as f64
539 );
540 }
541 continue;
542 }
543
544 if !check_tetrahedral_centers(&coords, &tetrahedral_centers) {
546 if _log_attempts {
547 eprintln!(" attempt {} → tetrahedral check failed", _iter);
548 }
549 continue;
550 }
551
552 if !chiral_sets.is_empty() && !check_chiral_centers(&coords, &chiral_sets) {
554 if _log_attempts {
555 eprintln!(" attempt {} → chiral check failed", _iter);
556 }
557 continue;
558 }
559
560 if use_4d {
563 let energy2 =
564 compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 1.0, 0.2);
565 if energy2 > ERROR_TOL {
566 let mut need_more = 1;
567 let mut restarts = 0;
568 while need_more != 0 && restarts < bfgs_restart_limit {
569 need_more = minimize_bfgs_rdkit(
570 &mut coords,
571 &bounds,
572 &chiral_sets,
573 200,
574 FORCE_TOL as f64,
575 bt,
576 1.0,
577 0.2,
578 );
579 restarts += 1;
580 }
581 }
582 }
583
584 let coords3d = coords.columns(0, 3).into_owned();
586
587 let ff = build_etkdg_3d_ff_with_torsions(mol, &coords3d, &bounds, csd_torsions);
590 let e3d = crate::forcefield::etkdg_3d::etkdg_3d_energy_f64(
592 &{
593 let n = mol.graph.node_count();
594 let mut flat = vec![0.0f64; n * 3];
595 for a in 0..n {
596 flat[a * 3] = coords3d[(a, 0)];
597 flat[a * 3 + 1] = coords3d[(a, 1)];
598 flat[a * 3 + 2] = coords3d[(a, 2)];
599 }
600 flat
601 },
602 mol.graph.node_count(),
603 mol,
604 &ff,
605 );
606 let refined = if e3d > ERROR_TOL {
607 minimize_etkdg_3d_bfgs(mol, &coords3d, &ff, 300, FORCE_TOL)
608 } else {
609 coords3d
610 };
611
612 {
615 let n_improper_atoms = ff.inversion_contribs.len() / 3;
616 let flat_f64: Vec<f64> = {
617 let nr = refined.nrows();
618 let mut flat = vec![0.0f64; nr * 3];
619 for a in 0..nr {
620 flat[a * 3] = refined[(a, 0)];
621 flat[a * 3 + 1] = refined[(a, 1)];
622 flat[a * 3 + 2] = refined[(a, 2)];
623 }
624 flat
625 };
626 let planarity_energy =
627 crate::forcefield::etkdg_3d::planarity_check_energy_f64(&flat_f64, n, &ff);
628 if planarity_energy > n_improper_atoms as f64 * PLANARITY_TOLERANCE as f64 {
629 if _log_attempts {
630 eprintln!(
631 " attempt {} → planarity check failed (energy={:.4} > threshold={:.4})",
632 _iter,
633 planarity_energy,
634 n_improper_atoms as f64 * PLANARITY_TOLERANCE as f64
635 );
636 }
637 continue;
638 }
639 }
640
641 if !check_double_bond_geometry(mol, &refined) {
643 if _log_attempts {
644 eprintln!(" attempt {} → double bond check failed", _iter);
645 }
646 continue;
647 }
648
649 if _log_attempts {
650 eprintln!(" attempt {} → SUCCESS", _iter);
651 }
652
653 let refined_f32 = refined.map(|v| v as f32);
654 return Ok(refined_f32);
655 }
656
657 Err(format!(
658 "Failed to generate valid conformer after {} iterations",
659 max_iterations
660 ))
661}
662
663pub fn compute_total_bounds_energy_f64(
666 coords: &DMatrix<f64>,
667 bounds: &DMatrix<f64>,
668 chiral_sets: &[crate::forcefield::bounds_ff::ChiralSet],
669 basin_thresh: f32,
670 weight_4d: f32,
671 weight_chiral: f32,
672) -> f64 {
673 let n = coords.nrows();
674 let dim_coords = coords.ncols();
675 let basin_thresh_f64 = basin_thresh as f64;
676 let weight_4d_f64 = weight_4d as f64;
677 let weight_chiral_f64 = weight_chiral as f64;
678
679 let mut energy = 0.0f64;
680 for i in 1..n {
681 for j in 0..i {
682 let ub = bounds[(j, i)];
683 let lb = bounds[(i, j)];
684 if ub - lb > basin_thresh_f64 {
685 continue;
686 }
687 let mut d2 = 0.0f64;
688 for d in 0..dim_coords {
689 let diff = coords[(i, d)] - coords[(j, d)];
690 d2 += diff * diff;
691 }
692 let ub2 = ub * ub;
693 let lb2 = lb * lb;
694 let val = if d2 > ub2 {
695 d2 / ub2 - 1.0
696 } else if d2 < lb2 {
697 2.0 * lb2 / (lb2 + d2) - 1.0
698 } else {
699 0.0
700 };
701 if val > 0.0 {
702 energy += val * val;
703 }
704 }
705 }
706 if !chiral_sets.is_empty() {
707 let mut flat = vec![0.0f64; n * dim_coords];
709 for i in 0..n {
710 for d in 0..dim_coords {
711 flat[i * dim_coords + d] = coords[(i, d)];
712 }
713 }
714 energy += weight_chiral_f64
715 * crate::forcefield::bounds_ff::chiral_violation_energy_f64(
716 &flat,
717 dim_coords,
718 chiral_sets,
719 );
720 }
721 if dim_coords == 4 {
722 for i in 0..n {
723 let x4 = coords[(i, 3)];
724 energy += weight_4d_f64 * x4 * x4;
725 }
726 }
727 energy
728}
729
730#[cfg(test)]
731mod tests {
732 #[test]
733 fn embed_handles_hydrogen_halides() {
734 for smiles in ["F", "Cl"] {
735 let result = crate::embed(smiles, 42);
736 assert!(
737 result.error.is_none(),
738 "{smiles} embed failed: {:?}",
739 result.error
740 );
741 assert_eq!(result.num_atoms, 2, "{smiles} should expand to a diatomic");
742 assert_eq!(
743 result.coords.len(),
744 6,
745 "{smiles} should return 2 x 3 coordinates"
746 );
747
748 let dx = result.coords[3] - result.coords[0];
749 let dy = result.coords[4] - result.coords[1];
750 let dz = result.coords[5] - result.coords[2];
751 let distance = (dx * dx + dy * dy + dz * dz).sqrt();
752 assert!(distance > 0.5, "{smiles} distance should be positive");
753 }
754 }
755}