1use crate::distgeom::{
2 calculate_bounds_matrix_opts, check_chiral_centers, check_double_bond_geometry,
3 check_tetrahedral_centers, compute_initial_coords_rdkit, identify_chiral_sets,
4 identify_tetrahedral_centers, pick_rdkit_distances, triangle_smooth_tol, MinstdRand,
5 MAX_MINIMIZED_E_PER_ATOM,
6};
7use crate::forcefield::bounds_ff::minimize_bfgs_rdkit;
8use crate::forcefield::etkdg_3d::{build_etkdg_3d_ff_with_torsions, minimize_etkdg_3d_bfgs};
9use crate::graph::Molecule;
10use nalgebra::DMatrix;
26
27const BASIN_THRESH: f32 = 5.0;
28const FORCE_TOL: f32 = 1e-3;
29const PLANARITY_TOLERANCE: f32 = 0.7;
30const ERROR_TOL: f64 = 1e-5; pub fn generate_3d_conformer_from_smiles(smiles: &str, seed: u64) -> Result<DMatrix<f32>, String> {
34 let mol = Molecule::from_smiles(smiles)?;
35 generate_3d_conformer(&mol, seed)
36}
37
38pub fn generate_3d_conformer(mol: &Molecule, seed: u64) -> Result<DMatrix<f32>, String> {
43 let csd_torsions = crate::smarts::match_experimental_torsions(mol);
44 generate_3d_conformer_with_torsions(mol, seed, &csd_torsions)
45}
46
47pub fn generate_3d_conformer_best_of_k(
50 mol: &Molecule,
51 seed: u64,
52 csd_torsions: &[crate::forcefield::etkdg_3d::M6TorsionContrib],
53 num_seeds: usize,
54) -> Result<DMatrix<f32>, String> {
55 if num_seeds <= 1 {
56 return generate_3d_conformer_with_torsions(mol, seed, csd_torsions);
57 }
58
59 let mut best: Option<(DMatrix<f32>, f64)> = None;
60 let mut last_err = String::new();
61
62 let bounds = {
64 let raw = calculate_bounds_matrix_opts(mol, true);
65 let mut b = raw;
66 if triangle_smooth_tol(&mut b, 0.0) {
67 b
68 } else {
69 let raw2 = calculate_bounds_matrix_opts(mol, false);
70 let mut b2 = raw2.clone();
71 if triangle_smooth_tol(&mut b2, 0.0) {
72 b2
73 } else {
74 let mut b3 = raw2;
75 triangle_smooth_tol(&mut b3, 0.05);
76 b3
77 }
78 }
79 };
80
81 for k in 0..num_seeds {
82 let s = seed.wrapping_add(k as u64 * 1000);
83 match generate_3d_conformer_with_torsions(mol, s, csd_torsions) {
84 Ok(coords) => {
85 let n = mol.graph.node_count();
87 let coords_f64 = coords.map(|v| v as f64);
88 let ff = build_etkdg_3d_ff_with_torsions(mol, &coords_f64, &bounds, csd_torsions);
89 let mut flat = vec![0.0f64; n * 3];
90 for a in 0..n {
91 flat[a * 3] = coords[(a, 0)] as f64;
92 flat[a * 3 + 1] = coords[(a, 1)] as f64;
93 flat[a * 3 + 2] = coords[(a, 2)] as f64;
94 }
95 let energy = crate::forcefield::etkdg_3d::etkdg_3d_energy_f64(&flat, n, mol, &ff);
96 match &best {
97 Some((_, best_e)) if energy >= *best_e => {}
98 _ => {
99 best = Some((coords, energy));
100 }
101 }
102 }
103 Err(e) => {
104 last_err = e;
105 }
106 }
107 }
108
109 match best {
110 Some((coords, _)) => Ok(coords),
111 None => Err(last_err),
112 }
113}
114
115pub fn generate_3d_conformer_with_torsions(
120 mol: &Molecule,
121 seed: u64,
122 csd_torsions: &[crate::forcefield::etkdg_3d::M6TorsionContrib],
123) -> Result<DMatrix<f32>, String> {
124 let n = mol.graph.node_count();
125 if n == 0 {
126 return Err("Empty molecule".to_string());
127 }
128
129 let bounds = {
132 let raw = calculate_bounds_matrix_opts(mol, true);
133 let mut b = raw;
134 if triangle_smooth_tol(&mut b, 0.0) {
135 b
136 } else {
137 #[cfg(test)]
138 eprintln!(" [FALLBACK] strict smoothing failed, retrying without set15");
139 let raw2 = calculate_bounds_matrix_opts(mol, false);
140 let mut b2 = raw2.clone();
141 if triangle_smooth_tol(&mut b2, 0.0) {
142 b2
143 } else {
144 #[cfg(test)]
145 eprintln!(" [FALLBACK] second smoothing also failed, using soft smooth");
146 let mut b3 = raw2;
147 triangle_smooth_tol(&mut b3, 0.05);
148 b3
149 }
150 }
151 };
152 let chiral_sets = identify_chiral_sets(mol);
153 let tetrahedral_centers = identify_tetrahedral_centers(mol);
154
155 let max_iterations = 10 * n;
156 let mut rng = MinstdRand::new(seed as u32);
157
158 let use_4d = !chiral_sets.is_empty();
161 let embed_dim = if use_4d { 4 } else { 3 };
162
163 let mut consecutive_embed_fails = 0u32;
165 let embed_fail_threshold = (n as u32 / 4).max(20); let mut random_coord_attempts = 0u32;
167 let max_random_coord_attempts = 100u32; for _iter in 0..max_iterations {
170 let _log_attempts = std::env::var("LOG_ATTEMPTS").is_ok();
172
173 let use_random_coords = consecutive_embed_fails >= embed_fail_threshold
175 && random_coord_attempts < max_random_coord_attempts;
176 let (mut coords, basin_thresh) = if use_random_coords {
177 random_coord_attempts += 1;
178 let box_size = 10.0f64; let mut c = DMatrix::from_element(n, embed_dim, 0.0f64);
182 for i in 0..n {
183 for d in 0..embed_dim {
184 c[(i, d)] = box_size * (rng.next_double() - 0.5);
185 }
186 }
187 (c, 1e8f64)
189 } else {
190 let dists = pick_rdkit_distances(&mut rng, &bounds);
191 let coords_opt = compute_initial_coords_rdkit(&mut rng, &dists, embed_dim);
192 match coords_opt {
193 Some(c) => {
194 consecutive_embed_fails = 0;
195 (c, BASIN_THRESH as f64)
196 }
197 None => {
198 consecutive_embed_fails += 1;
199 if random_coord_attempts >= max_random_coord_attempts {
201 break;
202 }
203 if _log_attempts && consecutive_embed_fails == embed_fail_threshold {
204 eprintln!(
205 " attempt {} → switching to random coords after {} failures",
206 _iter, embed_fail_threshold
207 );
208 } else if _log_attempts {
209 eprintln!(" attempt {} → embedding failed", _iter);
210 }
211 continue;
212 }
213 }
214 };
215
216 {
220 let bt = basin_thresh as f32;
221 let initial_energy =
222 compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 0.1, 1.0);
223 if initial_energy > ERROR_TOL {
224 let mut need_more = 1;
225 let mut restarts = 0;
226 while need_more != 0 && restarts < 50 {
227 need_more = minimize_bfgs_rdkit(
228 &mut coords,
229 &bounds,
230 &chiral_sets,
231 400,
232 FORCE_TOL as f64,
233 bt,
234 0.1,
235 1.0,
236 );
237 restarts += 1;
238 }
239 }
240 }
241
242 let bt = basin_thresh as f32;
244 let energy = compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 0.1, 1.0);
245 if energy / n as f64 >= MAX_MINIMIZED_E_PER_ATOM as f64 {
246 if _log_attempts {
247 eprintln!(
248 " attempt {} → energy check failed: {:.6}/atom",
249 _iter,
250 energy / n as f64
251 );
252 }
253 continue;
254 }
255
256 if !check_tetrahedral_centers(&coords, &tetrahedral_centers) {
258 if _log_attempts {
259 eprintln!(" attempt {} → tetrahedral check failed", _iter);
260 }
261 continue;
262 }
263
264 if !chiral_sets.is_empty() && !check_chiral_centers(&coords, &chiral_sets) {
266 if _log_attempts {
267 eprintln!(" attempt {} → chiral check failed", _iter);
268 }
269 continue;
270 }
271
272 if use_4d {
275 let energy2 =
276 compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 1.0, 0.2);
277 if energy2 > ERROR_TOL {
278 let mut need_more = 1;
279 let mut restarts = 0;
280 while need_more != 0 && restarts < 50 {
281 need_more = minimize_bfgs_rdkit(
282 &mut coords,
283 &bounds,
284 &chiral_sets,
285 200,
286 FORCE_TOL as f64,
287 bt,
288 1.0,
289 0.2,
290 );
291 restarts += 1;
292 }
293 }
294 }
295
296 let coords3d = coords.columns(0, 3).into_owned();
298
299 let ff = build_etkdg_3d_ff_with_torsions(mol, &coords3d, &bounds, csd_torsions);
302 let e3d = crate::forcefield::etkdg_3d::etkdg_3d_energy_f64(
304 &{
305 let n = mol.graph.node_count();
306 let mut flat = vec![0.0f64; n * 3];
307 for a in 0..n {
308 flat[a * 3] = coords3d[(a, 0)];
309 flat[a * 3 + 1] = coords3d[(a, 1)];
310 flat[a * 3 + 2] = coords3d[(a, 2)];
311 }
312 flat
313 },
314 mol.graph.node_count(),
315 mol,
316 &ff,
317 );
318 let refined = if e3d > ERROR_TOL {
319 minimize_etkdg_3d_bfgs(mol, &coords3d, &ff, 300, FORCE_TOL)
320 } else {
321 coords3d
322 };
323
324 {
327 let n_improper_atoms = ff.inversion_contribs.len() / 3;
328 let flat_f64: Vec<f64> = {
329 let nr = refined.nrows();
330 let mut flat = vec![0.0f64; nr * 3];
331 for a in 0..nr {
332 flat[a * 3] = refined[(a, 0)];
333 flat[a * 3 + 1] = refined[(a, 1)];
334 flat[a * 3 + 2] = refined[(a, 2)];
335 }
336 flat
337 };
338 let planarity_energy =
339 crate::forcefield::etkdg_3d::planarity_check_energy_f64(&flat_f64, n, &ff);
340 if planarity_energy > n_improper_atoms as f64 * PLANARITY_TOLERANCE as f64 {
341 if _log_attempts {
342 eprintln!(
343 " attempt {} → planarity check failed (energy={:.4} > threshold={:.4})",
344 _iter,
345 planarity_energy,
346 n_improper_atoms as f64 * PLANARITY_TOLERANCE as f64
347 );
348 }
349 continue;
350 }
351 }
352
353 if !check_double_bond_geometry(mol, &refined) {
355 if _log_attempts {
356 eprintln!(" attempt {} → double bond check failed", _iter);
357 }
358 continue;
359 }
360
361 if _log_attempts {
362 eprintln!(" attempt {} → SUCCESS", _iter);
363 }
364
365 let refined_f32 = refined.map(|v| v as f32);
366 return Ok(refined_f32);
367 }
368
369 Err(format!(
370 "Failed to generate valid conformer after {} iterations",
371 max_iterations
372 ))
373}
374
375pub fn compute_total_bounds_energy_f64(
378 coords: &DMatrix<f64>,
379 bounds: &DMatrix<f64>,
380 chiral_sets: &[crate::forcefield::bounds_ff::ChiralSet],
381 basin_thresh: f32,
382 weight_4d: f32,
383 weight_chiral: f32,
384) -> f64 {
385 let n = coords.nrows();
386 let dim_coords = coords.ncols();
387 let basin_thresh_f64 = basin_thresh as f64;
388 let weight_4d_f64 = weight_4d as f64;
389 let weight_chiral_f64 = weight_chiral as f64;
390
391 let mut energy = 0.0f64;
392 for i in 1..n {
393 for j in 0..i {
394 let ub = bounds[(j, i)];
395 let lb = bounds[(i, j)];
396 if ub - lb > basin_thresh_f64 {
397 continue;
398 }
399 let mut d2 = 0.0f64;
400 for d in 0..dim_coords {
401 let diff = coords[(i, d)] - coords[(j, d)];
402 d2 += diff * diff;
403 }
404 let ub2 = ub * ub;
405 let lb2 = lb * lb;
406 let val = if d2 > ub2 {
407 d2 / ub2 - 1.0
408 } else if d2 < lb2 {
409 2.0 * lb2 / (lb2 + d2) - 1.0
410 } else {
411 0.0
412 };
413 if val > 0.0 {
414 energy += val * val;
415 }
416 }
417 }
418 if !chiral_sets.is_empty() {
419 let mut flat = vec![0.0f64; n * dim_coords];
421 for i in 0..n {
422 for d in 0..dim_coords {
423 flat[i * dim_coords + d] = coords[(i, d)];
424 }
425 }
426 energy += weight_chiral_f64
427 * crate::forcefield::bounds_ff::chiral_violation_energy_f64(
428 &flat,
429 dim_coords,
430 chiral_sets,
431 );
432 }
433 if dim_coords == 4 {
434 for i in 0..n {
435 let x4 = coords[(i, 3)];
436 energy += weight_4d_f64 * x4 * x4;
437 }
438 }
439 energy
440}