1use crate::errors::RocheError;
2use crate::x_l1;
3use crate::{Vec3, vel_transform};
4use bulirsch::{self, Integrator};
5use pyo3::prelude::*;
6
7#[pyfunction]
21pub fn strinit(q: f64) -> Result<(Vec3, Vec3), RocheError> {
22 const SMALL: f64 = 1.0e-5;
23 let rl1: f64 = x_l1(q)?;
24 let mu: f64 = q / (1.0 + q);
25 let a: f64 = (1.0 - mu) / rl1.powi(3) + mu / (1.0 - rl1).powi(3);
26 let lambda1: f64 = (((a - 2.0) + (a * (9.0 * a - 8.0)).sqrt()) / 2.0).sqrt();
27 let m1: f64 = (lambda1 * lambda1 - 2.0 * a - 1.0) / 2.0 / lambda1;
28
29 let r: Vec3 = Vec3::new(rl1 - SMALL, -m1 * SMALL, 0.0);
30 let v: Vec3 = Vec3::new(-lambda1 * SMALL, -lambda1 * m1 * SMALL, 0.0);
31
32 Ok((r, v))
33}
34
35#[pyfunction]
54#[pyo3(signature = (q, step, n_points=200))]
55pub fn stream(q: f64, step: f64, n_points: usize) -> Result<(Vec<f64>, Vec<f64>), RocheError> {
56 if n_points < 2 {
57 return Err(RocheError::ParameterError(
58 "Need at least 2 points in the stream.".to_string(),
59 ));
60 }
61
62 if step <= 0.0 || step > 1.0 {
63 return Err(RocheError::ParameterError(
64 "Step size must be between 0.0 and 1.0".to_string(),
65 ));
66 }
67
68 if q <= 0.0 {
69 return Err(RocheError::ParameterError("q = {} <= 0".to_string()));
70 }
71
72 let mut x_arr: Vec<f64> = vec![];
73 let mut y_arr: Vec<f64> = vec![];
74
75 let rl1: f64 = x_l1(q)?;
77 let (mut r, mut v) = strinit(q)?;
78
79 x_arr.push(rl1);
81 y_arr.push(0.0);
82
83 let mut lp: usize = 0;
84
85 let mut dist: f64 = (r.x - rl1).hypot(r.y);
89
90 let frac: f64;
91
92 if dist > step {
93 frac = step / dist;
94 x_arr.push(rl1 + (r.x - rl1) * frac);
95 y_arr.push(r.y * frac);
96 lp += 1;
97 }
98
99 let system = OrbitalSystem { q };
101 let mut integrator = Integrator::default()
102 .with_abs_tol(1.0e-8)
103 .with_rel_tol(1.0e-8)
104 .into_adaptive();
105 let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
107 let mut y_next = ndarray::Array::zeros(y.raw_dim());
108
109 let mut delta_t: f64 = 1.0e-3;
110 let smax: f64 = (1.0e-3_f64).min(step / 2.0);
111
112 let mut vel: f64;
113 while lp < n_points - 1 {
114 integrator
115 .step(&system, delta_t, y.view(), y_next.view_mut())
116 .unwrap();
117 y.assign(&y_next);
118
119 r.set(y[0], y[1], y[2]);
120 v.set(y[3], y[4], y[5]);
121 dist = (r.x - x_arr[lp]).hypot(r.y - y_arr[lp]);
122 if dist > step {
123 let frac: f64 = step / dist;
124 x_arr.push(x_arr[lp] + (r.x - x_arr[lp]) * frac);
125 y_arr.push(y_arr[lp] + (r.y - y_arr[lp]) * frac);
126 lp += 1;
127 }
128 vel = v.x.hypot(v.y);
129 delta_t = (smax / vel).min(delta_t);
130 }
131
132 Ok((x_arr, y_arr))
133}
134
135pub fn strmnx(q: f64, r: &mut Vec3, v: &mut Vec3, acc: f64) -> Result<(), RocheError> {
148 let mut dir: f64;
149 let mut lo: f64;
150 let mut hi: f64;
151 let mut ro: Vec3 = *r;
152 let mut vo: Vec3 = *v;
153
154 let mut delta_t: f64 = 1.0e-2;
155
156 dir = r.dot(v);
158 let dir1: f64 = dir;
159
160 let system = OrbitalSystem { q };
162 let mut integrator = Integrator::default()
163 .with_abs_tol(1.0e-8)
164 .with_rel_tol(1.0e-8)
165 .into_adaptive();
166 let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
168 let mut y_next = ndarray::Array::zeros(y.raw_dim());
169 let mut yo = y.clone();
170
171 while (dir > 0.0 && dir1 > 0.0) || (dir < 0.0 && dir1 < 0.0) {
172 ro = *r;
173 vo = *v;
174 yo = y.clone();
175 integrator
176 .step(&system, delta_t, y.view(), y_next.view_mut())
177 .unwrap();
178 y.assign(&y_next);
179 r.set(y[0], y[1], y[2]);
180 v.set(y[3], y[4], y[5]);
181 dir = r.dot(v);
182 }
183
184 lo = 0.0;
188 hi = delta_t;
189 while (hi - lo).abs() > acc {
190 delta_t = (lo + hi) / 2.0;
191 y = yo.clone();
192 *r = ro;
193 *v = vo;
194 integrator
195 .step(&system, delta_t, y.view(), y_next.view_mut())
196 .unwrap();
197 y.assign(&y_next);
198
199 r.set(y[0], y[1], y[2]);
200 v.set(y[3], y[4], y[5]);
201 dir = r.dot(v);
202 if (dir > 0.0 && dir1 < 0.0) || (dir < 0.0 && dir1 > 0.0) {
203 hi = delta_t;
204 } else {
205 lo = delta_t;
206 }
207 }
208
209 Ok(())
210}
211
212#[pyfunction]
230#[pyo3(name = "strmnx")]
231#[pyo3(signature = (q, n=1, acc=1.0e-7))]
232pub fn strmnx_wrapper(
233 q: f64,
234 n: usize,
235 acc: f64,
236) -> Result<(f64, f64, f64, f64, f64, f64), RocheError> {
237 let (mut r, mut v) = strinit(q)?;
238 for _ in 0..n {
239 strmnx(q, &mut r, &mut v, acc)?
240 }
241 let (tvx1, tvy1) = vel_transform(q, 1, r.x, r.y, v.x, v.y)?;
242 let (tvx2, tvy2) = vel_transform(q, 2, r.x, r.y, v.x, v.y)?;
243 Ok((r.x, r.y, tvx1, tvy1, tvx2, tvy2))
244}
245
246#[pyfunction]
263#[pyo3(signature = (q, rad, n_points=200))]
264pub fn streamr(q: f64, rad: f64, n_points: usize) -> Result<(Vec<f64>, Vec<f64>), RocheError> {
265 if n_points < 2 {
266 return Err(RocheError::ParameterError(
267 "Need at least 2 points in the stream.".to_string(),
268 ));
269 }
270
271 if q <= 0.0 {
272 return Err(RocheError::ParameterError("q = {} <= 0".to_string()));
273 }
274
275 const EPS: f64 = 1.0e-8;
276
277 let mut x_arr: Vec<f64> = vec![];
278 let mut y_arr: Vec<f64> = vec![];
279
280 let rl1: f64 = x_l1(q)?;
282 let (mut r, mut v) = strinit(q)?;
283 let rs = r;
284 let vs = v;
285 strmnx(q, &mut r, &mut v, EPS)?;
286 let rmin = if r.length() > rad { r.length() } else { rad };
287
288 r = rs;
289 v = vs;
290 x_arr.push(r.x);
291 y_arr.push(r.y);
292 let mut rnext: f64;
293 for i in 1..n_points {
294 rnext = rl1 + (rmin - rl1) * (i as f64) / (n_points as f64 - 1.0);
295 stradv(q, &mut r, &mut v, rnext, 1.0e-6, 1.0e-4);
296 x_arr.push(r.x);
297 y_arr.push(r.y);
298 }
299
300 Ok((x_arr, y_arr))
301}
302
303pub fn stradv(q: f64, r: &mut Vec3, v: &mut Vec3, rad: f64, acc: f64, smax: f64) -> f64 {
326 const TMAX: f64 = 10.0;
327 let t_next: f64 = 1.0e-2;
328
329 let mut time: f64 = 0.0;
330
331 let mut ro = *r;
333 let mut vo = *v;
334
335 let rinit: f64 = r.length();
337 let mut rnow: f64 = rinit;
338
339 let system = OrbitalSystem { q };
341 let mut integrator = Integrator::default()
342 .with_abs_tol(1.0e-8)
343 .with_rel_tol(1.0e-8)
344 .into_adaptive();
345 let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
347 let mut y_next = ndarray::Array::zeros(y.raw_dim());
348
349 let mut yo = y.clone();
350 let mut delta_t = t_next.min(smax);
351 while (rinit > rad && rnow > rad) || (rinit < rad && rnow < rad) {
353 ro = *r;
354 vo = *v;
355 yo = y.clone();
356 integrator
357 .step(&system, delta_t, y.view(), y_next.view_mut())
358 .unwrap();
359 y.assign(&y_next);
360 r.set(y[0], y[1], y[2]);
361 v.set(y[3], y[4], y[5]);
362 rnow = r.length();
363 time += delta_t;
364
365 if time > TMAX {
366 panic!("roche::stradv taken too long without crossing given radius.")
367 }
368 }
369
370 let mut lo: f64 = 0.0;
374 let mut hi: f64 = delta_t;
375 let mut rlo: f64 = ro.length();
376 let mut rhi: f64 = rnow;
377 let to: f64 = time;
378
379 while (rhi - rlo).abs() > acc {
380 delta_t = (lo + hi) / 2.0;
381 y = yo.clone();
382 *r = ro;
383 *v = vo;
384 time = to;
385
386 integrator
387 .step(&system, delta_t, y.view(), y_next.view_mut())
388 .unwrap();
389 y.assign(&y_next);
390
391 r.set(y[0], y[1], y[2]);
392 v.set(y[3], y[4], y[5]);
393 rnow = r.length();
394
395 if (rhi > rad && rnow > rad) || (rhi < rad && rnow < rad) {
396 rhi = rnow;
397 hi = delta_t;
398 } else {
399 rlo = rnow;
400 lo = delta_t;
401 }
402 }
403
404 time
405}
406
407#[pyfunction]
427#[pyo3(name = "stradv")]
428pub fn stradv_wrapper(
429 q: f64,
430 r: &Vec3,
431 v: &Vec3,
432 rad: f64,
433 acc: f64,
434 smax: f64,
435) -> (f64, Vec3, Vec3) {
436 let mut r_mut = *r;
437 let mut v_mut = *v;
438 let timestep = stradv(q, &mut r_mut, &mut v_mut, rad, acc, smax);
439 (timestep, r_mut, v_mut)
440}
441
442#[pyfunction]
451pub fn rocacc(q: f64, r: &Vec3, v: &Vec3) -> (f64, f64, f64) {
452 let f1: f64 = 1.0 / (1.0 + q);
453 let f2: f64 = f1 * q;
454
455 let yzsq: f64 = r.y * r.y + r.z * r.z;
456 let r1sq: f64 = r.x * r.x + yzsq;
457 let r2sq: f64 = (r.x - 1.0) * (r.x - 1.0) + yzsq;
458 let fm1: f64 = f1 / (r1sq * (r1sq.sqrt()));
459 let fm2: f64 = f2 / (r2sq * (r2sq.sqrt()));
460 let fm3: f64 = fm1 + fm2;
461
462 let x: f64 = -fm3 * r.x + fm2 + 2.0 * v.y + r.x - f2;
463 let y: f64 = -fm3 * r.y - 2.0 * v.x + r.y;
464 let z: f64 = -fm3 * r.z;
465 (x, y, z)
466}
467
468#[pyfunction]
483#[pyo3(signature = (q, rad, acc=1.0e-7, smax=1.0e-2))]
484pub fn brightspot_position(q: f64, rad: f64, acc: f64, smax: f64) -> Result<Vec3, RocheError> {
485 let (mut r, mut v) = strinit(q)?;
486 let _ = stradv(q, &mut r, &mut v, rad, acc, smax);
487
488 Ok(r)
489}
490
491pub struct OrbitalSystem {
492 pub q: f64,
493}
494
495impl bulirsch::System for OrbitalSystem {
496 type Float = f64;
497
498 fn system(
499 &self,
500 y: bulirsch::ArrayView1<Self::Float>,
501 mut dydt: bulirsch::ArrayViewMut1<Self::Float>,
502 ) {
503 dydt[[0]] = y[[3]];
504 dydt[[1]] = y[[4]];
505 dydt[[2]] = y[[5]];
506 let r = Vec3::new(y[[0]], y[[1]], y[[2]]);
507 let v = Vec3::new(y[[3]], y[[4]], y[[5]]);
508 (dydt[[3]], dydt[[4]], dydt[[5]]) = rocacc(self.q, &r, &v);
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515
516 #[test]
517 fn strinit_stradv_test() -> Result<(), RocheError> {
518 let (mut r, mut v) = strinit(0.2)?;
520 let _time: f64 = stradv(0.2, &mut r, &mut v, 0.3, 1.0e-7, 1.0e-3);
521 assert!((r - Vec3::new(0.2660591412807423, 0.13860932478255575, 0.0)).length() < 1.0e-7);
522 assert!((v - Vec3::new(-1.4769457229627583, 0.31712381217252994, 0.0)).length() < 1.0e-7);
523 Ok(())
524 }
525
526 #[test]
527 fn stream_test() -> Result<(), RocheError> {
528 let (x, y) = stream(0.2, 0.01, 200)?;
530 assert!((x[0] - 0.6585557).hypot(y[0] - 0.0) < 1.0e-4);
531 assert!((x[50] - 0.18384902).hypot(y[50] - 0.15145306) < 1.0e-4);
532 assert!((x[100] - -0.100431986).hypot(y[100] - -0.13697079) < 1.0e-4);
533 assert!((x[150] - 0.21720248).hypot(y[150] - -0.4577784) < 1.0e-4);
534 assert!((x[y.len() - 1] - 0.15403406).hypot(y[y.len() - 1] - 0.016731631) < 1.0e-4);
535 assert!(stream(-0.2, 0.0001, 200).is_err());
536 assert!(stream(0.2, 1.1, 200).is_err());
537 assert!(stream(0.2, -0.1, 200).is_err());
538 assert!(stream(0.2, 0.0001, 1).is_err());
539 Ok(())
540 }
541
542 #[test]
543 fn strmnx_test() -> Result<(), RocheError> {
544 let (x, y, vx1, vy1, vx2, vy2) = strmnx_wrapper(0.2, 1, 1.0e-7)?;
546 assert!(
547 (x - -0.08613947462186848).hypot(y - 0.05411592729509131)
548 / (-0.08613947462186848_f64).hypot(0.05411592729509131)
549 < 1.0e-6
550 );
551 assert!(
552 (vx1 - -1.9727409465489645).hypot(vy1 - -3.30679322752132)
553 / (-1.9727409465489645_f64).hypot(-3.30679322752132)
554 < 1.0e-6
555 );
556 assert!(
557 (vx2 - -1.5225623467338747).hypot(vy2 - -2.5902178683586605)
558 / (-1.5225623467338747_f64).hypot(-2.5902178683586605)
559 < 1.0e-6
560 );
561 Ok(())
562 }
563
564 #[test]
565 fn brightspot_position_test() -> Result<(), RocheError> {
566 let r = brightspot_position(0.2, 0.3, 1.0e-7, 1.0e-3)?;
568 assert!((r - Vec3::new(0.2660591412807423, 0.13860932478255575, 0.0)).length() < 1.0e-7);
569 Ok(())
570 }
571}