1use nalgebra::DMatrix;
2use petgraph::visit::EdgeRef;
3use rand::rngs::StdRng;
4use rand::{Rng, SeedableRng};
5use std::collections::{HashSet, VecDeque};
6
7pub struct RotatableBond {
13 pub dihedral: [usize; 4],
16 pub mobile_atoms: Vec<usize>,
18 pub preferred_angles: Vec<f32>,
20}
21
22pub fn find_rotatable_bonds(mol: &crate::graph::Molecule) -> Vec<RotatableBond> {
27 let mut bonds = Vec::new();
28 let n = mol.graph.node_count();
29
30 for edge in mol.graph.edge_references() {
31 let u = edge.source();
32 let v = edge.target();
33
34 if edge.weight().order != crate::graph::BondOrder::Single {
36 continue;
37 }
38
39 let deg_u = mol.graph.neighbors(u).count();
41 let deg_v = mol.graph.neighbors(v).count();
42 if deg_u < 2 || deg_v < 2 {
43 continue;
44 }
45
46 if crate::graph::min_path_excluding2(mol, u, v, u, v, 7).is_some() {
48 continue;
49 }
50
51 let a = mol.graph.neighbors(u).find(|&x| x != v).unwrap();
53 let d = mol.graph.neighbors(v).find(|&x| x != u).unwrap();
54
55 let mut mobile = Vec::new();
57 let mut visited = HashSet::new();
58 visited.insert(u.index());
59 visited.insert(v.index());
60 let mut queue = VecDeque::new();
61 for nb in mol.graph.neighbors(v) {
62 if nb != u {
63 queue.push_back(nb.index());
64 visited.insert(nb.index());
65 }
66 }
67 while let Some(curr) = queue.pop_front() {
68 mobile.push(curr);
69 let ni = petgraph::graph::NodeIndex::new(curr);
70 for nb in mol.graph.neighbors(ni) {
71 if !visited.contains(&nb.index()) {
72 visited.insert(nb.index());
73 queue.push_back(nb.index());
74 }
75 }
76 }
77
78 let other_count = n - mobile.len() - 2; if mobile.len() > other_count {
81 let mut mobile_u = Vec::new();
83 let mut visited_u = HashSet::new();
84 visited_u.insert(u.index());
85 visited_u.insert(v.index());
86 let mut queue_u = VecDeque::new();
87 for nb in mol.graph.neighbors(u) {
88 if nb != v {
89 queue_u.push_back(nb.index());
90 visited_u.insert(nb.index());
91 }
92 }
93 while let Some(curr) = queue_u.pop_front() {
94 mobile_u.push(curr);
95 let ni = petgraph::graph::NodeIndex::new(curr);
96 for nb in mol.graph.neighbors(ni) {
97 if !visited_u.contains(&nb.index()) {
98 visited_u.insert(nb.index());
99 queue_u.push_back(nb.index());
100 }
101 }
102 }
103 let preferred = get_preferred_angles(mol, v, u);
105 bonds.push(RotatableBond {
106 dihedral: [d.index(), v.index(), u.index(), a.index()],
107 mobile_atoms: mobile_u,
108 preferred_angles: preferred,
109 });
110 } else {
111 let preferred = get_preferred_angles(mol, u, v);
112 bonds.push(RotatableBond {
113 dihedral: [a.index(), u.index(), v.index(), d.index()],
114 mobile_atoms: mobile,
115 preferred_angles: preferred,
116 });
117 }
118 }
119 bonds
120}
121
122fn get_preferred_angles(
124 mol: &crate::graph::Molecule,
125 u: petgraph::graph::NodeIndex,
126 v: petgraph::graph::NodeIndex,
127) -> Vec<f32> {
128 use crate::graph::Hybridization::*;
129 use std::f32::consts::PI;
130
131 let hyb_u = mol.graph[u].hybridization;
132 let hyb_v = mol.graph[v].hybridization;
133
134 match (hyb_u, hyb_v) {
135 (SP3, SP3) => {
136 vec![PI / 3.0, PI, 5.0 * PI / 3.0]
138 }
139 (SP2, SP2) => {
140 vec![0.0, PI]
142 }
143 (SP2, SP3) | (SP3, SP2) => {
144 vec![
146 0.0,
147 PI / 3.0,
148 2.0 * PI / 3.0,
149 PI,
150 4.0 * PI / 3.0,
151 5.0 * PI / 3.0,
152 ]
153 }
154 _ => {
155 (0..12).map(|i| i as f32 * PI / 6.0).collect()
157 }
158 }
159}
160
161pub fn compute_dihedral(coords: &DMatrix<f32>, i: usize, j: usize, k: usize, l: usize) -> f32 {
163 let b1 = nalgebra::Vector3::new(
164 coords[(j, 0)] - coords[(i, 0)],
165 coords[(j, 1)] - coords[(i, 1)],
166 coords[(j, 2)] - coords[(i, 2)],
167 );
168 let b2 = nalgebra::Vector3::new(
169 coords[(k, 0)] - coords[(j, 0)],
170 coords[(k, 1)] - coords[(j, 1)],
171 coords[(k, 2)] - coords[(j, 2)],
172 );
173 let b3 = nalgebra::Vector3::new(
174 coords[(l, 0)] - coords[(k, 0)],
175 coords[(l, 1)] - coords[(k, 1)],
176 coords[(l, 2)] - coords[(k, 2)],
177 );
178
179 let n1 = b1.cross(&b2).normalize();
180 let n2 = b2.cross(&b3).normalize();
181 let m1 = n1.cross(&b2.normalize());
182 let x = n1.dot(&n2);
183 let y = m1.dot(&n2);
184 y.atan2(x)
185}
186
187pub fn rotate_atoms(coords: &mut DMatrix<f32>, mobile: &[usize], j: usize, k: usize, angle: f32) {
189 if angle.abs() < 1e-8 {
190 return;
191 }
192 let axis = nalgebra::Vector3::new(
194 coords[(k, 0)] - coords[(j, 0)],
195 coords[(k, 1)] - coords[(j, 1)],
196 coords[(k, 2)] - coords[(j, 2)],
197 );
198 let axis_len = axis.norm();
199 if axis_len < 1e-8 {
200 return;
201 }
202 let axis = axis / axis_len;
203
204 let cos_a = angle.cos();
206 let sin_a = angle.sin();
207
208 let px = coords[(j, 0)];
210 let py = coords[(j, 1)];
211 let pz = coords[(j, 2)];
212
213 for &idx in mobile {
214 let vx = coords[(idx, 0)] - px;
215 let vy = coords[(idx, 1)] - py;
216 let vz = coords[(idx, 2)] - pz;
217
218 let dot = axis[0] * vx + axis[1] * vy + axis[2] * vz;
219 let cx = axis[1] * vz - axis[2] * vy;
220 let cy = axis[2] * vx - axis[0] * vz;
221 let cz = axis[0] * vy - axis[1] * vx;
222
223 coords[(idx, 0)] = px + vx * cos_a + cx * sin_a + axis[0] * dot * (1.0 - cos_a);
224 coords[(idx, 1)] = py + vy * cos_a + cy * sin_a + axis[1] * dot * (1.0 - cos_a);
225 coords[(idx, 2)] = pz + vz * cos_a + cz * sin_a + axis[2] * dot * (1.0 - cos_a);
226 }
227}
228
229pub fn snap_torsions_to_preferred(
232 coords: &mut DMatrix<f32>,
233 mol: &crate::graph::Molecule,
234) -> usize {
235 let rotatable = find_rotatable_bonds(mol);
236 let num_rotatable = rotatable.len();
237
238 for rb in &rotatable {
239 let [a, b, c, d] = rb.dihedral;
240 let current = compute_dihedral(coords, a, b, c, d);
241
242 let mut best_delta_abs = f32::MAX;
244 let mut best_rotation = 0.0f32;
245 for &target in &rb.preferred_angles {
246 let mut delta = target - current;
247 delta = (delta + std::f32::consts::PI).rem_euclid(2.0 * std::f32::consts::PI)
249 - std::f32::consts::PI;
250 if delta.abs() < best_delta_abs {
251 best_delta_abs = delta.abs();
252 best_rotation = delta;
253 }
254 }
255
256 if best_rotation.abs() > 0.05 {
257 rotate_atoms(coords, &rb.mobile_atoms, b, c, best_rotation);
258 }
259 }
260 num_rotatable
261}
262
263pub fn optimize_torsions_greedy(
267 coords: &mut DMatrix<f32>,
268 mol: &crate::graph::Molecule,
269 bounds: &DMatrix<f64>,
270 passes: usize,
271) -> usize {
272 let rotatable = find_rotatable_bonds(mol);
273 let num_rotatable = rotatable.len();
274 if rotatable.is_empty() {
275 return 0;
276 }
277
278 let params = super::energy::FFParams {
279 kb: 300.0,
280 k_theta: 200.0,
281 k_omega: 10.0,
282 k_oop: 20.0,
283 k_bounds: 100.0,
284 k_chiral: 0.0,
285 k_vdw: 0.0,
286 };
287
288 for _pass in 0..passes {
289 for rb in &rotatable {
290 let [a, b, c, d] = rb.dihedral;
291 let current_angle = compute_dihedral(coords, a, b, c, d);
292
293 let current_energy =
294 super::energy::calculate_total_energy(coords, mol, ¶ms, bounds);
295 let mut best_energy = current_energy;
296 let mut best_rotation = 0.0f32;
297
298 for &target_angle in &rb.preferred_angles {
299 let delta = target_angle - current_angle;
300 let delta = ((delta + std::f32::consts::PI) % (2.0 * std::f32::consts::PI))
301 - std::f32::consts::PI;
302
303 rotate_atoms(coords, &rb.mobile_atoms, b, c, delta);
304
305 let e = super::energy::calculate_total_energy(coords, mol, ¶ms, bounds);
306 if e < best_energy {
307 best_energy = e;
308 best_rotation = delta;
309 }
310
311 rotate_atoms(coords, &rb.mobile_atoms, b, c, -delta);
312 }
313
314 if best_rotation.abs() > 1e-6 {
315 rotate_atoms(coords, &rb.mobile_atoms, b, c, best_rotation);
316 }
317 }
318 }
319 num_rotatable
320}
321
322pub fn optimize_torsions_bounds(
326 coords: &mut DMatrix<f32>,
327 mol: &crate::graph::Molecule,
328 bounds: &DMatrix<f64>,
329 passes: usize,
330) -> usize {
331 let rotatable = find_rotatable_bonds(mol);
332 let num_rotatable = rotatable.len();
333 if rotatable.is_empty() {
334 return 0;
335 }
336
337 for _pass in 0..passes {
338 for rb in &rotatable {
339 let [a, b, c, d] = rb.dihedral;
340 let current_angle = compute_dihedral(coords, a, b, c, d);
341
342 let current_energy = super::bounds_ff::bounds_violation_energy(coords, bounds);
343 let mut best_energy = current_energy;
344 let mut best_rotation = 0.0f32;
345
346 for &target_angle in &rb.preferred_angles {
347 let delta = target_angle - current_angle;
348 let delta = ((delta + std::f32::consts::PI) % (2.0 * std::f32::consts::PI))
349 - std::f32::consts::PI;
350
351 rotate_atoms(coords, &rb.mobile_atoms, b, c, delta);
352
353 let e = super::bounds_ff::bounds_violation_energy(coords, bounds);
354 if e < best_energy {
355 best_energy = e;
356 best_rotation = delta;
357 }
358
359 rotate_atoms(coords, &rb.mobile_atoms, b, c, -delta);
360 }
361
362 if best_rotation.abs() > 1e-6 {
363 rotate_atoms(coords, &rb.mobile_atoms, b, c, best_rotation);
364 }
365 }
366 }
367 num_rotatable
368}
369
370pub fn optimize_torsions_monte_carlo_bounds(
375 coords: &mut DMatrix<f32>,
376 mol: &crate::graph::Molecule,
377 bounds: &DMatrix<f64>,
378 seed: u64,
379 n_steps: usize,
380 temperature: f32,
381) -> usize {
382 let rotatable = find_rotatable_bonds(mol);
383 let num_rotatable = rotatable.len();
384 if rotatable.is_empty() || n_steps == 0 {
385 return num_rotatable;
386 }
387
388 let mut rng = StdRng::seed_from_u64(seed);
389 let temp = temperature.max(1e-6);
390 let two_pi = 2.0 * std::f32::consts::PI;
391 let mut current_energy = super::bounds_ff::bounds_violation_energy(coords, bounds);
392
393 for _ in 0..n_steps {
394 let rb = &rotatable[rng.gen_range(0..rotatable.len())];
395 let [a, b, c, d] = rb.dihedral;
396 let current_angle = compute_dihedral(coords, a, b, c, d);
397 let target = rng.gen_range(-std::f32::consts::PI..std::f32::consts::PI);
398 let mut delta = target - current_angle;
399 delta = (delta + std::f32::consts::PI).rem_euclid(two_pi) - std::f32::consts::PI;
400
401 rotate_atoms(coords, &rb.mobile_atoms, b, c, delta);
402 let trial_energy = super::bounds_ff::bounds_violation_energy(coords, bounds);
403 let d_e = trial_energy - current_energy;
404
405 let accept = if d_e <= 0.0 {
406 true
407 } else {
408 let p_accept = (-d_e / temp).exp();
409 rng.gen::<f32>() < p_accept
410 };
411
412 if accept {
413 current_energy = trial_energy;
414 } else {
415 rotate_atoms(coords, &rb.mobile_atoms, b, c, -delta);
416 }
417 }
418
419 num_rotatable
420}
421
422pub fn systematic_rotor_search(
430 smiles: &str,
431 coords: &[f64],
432 max_rotors: usize,
433) -> Result<Vec<(Vec<f64>, f64)>, String> {
434 use std::f32::consts::PI;
435
436 let mol = crate::graph::Molecule::from_smiles(smiles)?;
437 let n_atoms = mol.graph.node_count();
438 if coords.len() != n_atoms * 3 {
439 return Err("coords length mismatch".to_string());
440 }
441
442 let rotatable = find_rotatable_bonds(&mol);
443 let n_rot = rotatable.len().min(max_rotors);
444
445 if n_rot == 0 {
446 let e = crate::compute_uff_energy(smiles, coords).unwrap_or(0.0);
447 return Ok(vec![(coords.to_vec(), e)]);
448 }
449
450 let ff = super::builder::build_uff_force_field(&mol);
451 let angles: Vec<f32> = (0..12).map(|i| i as f32 * PI / 6.0).collect();
452
453 let total: usize = 12usize.pow(n_rot as u32);
455
456 let base_matrix = flat_to_matrix_internal(coords, n_atoms);
457
458 let eval_combo = |combo_idx: usize| -> Option<(Vec<f64>, f64)> {
459 let mut matrix = base_matrix.clone();
460
461 let mut idx = combo_idx;
462 for r in 0..n_rot {
463 let angle_idx = idx % 12;
464 idx /= 12;
465
466 let rb = &rotatable[r];
467 let [a, b, c, d] = rb.dihedral;
468 let current = compute_dihedral(&matrix, a, b, c, d);
469 let target = angles[angle_idx];
470 let mut delta = target - current;
471 delta = (delta + PI).rem_euclid(2.0 * PI) - PI;
472 rotate_atoms(&mut matrix, &rb.mobile_atoms, b, c, delta);
473 }
474
475 let flat = matrix_to_flat(&matrix, n_atoms);
476 let mut grad = vec![0.0f64; n_atoms * 3];
477 let energy = ff.compute_system_energy_and_gradients(&flat, &mut grad);
478
479 if energy.is_finite() {
480 Some((flat, energy))
481 } else {
482 None
483 }
484 };
485
486 #[cfg(feature = "parallel")]
487 let mut results: Vec<(Vec<f64>, f64)> = {
488 use rayon::prelude::*;
489 (0..total).into_par_iter().filter_map(eval_combo).collect()
490 };
491
492 #[cfg(not(feature = "parallel"))]
493 let mut results: Vec<(Vec<f64>, f64)> = (0..total).filter_map(eval_combo).collect();
494
495 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
497 Ok(results)
498}
499
500pub fn simulated_annealing_torsion_search(
505 smiles: &str,
506 coords: &[f64],
507 n_steps: usize,
508 t_start: f64,
509 t_end: f64,
510 seed: u64,
511) -> Result<Vec<(Vec<f64>, f64)>, String> {
512 let mol = crate::graph::Molecule::from_smiles(smiles)?;
513 let n_atoms = mol.graph.node_count();
514 if coords.len() != n_atoms * 3 {
515 return Err("coords length mismatch".to_string());
516 }
517
518 let rotatable = find_rotatable_bonds(&mol);
519 if rotatable.is_empty() {
520 let e = crate::compute_uff_energy(smiles, coords).unwrap_or(0.0);
521 return Ok(vec![(coords.to_vec(), e)]);
522 }
523
524 let ff = super::builder::build_uff_force_field(&mol);
525 let mut rng = StdRng::seed_from_u64(seed);
526 let mut matrix = flat_to_matrix_internal(coords, n_atoms);
527
528 let flat = matrix_to_flat(&matrix, n_atoms);
529 let mut grad = vec![0.0f64; n_atoms * 3];
530 let mut current_energy = ff.compute_system_energy_and_gradients(&flat, &mut grad);
531
532 let mut best_coords = flat;
533 let mut best_energy = current_energy;
534
535 let cooling_rate = if n_steps > 1 {
536 (t_end / t_start).powf(1.0 / (n_steps - 1) as f64)
537 } else {
538 1.0
539 };
540
541 let mut collected = Vec::new();
542 let collect_interval = (n_steps / 50).max(1); let mut temp = t_start;
545 for step in 0..n_steps {
546 let rb_idx = rng.gen_range(0..rotatable.len());
547 let rb = &rotatable[rb_idx];
548 let [a, b, c, d] = rb.dihedral;
549 let current_angle = compute_dihedral(&matrix, a, b, c, d);
550 let perturbation: f32 = rng.gen_range(-std::f32::consts::PI..std::f32::consts::PI);
551 let mut delta = perturbation - current_angle;
552 delta = (delta + std::f32::consts::PI).rem_euclid(2.0 * std::f32::consts::PI)
553 - std::f32::consts::PI;
554
555 rotate_atoms(&mut matrix, &rb.mobile_atoms, b, c, delta);
556
557 let trial_flat = matrix_to_flat(&matrix, n_atoms);
558 let trial_energy = ff.compute_system_energy_and_gradients(&trial_flat, &mut grad);
559
560 let d_e = trial_energy - current_energy;
561 let accept = if d_e <= 0.0 {
562 true
563 } else {
564 let p = (-d_e / temp).exp();
565 rng.gen::<f64>() < p
566 };
567
568 if accept && trial_energy.is_finite() {
569 current_energy = trial_energy;
570 if current_energy < best_energy {
571 best_energy = current_energy;
572 best_coords = trial_flat;
573 }
574 } else {
575 rotate_atoms(&mut matrix, &rb.mobile_atoms, b, c, -delta);
576 }
577
578 if step % collect_interval == 0 {
579 let snap = matrix_to_flat(&matrix, n_atoms);
580 collected.push((snap, current_energy));
581 }
582
583 temp *= cooling_rate;
584 }
585
586 collected.push((best_coords, best_energy));
588 collected.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
589 Ok(collected)
590}
591
592pub fn torsional_sampling_diverse(
597 smiles: &str,
598 coords: &[f64],
599 rmsd_cutoff: f64,
600 seed: u64,
601) -> Result<Vec<(Vec<f64>, f64)>, String> {
602 let mol = crate::graph::Molecule::from_smiles(smiles)?;
603 let rotatable = find_rotatable_bonds(&mol);
604 let n_rot = rotatable.len();
605
606 let conformers = if n_rot <= 4 {
607 systematic_rotor_search(smiles, coords, 4)?
608 } else {
609 simulated_annealing_torsion_search(smiles, coords, 500, 5.0, 0.1, seed)?
610 };
611
612 if conformers.len() <= 1 {
613 return Ok(conformers);
614 }
615
616 let coords_vecs: Vec<Vec<f64>> = conformers.iter().map(|(c, _)| c.clone()).collect();
617 let cluster_result = crate::clustering::butina_cluster(&coords_vecs, rmsd_cutoff);
618
619 let diverse: Vec<(Vec<f64>, f64)> = cluster_result
620 .centroid_indices
621 .iter()
622 .map(|&ci| conformers[ci].clone())
623 .collect();
624
625 Ok(diverse)
626}
627
628fn flat_to_matrix_internal(coords: &[f64], n_atoms: usize) -> DMatrix<f32> {
629 let mut m = DMatrix::<f32>::zeros(n_atoms, 3);
630 for i in 0..n_atoms {
631 m[(i, 0)] = coords[3 * i] as f32;
632 m[(i, 1)] = coords[3 * i + 1] as f32;
633 m[(i, 2)] = coords[3 * i + 2] as f32;
634 }
635 m
636}
637
638fn matrix_to_flat(matrix: &DMatrix<f32>, n_atoms: usize) -> Vec<f64> {
639 let mut flat = vec![0.0f64; n_atoms * 3];
640 for i in 0..n_atoms {
641 flat[3 * i] = matrix[(i, 0)] as f64;
642 flat[3 * i + 1] = matrix[(i, 1)] as f64;
643 flat[3 * i + 2] = matrix[(i, 2)] as f64;
644 }
645 flat
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651
652 fn flat_to_matrix(coords: &[f64], n_atoms: usize) -> DMatrix<f32> {
653 let mut m = DMatrix::<f32>::zeros(n_atoms, 3);
654 for i in 0..n_atoms {
655 m[(i, 0)] = coords[3 * i] as f32;
656 m[(i, 1)] = coords[3 * i + 1] as f32;
657 m[(i, 2)] = coords[3 * i + 2] as f32;
658 }
659 m
660 }
661
662 #[test]
663 fn test_monte_carlo_torsion_optimizer_runs_for_butane() {
664 let smiles = "CCCC";
665 let mol = crate::graph::Molecule::from_smiles(smiles).unwrap();
666 let conf = crate::embed(smiles, 42);
667 assert!(conf.error.is_none());
668
669 let bounds =
670 crate::distgeom::smooth_bounds_matrix(crate::distgeom::calculate_bounds_matrix(&mol));
671 let mut coords = flat_to_matrix(&conf.coords, mol.graph.node_count());
672 let rot = optimize_torsions_monte_carlo_bounds(&mut coords, &mol, &bounds, 123, 64, 0.3);
673
674 assert!(rot >= 1);
675 assert!(coords.iter().all(|v| v.is_finite()));
676 }
677}