Skip to main content

simulator_client/managed/
session.rs

1use std::{collections::VecDeque, time::Duration};
2
3use simulator_api::{
4    BacktestError, BacktestStatus, ContinueParams, ContinueToParams, CreateBacktestSessionRequest,
5    DiscoveryBatchEvent, PausedEvent,
6};
7use solana_transaction_status::EncodedConfirmedTransactionWithStatusMeta;
8use thiserror::Error;
9use tokio::sync::watch;
10use tokio_util::sync::CancellationToken;
11
12use super::{
13    ConnectionStatus, ControlEvent, ControlHandle, SessionInfo, SubscriptionHandle,
14    SubscriptionNotification, spawn_account_diff_subscription_manager, spawn_control_manager,
15    spawn_transaction_subscription_manager,
16};
17use crate::subscriptions::AccountDiffNotification;
18
19/// Error returned by the high-level managed session wrapper.
20#[derive(Debug, Error)]
21pub enum ManagedSessionError {
22    #[error("session create failed: {0}")]
23    Create(String),
24
25    #[error("control channel closed")]
26    ControlClosed,
27
28    #[error("control failed: {0}")]
29    ControlFailed(String),
30
31    #[error("subscription failed: {0}")]
32    SubscriptionFailed(String),
33
34    #[error("cancelled")]
35    Cancelled,
36
37    #[error("control closed while sending continue: {0}")]
38    ContinueSend(String),
39}
40
41#[derive(Debug)]
42pub enum ManagedEvent {
43    ReadyForContinue,
44    /// Server paused at a `ContinueTo` target. The session is ready for
45    /// another `Continue` or `ContinueTo` from this point.
46    Paused(PausedEvent),
47    /// Server discovered an upcoming batch matching a registered
48    /// `DiscoveryFilter`. Send `send_continue_to(slot, batch_index)` to pause
49    /// immediately before it executes.
50    DiscoveryBatch(DiscoveryBatchEvent),
51    Slot(u64),
52    Status(BacktestStatus),
53    Completed,
54    Error(BacktestError),
55    Transaction(Box<EncodedConfirmedTransactionWithStatusMeta>),
56    AccountDiff(AccountDiffNotification),
57}
58
59/// Liveness backstop for the completion drain: a data plane mid-reconnect may
60/// never deliver its end-of-stream terminal, so cap how long `next_event` waits
61/// for trailing notifications after `Completed` before returning anyway.
62const DEFAULT_COMPLETION_DRAIN_TIMEOUT: Duration = Duration::from_secs(60);
63
64/// High-level managed backtest session.
65///
66/// This wrapper owns the control manager, supported subscription managers,
67/// cancellation, status gating, and shutdown order. Callers only need to react
68/// to [`ManagedEvent`]s and send [`ContinueParams`] after `ReadyForContinue`.
69pub struct ManagedBacktestSession {
70    session_info: SessionInfo,
71    control: Option<ControlHandle>,
72    subscriptions: Vec<SubscriptionHandle>,
73    session_cancel: CancellationToken,
74    /// Notifications drained on `Completed`, followed by `Completed`; served in
75    /// order by `next_event`. `None` until completion.
76    post_completion: Option<VecDeque<ManagedEvent>>,
77    completion_drain_timeout: Duration,
78}
79
80impl ManagedBacktestSession {
81    /// Start a managed session with an internally owned cancellation token.
82    pub async fn start(
83        url: String,
84        api_key: String,
85        create: CreateBacktestSessionRequest,
86    ) -> Result<Self, ManagedSessionError> {
87        Self::start_with_cancel(url, api_key, create, CancellationToken::new()).await
88    }
89
90    /// Start a managed session tied to a caller-owned cancellation token.
91    ///
92    /// Cancelling `parent_cancel` aborts startup and stops manager tasks.
93    pub async fn start_with_cancel(
94        url: String,
95        api_key: String,
96        create: CreateBacktestSessionRequest,
97        parent_cancel: CancellationToken,
98    ) -> Result<Self, ManagedSessionError> {
99        let session_cancel = parent_cancel.child_token();
100        let mut control = spawn_control_manager(url, api_key, create, session_cancel.clone());
101
102        let session_info = tokio::select! {
103            biased;
104            _ = parent_cancel.cancelled() => {
105                session_cancel.cancel();
106                control.join().await;
107                return Err(ManagedSessionError::Cancelled);
108            }
109            result = control.wait_for_session() => {
110                result.map_err(ManagedSessionError::Create)?
111            }
112        };
113
114        Ok(Self {
115            session_info,
116            control: Some(control),
117            subscriptions: Vec::new(),
118            session_cancel,
119            post_completion: None,
120            completion_drain_timeout: DEFAULT_COMPLETION_DRAIN_TIMEOUT,
121        })
122    }
123
124    /// Metadata reported by the server when the session was created.
125    pub fn session_info(&self) -> &SessionInfo {
126        &self.session_info
127    }
128
129    /// Subscribe to transaction notifications for the configured programs.
130    pub fn subscribe_transactions(&mut self, program_ids: Vec<String>) {
131        self.subscriptions
132            .push(spawn_transaction_subscription_manager(
133                self.session_info.rpc_endpoint.clone(),
134                program_ids,
135                self.session_cancel.clone(),
136            ));
137    }
138
139    /// Subscribe to account-diff notifications for the configured programs.
140    pub fn subscribe_account_diffs(&mut self, program_ids: Vec<String>) {
141        self.subscriptions
142            .push(spawn_account_diff_subscription_manager(
143                self.session_info.rpc_endpoint.clone(),
144                program_ids,
145                self.session_cancel.clone(),
146            ));
147    }
148
149    /// Drain notifications until every subscription delivers its end-of-stream
150    /// terminal (closing its channel), the session is cancelled, or `timeout`
151    /// elapses. The server orders the terminal after every notification, so
152    /// draining to closure yields every trailing transaction without racing the
153    /// control-plane `Completed`.
154    async fn drain_until_subscriptions_complete(
155        &mut self,
156        timeout: std::time::Duration,
157    ) -> Vec<ManagedEvent> {
158        let mut events = Vec::new();
159        if self.subscriptions.is_empty() {
160            return events;
161        }
162        let deadline = tokio::time::Instant::now() + timeout;
163        loop {
164            while let Some(event) = try_next_subscription_event(&mut self.subscriptions) {
165                events.push(event);
166            }
167            if self
168                .subscriptions
169                .iter()
170                .all(|s| s.notifications.is_closed())
171            {
172                return events;
173            }
174            tokio::select! {
175                biased;
176                _ = self.session_cancel.cancelled() => return events,
177                _ = tokio::time::sleep_until(deadline) => return events,
178                received = recv_any_open_subscription(&mut self.subscriptions) => {
179                    // `None` means a channel closed; the loop re-checks all-closed.
180                    if let Some(event) = received {
181                        events.push(event);
182                    }
183                }
184            }
185        }
186    }
187
188    /// Receive the next control or subscription event.
189    ///
190    /// On `Completed`, trailing subscription notifications are drained and
191    /// delivered before the `Completed` event.
192    pub async fn next_event(&mut self) -> Result<ManagedEvent, ManagedSessionError> {
193        // Serve buffered post-completion events (trailing notifications, then
194        // `Completed`); the control stream is gone once they're exhausted.
195        if let Some(buffered) = self.post_completion.as_mut() {
196            return buffered
197                .pop_front()
198                .ok_or(ManagedSessionError::ControlClosed);
199        }
200
201        if let Some(event) = try_next_subscription_event(&mut self.subscriptions) {
202            return Ok(event);
203        }
204
205        // Scope the borrows to the `select!` so the completion drain below can
206        // re-borrow `self`.
207        let event = {
208            let cancel = self.session_cancel.clone();
209            let control = self
210                .control
211                .as_mut()
212                .ok_or(ManagedSessionError::ControlClosed)?;
213            let subscriptions = &mut self.subscriptions;
214            tokio::select! {
215                biased;
216                _ = cancel.cancelled() => return Err(ManagedSessionError::Cancelled),
217                event = control.events.recv() => {
218                    event.map(ManagedEvent::from).ok_or(ManagedSessionError::ControlClosed)?
219                }
220                event = wait_any_subscription_event(subscriptions) => event,
221            }
222        };
223
224        if matches!(event, ManagedEvent::Completed) {
225            // Flush trailing notifications up to each subscription's terminal,
226            // delivering them before `Completed` so none are dropped.
227            let mut buffered: VecDeque<ManagedEvent> = self
228                .drain_until_subscriptions_complete(self.completion_drain_timeout)
229                .await
230                .into();
231            buffered.push_back(ManagedEvent::Completed);
232            let first = buffered.pop_front().expect("buffer contains Completed");
233            self.post_completion = Some(buffered);
234            return Ok(first);
235        }
236
237        Ok(event)
238    }
239
240    /// Wait until the control connection and all subscription connections are
241    /// up, then send a `Continue` request.
242    ///
243    /// Call this after receiving [`ManagedEvent::ReadyForContinue`] or
244    /// [`ManagedEvent::Paused`]. If there are no subscriptions, only the
245    /// control connection is gated.
246    pub async fn send_continue(
247        &mut self,
248        params: ContinueParams,
249    ) -> Result<(), ManagedSessionError> {
250        self.wait_all_up().await?;
251        self.control_mut()?
252            .send_continue(params)
253            .await
254            .map_err(|e| ManagedSessionError::ContinueSend(e.to_string()))
255    }
256
257    /// Wait until the control connection and all subscription connections are
258    /// up, then send a `ContinueTo` request to step to a specific slot/batch
259    /// boundary.
260    ///
261    /// Pair with [`ManagedEvent::DiscoveryBatch`] to drive a discovery-paced
262    /// loop: receive a discovery event, send `ContinueTo(slot, batch_index)`,
263    /// and wait for [`ManagedEvent::Paused`] before inspecting state.
264    pub async fn send_continue_to(
265        &mut self,
266        params: ContinueToParams,
267    ) -> Result<(), ManagedSessionError> {
268        self.wait_all_up().await?;
269        self.control_mut()?
270            .send_continue_to(params)
271            .await
272            .map_err(|e| ManagedSessionError::ContinueSend(e.to_string()))
273    }
274
275    /// Cancel the session and join all manager tasks.
276    pub async fn shutdown(mut self) {
277        self.session_cancel.cancel();
278        if let Some(control) = self.control.take() {
279            control.join().await;
280        }
281        for sub in std::mem::take(&mut self.subscriptions) {
282            let _ = sub.join.await;
283        }
284    }
285
286    fn control_mut(&mut self) -> Result<&mut ControlHandle, ManagedSessionError> {
287        self.control
288            .as_mut()
289            .ok_or(ManagedSessionError::ControlClosed)
290    }
291
292    async fn wait_all_up(&self) -> Result<(), ManagedSessionError> {
293        let mut control = self
294            .control
295            .as_ref()
296            .ok_or(ManagedSessionError::ControlClosed)?
297            .status
298            .clone();
299        let mut subscriptions: Vec<watch::Receiver<ConnectionStatus>> = self
300            .subscriptions
301            .iter()
302            .map(|s| s.status.clone())
303            .collect();
304
305        loop {
306            let control_status = control.borrow().clone();
307            if let ConnectionStatus::Failed(why) = &control_status {
308                return Err(ManagedSessionError::ControlFailed(why.clone()));
309            }
310
311            let mut all_subscriptions_up = true;
312            for subscription in &subscriptions {
313                match &*subscription.borrow() {
314                    ConnectionStatus::Failed(why) => {
315                        return Err(ManagedSessionError::SubscriptionFailed(why.clone()));
316                    }
317                    ConnectionStatus::Up => {}
318                    _ => all_subscriptions_up = false,
319                }
320            }
321
322            if control_status == ConnectionStatus::Up && all_subscriptions_up {
323                return Ok(());
324            }
325
326            tokio::select! {
327                _ = self.session_cancel.cancelled() => return Err(ManagedSessionError::Cancelled),
328                _ = control.changed() => {}
329                _ = wait_any_subscription_change(&mut subscriptions) => {}
330            }
331        }
332    }
333}
334
335impl Drop for ManagedBacktestSession {
336    fn drop(&mut self) {
337        self.session_cancel.cancel();
338    }
339}
340
341async fn wait_any_subscription_change(subscriptions: &mut [watch::Receiver<ConnectionStatus>]) {
342    if subscriptions.is_empty() {
343        std::future::pending::<()>().await;
344        return;
345    }
346    let _ =
347        futures::future::select_all(subscriptions.iter_mut().map(|s| Box::pin(s.changed()))).await;
348}
349
350async fn wait_any_subscription_event(subscriptions: &mut [SubscriptionHandle]) -> ManagedEvent {
351    loop {
352        if let Some(event) = try_next_subscription_event(subscriptions) {
353            return event;
354        }
355
356        let futures: Vec<_> = subscriptions
357            .iter_mut()
358            .filter(|s| !s.notifications.is_closed())
359            .map(|s| Box::pin(s.notifications.recv()))
360            .collect();
361
362        if futures.is_empty() {
363            std::future::pending::<()>().await;
364        }
365
366        let (notification, _, _) = futures::future::select_all(futures).await;
367        if let Some(notification) = notification {
368            return notification.into();
369        }
370    }
371}
372
373/// Await the next notification from any still-open subscription channel,
374/// returning `None` when one closes. Unlike [`wait_any_subscription_event`],
375/// which never resolves on closure, this lets the completion drain observe
376/// per-channel end-of-stream.
377async fn recv_any_open_subscription(
378    subscriptions: &mut [SubscriptionHandle],
379) -> Option<ManagedEvent> {
380    let futures: Vec<_> = subscriptions
381        .iter_mut()
382        .filter(|s| !s.notifications.is_closed())
383        .map(|s| Box::pin(s.notifications.recv()))
384        .collect();
385
386    if futures.is_empty() {
387        return None;
388    }
389
390    let (notification, _, _) = futures::future::select_all(futures).await;
391    notification.map(Into::into)
392}
393
394fn try_next_subscription_event(subscriptions: &mut [SubscriptionHandle]) -> Option<ManagedEvent> {
395    for subscription in subscriptions {
396        if let Ok(notification) = subscription.notifications.try_recv() {
397            return Some(notification.into());
398        }
399    }
400    None
401}
402
403impl From<ControlEvent> for ManagedEvent {
404    fn from(event: ControlEvent) -> Self {
405        match event {
406            ControlEvent::ReadyForContinue => Self::ReadyForContinue,
407            ControlEvent::Paused(event) => Self::Paused(event),
408            ControlEvent::DiscoveryBatch(event) => Self::DiscoveryBatch(event),
409            ControlEvent::Slot(slot) => Self::Slot(slot),
410            ControlEvent::Status(status) => Self::Status(status),
411            ControlEvent::Completed => Self::Completed,
412            ControlEvent::Error(error) => Self::Error(error),
413        }
414    }
415}
416
417impl From<SubscriptionNotification> for ManagedEvent {
418    fn from(notification: SubscriptionNotification) -> Self {
419        match notification {
420            SubscriptionNotification::Transaction(transaction) => Self::Transaction(transaction),
421            SubscriptionNotification::AccountDiff(diff) => Self::AccountDiff(diff),
422        }
423    }
424}