Skip to main content

roche/
stream_physics.rs

1use crate::errors::RocheError;
2use crate::x_l1;
3use crate::{Vec3, vel_transform};
4use bulirsch::{self, Integrator};
5use pyo3::prelude::*;
6use numpy::{IntoPyArray, PyArray1};
7
8///
9/// strinit sets a particle just inside the L1 point with the
10/// correct velocity as given in Lubow and Shu.
11///
12/// Arguments:
13///
14/// * `q`: mass ratio = M2/M1
15///
16/// Returns:
17///
18/// * start position
19/// * start velocity
20///
21#[pyfunction]
22pub fn strinit(q: f64) -> Result<(Vec3, Vec3), RocheError> {
23    const SMALL: f64 = 1.0e-5;
24    let rl1: f64 = x_l1(q)?;
25    let mu: f64 = q / (1.0 + q);
26    let a: f64 = (1.0 - mu) / rl1.powi(3) + mu / (1.0 - rl1).powi(3);
27    let lambda1: f64 = (((a - 2.0) + (a * (9.0 * a - 8.0)).sqrt()) / 2.0).sqrt();
28    let m1: f64 = (lambda1 * lambda1 - 2.0 * a - 1.0) / 2.0 / lambda1;
29
30    let r: Vec3 = Vec3::new(rl1 - SMALL, -m1 * SMALL, 0.0);
31    let v: Vec3 = Vec3::new(-lambda1 * SMALL, -lambda1 * m1 * SMALL, 0.0);
32
33    Ok((r, v))
34}
35
36///
37/// stream works by integrating the equations of motion for the Roche
38/// potential using Burlisch-Stoer integration. Every time the distance
39/// from the last point exceeds step, it interpolates and stores a new
40/// point. This allows one not to spend loads of points on regions where
41/// nothing is happening.
42///
43/// Arguments:
44///
45/// * `q`:    mass ratio = M2/M1. Stream flows from star 2 to 1.
46/// * `step`: step between points (units of separation).
47/// * `n_points`:    number of points to compute.
48///
49/// Returns:
50///
51/// * `x`:    array of x values returned.
52/// * `y`:    array of y values returned.
53///
54pub fn stream(q: f64, step: f64, n_points: usize) -> Result<(Vec<f64>, Vec<f64>), RocheError> {
55    if n_points < 2 {
56        return Err(RocheError::ParameterError(
57            "Need at least 2 points in the stream.".to_string(),
58        ));
59    }
60
61    if step <= 0.0 || step > 1.0 {
62        return Err(RocheError::ParameterError(
63            "Step size must be between 0.0 and 1.0".to_string(),
64        ));
65    }
66
67    if q <= 0.0 {
68        return Err(RocheError::ParameterError("q = {} <= 0".to_string()));
69    }
70
71    let mut x_arr: Vec<f64> = vec![];
72    let mut y_arr: Vec<f64> = vec![];
73
74    // Initialise stream
75    let rl1: f64 = x_l1(q)?;
76    let (mut r, mut v) = strinit(q)?;
77
78    // Store L1 as first point
79    x_arr.push(rl1);
80    y_arr.push(0.0);
81
82    let mut lp: usize = 0;
83
84    // Store interpolation between L1 and initial point if
85    // step has been set small enough
86
87    let mut dist: f64 = (r.x - rl1).hypot(r.y);
88
89    let frac: f64;
90
91    if dist > step {
92        frac = step / dist;
93        x_arr.push(rl1 + (r.x - rl1) * frac);
94        y_arr.push(r.y * frac);
95        lp += 1;
96    }
97
98    // set up Bulirsch-Stoer integrator
99    let system = OrbitalSystem { q };
100    let mut integrator = Integrator::default()
101        .with_abs_tol(1.0e-8)
102        .with_rel_tol(1.0e-8)
103        .into_adaptive();
104    // Initialise arrays
105    let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
106    let mut y_next = ndarray::Array::zeros(y.raw_dim());
107
108    let mut delta_t: f64 = 1.0e-3;
109    let smax: f64 = (1.0e-3_f64).min(step / 2.0);
110
111    let mut vel: f64;
112    while lp < n_points - 1 {
113        integrator
114            .step(&system, delta_t, y.view(), y_next.view_mut())
115            .unwrap();
116        y.assign(&y_next);
117
118        r.set(y[0], y[1], y[2]);
119        v.set(y[3], y[4], y[5]);
120        dist = (r.x - x_arr[lp]).hypot(r.y - y_arr[lp]);
121        if dist > step {
122            let frac: f64 = step / dist;
123            x_arr.push(x_arr[lp] + (r.x - x_arr[lp]) * frac);
124            y_arr.push(y_arr[lp] + (r.y - y_arr[lp]) * frac);
125            lp += 1;
126        }
127        vel = v.x.hypot(v.y);
128        delta_t = (smax / vel).min(delta_t);
129    }
130
131    Ok((x_arr, y_arr))
132}
133
134///
135/// stream works by integrating the equations of motion for the Roche
136/// potential using Burlisch-Stoer integration. Every time the distance
137/// from the last point exceeds step, it interpolates and stores a new
138/// point. This allows one not to spend loads of points on regions where
139/// nothing is happening.
140///
141/// Arguments:
142///
143/// * `q`:    mass ratio = M2/M1. Stream flows from star 2 to 1.
144/// * `step`: step between points (units of separation).
145/// * `n_points`:    number of points to compute.
146///
147/// Returns:
148///
149/// * `x`:    array of x values returned.
150/// * `y`:    array of y values returned.
151///
152#[pyfunction]
153#[pyo3(name = "stream", signature = (q, step, n_points=200))]
154pub fn stream_py(py: Python, q: f64, step: f64, n_points: usize) -> PyResult<(Py<PyArray1<f64>>, Py<PyArray1<f64>>)> {
155    let (x_arr, y_arr) = stream(q, step, n_points)?;
156    Ok((x_arr.into_pyarray(py).unbind(), y_arr.into_pyarray(py).unbind()))
157}
158
159///
160/// strmnx finds the next point at which stream is closest or furthest
161/// from primary.
162///
163/// Arguments:
164///
165/// * `q`: mass ratio = M2/M1
166/// * `r`: initial and final position
167/// * `v`: initial and final velocity
168/// * `acc`: accuracy in time to locate minimum/maximum.
169///
170///
171pub fn strmnx(q: f64, r: &mut Vec3, v: &mut Vec3, acc: f64) -> Result<(), RocheError> {
172    let mut dir: f64;
173    let mut lo: f64;
174    let mut hi: f64;
175    let mut ro: Vec3 = *r;
176    let mut vo: Vec3 = *v;
177
178    let mut delta_t: f64 = 1.0e-2;
179
180    // Store initial direction
181    dir = r.dot(v);
182    let dir1: f64 = dir;
183
184    // set up Bulirsch-Stoer integrator
185    let system = OrbitalSystem { q };
186    let mut integrator = Integrator::default()
187        .with_abs_tol(1.0e-8)
188        .with_rel_tol(1.0e-8)
189        .into_adaptive();
190    // Initialise arrays
191    let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
192    let mut y_next = ndarray::Array::zeros(y.raw_dim());
193    let mut yo = y.clone();
194
195    while (dir > 0.0 && dir1 > 0.0) || (dir < 0.0 && dir1 < 0.0) {
196        ro = *r;
197        vo = *v;
198        yo = y.clone();
199        integrator
200            .step(&system, delta_t, y.view(), y_next.view_mut())
201            .unwrap();
202        y.assign(&y_next);
203        r.set(y[0], y[1], y[2]);
204        v.set(y[3], y[4], y[5]);
205        dir = r.dot(v);
206    }
207
208    //   Now refine by reinitialising and binary chopping until
209    //   close enough to requested radius.
210
211    lo = 0.0;
212    hi = delta_t;
213    while (hi - lo).abs() > acc {
214        delta_t = (lo + hi) / 2.0;
215        y = yo.clone();
216        *r = ro;
217        *v = vo;
218        integrator
219            .step(&system, delta_t, y.view(), y_next.view_mut())
220            .unwrap();
221        y.assign(&y_next);
222
223        r.set(y[0], y[1], y[2]);
224        v.set(y[3], y[4], y[5]);
225        dir = r.dot(v);
226        if (dir > 0.0 && dir1 < 0.0) || (dir < 0.0 && dir1 > 0.0) {
227            hi = delta_t;
228        } else {
229            lo = delta_t;
230        }
231    }
232
233    Ok(())
234}
235
236// wrapper for python library, avoiding mutable references
237
238///
239/// Calculates position & velocity of n-th turning point of stream.
240/// x,y,vx1,vy1,vx2,vy2 = strmnx(q, n=1, acc=1.e-7), q = M2/M1.
241/// Two sets of velocities are reported, the first for the pure stream,
242/// the second for the disk at that point.
243///
244/// Arguments:
245///
246/// * `q`: mass ratio = M2/M1
247/// * `n`: turning point number
248/// * `acc`: accuracy in time to locate minimum/maximum.
249///
250/// Returns:
251/// (x, y, vx1, vy1, vx2, vy2)
252///
253#[pyfunction]
254#[pyo3(name = "strmnx")]
255#[pyo3(signature = (q, n=1, acc=1.0e-7))]
256pub fn strmnx_wrapper(
257    q: f64,
258    n: usize,
259    acc: f64,
260) -> Result<(f64, f64, f64, f64, f64, f64), RocheError> {
261    let (mut r, mut v) = strinit(q)?;
262    for _ in 0..n {
263        strmnx(q, &mut r, &mut v, acc)?
264    }
265    let (tvx1, tvy1) = vel_transform(q, 1, r.x, r.y, v.x, v.y)?;
266    let (tvx2, tvy2) = vel_transform(q, 2, r.x, r.y, v.x, v.y)?;
267    Ok((r.x, r.y, tvx1, tvy1, tvx2, tvy2))
268}
269
270///
271/// streamr works by integrating the equations of motion for the Roche
272/// potential using Burlisch-Stoer integration. It stops when the stream
273/// reaches a target radius or a minimum radius, whichever is the larger.
274///
275/// Arguments:
276///
277/// * `q`: mass ratio = M2/M1. Stream flows from star 2 to 1.
278/// * `rad`: Radius to aim for. If this is less than the minimum, the stream will stop at the minimum
279/// * `n_points`: number of points to compute.
280///
281/// Results:
282///
283/// * `x`: array of x values returned.
284/// * `y`: array of y values returned.
285///
286pub fn streamr(q: f64, rad: f64, n_points: usize) -> Result<(Vec<f64>, Vec<f64>), RocheError> {
287    if n_points < 2 {
288        return Err(RocheError::ParameterError(
289            "Need at least 2 points in the stream.".to_string(),
290        ));
291    }
292
293    if q <= 0.0 {
294        return Err(RocheError::ParameterError("q = {} <= 0".to_string()));
295    }
296
297    const EPS: f64 = 1.0e-8;
298
299    let mut x_arr: Vec<f64> = vec![];
300    let mut y_arr: Vec<f64> = vec![];
301
302    // Initialise stream
303    let rl1: f64 = x_l1(q)?;
304    let (mut r, mut v) = strinit(q)?;
305    let rs = r;
306    let vs = v;
307    strmnx(q, &mut r, &mut v, EPS)?;
308    let rmin = if r.length() > rad { r.length() } else { rad };
309
310    r = rs;
311    v = vs;
312    x_arr.push(r.x);
313    y_arr.push(r.y);
314    let mut rnext: f64;
315    for i in 1..n_points {
316        rnext = rl1 + (rmin - rl1) * (i as f64) / (n_points as f64 - 1.0);
317        stradv(q, &mut r, &mut v, rnext, 1.0e-6, 1.0e-4);
318        x_arr.push(r.x);
319        y_arr.push(r.y);
320    }
321
322    Ok((x_arr, y_arr))
323}
324
325///
326/// streamr works by integrating the equations of motion for the Roche
327/// potential using Burlisch-Stoer integration. It stops when the stream
328/// reaches a target radius or a minimum radius, whichever is the larger.
329///
330/// Arguments:
331///
332/// * `q`: mass ratio = M2/M1. Stream flows from star 2 to 1.
333/// * `rad`: Radius to aim for. If this is less than the minimum, the stream will stop at the minimum
334/// * `n_points`: number of points to compute.
335///
336/// Results:
337///
338/// * `x`: array of x values returned.
339/// * `y`: array of y values returned.
340///
341#[pyfunction]
342#[pyo3(name = "streamr", signature = (q, rad, n_points=200))]
343pub fn streamr_py(py: Python, q: f64, rad: f64, n_points: usize) -> PyResult<(Py<PyArray1<f64>>, Py<PyArray1<f64>>)> {
344    let (x_arr, y_arr) = streamr(q, rad, n_points)?;
345    Ok((x_arr.into_pyarray(py).unbind(), y_arr.into_pyarray(py).unbind()))
346}
347
348///
349/// stradv advances a particle of given position and velocity until
350/// it reaches a specified radius. It then returns with updated position and
351/// velocity. It is up to the user not to request a value that cannot be reached.
352///
353/// Arguments:
354///
355/// * `q`:    mass ratio = M2/M1
356/// * `r`:    Initial and final position
357/// * `v`:    Initial and final velocity
358/// * `rad`:  Radius to aim for
359/// * `acc`:  Accuracy with which to place output point at rad.
360/// * `smax`: Largest time step allowed. It is possible that the
361///   routine could take such a large step that it misses
362///   the point when the stream is inside the requested
363///   radius. This allows one to control this. Typical
364///   value = 1.e-3.
365///
366/// Returns:
367///
368/// * time step taken
369///
370pub fn stradv(q: f64, r: &mut Vec3, v: &mut Vec3, rad: f64, acc: f64, smax: f64) -> f64 {
371    const TMAX: f64 = 10.0;
372    let t_next: f64 = 1.0e-2;
373
374    let mut time: f64 = 0.0;
375
376    // let to: f64;
377    let mut ro = *r;
378    let mut vo = *v;
379
380    // Store initial radius
381    let rinit: f64 = r.length();
382    let mut rnow: f64 = rinit;
383
384    // set up Bulirsch-Stoer integrator
385    let system = OrbitalSystem { q };
386    let mut integrator = Integrator::default()
387        .with_abs_tol(1.0e-8)
388        .with_rel_tol(1.0e-8)
389        .into_adaptive();
390    // Initialise arrays
391    let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
392    let mut y_next = ndarray::Array::zeros(y.raw_dim());
393
394    let mut yo = y.clone();
395    let mut delta_t = t_next.min(smax);
396    // Step until radius crossed
397    while (rinit > rad && rnow > rad) || (rinit < rad && rnow < rad) {
398        ro = *r;
399        vo = *v;
400        yo = y.clone();
401        integrator
402            .step(&system, delta_t, y.view(), y_next.view_mut())
403            .unwrap();
404        y.assign(&y_next);
405        r.set(y[0], y[1], y[2]);
406        v.set(y[3], y[4], y[5]);
407        rnow = r.length();
408        time += delta_t;
409
410        if time > TMAX {
411            panic!("roche::stradv taken too long without crossing given radius.")
412        }
413    }
414
415    // Now refine by reinitialising and binary chopping until
416    // close enough to requested radius.
417
418    let mut lo: f64 = 0.0;
419    let mut hi: f64 = delta_t;
420    let mut rlo: f64 = ro.length();
421    let mut rhi: f64 = rnow;
422    let to: f64 = time;
423
424    while (rhi - rlo).abs() > acc {
425        delta_t = (lo + hi) / 2.0;
426        y = yo.clone();
427        *r = ro;
428        *v = vo;
429        time = to;
430
431        integrator
432            .step(&system, delta_t, y.view(), y_next.view_mut())
433            .unwrap();
434        y.assign(&y_next);
435
436        r.set(y[0], y[1], y[2]);
437        v.set(y[3], y[4], y[5]);
438        rnow = r.length();
439
440        if (rhi > rad && rnow > rad) || (rhi < rad && rnow < rad) {
441            rhi = rnow;
442            hi = delta_t;
443        } else {
444            rlo = rnow;
445            lo = delta_t;
446        }
447    }
448
449    time
450}
451
452// wrapper for python library, avoiding mutable references
453
454///
455/// stradv advances a particle of given position and velocity until
456/// it reaches a specified radius. It then returns with updated position and
457/// velocity. It is up to the user not to request a value that cannot be reached.
458///
459/// \param q    mass ratio = M2/M1
460/// \param r    Initial position
461/// \param v    Initial velocity
462/// \param rad  Radius to aim for
463/// \param acc  Accuracy with which to place output point at rad.
464/// \param smax Largest time step allowed. It is possible that the
465/// routine could take such a large step that it misses
466/// the point when the stream is inside the requested
467/// radius. This allows one to control this. Typical
468/// value = 1.e-3.
469/// \returns (timestep, new position, new velocity)
470///
471#[pyfunction]
472#[pyo3(name = "stradv")]
473pub fn stradv_py(
474    q: f64,
475    r: &Vec3,
476    v: &Vec3,
477    rad: f64,
478    acc: f64,
479    smax: f64,
480) -> (f64, Vec3, Vec3) {
481    let mut r_mut = *r;
482    let mut v_mut = *v;
483    let timestep = stradv(q, &mut r_mut, &mut v_mut, rad, acc, smax);
484    (timestep, r_mut, v_mut)
485}
486
487///
488/// rocacc calculates and returns the acceleration (in the rotating frame)
489/// in a Roche potential of a particle of given position and velocity.
490///
491/// \param q mass ratio = M2/M1
492/// \param r position, scaled in units of separation.
493/// \param v velocity, scaled in units of separation
494///
495#[pyfunction]
496pub fn rocacc(q: f64, r: &Vec3, v: &Vec3) -> (f64, f64, f64) {
497    let f1: f64 = 1.0 / (1.0 + q);
498    let f2: f64 = f1 * q;
499
500    let yzsq: f64 = r.y * r.y + r.z * r.z;
501    let r1sq: f64 = r.x * r.x + yzsq;
502    let r2sq: f64 = (r.x - 1.0) * (r.x - 1.0) + yzsq;
503    let fm1: f64 = f1 / (r1sq * (r1sq.sqrt()));
504    let fm2: f64 = f2 / (r2sq * (r2sq.sqrt()));
505    let fm3: f64 = fm1 + fm2;
506
507    let x: f64 = -fm3 * r.x + fm2 + 2.0 * v.y + r.x - f2;
508    let y: f64 = -fm3 * r.y - 2.0 * v.x + r.y;
509    let z: f64 = -fm3 * r.z;
510    (x, y, z)
511}
512
513///
514/// brightspot_position runs strinit then stradv to get the coordinates of
515/// of the gas stream when it reaches a given radius from the primary star.
516///
517/// Arguments:
518///
519/// * `q`:  mass ratio = M2/M1
520/// * `rad`: radius from primary star
521/// * `acc`: computational accuracy
522/// * `smax`: maximum time step of Bulirsch-Stoer integration
523///
524/// Returns:
525/// * `r`: Vec3 coordinates of gas stream at given radius from primary star
526///
527#[pyfunction]
528#[pyo3(signature = (q, rad, acc=1.0e-7, smax=1.0e-2))]
529pub fn brightspot_position(q: f64, rad: f64, acc: f64, smax: f64) -> Result<Vec3, RocheError> {
530    let (mut r, mut v) = strinit(q)?;
531    let _ = stradv(q, &mut r, &mut v, rad, acc, smax);
532
533    Ok(r)
534}
535
536///
537/// bspot runs strinit then stradv to get the coordinate and velocity
538/// vectors of the gas stream when it reaches a given radius from the primary star.
539///
540/// Arguments:
541///
542/// * `q`:  mass ratio = M2/M1
543/// * `rad`: radius from primary star
544/// * `acc`: computational accuracy
545/// * `smax`: maximum time step of Bulirsch-Stoer integration
546///
547/// Returns:
548/// * `r`: Vec3 coordinates of gas stream at given radius from primary star
549/// * `v`: Vec3 velocity of gas stream at given radius from primary star
550///
551#[pyfunction]
552#[pyo3(signature = (q, rad, acc=1.0e-7, smax=1.0e-2))]
553pub fn bspot(q: f64, rad: f64, acc: f64, smax: f64) -> Result<(Vec3, Vec3), RocheError> {
554    let (mut r, mut v) = strinit(q)?;
555    let _ = stradv(q, &mut r, &mut v, rad, acc, smax);
556
557    Ok((r, v))
558}
559
560pub struct OrbitalSystem {
561    pub q: f64,
562}
563
564impl bulirsch::System for OrbitalSystem {
565    type Float = f64;
566
567    fn system(
568        &self,
569        y: bulirsch::ArrayView1<Self::Float>,
570        mut dydt: bulirsch::ArrayViewMut1<Self::Float>,
571    ) {
572        dydt[[0]] = y[[3]];
573        dydt[[1]] = y[[4]];
574        dydt[[2]] = y[[5]];
575        let r = Vec3::new(y[[0]], y[[1]], y[[2]]);
576        let v = Vec3::new(y[[3]], y[[4]], y[[5]]);
577        (dydt[[3]], dydt[[4]], dydt[[5]]) = rocacc(self.q, &r, &v);
578    }
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584
585    #[test]
586    fn strinit_stradv_test() -> Result<(), RocheError> {
587        // Values from trm.roche.bspot
588        let (mut r, mut v) = strinit(0.2)?;
589        let _time: f64 = stradv(0.2, &mut r, &mut v, 0.3, 1.0e-7, 1.0e-3);
590        assert!((r - Vec3::new(0.2660591412807423, 0.13860932478255575, 0.0)).length() < 1.0e-7);
591        assert!((v - Vec3::new(-1.4769457229627583, 0.31712381217252994, 0.0)).length() < 1.0e-7);
592        Ok(())
593    }
594
595    #[test]
596    fn stream_test() -> Result<(), RocheError> {
597        // Values from trm.roche.stream
598        let (x, y) = stream(0.2, 0.01, 200)?;
599        assert!((x[0] - 0.6585557).hypot(y[0] - 0.0) < 1.0e-4);
600        assert!((x[50] - 0.18384902).hypot(y[50] - 0.15145306) < 1.0e-4);
601        assert!((x[100] - -0.100431986).hypot(y[100] - -0.13697079) < 1.0e-4);
602        assert!((x[150] - 0.21720248).hypot(y[150] - -0.4577784) < 1.0e-4);
603        assert!((x[y.len() - 1] - 0.15403406).hypot(y[y.len() - 1] - 0.016731631) < 1.0e-4);
604        assert!(stream(-0.2, 0.0001, 200).is_err());
605        assert!(stream(0.2, 1.1, 200).is_err());
606        assert!(stream(0.2, -0.1, 200).is_err());
607        assert!(stream(0.2, 0.0001, 1).is_err());
608        Ok(())
609    }
610
611    #[test]
612    fn strmnx_test() -> Result<(), RocheError> {
613        // Values from trm.roche.strmnx
614        let (x, y, vx1, vy1, vx2, vy2) = strmnx_wrapper(0.2, 1, 1.0e-7)?;
615        assert!(
616            (x - -0.08613947462186848).hypot(y - 0.05411592729509131)
617                / (-0.08613947462186848_f64).hypot(0.05411592729509131)
618                < 1.0e-6
619        );
620        assert!(
621            (vx1 - -1.9727409465489645).hypot(vy1 - -3.30679322752132)
622                / (-1.9727409465489645_f64).hypot(-3.30679322752132)
623                < 1.0e-6
624        );
625        assert!(
626            (vx2 - -1.5225623467338747).hypot(vy2 - -2.5902178683586605)
627                / (-1.5225623467338747_f64).hypot(-2.5902178683586605)
628                < 1.0e-6
629        );
630        Ok(())
631    }
632
633    #[test]
634    fn brightspot_position_test() -> Result<(), RocheError> {
635        // Values from trm.roche.bspot
636        let r = brightspot_position(0.2, 0.3, 1.0e-7, 1.0e-3)?;
637        assert!((r - Vec3::new(0.2660591412807423, 0.13860932478255575, 0.0)).length() < 1.0e-7);
638        Ok(())
639    }
640
641    #[test]
642    fn bspot_test() -> Result<(), RocheError> {
643        // Values from trm.roche.bspot
644        let (r, v) = bspot(0.2, 0.3, 1.0e-7, 1.0e-3)?;
645        assert!((r - Vec3::new(0.2660591412807423, 0.13860932478255575, 0.0)).length() < 1.0e-7);
646        assert!((v - Vec3::new(-1.476945722613775, 0.31712381223279495, 0.0)).length() < 1.0e-6);
647        Ok(())
648    }
649}