1use crate::errors::RocheError;
2use crate::x_l1;
3use crate::{Vec3, vel_transform};
4use bulirsch::{self, Integrator};
5use pyo3::prelude::*;
6
7#[pyfunction]
21pub fn strinit(q: f64) -> Result<(Vec3, Vec3), RocheError> {
22 const SMALL: f64 = 1.0e-5;
23 let rl1: f64 = x_l1(q)?;
24 let mu: f64 = q / (1.0 + q);
25 let a: f64 = (1.0 - mu) / rl1.powi(3) + mu / (1.0 - rl1).powi(3);
26 let lambda1: f64 = (((a - 2.0) + (a * (9.0 * a - 8.0)).sqrt()) / 2.0).sqrt();
27 let m1: f64 = (lambda1 * lambda1 - 2.0 * a - 1.0) / 2.0 / lambda1;
28
29 let r: Vec3 = Vec3::new(rl1 - SMALL, -m1 * SMALL, 0.0);
30 let v: Vec3 = Vec3::new(-lambda1 * SMALL, -lambda1 * m1 * SMALL, 0.0);
31
32 Ok((r, v))
33}
34
35#[pyfunction]
54#[pyo3(signature = (q, step, n_points=200))]
55pub fn stream(q: f64, step: f64, n_points: usize) -> Result<(Vec<f64>, Vec<f64>), RocheError> {
56 if n_points < 2 {
57 return Err(RocheError::ParameterError(
58 "Need at least 2 points in the stream.".to_string(),
59 ));
60 }
61
62 if step <= 0.0 || step > 1.0 {
63 return Err(RocheError::ParameterError(
64 "Step size must be between 0.0 and 1.0".to_string(),
65 ));
66 }
67
68 if q <= 0.0 {
69 return Err(RocheError::ParameterError("q = {} <= 0".to_string()));
70 }
71
72 let mut x_arr: Vec<f64> = vec![];
73 let mut y_arr: Vec<f64> = vec![];
74
75 let rl1: f64 = x_l1(q)?;
77 let (mut r, mut v) = strinit(q)?;
78
79 x_arr.push(rl1);
81 y_arr.push(0.0);
82
83 let mut lp: usize = 0;
84
85 let mut dist = (r.x - rl1).hypot(r.y);
89
90 let frac: f64;
91
92 if dist > step {
93 frac = step / dist;
94 x_arr.push(rl1 + (r.x - rl1) * frac);
95 y_arr.push(r.y * frac);
96 lp += 1;
97 }
98
99 let system = OrbitalSystem { q: q };
101 let mut integrator = Integrator::default()
102 .with_abs_tol(1.0e-8)
103 .with_rel_tol(1.0e-8)
104 .into_adaptive();
105 let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
107 let mut y_next = ndarray::Array::zeros(y.raw_dim());
108
109 let mut delta_t = 1.0e-3;
110 let smax = (1.0e-3_f64).min(step / 2.0);
111
112 let mut vel: f64;
113 while lp < n_points - 1 {
114 integrator
115 .step(&system, delta_t, y.view(), y_next.view_mut())
116 .unwrap();
117 y.assign(&y_next);
118
119 r.set(y[0], y[1], y[2]);
120 v.set(y[3], y[4], y[5]);
121 dist = (r.x - x_arr[lp]).hypot(r.y - y_arr[lp]);
122 if dist > step {
123 let frac: f64 = step / dist;
124 x_arr.push(x_arr[lp] + (r.x - x_arr[lp]) * frac);
125 y_arr.push(y_arr[lp] + (r.y - y_arr[lp]) * frac);
126 lp += 1;
127 }
128 vel = v.x.hypot(v.y);
129 delta_t = (smax / vel).min(delta_t);
130 }
131
132 Ok((x_arr, y_arr))
133}
134
135pub fn strmnx(q: f64, r: &mut Vec3, v: &mut Vec3, acc: f64) -> Result<(), RocheError> {
148 let mut dir: f64;
149 let dir1: f64;
150 let mut lo: f64;
151 let mut hi: f64;
152 let mut ro: Vec3 = *r;
153 let mut vo: Vec3 = *v;
154
155 let mut delta_t: f64 = 1.0e-2;
156
157 dir = r.dot(v);
159 dir1 = dir;
160
161 let system = OrbitalSystem { q: q };
163 let mut integrator = Integrator::default()
164 .with_abs_tol(1.0e-8)
165 .with_rel_tol(1.0e-8)
166 .into_adaptive();
167 let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
169 let mut y_next = ndarray::Array::zeros(y.raw_dim());
170 let mut yo = y.clone();
171
172 while (dir > 0.0 && dir1 > 0.0) || (dir < 0.0 && dir1 < 0.0) {
173 ro = *r;
174 vo = *v;
175 yo = y.clone();
176 integrator
177 .step(&system, delta_t, y.view(), y_next.view_mut())
178 .unwrap();
179 y.assign(&y_next);
180 r.set(y[0], y[1], y[2]);
181 v.set(y[3], y[4], y[5]);
182 dir = r.dot(v);
183 }
184
185 lo = 0.0;
189 hi = delta_t;
190 while (hi - lo).abs() > acc {
191 delta_t = (lo + hi) / 2.0;
192 y = yo.clone();
193 *r = ro;
194 *v = vo;
195 integrator
196 .step(&system, delta_t, y.view(), y_next.view_mut())
197 .unwrap();
198 y.assign(&y_next);
199
200 r.set(y[0], y[1], y[2]);
201 v.set(y[3], y[4], y[5]);
202 dir = r.dot(v);
203 if (dir > 0.0 && dir1 < 0.0) || (dir < 0.0 && dir1 > 0.0) {
204 hi = delta_t;
205 } else {
206 lo = delta_t;
207 }
208 }
209
210 Ok(())
211}
212
213#[pyfunction]
231#[pyo3(name = "strmnx")]
232#[pyo3(signature = (q, n=1, acc=1.0e-7))]
233pub fn strmnx_wrapper(
234 q: f64,
235 n: usize,
236 acc: f64,
237) -> Result<(f64, f64, f64, f64, f64, f64), RocheError> {
238 let (mut r, mut v) = strinit(q)?;
239 for _ in 0..n {
240 strmnx(q, &mut r, &mut v, acc)?
241 }
242 let (tvx1, tvy1) = vel_transform(q, 1, r.x, r.y, v.x, v.y)?;
243 let (tvx2, tvy2) = vel_transform(q, 2, r.x, r.y, v.x, v.y)?;
244 Ok((r.x, r.y, tvx1, tvy1, tvx2, tvy2))
245}
246
247#[pyfunction]
264#[pyo3(signature = (q, rad, n_points=200))]
265pub fn streamr(q: f64, rad: f64, n_points: usize) -> Result<(Vec<f64>, Vec<f64>), RocheError> {
266 if n_points < 2 {
267 return Err(RocheError::ParameterError(
268 "Need at least 2 points in the stream.".to_string(),
269 ));
270 }
271
272 if q <= 0.0 {
273 return Err(RocheError::ParameterError("q = {} <= 0".to_string()));
274 }
275
276 const EPS: f64 = 1.0e-8;
277
278 let mut x_arr: Vec<f64> = vec![];
279 let mut y_arr: Vec<f64> = vec![];
280
281 let rl1: f64 = x_l1(q)?;
283 let (mut r, mut v) = strinit(q)?;
284 let rs = r;
285 let vs = v;
286 strmnx(q, &mut r, &mut v, EPS)?;
287 let rmin = if r.length() > rad { r.length() } else { rad };
288
289 r = rs;
290 v = vs;
291 x_arr.push(r.x);
292 y_arr.push(r.y);
293 let mut rnext: f64;
294 for i in 1..n_points {
295 rnext = rl1 + (rmin - rl1) * (i as f64) / (n_points as f64 - 1.0);
296 stradv(q, &mut r, &mut v, rnext, 1.0e-6, 1.0e-4);
297 x_arr.push(r.x);
298 y_arr.push(r.y);
299 }
300
301 Ok((x_arr, y_arr))
302}
303
304pub fn stradv(q: f64, r: &mut Vec3, v: &mut Vec3, rad: f64, acc: f64, smax: f64) -> f64 {
327 const TMAX: f64 = 10.0;
328 let t_next: f64 = 1.0e-2;
329
330 let mut time: f64 = 0.0;
331
332 let mut ro = *r;
334 let mut vo = *v;
335
336 let rinit: f64 = r.length();
338 let mut rnow: f64 = rinit;
339
340 let system = OrbitalSystem { q: q };
342 let mut integrator = Integrator::default()
343 .with_abs_tol(1.0e-8)
344 .with_rel_tol(1.0e-8)
345 .into_adaptive();
346 let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
348 let mut y_next = ndarray::Array::zeros(y.raw_dim());
349
350 let mut yo = y.clone();
351 let mut delta_t = t_next.min(smax);
352 while (rinit > rad && rnow > rad) || (rinit < rad && rnow < rad) {
354 ro = *r;
355 vo = *v;
356 yo = y.clone();
357 integrator
358 .step(&system, delta_t, y.view(), y_next.view_mut())
359 .unwrap();
360 y.assign(&y_next);
361 r.set(y[0], y[1], y[2]);
362 v.set(y[3], y[4], y[5]);
363 rnow = r.length();
364 time += delta_t;
365
366 if time > TMAX {
367 panic!("roche::stradv taken too long without crossing given radius.")
368 }
369 }
370
371 let mut lo: f64 = 0.0;
375 let mut hi: f64 = delta_t;
376 let mut rlo: f64 = ro.length();
377 let mut rhi: f64 = rnow;
378 let to: f64 = time;
379
380 while (rhi - rlo).abs() > acc {
381 delta_t = (lo + hi) / 2.0;
382 y = yo.clone();
383 *r = ro;
384 *v = vo;
385 time = to;
386
387 integrator
388 .step(&system, delta_t, y.view(), y_next.view_mut())
389 .unwrap();
390 y.assign(&y_next);
391
392 r.set(y[0], y[1], y[2]);
393 v.set(y[3], y[4], y[5]);
394 rnow = r.length();
395
396 if (rhi > rad && rnow > rad) || (rhi < rad && rnow < rad) {
397 rhi = rnow;
398 hi = delta_t;
399 } else {
400 rlo = rnow;
401 lo = delta_t;
402 }
403 }
404
405 time
406}
407
408#[pyfunction]
428#[pyo3(name = "stradv")]
429pub fn stradv_wrapper(
430 q: f64,
431 r: &Vec3,
432 v: &Vec3,
433 rad: f64,
434 acc: f64,
435 smax: f64,
436) -> (f64, Vec3, Vec3) {
437 let mut r_mut = *r;
438 let mut v_mut = *v;
439 let timestep = stradv(q, &mut r_mut, &mut v_mut, rad, acc, smax);
440 (timestep, r_mut, v_mut)
441}
442
443#[pyfunction]
452pub fn rocacc(q: f64, r: &Vec3, v: &Vec3) -> (f64, f64, f64) {
453 let f1: f64 = 1.0 / (1.0 + q);
454 let f2: f64 = f1 * q;
455
456 let yzsq: f64 = r.y * r.y + r.z * r.z;
457 let r1sq: f64 = r.x * r.x + yzsq;
458 let r2sq: f64 = (r.x - 1.0) * (r.x - 1.0) + yzsq;
459 let fm1: f64 = f1 / (r1sq * (r1sq.sqrt()));
460 let fm2: f64 = f2 / (r2sq * (r2sq.sqrt()));
461 let fm3 = fm1 + fm2;
462
463 let x: f64 = -fm3 * r.x + fm2 + 2.0 * v.y + r.x - f2;
464 let y: f64 = -fm3 * r.y - 2.0 * v.x + r.y;
465 let z: f64 = -fm3 * r.z;
466 (x, y, z)
467}
468
469#[pyfunction]
484#[pyo3(signature = (q, rad, acc=1.0e-7, smax=1.0e-2))]
485pub fn brightspot_position(q: f64, rad: f64, acc: f64, smax: f64) -> Result<Vec3, RocheError> {
486 let (mut r, mut v) = strinit(q)?;
487 let _ = stradv(q, &mut r, &mut v, rad, acc, smax);
488
489 Ok(r)
490}
491
492struct OrbitalSystem {
493 q: f64,
494}
495
496impl bulirsch::System for OrbitalSystem {
497 type Float = f64;
498
499 fn system(
500 &self,
501 y: bulirsch::ArrayView1<Self::Float>,
502 mut dydt: bulirsch::ArrayViewMut1<Self::Float>,
503 ) {
504 dydt[[0]] = y[[3]];
505 dydt[[1]] = y[[4]];
506 dydt[[2]] = y[[5]];
507 let r = Vec3::new(y[[0]], y[[1]], y[[2]]);
508 let v = Vec3::new(y[[3]], y[[4]], y[[5]]);
509 (dydt[[3]], dydt[[4]], dydt[[5]]) = rocacc(self.q, &r, &v);
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn strinit_stradv_test() -> Result<(), RocheError> {
519 let (mut r, mut v) = strinit(0.2)?;
521 let _time = stradv(0.2, &mut r, &mut v, 0.3, 1.0e-7, 1.0e-3);
522 assert!((r - Vec3::new(0.2660591412807423, 0.13860932478255575, 0.0)).length() < 1.0e-7);
523 assert!((v - Vec3::new(-1.4769457229627583, 0.31712381217252994, 0.0)).length() < 1.0e-7);
524 Ok(())
525 }
526
527 #[test]
528 fn stream_test() -> Result<(), RocheError> {
529 let (x, y) = stream(0.2, 0.01, 200)?;
531 assert!((x[0] - 0.6585557).hypot(y[0] - 0.0) < 1.0e-4);
532 assert!((x[50] - 0.18384902).hypot(y[50] - 0.15145306) < 1.0e-4);
533 assert!((x[100] - -0.100431986).hypot(y[100] - -0.13697079) < 1.0e-4);
534 assert!((x[150] - 0.21720248).hypot(y[150] - -0.4577784) < 1.0e-4);
535 assert!((x[y.len() - 1] - 0.15403406).hypot(y[y.len() - 1] - 0.016731631) < 1.0e-4);
536 assert!(stream(-0.2, 0.0001, 200).is_err());
537 assert!(stream(0.2, 1.1, 200).is_err());
538 assert!(stream(0.2, -0.1, 200).is_err());
539 assert!(stream(0.2, 0.0001, 1).is_err());
540 Ok(())
541 }
542
543 #[test]
544 fn strmnx_test() -> Result<(), RocheError> {
545 let (x, y, vx1, vy1, vx2, vy2) = strmnx_wrapper(0.2, 1, 1.0e-7)?;
547 assert!(
548 (x - -0.08613947462186848).hypot(y - 0.05411592729509131)
549 / (-0.08613947462186848_f64).hypot(0.05411592729509131)
550 < 1.0e-6
551 );
552 assert!(
553 (vx1 - -1.9727409465489645).hypot(vy1 - -3.30679322752132)
554 / (-1.9727409465489645_f64).hypot(-3.30679322752132)
555 < 1.0e-6
556 );
557 assert!(
558 (vx2 - -1.5225623467338747).hypot(vy2 - -2.5902178683586605)
559 / (-1.5225623467338747_f64).hypot(-2.5902178683586605)
560 < 1.0e-6
561 );
562 Ok(())
563 }
564
565 #[test]
566 fn brightspot_position_test() -> Result<(), RocheError> {
567 let r = brightspot_position(0.2, 0.3, 1.0e-7, 1.0e-3)?;
569 assert!((r - Vec3::new(0.2660591412807423, 0.13860932478255575, 0.0)).length() < 1.0e-7);
570 Ok(())
571 }
572}