simulator_client/managed/
session.rs1use std::{collections::VecDeque, sync::Arc, time::Duration};
2
3use simulator_api::{
4 AgentStatsReport, BacktestError, BacktestStatus, ContinueParams, ContinueToParams,
5 CreateBacktestSessionRequest, DiscoveryBatchEvent, PausedEvent, SessionSummary,
6};
7use solana_transaction_status::EncodedConfirmedTransactionWithStatusMeta;
8use thiserror::Error;
9use tokio::sync::watch;
10use tokio_util::sync::CancellationToken;
11
12use super::{
13 ConnectionStatus, ControlEvent, ControlHandle, ReconnectCoordinator, SessionInfo,
14 SubscriptionHandle, SubscriptionNotification, spawn_account_diff_subscription_manager,
15 spawn_control_manager, spawn_transaction_subscription_manager,
16};
17use crate::subscriptions::AccountDiffNotification;
18
19#[derive(Debug, Error)]
21pub enum ManagedSessionError {
22 #[error("session create failed: {0}")]
23 Create(String),
24
25 #[error("control channel closed")]
26 ControlClosed,
27
28 #[error("control failed: {0}")]
29 ControlFailed(String),
30
31 #[error("subscription failed: {0}")]
32 SubscriptionFailed(String),
33
34 #[error("cancelled")]
35 Cancelled,
36
37 #[error("control closed while sending continue: {0}")]
38 ContinueSend(String),
39}
40
41#[derive(Debug)]
42pub enum ManagedEvent {
43 ReadyForContinue,
44 Paused(PausedEvent),
47 DiscoveryBatch(DiscoveryBatchEvent),
51 Slot(u64),
52 Status(BacktestStatus),
53 Completed {
56 summary: Option<SessionSummary>,
57 agent_stats: Option<Vec<AgentStatsReport>>,
58 },
59 Error(BacktestError),
60 Transaction(Box<EncodedConfirmedTransactionWithStatusMeta>),
61 AccountDiff(AccountDiffNotification),
62}
63
64const DEFAULT_COMPLETION_DRAIN_TIMEOUT: Duration = Duration::from_secs(60);
69
70pub(super) enum DrainOutcome {
72 Complete(Vec<ManagedEvent>),
75 Stalled(Vec<ManagedEvent>),
79}
80
81pub struct ManagedBacktestSession {
87 session_info: SessionInfo,
88 control: Option<ControlHandle>,
89 subscriptions: Vec<SubscriptionHandle>,
90 session_cancel: CancellationToken,
91 post_completion: Option<VecDeque<ManagedEvent>>,
94 post_completion_error: Option<ManagedSessionError>,
98 completion_drain_timeout: Duration,
99 reconnect_coordinator: Option<Arc<ReconnectCoordinator>>,
102}
103
104impl ManagedBacktestSession {
105 pub async fn start(
107 url: String,
108 api_key: String,
109 create: CreateBacktestSessionRequest,
110 ) -> Result<Self, ManagedSessionError> {
111 Self::start_with_cancel(url, api_key, create, CancellationToken::new(), None).await
112 }
113
114 pub async fn start_with_cancel(
124 url: String,
125 api_key: String,
126 create: CreateBacktestSessionRequest,
127 parent_cancel: CancellationToken,
128 reconnect_coordinator: Option<Arc<ReconnectCoordinator>>,
129 ) -> Result<Self, ManagedSessionError> {
130 let session_cancel = parent_cancel.child_token();
131 let mut control = spawn_control_manager(url, api_key, create, session_cancel.clone());
132
133 let session_info = tokio::select! {
134 biased;
135 _ = parent_cancel.cancelled() => {
136 session_cancel.cancel();
137 control.join().await;
138 return Err(ManagedSessionError::Cancelled);
139 }
140 result = control.wait_for_session() => {
141 result.map_err(ManagedSessionError::Create)?
142 }
143 };
144
145 Ok(Self {
146 session_info,
147 control: Some(control),
148 subscriptions: Vec::new(),
149 session_cancel,
150 post_completion: None,
151 post_completion_error: None,
152 completion_drain_timeout: DEFAULT_COMPLETION_DRAIN_TIMEOUT,
153 reconnect_coordinator,
154 })
155 }
156
157 pub fn session_info(&self) -> &SessionInfo {
159 &self.session_info
160 }
161
162 pub fn set_completion_drain_timeout(&mut self, idle_timeout: std::time::Duration) {
166 self.completion_drain_timeout = idle_timeout;
167 }
168
169 pub fn subscribe_transactions(&mut self, program_ids: Vec<String>) {
171 self.subscriptions
172 .push(spawn_transaction_subscription_manager(
173 self.session_info.rpc_endpoint.clone(),
174 program_ids,
175 self.session_cancel.clone(),
176 self.reconnect_coordinator.clone(),
177 ));
178 }
179
180 pub fn subscribe_account_diffs(&mut self, program_ids: Vec<String>) {
182 self.subscriptions
183 .push(spawn_account_diff_subscription_manager(
184 self.session_info.rpc_endpoint.clone(),
185 program_ids,
186 self.session_cancel.clone(),
187 self.reconnect_coordinator.clone(),
188 ));
189 }
190
191 async fn drain_until_subscriptions_complete(
196 &mut self,
197 idle_timeout: std::time::Duration,
198 ) -> DrainOutcome {
199 drain_subscriptions_until_complete(
200 &mut self.subscriptions,
201 &self.session_cancel,
202 idle_timeout,
203 )
204 .await
205 }
206
207 pub async fn next_event(&mut self) -> Result<ManagedEvent, ManagedSessionError> {
212 if let Some(buffered) = self.post_completion.as_mut() {
215 if let Some(event) = buffered.pop_front() {
216 return Ok(event);
217 }
218 return Err(self
221 .post_completion_error
222 .take()
223 .unwrap_or(ManagedSessionError::ControlClosed));
224 }
225
226 if let Some(event) = try_next_subscription_event(&mut self.subscriptions) {
227 return Ok(event);
228 }
229
230 let event = {
233 let cancel = self.session_cancel.clone();
234 let control = self
235 .control
236 .as_mut()
237 .ok_or(ManagedSessionError::ControlClosed)?;
238 let subscriptions = &mut self.subscriptions;
239 tokio::select! {
240 biased;
241 _ = cancel.cancelled() => return Err(ManagedSessionError::Cancelled),
242 event = control.events.recv() => {
243 event.map(ManagedEvent::from).ok_or(ManagedSessionError::ControlClosed)?
244 }
245 event = wait_any_subscription_event(subscriptions) => event,
246 }
247 };
248
249 let ManagedEvent::Completed {
251 summary,
252 agent_stats,
253 } = event
254 else {
255 return Ok(event);
256 };
257
258 let (mut buffered, terminal): (VecDeque<ManagedEvent>, _) = match self
261 .drain_until_subscriptions_complete(self.completion_drain_timeout)
262 .await
263 {
264 DrainOutcome::Complete(events) => (
265 events.into(),
266 Ok(ManagedEvent::Completed {
267 summary,
268 agent_stats,
269 }),
270 ),
271 DrainOutcome::Stalled(events) => (
276 events.into(),
277 Err(ManagedSessionError::SubscriptionFailed(
278 "completion drain stalled: subscriptions did not deliver their \
279 end-of-stream terminals; the captured stream is incomplete"
280 .to_string(),
281 )),
282 ),
283 };
284 match terminal {
285 Ok(completed) => buffered.push_back(completed),
286 Err(err) => self.post_completion_error = Some(err),
287 }
288 let first = buffered.pop_front();
289 self.post_completion = Some(buffered);
290 match first {
291 Some(event) => Ok(event),
292 None => Err(self
294 .post_completion_error
295 .take()
296 .unwrap_or(ManagedSessionError::ControlClosed)),
297 }
298 }
299
300 pub async fn send_continue(
307 &mut self,
308 params: ContinueParams,
309 ) -> Result<(), ManagedSessionError> {
310 self.wait_all_up().await?;
311 self.control_mut()?
312 .send_continue(params)
313 .await
314 .map_err(|e| ManagedSessionError::ContinueSend(e.to_string()))
315 }
316
317 pub async fn send_continue_to(
325 &mut self,
326 params: ContinueToParams,
327 ) -> Result<(), ManagedSessionError> {
328 self.wait_all_up().await?;
329 self.control_mut()?
330 .send_continue_to(params)
331 .await
332 .map_err(|e| ManagedSessionError::ContinueSend(e.to_string()))
333 }
334
335 pub async fn shutdown(mut self) {
337 self.session_cancel.cancel();
338 if let Some(control) = self.control.take() {
339 control.join().await;
340 }
341 for sub in std::mem::take(&mut self.subscriptions) {
342 let _ = sub.join.await;
343 }
344 }
345
346 fn control_mut(&mut self) -> Result<&mut ControlHandle, ManagedSessionError> {
347 self.control
348 .as_mut()
349 .ok_or(ManagedSessionError::ControlClosed)
350 }
351
352 async fn wait_all_up(&self) -> Result<(), ManagedSessionError> {
353 let control = self
354 .control
355 .as_ref()
356 .ok_or(ManagedSessionError::ControlClosed)?
357 .status
358 .clone();
359 let subscriptions = self
360 .subscriptions
361 .iter()
362 .map(|s| s.status.clone())
363 .collect();
364 wait_connections_up(control, subscriptions, &self.session_cancel).await
365 }
366}
367
368pub(super) async fn wait_connections_up(
372 mut control: watch::Receiver<ConnectionStatus>,
373 mut subscriptions: Vec<watch::Receiver<ConnectionStatus>>,
374 cancel: &CancellationToken,
375) -> Result<(), ManagedSessionError> {
376 loop {
377 let control_status = control.borrow().clone();
378 if let ConnectionStatus::Failed(why) = &control_status {
379 return Err(ManagedSessionError::ControlFailed(why.clone()));
380 }
381
382 let mut all_subscriptions_up = true;
383 for subscription in &subscriptions {
384 match &*subscription.borrow() {
385 ConnectionStatus::Failed(why) => {
386 return Err(ManagedSessionError::SubscriptionFailed(why.clone()));
387 }
388 ConnectionStatus::Up => {}
389 _ => all_subscriptions_up = false,
390 }
391 }
392
393 if control_status == ConnectionStatus::Up && all_subscriptions_up {
394 return Ok(());
395 }
396
397 tokio::select! {
398 _ = cancel.cancelled() => return Err(ManagedSessionError::Cancelled),
399 _ = control.changed() => {}
400 _ = wait_any_subscription_change(&mut subscriptions) => {}
401 }
402 }
403}
404
405pub(super) async fn drain_subscriptions_until_complete(
420 subscriptions: &mut [SubscriptionHandle],
421 cancel: &CancellationToken,
422 idle_timeout: std::time::Duration,
423) -> DrainOutcome {
424 let mut events = Vec::new();
425 if subscriptions.is_empty() {
426 return DrainOutcome::Complete(events);
427 }
428 loop {
429 while let Some(event) = try_next_subscription_event(subscriptions) {
430 events.push(event);
431 }
432 if subscriptions.iter().all(|s| s.notifications.is_closed()) {
433 let any_failed = subscriptions
434 .iter()
435 .any(|s| matches!(*s.status.borrow(), ConnectionStatus::Failed(_)));
436 return if any_failed {
437 DrainOutcome::Stalled(events)
438 } else {
439 DrainOutcome::Complete(events)
440 };
441 }
442 tokio::select! {
443 biased;
444 _ = cancel.cancelled() => return DrainOutcome::Complete(events),
446 _ = tokio::time::sleep(idle_timeout) => {
451 let any_up = subscriptions.iter().any(|s| {
452 !s.notifications.is_closed()
453 && matches!(*s.status.borrow(), ConnectionStatus::Up)
454 });
455 if any_up {
456 return DrainOutcome::Stalled(events);
457 }
458 }
459 received = recv_any_open_subscription(subscriptions) => {
460 if let Some(event) = received {
462 events.push(event);
463 }
464 }
465 }
466 }
467}
468
469impl Drop for ManagedBacktestSession {
470 fn drop(&mut self) {
471 self.session_cancel.cancel();
472 }
473}
474
475pub(super) async fn wait_any_subscription_change(
476 subscriptions: &mut [watch::Receiver<ConnectionStatus>],
477) {
478 if subscriptions.is_empty() {
479 std::future::pending::<()>().await;
480 return;
481 }
482 let _ =
483 futures::future::select_all(subscriptions.iter_mut().map(|s| Box::pin(s.changed()))).await;
484}
485
486pub(super) async fn wait_any_subscription_event(
487 subscriptions: &mut [SubscriptionHandle],
488) -> ManagedEvent {
489 loop {
490 if let Some(event) = try_next_subscription_event(subscriptions) {
491 return event;
492 }
493
494 let futures: Vec<_> = subscriptions
495 .iter_mut()
496 .filter(|s| !s.notifications.is_closed())
497 .map(|s| Box::pin(s.notifications.recv()))
498 .collect();
499
500 if futures.is_empty() {
501 std::future::pending::<()>().await;
502 }
503
504 let (notification, _, _) = futures::future::select_all(futures).await;
505 if let Some(notification) = notification {
506 return notification.into();
507 }
508 }
509}
510
511pub(super) async fn recv_any_open_subscription(
516 subscriptions: &mut [SubscriptionHandle],
517) -> Option<ManagedEvent> {
518 let futures: Vec<_> = subscriptions
519 .iter_mut()
520 .filter(|s| !s.notifications.is_closed())
521 .map(|s| Box::pin(s.notifications.recv()))
522 .collect();
523
524 if futures.is_empty() {
525 return None;
526 }
527
528 let (notification, _, _) = futures::future::select_all(futures).await;
529 notification.map(Into::into)
530}
531
532pub(super) fn try_next_subscription_event(
533 subscriptions: &mut [SubscriptionHandle],
534) -> Option<ManagedEvent> {
535 for subscription in subscriptions {
536 if let Ok(notification) = subscription.notifications.try_recv() {
537 return Some(notification.into());
538 }
539 }
540 None
541}
542
543impl From<ControlEvent> for ManagedEvent {
544 fn from(event: ControlEvent) -> Self {
545 match event {
546 ControlEvent::ReadyForContinue => Self::ReadyForContinue,
547 ControlEvent::Paused(event) => Self::Paused(event),
548 ControlEvent::DiscoveryBatch(event) => Self::DiscoveryBatch(event),
549 ControlEvent::Slot(slot) => Self::Slot(slot),
550 ControlEvent::Status(status) => Self::Status(status),
551 ControlEvent::Completed {
552 summary,
553 agent_stats,
554 } => Self::Completed {
555 summary,
556 agent_stats,
557 },
558 ControlEvent::Error(error) => Self::Error(error),
559 }
560 }
561}
562
563impl From<SubscriptionNotification> for ManagedEvent {
564 fn from(notification: SubscriptionNotification) -> Self {
565 match notification {
566 SubscriptionNotification::Transaction(transaction) => Self::Transaction(transaction),
567 SubscriptionNotification::AccountDiff(diff) => Self::AccountDiff(diff),
568 }
569 }
570}