Skip to main content

roche/
stream_physics.rs

1use crate::errors::RocheError;
2use crate::x_l1;
3use crate::{Vec3, vel_transform};
4use bulirsch::{self, Integrator};
5use pyo3::prelude::*;
6
7///
8/// strinit sets a particle just inside the L1 point with the
9/// correct velocity as given in Lubow and Shu.
10///
11/// Arguments:
12///
13/// * `q`: mass ratio = M2/M1
14///
15/// Returns:
16///
17/// * start position
18/// * start velocity
19///
20#[pyfunction]
21pub fn strinit(q: f64) -> Result<(Vec3, Vec3), RocheError> {
22    const SMALL: f64 = 1.0e-5;
23    let rl1: f64 = x_l1(q)?;
24    let mu: f64 = q / (1.0 + q);
25    let a: f64 = (1.0 - mu) / rl1.powi(3) + mu / (1.0 - rl1).powi(3);
26    let lambda1: f64 = (((a - 2.0) + (a * (9.0 * a - 8.0)).sqrt()) / 2.0).sqrt();
27    let m1: f64 = (lambda1 * lambda1 - 2.0 * a - 1.0) / 2.0 / lambda1;
28
29    let r: Vec3 = Vec3::new(rl1 - SMALL, -m1 * SMALL, 0.0);
30    let v: Vec3 = Vec3::new(-lambda1 * SMALL, -lambda1 * m1 * SMALL, 0.0);
31
32    Ok((r, v))
33}
34
35///
36/// stream works by integrating the equations of motion for the Roche
37/// potential using Burlisch-Stoer integration. Every time the distance
38/// from the last point exceeds step, it interpolates and stores a new
39/// point. This allows one not to spend loads of points on regions where
40/// nothing is happening.
41///
42/// Arguments:
43///
44/// * `q`:    mass ratio = M2/M1. Stream flows from star 2 to 1.
45/// * `step`: step between points (units of separation).
46/// * `n_points`:    number of points to compute.
47///
48/// Returns:
49///
50/// * `x`:    array of x values returned.
51/// * `y`:    array of y values returned.
52///
53#[pyfunction]
54#[pyo3(signature = (q, step, n_points=200))]
55pub fn stream(q: f64, step: f64, n_points: usize) -> Result<(Vec<f64>, Vec<f64>), RocheError> {
56    if n_points < 2 {
57        return Err(RocheError::ParameterError(
58            "Need at least 2 points in the stream.".to_string(),
59        ));
60    }
61
62    if step <= 0.0 || step > 1.0 {
63        return Err(RocheError::ParameterError(
64            "Step size must be between 0.0 and 1.0".to_string(),
65        ));
66    }
67
68    if q <= 0.0 {
69        return Err(RocheError::ParameterError("q = {} <= 0".to_string()));
70    }
71
72    let mut x_arr: Vec<f64> = vec![];
73    let mut y_arr: Vec<f64> = vec![];
74
75    // Initialise stream
76    let rl1: f64 = x_l1(q)?;
77    let (mut r, mut v) = strinit(q)?;
78
79    // Store L1 as first point
80    x_arr.push(rl1);
81    y_arr.push(0.0);
82
83    let mut lp: usize = 0;
84
85    // Store interpolation between L1 and initial point if
86    // step has been set small enough
87
88    let mut dist: f64 = (r.x - rl1).hypot(r.y);
89
90    let frac: f64;
91
92    if dist > step {
93        frac = step / dist;
94        x_arr.push(rl1 + (r.x - rl1) * frac);
95        y_arr.push(r.y * frac);
96        lp += 1;
97    }
98
99    // set up Bulirsch-Stoer integrator
100    let system = OrbitalSystem { q };
101    let mut integrator = Integrator::default()
102        .with_abs_tol(1.0e-8)
103        .with_rel_tol(1.0e-8)
104        .into_adaptive();
105    // Initialise arrays
106    let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
107    let mut y_next = ndarray::Array::zeros(y.raw_dim());
108
109    let mut delta_t: f64 = 1.0e-3;
110    let smax: f64 = (1.0e-3_f64).min(step / 2.0);
111
112    let mut vel: f64;
113    while lp < n_points - 1 {
114        integrator
115            .step(&system, delta_t, y.view(), y_next.view_mut())
116            .unwrap();
117        y.assign(&y_next);
118
119        r.set(y[0], y[1], y[2]);
120        v.set(y[3], y[4], y[5]);
121        dist = (r.x - x_arr[lp]).hypot(r.y - y_arr[lp]);
122        if dist > step {
123            let frac: f64 = step / dist;
124            x_arr.push(x_arr[lp] + (r.x - x_arr[lp]) * frac);
125            y_arr.push(y_arr[lp] + (r.y - y_arr[lp]) * frac);
126            lp += 1;
127        }
128        vel = v.x.hypot(v.y);
129        delta_t = (smax / vel).min(delta_t);
130    }
131
132    Ok((x_arr, y_arr))
133}
134
135///
136/// strmnx finds the next point at which stream is closest or furthest
137/// from primary.
138///
139/// Arguments:
140///
141/// * `q`: mass ratio = M2/M1
142/// * `r`: initial and final position
143/// * `v`: initial and final velocity
144/// * `acc`: accuracy in time to locate minimum/maximum.
145///
146///
147pub fn strmnx(q: f64, r: &mut Vec3, v: &mut Vec3, acc: f64) -> Result<(), RocheError> {
148    let mut dir: f64;
149    let mut lo: f64;
150    let mut hi: f64;
151    let mut ro: Vec3 = *r;
152    let mut vo: Vec3 = *v;
153
154    let mut delta_t: f64 = 1.0e-2;
155
156    // Store initial direction
157    dir = r.dot(v);
158    let dir1: f64 = dir;
159
160    // set up Bulirsch-Stoer integrator
161    let system = OrbitalSystem { q };
162    let mut integrator = Integrator::default()
163        .with_abs_tol(1.0e-8)
164        .with_rel_tol(1.0e-8)
165        .into_adaptive();
166    // Initialise arrays
167    let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
168    let mut y_next = ndarray::Array::zeros(y.raw_dim());
169    let mut yo = y.clone();
170
171    while (dir > 0.0 && dir1 > 0.0) || (dir < 0.0 && dir1 < 0.0) {
172        ro = *r;
173        vo = *v;
174        yo = y.clone();
175        integrator
176            .step(&system, delta_t, y.view(), y_next.view_mut())
177            .unwrap();
178        y.assign(&y_next);
179        r.set(y[0], y[1], y[2]);
180        v.set(y[3], y[4], y[5]);
181        dir = r.dot(v);
182    }
183
184    //   Now refine by reinitialising and binary chopping until
185    //   close enough to requested radius.
186
187    lo = 0.0;
188    hi = delta_t;
189    while (hi - lo).abs() > acc {
190        delta_t = (lo + hi) / 2.0;
191        y = yo.clone();
192        *r = ro;
193        *v = vo;
194        integrator
195            .step(&system, delta_t, y.view(), y_next.view_mut())
196            .unwrap();
197        y.assign(&y_next);
198
199        r.set(y[0], y[1], y[2]);
200        v.set(y[3], y[4], y[5]);
201        dir = r.dot(v);
202        if (dir > 0.0 && dir1 < 0.0) || (dir < 0.0 && dir1 > 0.0) {
203            hi = delta_t;
204        } else {
205            lo = delta_t;
206        }
207    }
208
209    Ok(())
210}
211
212// wrapper for python library, avoiding mutable references
213
214///
215/// Calculates position & velocity of n-th turning point of stream.
216/// x,y,vx1,vy1,vx2,vy2 = strmnx(q, n=1, acc=1.e-7), q = M2/M1.
217/// Two sets of velocities are reported, the first for the pure stream,
218/// the second for the disk at that point.
219///
220/// Arguments:
221///
222/// * `q`: mass ratio = M2/M1
223/// * `n`: turning point number
224/// * `acc`: accuracy in time to locate minimum/maximum.
225///
226/// Returns:
227/// (x, y, vx1, vy1, vx2, vy2)
228///
229#[pyfunction]
230#[pyo3(name = "strmnx")]
231#[pyo3(signature = (q, n=1, acc=1.0e-7))]
232pub fn strmnx_wrapper(
233    q: f64,
234    n: usize,
235    acc: f64,
236) -> Result<(f64, f64, f64, f64, f64, f64), RocheError> {
237    let (mut r, mut v) = strinit(q)?;
238    for _ in 0..n {
239        strmnx(q, &mut r, &mut v, acc)?
240    }
241    let (tvx1, tvy1) = vel_transform(q, 1, r.x, r.y, v.x, v.y)?;
242    let (tvx2, tvy2) = vel_transform(q, 2, r.x, r.y, v.x, v.y)?;
243    Ok((r.x, r.y, tvx1, tvy1, tvx2, tvy2))
244}
245
246///
247/// streamr works by integrating the equations of motion for the Roche
248/// potential using Burlisch-Stoer integration. It stops when the stream
249/// reaches a target radius or a minimum radius, whichever is the larger.
250///
251/// Arguments:
252///
253/// * `q`: mass ratio = M2/M1. Stream flows from star 2 to 1.
254/// * `rad`: Radius to aim for. If this is less than the minimum, the stream will stop at the minimum
255/// * `n_points`: number of points to compute.
256///
257/// Results:
258///
259/// * `x`: array of x values returned.
260/// * `y`: array of y values returned.
261///
262#[pyfunction]
263#[pyo3(signature = (q, rad, n_points=200))]
264pub fn streamr(q: f64, rad: f64, n_points: usize) -> Result<(Vec<f64>, Vec<f64>), RocheError> {
265    if n_points < 2 {
266        return Err(RocheError::ParameterError(
267            "Need at least 2 points in the stream.".to_string(),
268        ));
269    }
270
271    if q <= 0.0 {
272        return Err(RocheError::ParameterError("q = {} <= 0".to_string()));
273    }
274
275    const EPS: f64 = 1.0e-8;
276
277    let mut x_arr: Vec<f64> = vec![];
278    let mut y_arr: Vec<f64> = vec![];
279
280    // Initialise stream
281    let rl1: f64 = x_l1(q)?;
282    let (mut r, mut v) = strinit(q)?;
283    let rs = r;
284    let vs = v;
285    strmnx(q, &mut r, &mut v, EPS)?;
286    let rmin = if r.length() > rad { r.length() } else { rad };
287
288    r = rs;
289    v = vs;
290    x_arr.push(r.x);
291    y_arr.push(r.y);
292    let mut rnext: f64;
293    for i in 1..n_points {
294        rnext = rl1 + (rmin - rl1) * (i as f64) / (n_points as f64 - 1.0);
295        stradv(q, &mut r, &mut v, rnext, 1.0e-6, 1.0e-4);
296        x_arr.push(r.x);
297        y_arr.push(r.y);
298    }
299
300    Ok((x_arr, y_arr))
301}
302
303///
304/// stradv advances a particle of given position and velocity until
305/// it reaches a specified radius. It then returns with updated position and
306/// velocity. It is up to the user not to request a value that cannot be reached.
307///
308/// Arguments:
309///
310/// * `q`:    mass ratio = M2/M1
311/// * `r`:    Initial and final position
312/// * `v`:    Initial and final velocity
313/// * `rad`:  Radius to aim for
314/// * `acc`:  Accuracy with which to place output point at rad.
315/// * `smax`: Largest time step allowed. It is possible that the
316///   routine could take such a large step that it misses
317///   the point when the stream is inside the requested
318///   radius. This allows one to control this. Typical
319///   value = 1.e-3.
320///
321/// Returns:
322///
323/// * time step taken
324///
325pub fn stradv(q: f64, r: &mut Vec3, v: &mut Vec3, rad: f64, acc: f64, smax: f64) -> f64 {
326    const TMAX: f64 = 10.0;
327    let t_next: f64 = 1.0e-2;
328
329    let mut time: f64 = 0.0;
330
331    // let to: f64;
332    let mut ro = *r;
333    let mut vo = *v;
334
335    // Store initial radius
336    let rinit: f64 = r.length();
337    let mut rnow: f64 = rinit;
338
339    // set up Bulirsch-Stoer integrator
340    let system = OrbitalSystem { q };
341    let mut integrator = Integrator::default()
342        .with_abs_tol(1.0e-8)
343        .with_rel_tol(1.0e-8)
344        .into_adaptive();
345    // Initialise arrays
346    let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
347    let mut y_next = ndarray::Array::zeros(y.raw_dim());
348
349    let mut yo = y.clone();
350    let mut delta_t = t_next.min(smax);
351    // Step until radius crossed
352    while (rinit > rad && rnow > rad) || (rinit < rad && rnow < rad) {
353        ro = *r;
354        vo = *v;
355        yo = y.clone();
356        integrator
357            .step(&system, delta_t, y.view(), y_next.view_mut())
358            .unwrap();
359        y.assign(&y_next);
360        r.set(y[0], y[1], y[2]);
361        v.set(y[3], y[4], y[5]);
362        rnow = r.length();
363        time += delta_t;
364
365        if time > TMAX {
366            panic!("roche::stradv taken too long without crossing given radius.")
367        }
368    }
369
370    // Now refine by reinitialising and binary chopping until
371    // close enough to requested radius.
372
373    let mut lo: f64 = 0.0;
374    let mut hi: f64 = delta_t;
375    let mut rlo: f64 = ro.length();
376    let mut rhi: f64 = rnow;
377    let to: f64 = time;
378
379    while (rhi - rlo).abs() > acc {
380        delta_t = (lo + hi) / 2.0;
381        y = yo.clone();
382        *r = ro;
383        *v = vo;
384        time = to;
385
386        integrator
387            .step(&system, delta_t, y.view(), y_next.view_mut())
388            .unwrap();
389        y.assign(&y_next);
390
391        r.set(y[0], y[1], y[2]);
392        v.set(y[3], y[4], y[5]);
393        rnow = r.length();
394
395        if (rhi > rad && rnow > rad) || (rhi < rad && rnow < rad) {
396            rhi = rnow;
397            hi = delta_t;
398        } else {
399            rlo = rnow;
400            lo = delta_t;
401        }
402    }
403
404    time
405}
406
407// wrapper for python library, avoiding mutable references
408
409///
410/// stradv advances a particle of given position and velocity until
411/// it reaches a specified radius. It then returns with updated position and
412/// velocity. It is up to the user not to request a value that cannot be reached.
413///
414/// \param q    mass ratio = M2/M1
415/// \param r    Initial position
416/// \param v    Initial velocity
417/// \param rad  Radius to aim for
418/// \param acc  Accuracy with which to place output point at rad.
419/// \param smax Largest time step allowed. It is possible that the
420/// routine could take such a large step that it misses
421/// the point when the stream is inside the requested
422/// radius. This allows one to control this. Typical
423/// value = 1.e-3.
424/// \returns (timestep, new position, new velocity)
425///
426#[pyfunction]
427#[pyo3(name = "stradv")]
428pub fn stradv_wrapper(
429    q: f64,
430    r: &Vec3,
431    v: &Vec3,
432    rad: f64,
433    acc: f64,
434    smax: f64,
435) -> (f64, Vec3, Vec3) {
436    let mut r_mut = *r;
437    let mut v_mut = *v;
438    let timestep = stradv(q, &mut r_mut, &mut v_mut, rad, acc, smax);
439    (timestep, r_mut, v_mut)
440}
441
442///
443/// rocacc calculates and returns the acceleration (in the rotating frame)
444/// in a Roche potential of a particle of given position and velocity.
445///
446/// \param q mass ratio = M2/M1
447/// \param r position, scaled in units of separation.
448/// \param v velocity, scaled in units of separation
449///
450#[pyfunction]
451pub fn rocacc(q: f64, r: &Vec3, v: &Vec3) -> (f64, f64, f64) {
452    let f1: f64 = 1.0 / (1.0 + q);
453    let f2: f64 = f1 * q;
454
455    let yzsq: f64 = r.y * r.y + r.z * r.z;
456    let r1sq: f64 = r.x * r.x + yzsq;
457    let r2sq: f64 = (r.x - 1.0) * (r.x - 1.0) + yzsq;
458    let fm1: f64 = f1 / (r1sq * (r1sq.sqrt()));
459    let fm2: f64 = f2 / (r2sq * (r2sq.sqrt()));
460    let fm3: f64 = fm1 + fm2;
461
462    let x: f64 = -fm3 * r.x + fm2 + 2.0 * v.y + r.x - f2;
463    let y: f64 = -fm3 * r.y - 2.0 * v.x + r.y;
464    let z: f64 = -fm3 * r.z;
465    (x, y, z)
466}
467
468///
469/// brightspot_position runs strinit then stradv to get the coordinates of
470/// of the gas stream when it reaches a given radius from the primary star.
471///
472/// Arguments:
473///
474/// * `q`:  mass ratio = M2/M1
475/// * `rad`: radius from primary star
476/// * `acc`: computational accuracy
477/// * `smax`: maximum time step of Bulirsch-Stoer integration
478///
479/// Returns:
480/// * `r`: Vec3 coordinates of gas stream at given radius from primary star
481///
482#[pyfunction]
483#[pyo3(signature = (q, rad, acc=1.0e-7, smax=1.0e-2))]
484pub fn brightspot_position(q: f64, rad: f64, acc: f64, smax: f64) -> Result<Vec3, RocheError> {
485    let (mut r, mut v) = strinit(q)?;
486    let _ = stradv(q, &mut r, &mut v, rad, acc, smax);
487
488    Ok(r)
489}
490
491pub struct OrbitalSystem {
492    pub q: f64,
493}
494
495impl bulirsch::System for OrbitalSystem {
496    type Float = f64;
497
498    fn system(
499        &self,
500        y: bulirsch::ArrayView1<Self::Float>,
501        mut dydt: bulirsch::ArrayViewMut1<Self::Float>,
502    ) {
503        dydt[[0]] = y[[3]];
504        dydt[[1]] = y[[4]];
505        dydt[[2]] = y[[5]];
506        let r = Vec3::new(y[[0]], y[[1]], y[[2]]);
507        let v = Vec3::new(y[[3]], y[[4]], y[[5]]);
508        (dydt[[3]], dydt[[4]], dydt[[5]]) = rocacc(self.q, &r, &v);
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515
516    #[test]
517    fn strinit_stradv_test() -> Result<(), RocheError> {
518        // Values from trm.roche.bspot
519        let (mut r, mut v) = strinit(0.2)?;
520        let _time: f64 = stradv(0.2, &mut r, &mut v, 0.3, 1.0e-7, 1.0e-3);
521        assert!((r - Vec3::new(0.2660591412807423, 0.13860932478255575, 0.0)).length() < 1.0e-7);
522        assert!((v - Vec3::new(-1.4769457229627583, 0.31712381217252994, 0.0)).length() < 1.0e-7);
523        Ok(())
524    }
525
526    #[test]
527    fn stream_test() -> Result<(), RocheError> {
528        // Values from trm.roche.stream
529        let (x, y) = stream(0.2, 0.01, 200)?;
530        assert!((x[0] - 0.6585557).hypot(y[0] - 0.0) < 1.0e-4);
531        assert!((x[50] - 0.18384902).hypot(y[50] - 0.15145306) < 1.0e-4);
532        assert!((x[100] - -0.100431986).hypot(y[100] - -0.13697079) < 1.0e-4);
533        assert!((x[150] - 0.21720248).hypot(y[150] - -0.4577784) < 1.0e-4);
534        assert!((x[y.len() - 1] - 0.15403406).hypot(y[y.len() - 1] - 0.016731631) < 1.0e-4);
535        assert!(stream(-0.2, 0.0001, 200).is_err());
536        assert!(stream(0.2, 1.1, 200).is_err());
537        assert!(stream(0.2, -0.1, 200).is_err());
538        assert!(stream(0.2, 0.0001, 1).is_err());
539        Ok(())
540    }
541
542    #[test]
543    fn strmnx_test() -> Result<(), RocheError> {
544        // Values from trm.roche.strmnx
545        let (x, y, vx1, vy1, vx2, vy2) = strmnx_wrapper(0.2, 1, 1.0e-7)?;
546        assert!(
547            (x - -0.08613947462186848).hypot(y - 0.05411592729509131)
548                / (-0.08613947462186848_f64).hypot(0.05411592729509131)
549                < 1.0e-6
550        );
551        assert!(
552            (vx1 - -1.9727409465489645).hypot(vy1 - -3.30679322752132)
553                / (-1.9727409465489645_f64).hypot(-3.30679322752132)
554                < 1.0e-6
555        );
556        assert!(
557            (vx2 - -1.5225623467338747).hypot(vy2 - -2.5902178683586605)
558                / (-1.5225623467338747_f64).hypot(-2.5902178683586605)
559                < 1.0e-6
560        );
561        Ok(())
562    }
563
564    #[test]
565    fn brightspot_position_test() -> Result<(), RocheError> {
566        // Values from trm.roche.bspot
567        let r = brightspot_position(0.2, 0.3, 1.0e-7, 1.0e-3)?;
568        assert!((r - Vec3::new(0.2660591412807423, 0.13860932478255575, 0.0)).length() < 1.0e-7);
569        Ok(())
570    }
571}