sqlsync_reducer/
guest_reactor.rs

1use std::{
2    collections::BTreeMap,
3    future::Future,
4    mem::MaybeUninit,
5    pin::Pin,
6    sync::Once,
7    task::{Context, Poll},
8};
9
10use serde::de::DeserializeOwned;
11
12use crate::{
13    guest_ffi::{fbm, FFIBufPtr},
14    types::{
15        ErrorResponse, ExecResponse, QueryResponse, ReducerError, Request, RequestId, Requests,
16        Responses, SqliteValue,
17    },
18};
19
20pub fn reactor() -> &'static mut Reactor {
21    static mut SINGLETON: MaybeUninit<Reactor> = MaybeUninit::uninit();
22    static ONCE: Once = Once::new();
23    unsafe {
24        ONCE.call_once(|| {
25            let singleton = Reactor::new();
26            SINGLETON.write(singleton);
27        });
28        SINGLETON.assume_init_mut()
29    }
30}
31
32type ReducerTask = Pin<Box<dyn Future<Output = Result<(), ReducerError>>>>;
33
34#[derive(Default)]
35pub struct Reactor {
36    task: Option<ReducerTask>,
37    request_id_generator: RequestId,
38
39    // requests from guest -> host
40    requests: Requests,
41    // responses from host -> guest
42    responses: Responses,
43}
44
45impl Reactor {
46    pub fn new() -> Self {
47        Reactor::default()
48    }
49
50    fn queue_request(&mut self, request: Request) -> RequestId {
51        let id = self.request_id_generator;
52        self.request_id_generator = self.request_id_generator.wrapping_add(1);
53        self.requests
54            .get_or_insert_with(BTreeMap::new)
55            .insert(id, request);
56        id
57    }
58
59    fn get_response<T: DeserializeOwned>(&mut self, id: RequestId) -> Option<T> {
60        self.responses
61            .as_mut()
62            .and_then(|b| b.remove(&id))
63            .map(|ptr| {
64                let f = fbm();
65                unsafe { f.decode(ptr as *mut u8).unwrap() }
66            })
67    }
68
69    pub fn spawn(&mut self, task: ReducerTask) {
70        if self.task.is_some() {
71            panic!("Reducer task already running");
72        }
73        self.task = Some(task);
74    }
75
76    pub fn step(&mut self, responses: Responses) -> Result<Requests, ReducerError> {
77        if let Some(ref mut previous) = self.responses {
78            // if we still have previous responses, merge new responses in
79            // this replaces keys in previous with those in next - as long
80            // as the host respects the request indexes this is safe
81            if let Some(mut next) = responses {
82                previous.append(&mut next);
83            }
84        } else {
85            // otherwise, just use the new responses
86            self.responses = responses;
87        }
88
89        if let Some(mut task) = self.task.take() {
90            let mut ctx = Context::from_waker(futures::task::noop_waker_ref());
91            match task.as_mut().poll(&mut ctx) {
92                Poll::Ready(result) => result?,
93                Poll::Pending => {
94                    self.task = Some(task);
95                }
96            }
97        }
98
99        Ok(self.requests.take())
100    }
101}
102
103#[must_use]
104pub struct ResponseFuture<T: DeserializeOwned> {
105    id: RequestId,
106    _marker: std::marker::PhantomData<T>,
107}
108
109impl<T: DeserializeOwned> ResponseFuture<T> {
110    fn new(id: RequestId) -> Self {
111        Self { id, _marker: std::marker::PhantomData }
112    }
113}
114
115impl<T: DeserializeOwned> Future for ResponseFuture<T> {
116    type Output = T;
117
118    fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
119        match reactor().get_response(self.id) {
120            Some(response) => Poll::Ready(response),
121            None => Poll::Pending,
122        }
123    }
124}
125
126pub fn raw_query(
127    sql: String,
128    params: Vec<SqliteValue>,
129) -> ResponseFuture<Result<QueryResponse, ErrorResponse>> {
130    let request = Request::Query { sql, params };
131    let id = reactor().queue_request(request);
132    ResponseFuture::new(id)
133}
134
135pub fn raw_execute(
136    sql: String,
137    params: Vec<SqliteValue>,
138) -> ResponseFuture<Result<ExecResponse, ErrorResponse>> {
139    let request = Request::Exec { sql, params };
140    let id = reactor().queue_request(request);
141    ResponseFuture::new(id)
142}
143
144#[macro_export]
145macro_rules! query {
146    ($sql:expr $(, $arg:expr)*) => {
147        sqlsync_reducer::guest_reactor::raw_query($sql.into(), vec![$($arg.into()),*])
148    };
149}
150
151#[macro_export]
152macro_rules! execute {
153    ($sql:expr $(, $arg:expr)*) => {
154        sqlsync_reducer::guest_reactor::raw_execute($sql.into(), vec![$($arg.into()),*])
155    };
156}
157
158#[macro_export]
159macro_rules! init_reducer {
160    // fn should be (Vec<u8>) -> Future<Output = Result<(), ReducerError>>
161    ($fn:ident) => {
162        /// ffi_reduce is called by the host to cause the reducer to start processing a new mutation.
163        ///
164        /// # Panics
165        /// Panics if the host passes in an invalid pointer.
166        /// # Safety
167        /// The host must pass in a valid pointer to a Mutation buffer.
168        #[no_mangle]
169        pub unsafe fn ffi_reduce(
170            mutation_ptr: sqlsync_reducer::guest_ffi::FFIBufPtr,
171        ) -> sqlsync_reducer::guest_ffi::FFIBufPtr {
172            let reactor = sqlsync_reducer::guest_reactor::reactor();
173            let fbm = sqlsync_reducer::guest_ffi::fbm();
174            let mutation = fbm.consume(mutation_ptr);
175
176            reactor.spawn(Box::pin(async move { $fn(mutation).await }));
177
178            let requests = reactor.step(None);
179            fbm.encode(&requests).unwrap()
180        }
181
182        static LOGGER: sqlsync_reducer::guest_ffi::FFILogger =
183            sqlsync_reducer::guest_ffi::FFILogger;
184
185        #[no_mangle]
186        pub extern "C" fn ffi_init_reducer() {
187            LOGGER.init(log::Level::Trace).unwrap();
188            sqlsync_reducer::guest_ffi::install_panic_hook();
189        }
190    };
191}
192
193/// ffi_reactor_step is called by the host to advance the reactor forward.
194///
195/// # Panics
196/// Panics if the host passes in an invalid pointer.
197///
198/// # Safety
199/// The host must pass in a valid pointer to a serialized Responses object.
200#[no_mangle]
201pub unsafe fn ffi_reactor_step(responses_ptr: FFIBufPtr) -> FFIBufPtr {
202    let fbm = fbm();
203    let responses = fbm.decode(responses_ptr).unwrap();
204    let out = reactor().step(responses);
205    fbm.encode(&out).unwrap()
206}