1use anyhow::anyhow;
2use pyo3::exceptions::PyException;
3use std::time::SystemTime;
4
5use ::wingfoil::{Node, NodeOperators, RunFor, RunMode, Stream, StreamOperators, NanoTime};
6
7use pyo3::conversion::IntoPyObjectExt;
8use pyo3::prelude::*;
9
10use std::rc::Rc;
11use std::time::Duration;
12use std::time::UNIX_EPOCH;
13
14#[pyclass(unsendable, name = "Node")]
15#[derive(Clone)]
16struct PyNode(Rc<dyn Node>);
17
18impl PyNode {
19 fn new(node: Rc<dyn Node>) -> Self {
20 Self(node)
21 }
22}
23
24#[pymethods]
25impl PyNode {
26 fn count(&self) -> PyResult<PyStream> {
27 let count = self.0.count().map(move |x| {
28 Python::attach(|py| {
29 let x: Py<PyAny> = x.into_py_any(py).unwrap();
30 PyElement(x)
31 })
32 });
33 Ok(PyStream(count))
34 }
35}
36
37#[pyfunction]
38fn ticker(seconds: f64) -> PyResult<PyNode> {
39 let ticker = ::wingfoil::ticker(Duration::from_secs_f64(seconds));
40 let node = PyNode::new(ticker);
41 Ok(node)
42}
43
44struct PyElement(Py<PyAny>);
45
46impl Default for PyElement {
47 fn default() -> Self {
48 Python::attach(|py| PyElement(py.None()))
49 }
50}
51
52impl std::fmt::Debug for PyElement {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 Python::attach(|py| {
55 let result = self.0.call_method0(py, "__str__").unwrap();
56 write!(f, "{}", result.extract::<String>(py).unwrap())
57 })
58 }
59}
60
61impl Clone for PyElement {
62 fn clone(&self) -> Self {
63 Python::attach(|py| PyElement(self.0.clone_ref(py)))
64 }
65}
66
67#[derive(Clone)]
68#[pyclass(subclass, unsendable)]
69struct PyStream(Rc<dyn Stream<PyElement>>);
70
71#[pymethods]
72impl PyStream {
73 #[pyo3(signature = (realtime=true, start=None, duration=None, cycles=None))]
74 fn run(
75 &self,
76 py: Python<'_>,
77 realtime: Option<bool>,
78 start: Option<Py<PyAny>>,
79 duration: Option<Py<PyAny>>,
80 cycles: Option<u32>,
81 ) -> PyResult<()> {
82 let (run_mode, run_for) =
83 parse_run_args(py, realtime, start, duration, cycles).to_pyresult()?;
84 self.0.run(run_mode, run_for).to_pyresult()?;
85 Ok(())
86 }
87 fn peek_value(&self) -> Py<PyAny> {
88 self.0.peek_value().0
89 }
90 fn logged(&self, label: String) -> PyStream {
91 PyStream(self.0.logged(&label, log::Level::Info))
92 }
93}
94
95pub trait ToPyResult<T> {
96 fn to_pyresult(self) -> PyResult<T>;
97}
98
99impl<T> ToPyResult<T> for anyhow::Result<T> {
100 fn to_pyresult(self) -> PyResult<T> {
101 self.map_err(|e| PyException::new_err(e.to_string()))
102 }
103}
104
105fn parse_run_args(
106 py: Python<'_>,
107 realtime: Option<bool>,
108 start: Option<Py<PyAny>>,
109 duration: Option<Py<PyAny>>,
110 cycles: Option<u32>,
111) -> anyhow::Result<(RunMode, RunFor)> {
112 if duration.is_some() && cycles.is_some() {
113 panic!("Cannot specify both duration and cycles");
114 }
115 let realtime = realtime.unwrap_or(false);
116 if realtime && start.is_some() {
117 panic!("Cannot specify start in realtime mode");
118 }
119 let run_mode = if realtime {
120 RunMode::RealTime
121 } else {
122 let t = match start {
123 Some(start) => to_nano_time(py, start)?,
124 None => NanoTime::ZERO,
125 };
126 RunMode::HistoricalFrom(t)
127 };
128 let run_for = if let Some(cycles) = cycles {
129 RunFor::Cycles(cycles)
130 } else {
131 match duration {
132 Some(duration) => {
133 let duration = to_duration(py, duration)?;
134 RunFor::Duration(duration)
135 }
136 None => RunFor::Forever,
137 }
138 };
139 Ok((run_mode, run_for))
140}
141
142fn to_duration(py: Python<'_>, obj: Py<PyAny>) -> anyhow::Result<Duration> {
143 if let Ok(f) = obj.extract::<f64>(py) {
144 if f < 0.0 {
145 anyhow::bail!("duration can not be negative");
146 }
147 let nanos = (f * 1e9) as u64;
148 Ok(Duration::from_nanos(nanos))
149 } else if let Ok(td) = obj.extract::<Duration>(py) {
150 Ok(td)
151 } else {
152 anyhow::bail!("failed to convert duration");
153 }
154}
155
156fn f64_secs_to_nanos(ts: f64) -> anyhow::Result<u64> {
157 if ts.is_sign_negative() {
158 return Err(anyhow!("Negative timestamps not supported"));
159 }
160 let secs = ts.trunc() as u64;
161 let frac_nanos = (ts.fract() * 1e9) as u64;
162 let nanos = secs
163 .checked_mul(1_000_000_000)
164 .ok_or_else(|| anyhow!("Overflow when converting seconds to nanoseconds"))?;
165 let total_nanos = nanos
166 .checked_add(frac_nanos)
167 .ok_or_else(|| anyhow!("Overflow when adding fractional nanoseconds"))?;
168 Ok(total_nanos)
169}
170
171fn to_nano_time(py: Python<'_>, obj: Py<PyAny>) -> anyhow::Result<NanoTime> {
172 if let Ok(dt) = obj.extract::<SystemTime>(py) {
173 let nanos = dt.duration_since(UNIX_EPOCH)?.as_nanos();
174 Ok(nanos.into())
175 } else if let Ok(ts) = obj.extract::<f64>(py) {
176 Ok(f64_secs_to_nanos(ts)?.into())
177 } else {
178 Err(anyhow!("failed to convert to NanoTime"))
179 }
180}
181
182#[pymodule]
183fn _wingfoil(module: &Bound<'_, PyModule>) -> PyResult<()> {
184 _ = env_logger::try_init();
185 module.add_function(wrap_pyfunction!(ticker, module)?)?;
186 module.add_class::<PyNode>()?;
187 module.add("__version__", env!("CARGO_PKG_VERSION"))?;
188 Ok(())
189}