simulator_client/managed/
session.rs1use std::{collections::VecDeque, time::Duration};
2
3use simulator_api::{
4 BacktestError, BacktestStatus, ContinueParams, ContinueToParams, CreateBacktestSessionRequest,
5 DiscoveryBatchEvent, PausedEvent,
6};
7use solana_transaction_status::EncodedConfirmedTransactionWithStatusMeta;
8use thiserror::Error;
9use tokio::sync::watch;
10use tokio_util::sync::CancellationToken;
11
12use super::{
13 ConnectionStatus, ControlEvent, ControlHandle, SessionInfo, SubscriptionHandle,
14 SubscriptionNotification, spawn_account_diff_subscription_manager, spawn_control_manager,
15 spawn_transaction_subscription_manager,
16};
17use crate::subscriptions::AccountDiffNotification;
18
19#[derive(Debug, Error)]
21pub enum ManagedSessionError {
22 #[error("session create failed: {0}")]
23 Create(String),
24
25 #[error("control channel closed")]
26 ControlClosed,
27
28 #[error("control failed: {0}")]
29 ControlFailed(String),
30
31 #[error("subscription failed: {0}")]
32 SubscriptionFailed(String),
33
34 #[error("cancelled")]
35 Cancelled,
36
37 #[error("control closed while sending continue: {0}")]
38 ContinueSend(String),
39}
40
41#[derive(Debug)]
42pub enum ManagedEvent {
43 ReadyForContinue,
44 Paused(PausedEvent),
47 DiscoveryBatch(DiscoveryBatchEvent),
51 Slot(u64),
52 Status(BacktestStatus),
53 Completed,
54 Error(BacktestError),
55 Transaction(Box<EncodedConfirmedTransactionWithStatusMeta>),
56 AccountDiff(AccountDiffNotification),
57}
58
59const DEFAULT_COMPLETION_DRAIN_TIMEOUT: Duration = Duration::from_secs(60);
63
64pub struct ManagedBacktestSession {
70 session_info: SessionInfo,
71 control: Option<ControlHandle>,
72 subscriptions: Vec<SubscriptionHandle>,
73 session_cancel: CancellationToken,
74 post_completion: Option<VecDeque<ManagedEvent>>,
77 completion_drain_timeout: Duration,
78}
79
80impl ManagedBacktestSession {
81 pub async fn start(
83 url: String,
84 api_key: String,
85 create: CreateBacktestSessionRequest,
86 ) -> Result<Self, ManagedSessionError> {
87 Self::start_with_cancel(url, api_key, create, CancellationToken::new()).await
88 }
89
90 pub async fn start_with_cancel(
94 url: String,
95 api_key: String,
96 create: CreateBacktestSessionRequest,
97 parent_cancel: CancellationToken,
98 ) -> Result<Self, ManagedSessionError> {
99 let session_cancel = parent_cancel.child_token();
100 let mut control = spawn_control_manager(url, api_key, create, session_cancel.clone());
101
102 let session_info = tokio::select! {
103 biased;
104 _ = parent_cancel.cancelled() => {
105 session_cancel.cancel();
106 control.join().await;
107 return Err(ManagedSessionError::Cancelled);
108 }
109 result = control.wait_for_session() => {
110 result.map_err(ManagedSessionError::Create)?
111 }
112 };
113
114 Ok(Self {
115 session_info,
116 control: Some(control),
117 subscriptions: Vec::new(),
118 session_cancel,
119 post_completion: None,
120 completion_drain_timeout: DEFAULT_COMPLETION_DRAIN_TIMEOUT,
121 })
122 }
123
124 pub fn session_info(&self) -> &SessionInfo {
126 &self.session_info
127 }
128
129 pub fn subscribe_transactions(&mut self, program_ids: Vec<String>) {
131 self.subscriptions
132 .push(spawn_transaction_subscription_manager(
133 self.session_info.rpc_endpoint.clone(),
134 program_ids,
135 self.session_cancel.clone(),
136 ));
137 }
138
139 pub fn subscribe_account_diffs(&mut self, program_ids: Vec<String>) {
141 self.subscriptions
142 .push(spawn_account_diff_subscription_manager(
143 self.session_info.rpc_endpoint.clone(),
144 program_ids,
145 self.session_cancel.clone(),
146 ));
147 }
148
149 async fn drain_until_subscriptions_complete(
155 &mut self,
156 timeout: std::time::Duration,
157 ) -> Vec<ManagedEvent> {
158 let mut events = Vec::new();
159 if self.subscriptions.is_empty() {
160 return events;
161 }
162 let deadline = tokio::time::Instant::now() + timeout;
163 loop {
164 while let Some(event) = try_next_subscription_event(&mut self.subscriptions) {
165 events.push(event);
166 }
167 if self
168 .subscriptions
169 .iter()
170 .all(|s| s.notifications.is_closed())
171 {
172 return events;
173 }
174 tokio::select! {
175 biased;
176 _ = self.session_cancel.cancelled() => return events,
177 _ = tokio::time::sleep_until(deadline) => return events,
178 received = recv_any_open_subscription(&mut self.subscriptions) => {
179 if let Some(event) = received {
181 events.push(event);
182 }
183 }
184 }
185 }
186 }
187
188 pub async fn next_event(&mut self) -> Result<ManagedEvent, ManagedSessionError> {
193 if let Some(buffered) = self.post_completion.as_mut() {
196 return buffered
197 .pop_front()
198 .ok_or(ManagedSessionError::ControlClosed);
199 }
200
201 if let Some(event) = try_next_subscription_event(&mut self.subscriptions) {
202 return Ok(event);
203 }
204
205 let event = {
208 let cancel = self.session_cancel.clone();
209 let control = self
210 .control
211 .as_mut()
212 .ok_or(ManagedSessionError::ControlClosed)?;
213 let subscriptions = &mut self.subscriptions;
214 tokio::select! {
215 biased;
216 _ = cancel.cancelled() => return Err(ManagedSessionError::Cancelled),
217 event = control.events.recv() => {
218 event.map(ManagedEvent::from).ok_or(ManagedSessionError::ControlClosed)?
219 }
220 event = wait_any_subscription_event(subscriptions) => event,
221 }
222 };
223
224 if matches!(event, ManagedEvent::Completed) {
225 let mut buffered: VecDeque<ManagedEvent> = self
228 .drain_until_subscriptions_complete(self.completion_drain_timeout)
229 .await
230 .into();
231 buffered.push_back(ManagedEvent::Completed);
232 let first = buffered.pop_front().expect("buffer contains Completed");
233 self.post_completion = Some(buffered);
234 return Ok(first);
235 }
236
237 Ok(event)
238 }
239
240 pub async fn send_continue(
247 &mut self,
248 params: ContinueParams,
249 ) -> Result<(), ManagedSessionError> {
250 self.wait_all_up().await?;
251 self.control_mut()?
252 .send_continue(params)
253 .await
254 .map_err(|e| ManagedSessionError::ContinueSend(e.to_string()))
255 }
256
257 pub async fn send_continue_to(
265 &mut self,
266 params: ContinueToParams,
267 ) -> Result<(), ManagedSessionError> {
268 self.wait_all_up().await?;
269 self.control_mut()?
270 .send_continue_to(params)
271 .await
272 .map_err(|e| ManagedSessionError::ContinueSend(e.to_string()))
273 }
274
275 pub async fn shutdown(mut self) {
277 self.session_cancel.cancel();
278 if let Some(control) = self.control.take() {
279 control.join().await;
280 }
281 for sub in std::mem::take(&mut self.subscriptions) {
282 let _ = sub.join.await;
283 }
284 }
285
286 fn control_mut(&mut self) -> Result<&mut ControlHandle, ManagedSessionError> {
287 self.control
288 .as_mut()
289 .ok_or(ManagedSessionError::ControlClosed)
290 }
291
292 async fn wait_all_up(&self) -> Result<(), ManagedSessionError> {
293 let mut control = self
294 .control
295 .as_ref()
296 .ok_or(ManagedSessionError::ControlClosed)?
297 .status
298 .clone();
299 let mut subscriptions: Vec<watch::Receiver<ConnectionStatus>> = self
300 .subscriptions
301 .iter()
302 .map(|s| s.status.clone())
303 .collect();
304
305 loop {
306 let control_status = control.borrow().clone();
307 if let ConnectionStatus::Failed(why) = &control_status {
308 return Err(ManagedSessionError::ControlFailed(why.clone()));
309 }
310
311 let mut all_subscriptions_up = true;
312 for subscription in &subscriptions {
313 match &*subscription.borrow() {
314 ConnectionStatus::Failed(why) => {
315 return Err(ManagedSessionError::SubscriptionFailed(why.clone()));
316 }
317 ConnectionStatus::Up => {}
318 _ => all_subscriptions_up = false,
319 }
320 }
321
322 if control_status == ConnectionStatus::Up && all_subscriptions_up {
323 return Ok(());
324 }
325
326 tokio::select! {
327 _ = self.session_cancel.cancelled() => return Err(ManagedSessionError::Cancelled),
328 _ = control.changed() => {}
329 _ = wait_any_subscription_change(&mut subscriptions) => {}
330 }
331 }
332 }
333}
334
335impl Drop for ManagedBacktestSession {
336 fn drop(&mut self) {
337 self.session_cancel.cancel();
338 }
339}
340
341async fn wait_any_subscription_change(subscriptions: &mut [watch::Receiver<ConnectionStatus>]) {
342 if subscriptions.is_empty() {
343 std::future::pending::<()>().await;
344 return;
345 }
346 let _ =
347 futures::future::select_all(subscriptions.iter_mut().map(|s| Box::pin(s.changed()))).await;
348}
349
350async fn wait_any_subscription_event(subscriptions: &mut [SubscriptionHandle]) -> ManagedEvent {
351 loop {
352 if let Some(event) = try_next_subscription_event(subscriptions) {
353 return event;
354 }
355
356 let futures: Vec<_> = subscriptions
357 .iter_mut()
358 .filter(|s| !s.notifications.is_closed())
359 .map(|s| Box::pin(s.notifications.recv()))
360 .collect();
361
362 if futures.is_empty() {
363 std::future::pending::<()>().await;
364 }
365
366 let (notification, _, _) = futures::future::select_all(futures).await;
367 if let Some(notification) = notification {
368 return notification.into();
369 }
370 }
371}
372
373async fn recv_any_open_subscription(
378 subscriptions: &mut [SubscriptionHandle],
379) -> Option<ManagedEvent> {
380 let futures: Vec<_> = subscriptions
381 .iter_mut()
382 .filter(|s| !s.notifications.is_closed())
383 .map(|s| Box::pin(s.notifications.recv()))
384 .collect();
385
386 if futures.is_empty() {
387 return None;
388 }
389
390 let (notification, _, _) = futures::future::select_all(futures).await;
391 notification.map(Into::into)
392}
393
394fn try_next_subscription_event(subscriptions: &mut [SubscriptionHandle]) -> Option<ManagedEvent> {
395 for subscription in subscriptions {
396 if let Ok(notification) = subscription.notifications.try_recv() {
397 return Some(notification.into());
398 }
399 }
400 None
401}
402
403impl From<ControlEvent> for ManagedEvent {
404 fn from(event: ControlEvent) -> Self {
405 match event {
406 ControlEvent::ReadyForContinue => Self::ReadyForContinue,
407 ControlEvent::Paused(event) => Self::Paused(event),
408 ControlEvent::DiscoveryBatch(event) => Self::DiscoveryBatch(event),
409 ControlEvent::Slot(slot) => Self::Slot(slot),
410 ControlEvent::Status(status) => Self::Status(status),
411 ControlEvent::Completed => Self::Completed,
412 ControlEvent::Error(error) => Self::Error(error),
413 }
414 }
415}
416
417impl From<SubscriptionNotification> for ManagedEvent {
418 fn from(notification: SubscriptionNotification) -> Self {
419 match notification {
420 SubscriptionNotification::Transaction(transaction) => Self::Transaction(transaction),
421 SubscriptionNotification::AccountDiff(diff) => Self::AccountDiff(diff),
422 }
423 }
424}