Skip to main content

wingfoil/
lib.rs

1mod proxy_stream;
2mod py_element;
3mod py_kdb;
4mod py_stream;
5mod py_zmq;
6mod types;
7
8use ::wingfoil::{Dep, Node, NodeOperators};
9use py_element::*;
10use py_stream::*;
11use types::ToPyResult;
12
13use pyo3::prelude::*;
14use std::rc::Rc;
15use std::time::Duration;
16
17#[pyclass(unsendable, name = "Node")]
18#[derive(Clone)]
19pub(crate) struct PyNode(Rc<dyn Node>);
20
21impl PyNode {
22    pub(crate) fn new(node: Rc<dyn Node>) -> Self {
23        Self(node)
24    }
25}
26
27#[pymethods]
28impl PyNode {
29    /// Counts how many times upstream node has ticked.
30    fn count(&self) -> PyStream {
31        self.0.count().as_py_stream()
32    }
33
34    #[pyo3(signature = (realtime=true, start=None, duration=None, cycles=None))]
35    fn run(
36        &self,
37        py: Python<'_>,
38        realtime: Option<bool>,
39        start: Option<Py<PyAny>>,
40        duration: Option<Py<PyAny>>,
41        cycles: Option<u32>,
42    ) -> PyResult<()> {
43        let (run_mode, run_for) =
44            types::parse_run_args(py, realtime, start, duration, cycles).to_pyresult()?;
45
46        // Convert fat pointer to (addr, vtable) pair which is Send+Sync
47        let node_ptr = Rc::as_ptr(&self.0);
48        let (addr, vtable): (usize, usize) = unsafe { std::mem::transmute(node_ptr) };
49
50        // Release GIL during the run to allow async tasks to acquire it
51        // SAFETY: The Rc is kept alive by self for the duration of this call
52        let result = py.detach(move || {
53            // Reconstruct the fat pointer from (addr, vtable)
54            let node_ptr: *const dyn Node = unsafe { std::mem::transmute((addr, vtable)) };
55            // Temporarily reconstruct the Rc without taking ownership
56            let node = unsafe { Rc::from_raw(node_ptr) };
57            let result = ::wingfoil::NodeOperators::run(&node, run_mode, run_for);
58            std::mem::forget(node); // Don't drop the Rc (self.0 still owns it)
59            result
60        });
61        result.to_pyresult()?;
62        Ok(())
63    }
64}
65
66/// A node that ticks at the specified period
67#[pyfunction]
68fn ticker(seconds: f64) -> PyNode {
69    let ticker = ::wingfoil::ticker(Duration::from_secs_f64(seconds));
70    PyNode::new(ticker)
71}
72
73/// A stream that ticks once, on first engine cycle
74#[pyfunction]
75fn constant(val: Py<PyAny>) -> PyStream {
76    let strm = ::wingfoil::constant(PyElement::new(val));
77    PyStream(strm)
78}
79
80/// maps steams a amd b into a new stream using func (e.g lambda a, b: a + b)
81#[pyfunction]
82fn bimap(a: Py<PyAny>, b: Py<PyAny>, func: Py<PyAny>) -> PyStream {
83    Python::attach(|py| {
84        let a = a
85            .as_ref()
86            .extract::<PyRef<PyStream>>(py)
87            .unwrap()
88            .inner_stream();
89        let b = b
90            .as_ref()
91            .extract::<PyRef<PyStream>>(py)
92            .unwrap()
93            .inner_stream();
94        let stream = ::wingfoil::bimap(
95            Dep::Active(a),
96            Dep::Active(b),
97            move |a: PyElement, b: PyElement| {
98                Python::attach(|py: Python<'_>| {
99                    let res = func.call1(py, (a.value(), b.value())).unwrap();
100                    PyElement::new(res)
101                })
102            },
103        );
104        PyStream(stream)
105    })
106}
107
108#[pyclass(unsendable, name = "Graph")]
109#[derive(Clone)]
110pub(crate) struct PyGraph(Vec<Rc<dyn Node>>);
111
112#[pymethods]
113impl PyGraph {
114    #[new]
115    fn new(nodes: Vec<Py<PyAny>>) -> PyResult<Self> {
116        Python::attach(|py| {
117            let mut roots: Vec<Rc<dyn Node>> = Vec::new();
118            for obj in nodes {
119                if let Ok(stream) = obj.extract::<PyRef<PyStream>>(py) {
120                    roots.push(stream.0.clone().as_node());
121                } else if let Ok(node) = obj.extract::<PyRef<PyNode>>(py) {
122                    roots.push(node.0.clone());
123                } else {
124                    return Err(pyo3::exceptions::PyTypeError::new_err(
125                        "Graph components must be Stream or Node",
126                    ));
127                }
128            }
129            Ok(PyGraph(roots))
130        })
131    }
132
133    #[pyo3(signature = (realtime=true, start=None, duration=None, cycles=None))]
134    fn run(
135        &self,
136        py: Python<'_>,
137        realtime: Option<bool>,
138        start: Option<Py<PyAny>>,
139        duration: Option<Py<PyAny>>,
140        cycles: Option<u32>,
141    ) -> PyResult<()> {
142        let (run_mode, run_for) =
143            types::parse_run_args(py, realtime, start, duration, cycles).to_pyresult()?;
144
145        let mut ptrs: Vec<(usize, usize)> = Vec::with_capacity(self.0.len());
146        for node in &self.0 {
147            let node_ptr = Rc::as_ptr(node);
148            let (addr, vtable): (usize, usize) = unsafe { std::mem::transmute(node_ptr) };
149            ptrs.push((addr, vtable));
150        }
151
152        let result = py.detach(move || {
153            let mut roots: Vec<Rc<dyn Node>> = Vec::with_capacity(ptrs.len());
154            for (addr, vtable) in ptrs {
155                let node_ptr: *const dyn Node = unsafe { std::mem::transmute((addr, vtable)) };
156                let node = unsafe { Rc::from_raw(node_ptr) };
157                roots.push(node.clone());
158                std::mem::forget(node);
159            }
160
161            let mut graph = ::wingfoil::Graph::new(roots, run_mode, run_for);
162            graph.run()
163        });
164        result.to_pyresult()?;
165        Ok(())
166    }
167}
168
169/// Wingfoil is a blazingly fast, highly scalable stream processing
170/// framework designed for latency-critical use cases such as electronic
171/// trading and real-time AI systems
172#[pymodule]
173fn _wingfoil(module: &Bound<'_, PyModule>) -> PyResult<()> {
174    let env = env_logger::Env::default().default_filter_or("info");
175    env_logger::Builder::from_env(env).init();
176    module.add_function(wrap_pyfunction!(ticker, module)?)?;
177    module.add_function(wrap_pyfunction!(constant, module)?)?;
178    module.add_function(wrap_pyfunction!(bimap, module)?)?;
179    module.add_function(wrap_pyfunction!(py_kdb::py_kdb_read, module)?)?;
180    module.add_function(wrap_pyfunction!(py_kdb::py_kdb_write, module)?)?;
181    module.add_function(wrap_pyfunction!(py_zmq::py_zmq_sub, module)?)?;
182    module.add_class::<PyNode>()?;
183    module.add_class::<PyStream>()?;
184    module.add_class::<PyGraph>()?;
185    module.add("__version__", env!("CARGO_PKG_VERSION"))?;
186    Ok(())
187}