1use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct JCoupling {
13 pub h1_index: usize,
15 pub h2_index: usize,
17 pub j_hz: f64,
19 pub n_bonds: usize,
21 pub coupling_type: String,
23}
24
25const KARPLUS_A: f64 = 7.76;
28const KARPLUS_B: f64 = -1.10;
29const KARPLUS_C: f64 = 1.40;
30
31#[derive(Debug, Clone, Copy)]
33pub struct KarplusParams {
34 pub a: f64,
35 pub b: f64,
36 pub c: f64,
37}
38
39impl KarplusParams {
40 pub fn evaluate(&self, phi_rad: f64) -> f64 {
42 let cos_phi = phi_rad.cos();
43 self.a * cos_phi * cos_phi + self.b * cos_phi + self.c
44 }
45}
46
47fn get_karplus_params(x_elem: u8, y_elem: u8) -> KarplusParams {
50 match (x_elem, y_elem) {
51 (6, 6) => KarplusParams {
53 a: 7.76,
54 b: -1.10,
55 c: 1.40,
56 },
57 (6, 7) | (7, 6) => KarplusParams {
59 a: 6.40,
60 b: -1.40,
61 c: 1.90,
62 },
63 (6, 8) | (8, 6) => KarplusParams {
65 a: 5.80,
66 b: -1.20,
67 c: 1.50,
68 },
69 (6, 16) | (16, 6) => KarplusParams {
71 a: 6.00,
72 b: -1.00,
73 c: 1.30,
74 },
75 _ => KarplusParams {
77 a: KARPLUS_A,
78 b: KARPLUS_B,
79 c: KARPLUS_C,
80 },
81 }
82}
83
84fn dihedral_angle(p1: &[f64; 3], p2: &[f64; 3], p3: &[f64; 3], p4: &[f64; 3]) -> f64 {
86 let b1 = [p2[0] - p1[0], p2[1] - p1[1], p2[2] - p1[2]];
87 let b2 = [p3[0] - p2[0], p3[1] - p2[1], p3[2] - p2[2]];
88 let b3 = [p4[0] - p3[0], p4[1] - p3[1], p4[2] - p3[2]];
89
90 let n1 = cross(&b1, &b2);
92 let n2 = cross(&b2, &b3);
94
95 let m1 = cross(&n1, &normalize(&b2));
96
97 let x = dot(&n1, &n2);
98 let y = dot(&m1, &n2);
99
100 (-y).atan2(x)
101}
102
103fn cross(a: &[f64; 3], b: &[f64; 3]) -> [f64; 3] {
104 [
105 a[1] * b[2] - a[2] * b[1],
106 a[2] * b[0] - a[0] * b[2],
107 a[0] * b[1] - a[1] * b[0],
108 ]
109}
110
111fn dot(a: &[f64; 3], b: &[f64; 3]) -> f64 {
112 a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
113}
114
115fn normalize(v: &[f64; 3]) -> [f64; 3] {
116 let len = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
117 if len < 1e-10 {
118 return [0.0, 0.0, 0.0];
119 }
120 [v[0] / len, v[1] / len, v[2] / len]
121}
122
123pub fn predict_j_couplings(mol: &crate::graph::Molecule, positions: &[[f64; 3]]) -> Vec<JCoupling> {
133 let n = mol.graph.node_count();
134 let has_3d = positions.len() == n;
135 let mut couplings = Vec::new();
136
137 let h_atoms: Vec<usize> = (0..n)
139 .filter(|&i| mol.graph[petgraph::graph::NodeIndex::new(i)].element == 1)
140 .collect();
141
142 for i in 0..h_atoms.len() {
144 for j in (i + 1)..h_atoms.len() {
145 let h1 = h_atoms[i];
146 let h2 = h_atoms[j];
147 let h1_idx = petgraph::graph::NodeIndex::new(h1);
148 let h2_idx = petgraph::graph::NodeIndex::new(h2);
149
150 let parent1: Vec<petgraph::graph::NodeIndex> = mol
152 .graph
153 .neighbors(h1_idx)
154 .filter(|n| mol.graph[*n].element != 1)
155 .collect();
156 let parent2: Vec<petgraph::graph::NodeIndex> = mol
157 .graph
158 .neighbors(h2_idx)
159 .filter(|n| mol.graph[*n].element != 1)
160 .collect();
161
162 if parent1.is_empty() || parent2.is_empty() {
163 continue;
164 }
165
166 let p1 = parent1[0];
167 let p2 = parent2[0];
168
169 if p1 == p2 {
170 let j_hz: f64 = -12.0; couplings.push(JCoupling {
173 h1_index: h1,
174 h2_index: h2,
175 j_hz: j_hz.abs(), n_bonds: 2,
177 coupling_type: "geminal_2J".to_string(),
178 });
179 } else if mol.graph.find_edge(p1, p2).is_some() {
180 let x_elem = mol.graph[p1].element;
182 let y_elem = mol.graph[p2].element;
183 let params = get_karplus_params(x_elem, y_elem);
184
185 let j_hz = if has_3d {
186 let phi = dihedral_angle(
187 &positions[h1],
188 &positions[p1.index()],
189 &positions[p2.index()],
190 &positions[h2],
191 );
192 params.evaluate(phi)
193 } else {
194 7.0
196 };
197
198 couplings.push(JCoupling {
199 h1_index: h1,
200 h2_index: h2,
201 j_hz,
202 n_bonds: 3,
203 coupling_type: format!(
204 "vicinal_3J_H-{}-{}-H",
205 element_symbol(x_elem),
206 element_symbol(y_elem)
207 ),
208 });
209 }
210 }
212 }
213
214 couplings
215}
216
217fn element_symbol(z: u8) -> &'static str {
218 match z {
219 6 => "C",
220 7 => "N",
221 8 => "O",
222 16 => "S",
223 _ => "X",
224 }
225}
226
227pub fn ensemble_averaged_j_couplings(
240 mol: &crate::graph::Molecule,
241 conformer_positions: &[Vec<[f64; 3]>],
242 energies_kcal: &[f64],
243 temperature_k: f64,
244) -> Vec<JCoupling> {
245 if conformer_positions.is_empty() {
246 return Vec::new();
247 }
248 if conformer_positions.len() != energies_kcal.len() {
249 return predict_j_couplings(mol, &conformer_positions[0]);
250 }
251
252 const KB_KCAL: f64 = 0.001987204;
254 let beta = 1.0 / (KB_KCAL * temperature_k);
255
256 let e_min = energies_kcal.iter().cloned().fold(f64::INFINITY, f64::min);
258
259 let weights: Vec<f64> = energies_kcal
261 .iter()
262 .map(|&e| (-(e - e_min) * beta).exp())
263 .collect();
264 let weight_sum: f64 = weights.iter().sum();
265
266 if weight_sum < 1e-30 {
267 return predict_j_couplings(mol, &conformer_positions[0]);
268 }
269
270 let all_couplings: Vec<Vec<JCoupling>> = conformer_positions
272 .iter()
273 .map(|pos| predict_j_couplings(mol, pos))
274 .collect();
275
276 if all_couplings.is_empty() {
278 return Vec::new();
279 }
280
281 let n_couplings = all_couplings[0].len();
282 let mut averaged = all_couplings[0].clone();
283
284 for k in 0..n_couplings {
285 let mut weighted_j = 0.0;
286 for (conf_idx, couplings) in all_couplings.iter().enumerate() {
287 if k < couplings.len() {
288 weighted_j += couplings[k].j_hz * weights[conf_idx];
289 }
290 }
291 averaged[k].j_hz = weighted_j / weight_sum;
292 }
293
294 averaged
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 fn karplus_3j(phi_rad: f64) -> f64 {
303 let cos_phi = phi_rad.cos();
304 KARPLUS_A * cos_phi * cos_phi + KARPLUS_B * cos_phi + KARPLUS_C
305 }
306
307 #[test]
308 fn test_karplus_equation_values() {
309 let j_0 = karplus_3j(0.0);
311 assert!(
312 j_0 > 8.0 && j_0 < 10.0,
313 "³J(0°) = {} Hz, expected ~9 Hz",
314 j_0
315 );
316
317 let j_90 = karplus_3j(std::f64::consts::FRAC_PI_2);
319 assert!(
320 j_90 > 0.0 && j_90 < 3.0,
321 "³J(90°) = {} Hz, expected ~1.4 Hz",
322 j_90
323 );
324
325 let j_180 = karplus_3j(std::f64::consts::PI);
327 assert!(
328 j_180 > 6.0 && j_180 < 12.0,
329 "³J(180°) = {} Hz, expected ~10 Hz",
330 j_180
331 );
332 }
333
334 #[test]
335 fn test_dihedral_angle_basic() {
336 let p1 = [1.0, 0.0, 0.0];
338 let p2 = [0.0, 0.0, 0.0];
339 let p3 = [0.0, 1.0, 0.0];
340 let p4 = [-1.0, 1.0, 0.0];
341 let angle = dihedral_angle(&p1, &p2, &p3, &p4);
342 assert!(angle.abs() < 0.1 || (angle.abs() - std::f64::consts::PI).abs() < 0.1);
344 }
345
346 #[test]
347 fn test_ethane_j_couplings() {
348 let mol = crate::graph::Molecule::from_smiles("CC").unwrap();
349 let couplings = predict_j_couplings(&mol, &[]);
350
351 assert!(
353 !couplings.is_empty(),
354 "Ethane should have J-coupling predictions"
355 );
356
357 let vicinal: Vec<&JCoupling> = couplings.iter().filter(|c| c.n_bonds == 3).collect();
359 assert!(
360 !vicinal.is_empty(),
361 "Ethane should have ³J vicinal couplings"
362 );
363
364 for c in &vicinal {
366 assert!(
367 c.coupling_type.contains("vicinal_3J"),
368 "Coupling type should be vicinal_3J, got {}",
369 c.coupling_type
370 );
371 }
372 }
373
374 #[test]
375 fn test_karplus_pathway_specific() {
376 let cc_params = get_karplus_params(6, 6);
378 let cn_params = get_karplus_params(6, 7);
379
380 let j_cc = cc_params.evaluate(0.0);
382 let j_cn = cn_params.evaluate(0.0);
383 assert!(
384 (j_cc - j_cn).abs() > 0.1,
385 "H-C-C-H and H-C-N-H should have different J at φ=0: {} vs {}",
386 j_cc,
387 j_cn
388 );
389 }
390
391 #[test]
392 fn test_ensemble_averaging() {
393 let mol = crate::graph::Molecule::from_smiles("CC").unwrap();
394 let n = mol.graph.node_count();
396 let positions = vec![[0.0, 0.0, 0.0]; n];
397 let result = ensemble_averaged_j_couplings(&mol, &[positions], &[0.0], 298.15);
398 assert!(!result.is_empty());
399 }
400
401 #[test]
402 fn test_methane_j_couplings() {
403 let mol = crate::graph::Molecule::from_smiles("C").unwrap();
404 let couplings = predict_j_couplings(&mol, &[]);
405
406 for c in &couplings {
408 assert_eq!(c.n_bonds, 2, "Methane H-H should be 2-bond (geminal)");
409 }
410 }
411}