1use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum OptMethod {
19 SteepestDescent,
21 Bfgs,
23}
24
25#[derive(Debug, Clone)]
27pub struct FrameworkOptConfig {
28 pub method: OptMethod,
30 pub max_iter: usize,
32 pub force_tol: f64,
34 pub energy_tol: f64,
36 pub max_step: f64,
38 pub fixed_atoms: Vec<usize>,
40 pub lattice: Option<[[f64; 3]; 3]>,
42}
43
44impl Default for FrameworkOptConfig {
45 fn default() -> Self {
46 Self {
47 method: OptMethod::Bfgs,
48 max_iter: 200,
49 force_tol: 0.05,
50 energy_tol: 1e-6,
51 max_step: 0.2,
52 fixed_atoms: vec![],
53 lattice: None,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct FrameworkOptResult {
61 pub positions: Vec<[f64; 3]>,
63 pub energy: f64,
65 pub forces: Vec<[f64; 3]>,
67 pub max_force: f64,
69 pub n_iterations: usize,
71 pub converged: bool,
73 pub energy_history: Vec<f64>,
75}
76
77pub type EnergyForceFn = dyn Fn(&[u8], &[[f64; 3]]) -> Result<(f64, Vec<[f64; 3]>), String>;
80
81pub fn optimize_framework(
89 elements: &[u8],
90 initial_positions: &[[f64; 3]],
91 energy_force_fn: &EnergyForceFn,
92 config: &FrameworkOptConfig,
93) -> Result<FrameworkOptResult, String> {
94 match config.method {
95 OptMethod::SteepestDescent => {
96 optimize_steepest_descent(elements, initial_positions, energy_force_fn, config)
97 }
98 OptMethod::Bfgs => optimize_bfgs(elements, initial_positions, energy_force_fn, config),
99 }
100}
101
102fn optimize_steepest_descent(
104 elements: &[u8],
105 initial_positions: &[[f64; 3]],
106 energy_force_fn: &EnergyForceFn,
107 config: &FrameworkOptConfig,
108) -> Result<FrameworkOptResult, String> {
109 let n = elements.len();
110 let mut positions: Vec<[f64; 3]> = initial_positions.to_vec();
111 let mut energy_history = Vec::new();
112 let mut step_size = config.max_step;
113
114 let (mut energy, mut forces) = energy_force_fn(elements, &positions)?;
115 energy_history.push(energy);
116
117 let mut converged = false;
118 let mut n_iter = 0;
119
120 for iter in 0..config.max_iter {
121 n_iter = iter + 1;
122
123 zero_fixed_forces(&mut forces, &config.fixed_atoms);
125
126 let max_force = max_force_magnitude(&forces);
127 if max_force < config.force_tol {
128 converged = true;
129 break;
130 }
131
132 let mut new_positions = positions.clone();
134 for i in 0..n {
135 if config.fixed_atoms.contains(&i) {
136 continue;
137 }
138 let f_mag = (forces[i][0].powi(2) + forces[i][1].powi(2) + forces[i][2].powi(2)).sqrt();
139 if f_mag < 1e-12 {
140 continue;
141 }
142 let scale = step_size / f_mag;
143 for d in 0..3 {
144 new_positions[i][d] += forces[i][d] * scale;
145 }
146 }
147
148 if let Some(ref lattice) = config.lattice {
150 apply_pbc(&mut new_positions, lattice);
151 }
152
153 let (new_energy, new_forces) = energy_force_fn(elements, &new_positions)?;
154
155 if new_energy < energy {
156 positions = new_positions;
157 energy = new_energy;
158 forces = new_forces;
159 step_size = (step_size * 1.2).min(config.max_step);
160 } else {
161 step_size *= 0.5;
162 if step_size < 1e-10 {
163 break;
164 }
165 }
166
167 energy_history.push(energy);
168
169 if energy_history.len() > 1 {
170 let de = (energy_history[energy_history.len() - 2] - energy).abs();
171 if de < config.energy_tol {
172 converged = true;
173 break;
174 }
175 }
176 }
177
178 zero_fixed_forces(&mut forces, &config.fixed_atoms);
179 let max_force = max_force_magnitude(&forces);
180
181 Ok(FrameworkOptResult {
182 positions,
183 energy,
184 forces,
185 max_force,
186 n_iterations: n_iter,
187 converged,
188 energy_history,
189 })
190}
191
192fn optimize_bfgs(
194 elements: &[u8],
195 initial_positions: &[[f64; 3]],
196 energy_force_fn: &EnergyForceFn,
197 config: &FrameworkOptConfig,
198) -> Result<FrameworkOptResult, String> {
199 let n = elements.len();
200 let ndim = n * 3;
201 let mut positions: Vec<[f64; 3]> = initial_positions.to_vec();
202 let mut energy_history = Vec::new();
203
204 let (mut energy, mut forces) = energy_force_fn(elements, &positions)?;
206 zero_fixed_forces(&mut forces, &config.fixed_atoms);
207 energy_history.push(energy);
208
209 let mut grad = flatten_neg_forces(&forces);
211
212 let mut h_inv = vec![vec![0.0f64; ndim]; ndim];
214 for i in 0..ndim {
215 h_inv[i][i] = 1.0;
216 }
217
218 for &fixed in &config.fixed_atoms {
220 for d in 0..3 {
221 let idx = fixed * 3 + d;
222 if idx < ndim {
223 for j in 0..ndim {
224 h_inv[idx][j] = 0.0;
225 h_inv[j][idx] = 0.0;
226 }
227 }
228 }
229 }
230
231 let mut converged = false;
232 let mut n_iter = 0;
233
234 for iter in 0..config.max_iter {
235 n_iter = iter + 1;
236
237 let max_force = max_force_magnitude(&forces);
238 if max_force < config.force_tol {
239 converged = true;
240 break;
241 }
242
243 let mut p = vec![0.0f64; ndim];
245 for i in 0..ndim {
246 for j in 0..ndim {
247 p[i] -= h_inv[i][j] * grad[j];
248 }
249 }
250
251 let p_norm: f64 = p.iter().map(|x| x * x).sum::<f64>().sqrt();
253 if p_norm > config.max_step {
254 let scale = config.max_step / p_norm;
255 for x in &mut p {
256 *x *= scale;
257 }
258 }
259
260 let directional_deriv: f64 = p.iter().zip(grad.iter()).map(|(a, b)| a * b).sum();
262 let c_armijo = 1e-4;
263 let mut alpha = 1.0;
264 let mut new_positions;
265 let mut new_energy;
266 let mut new_forces;
267
268 loop {
269 new_positions = positions.clone();
270 for i in 0..n {
271 if config.fixed_atoms.contains(&i) {
272 continue;
273 }
274 for d in 0..3 {
275 new_positions[i][d] += alpha * p[i * 3 + d];
276 }
277 }
278
279 if let Some(ref lattice) = config.lattice {
280 apply_pbc(&mut new_positions, lattice);
281 }
282
283 let result = energy_force_fn(elements, &new_positions)?;
284 new_energy = result.0;
285 new_forces = result.1;
286
287 if new_energy <= energy + c_armijo * alpha * directional_deriv || alpha < 0.1 {
289 break;
290 }
291 alpha *= 0.5;
292 }
293
294 zero_fixed_forces(&mut new_forces, &config.fixed_atoms);
295
296 let new_grad = flatten_neg_forces(&new_forces);
297
298 let s: Vec<f64> = p; let y: Vec<f64> = (0..ndim).map(|i| new_grad[i] - grad[i]).collect();
301
302 let sy: f64 = s.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
303
304 if sy > 1e-12 {
305 let mut hy = vec![0.0f64; ndim];
307 for i in 0..ndim {
308 for j in 0..ndim {
309 hy[i] += h_inv[i][j] * y[j];
310 }
311 }
312
313 let yhy: f64 = y.iter().zip(hy.iter()).map(|(a, b)| a * b).sum();
314 let rho = 1.0 / sy;
315
316 for i in 0..ndim {
317 for j in 0..ndim {
318 h_inv[i][j] +=
319 rho * ((1.0 + yhy * rho) * s[i] * s[j] - hy[i] * s[j] - s[i] * hy[j]);
320 }
321 }
322
323 let has_negative_diag = (0..ndim).any(|i| h_inv[i][i] <= 0.0);
325 if has_negative_diag {
326 for i in 0..ndim {
327 for j in 0..ndim {
328 h_inv[i][j] = if i == j { 1.0 } else { 0.0 };
329 }
330 }
331 }
332 }
333
334 positions = new_positions;
335 energy = new_energy;
336 forces = new_forces;
337 grad = new_grad;
338 energy_history.push(energy);
339
340 if energy_history.len() > 1 {
341 let de = (energy_history[energy_history.len() - 2] - energy).abs();
342 if de < config.energy_tol {
343 converged = true;
344 break;
345 }
346 }
347 }
348
349 let max_force = max_force_magnitude(&forces);
350
351 Ok(FrameworkOptResult {
352 positions,
353 energy,
354 forces,
355 max_force,
356 n_iterations: n_iter,
357 converged,
358 energy_history,
359 })
360}
361
362fn zero_fixed_forces(forces: &mut [[f64; 3]], fixed: &[usize]) {
363 for &idx in fixed {
364 if idx < forces.len() {
365 forces[idx] = [0.0, 0.0, 0.0];
366 }
367 }
368}
369
370fn max_force_magnitude(forces: &[[f64; 3]]) -> f64 {
371 forces
372 .iter()
373 .map(|f| (f[0] * f[0] + f[1] * f[1] + f[2] * f[2]).sqrt())
374 .fold(0.0f64, f64::max)
375}
376
377fn flatten_neg_forces(forces: &[[f64; 3]]) -> Vec<f64> {
378 let mut g = Vec::with_capacity(forces.len() * 3);
379 for f in forces {
380 g.push(-f[0]);
381 g.push(-f[1]);
382 g.push(-f[2]);
383 }
384 g
385}
386
387fn apply_pbc(positions: &mut [[f64; 3]], lattice: &[[f64; 3]; 3]) {
389 let inv = invert_3x3_lattice(lattice);
391
392 for pos in positions.iter_mut() {
393 let frac = [
395 inv[0][0] * pos[0] + inv[0][1] * pos[1] + inv[0][2] * pos[2],
396 inv[1][0] * pos[0] + inv[1][1] * pos[1] + inv[1][2] * pos[2],
397 inv[2][0] * pos[0] + inv[2][1] * pos[1] + inv[2][2] * pos[2],
398 ];
399
400 let wrapped = [
402 frac[0] - frac[0].floor(),
403 frac[1] - frac[1].floor(),
404 frac[2] - frac[2].floor(),
405 ];
406
407 pos[0] =
409 lattice[0][0] * wrapped[0] + lattice[1][0] * wrapped[1] + lattice[2][0] * wrapped[2];
410 pos[1] =
411 lattice[0][1] * wrapped[0] + lattice[1][1] * wrapped[1] + lattice[2][1] * wrapped[2];
412 pos[2] =
413 lattice[0][2] * wrapped[0] + lattice[1][2] * wrapped[1] + lattice[2][2] * wrapped[2];
414 }
415}
416
417fn invert_3x3_lattice(m: &[[f64; 3]; 3]) -> [[f64; 3]; 3] {
418 let det = m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
419 - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
420 + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0]);
421
422 if det.abs() < 1e-30 {
423 return [[0.0; 3]; 3];
424 }
425
426 let inv_det = 1.0 / det;
427 [
428 [
429 (m[1][1] * m[2][2] - m[1][2] * m[2][1]) * inv_det,
430 (m[0][2] * m[2][1] - m[0][1] * m[2][2]) * inv_det,
431 (m[0][1] * m[1][2] - m[0][2] * m[1][1]) * inv_det,
432 ],
433 [
434 (m[1][2] * m[2][0] - m[1][0] * m[2][2]) * inv_det,
435 (m[0][0] * m[2][2] - m[0][2] * m[2][0]) * inv_det,
436 (m[0][2] * m[1][0] - m[0][0] * m[1][2]) * inv_det,
437 ],
438 [
439 (m[1][0] * m[2][1] - m[1][1] * m[2][0]) * inv_det,
440 (m[0][1] * m[2][0] - m[0][0] * m[2][1]) * inv_det,
441 (m[0][0] * m[1][1] - m[0][1] * m[1][0]) * inv_det,
442 ],
443 ]
444}
445
446pub fn frac_to_cart(frac: &[f64; 3], lattice: &[[f64; 3]; 3]) -> [f64; 3] {
448 [
449 lattice[0][0] * frac[0] + lattice[1][0] * frac[1] + lattice[2][0] * frac[2],
450 lattice[0][1] * frac[0] + lattice[1][1] * frac[1] + lattice[2][1] * frac[2],
451 lattice[0][2] * frac[0] + lattice[1][2] * frac[1] + lattice[2][2] * frac[2],
452 ]
453}
454
455pub fn cart_to_frac(cart: &[f64; 3], lattice: &[[f64; 3]; 3]) -> [f64; 3] {
457 let inv = invert_3x3_lattice(lattice);
458 [
459 inv[0][0] * cart[0] + inv[0][1] * cart[1] + inv[0][2] * cart[2],
460 inv[1][0] * cart[0] + inv[1][1] * cart[1] + inv[1][2] * cart[2],
461 inv[2][0] * cart[0] + inv[2][1] * cart[1] + inv[2][2] * cart[2],
462 ]
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468
469 fn simple_harmonic_energy(
470 _elements: &[u8],
471 positions: &[[f64; 3]],
472 ) -> Result<(f64, Vec<[f64; 3]>), String> {
473 let mut energy = 0.0;
475 let mut forces = Vec::with_capacity(positions.len());
476 for pos in positions {
477 let r2 = pos[0] * pos[0] + pos[1] * pos[1] + pos[2] * pos[2];
478 energy += 0.5 * r2;
479 forces.push([-pos[0], -pos[1], -pos[2]]); }
481 Ok((energy, forces))
482 }
483
484 #[test]
485 fn test_steepest_descent() {
486 let elements = vec![6u8];
487 let initial = vec![[1.0, 0.5, 0.2]];
488 let config = FrameworkOptConfig {
489 method: OptMethod::SteepestDescent,
490 max_iter: 100,
491 force_tol: 0.01,
492 ..Default::default()
493 };
494
495 let result =
496 optimize_framework(&elements, &initial, &simple_harmonic_energy, &config).unwrap();
497 assert!(result.converged);
498 assert!(result.positions[0][0].abs() < 0.1);
499 assert!(result.positions[0][1].abs() < 0.1);
500 }
501
502 #[test]
503 fn test_bfgs() {
504 let elements = vec![6u8];
505 let initial = vec![[1.0, 0.5, 0.2]];
506 let config = FrameworkOptConfig {
507 method: OptMethod::Bfgs,
508 max_iter: 50,
509 force_tol: 0.01,
510 ..Default::default()
511 };
512
513 let result =
514 optimize_framework(&elements, &initial, &simple_harmonic_energy, &config).unwrap();
515 assert!(result.converged);
516 assert!(result.n_iterations < 20);
518 }
519
520 #[test]
521 fn test_fixed_atoms() {
522 let elements = vec![6u8, 8u8];
523 let initial = vec![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
524 let config = FrameworkOptConfig {
525 method: OptMethod::Bfgs,
526 max_iter: 50,
527 force_tol: 0.01,
528 fixed_atoms: vec![0], ..Default::default()
530 };
531
532 let result =
533 optimize_framework(&elements, &initial, &simple_harmonic_energy, &config).unwrap();
534 assert!((result.positions[0][0] - 1.0).abs() < 1e-10);
536 assert!(result.positions[1][1].abs() < 0.2);
538 }
539
540 #[test]
541 fn test_frac_cart_conversion() {
542 let lattice = [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]];
543 let frac = [0.5, 0.25, 0.1];
544 let cart = frac_to_cart(&frac, &lattice);
545 assert!((cart[0] - 5.0).abs() < 1e-10);
546 assert!((cart[1] - 2.5).abs() < 1e-10);
547 assert!((cart[2] - 1.0).abs() < 1e-10);
548
549 let back = cart_to_frac(&cart, &lattice);
550 assert!((back[0] - 0.5).abs() < 1e-10);
551 }
552}