simulator_client/managed/
mod.rs1use std::{
13 future::Future,
14 sync::{
15 Arc,
16 atomic::{AtomicUsize, Ordering},
17 },
18 time::{Duration, Instant},
19};
20
21use futures::SinkExt;
22use rand::Rng;
23use simulator_api::{BacktestError, BacktestRequest};
24use tokio::{
25 net::TcpStream,
26 sync::{Notify, OwnedSemaphorePermit, Semaphore, watch},
27};
28use tokio_tungstenite::{
29 MaybeTlsStream, WebSocketStream, connect_async,
30 tungstenite::{Error as WsError, Message, client::IntoClientRequest, http::HeaderValue},
31};
32use tokio_util::sync::CancellationToken;
33use tracing::warn;
34
35use crate::error::err_chain;
36
37mod control;
38mod parallel;
39mod session;
40mod subscription;
41
42pub use control::{ControlEvent, ControlHandle, spawn_control_manager};
43pub use parallel::{ManagedParallelSession, ParallelSubSession};
44pub use session::{ManagedBacktestSession, ManagedEvent, ManagedSessionError};
45pub use subscription::{
46 SubscriptionHandle, SubscriptionNotification, spawn_account_diff_subscription_manager,
47 spawn_transaction_subscription_manager,
48};
49
50pub const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
52
53pub const HANDSHAKE_RESPONSE_TIMEOUT: Duration = Duration::from_secs(120);
64
65pub const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(15);
67
68pub const KEEPALIVE_MISS_DEADLINE: Duration = Duration::from_secs(45);
71
72pub const GRACEFUL_CLOSE_TIMEOUT: Duration = Duration::from_secs(5);
75
76pub const RECONNECT_INITIAL_BACKOFF: Duration = Duration::from_secs(1);
77pub const RECONNECT_MAX_BACKOFF: Duration = Duration::from_secs(30);
78pub const RECONNECT_BACKOFF_MULTIPLIER: f64 = 2.0;
79pub const RECONNECT_JITTER: f64 = 0.2;
80pub const RECONNECT_MAX_TOTAL: Duration = Duration::from_secs(5 * 60);
81pub const RECONNECT_MAX_ATTEMPTS: u32 = 20;
82
83pub const RECONNECT_UNGATED_ATTEMPTS: u32 = 5;
87
88pub const RECONNECT_UPTIME_RESET: Duration = Duration::from_secs(30);
90
91#[derive(Clone, Debug, PartialEq, Eq)]
96pub enum ConnectionStatus {
97 Up,
98 Down,
99 Failed(String),
100}
101
102#[derive(Clone, Debug)]
104pub struct SessionInfo {
105 pub session_id: String,
106 pub rpc_endpoint: String,
107 pub task_id: Option<String>,
109}
110
111pub(crate) struct ReconnectBudget {
113 attempts: u32,
114 started_at: std::time::Instant,
115 current_backoff: Duration,
116}
117
118impl ReconnectBudget {
119 pub fn new() -> Self {
120 Self {
121 attempts: 0,
122 started_at: std::time::Instant::now(),
123 current_backoff: RECONNECT_INITIAL_BACKOFF,
124 }
125 }
126
127 pub fn reset(&mut self) {
128 self.attempts = 0;
129 self.started_at = std::time::Instant::now();
130 self.current_backoff = RECONNECT_INITIAL_BACKOFF;
131 }
132
133 pub fn attempt(&self) -> u32 {
134 self.attempts
135 }
136
137 pub fn discount_parked(&mut self, parked: Duration) {
144 self.started_at += parked;
145 }
146
147 pub fn next_backoff(&mut self) -> Option<Duration> {
150 if self.attempts >= RECONNECT_MAX_ATTEMPTS
151 || self.started_at.elapsed() >= RECONNECT_MAX_TOTAL
152 {
153 return None;
154 }
155 self.attempts += 1;
156 let backoff = with_jitter(self.current_backoff);
157 self.current_backoff = std::cmp::min(
158 RECONNECT_MAX_BACKOFF,
159 Duration::from_secs_f64(
160 self.current_backoff.as_secs_f64() * RECONNECT_BACKOFF_MULTIPLIER,
161 ),
162 );
163 Some(backoff)
164 }
165}
166
167pub struct ReconnectCoordinator {
179 streaming: AtomicUsize,
182 drained: Notify,
184 handshake: Arc<Semaphore>,
187}
188
189impl Default for ReconnectCoordinator {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195impl ReconnectCoordinator {
196 pub fn new() -> Self {
197 Self {
198 streaming: AtomicUsize::new(0),
199 drained: Notify::new(),
200 handshake: Arc::new(Semaphore::new(1)),
201 }
202 }
203
204 pub fn enter(self: &Arc<Self>) -> StreamingGuard {
209 self.streaming.fetch_add(1, Ordering::SeqCst);
210 StreamingGuard(self.clone())
211 }
212
213 pub async fn reconnect_slot(&self, cancel: &CancellationToken) -> Option<OwnedSemaphorePermit> {
219 loop {
220 loop {
223 let drained = self.drained.notified();
224 if self.streaming.load(Ordering::SeqCst) == 0 {
225 break;
226 }
227 tokio::select! {
228 biased;
229 _ = cancel.cancelled() => return None,
230 _ = drained => {}
231 }
232 }
233 let permit = tokio::select! {
234 biased;
235 _ = cancel.cancelled() => return None,
236 p = self.handshake.clone().acquire_owned() => p.ok()?,
237 };
238 if self.streaming.load(Ordering::SeqCst) == 0 {
242 return Some(permit);
243 }
244 }
245 }
246}
247
248pub struct StreamingGuard(Arc<ReconnectCoordinator>);
251
252impl Drop for StreamingGuard {
253 fn drop(&mut self) {
254 self.0.streaming.fetch_sub(1, Ordering::SeqCst);
255 self.0.drained.notify_waiters();
256 }
257}
258
259fn with_jitter(d: Duration) -> Duration {
260 let jitter = rand::rng().random_range(-RECONNECT_JITTER..RECONNECT_JITTER);
261 let secs = (d.as_secs_f64() * (1.0 + jitter)).max(0.0);
262 Duration::from_secs_f64(secs)
263}
264
265pub(crate) async fn cancellable_sleep(delay: Duration, cancel: &CancellationToken) -> bool {
268 tokio::select! {
269 _ = tokio::time::sleep(delay) => true,
270 _ = cancel.cancelled() => false,
271 }
272}
273
274pub(super) type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
276
277pub(super) async fn connect_ws(url: &str, api_key: &str) -> Result<Ws, String> {
280 let mut request = url
281 .into_client_request()
282 .map_err(|e| format!("build request: {}", err_chain(&e)))?;
283 request.headers_mut().insert(
284 "X-API-Key",
285 HeaderValue::from_str(api_key).map_err(|e| format!("api key header: {}", err_chain(&e)))?,
286 );
287
288 let connect = tokio::time::timeout(CONNECT_TIMEOUT, connect_async(request))
289 .await
290 .map_err(|_| format!("connect timeout after {CONNECT_TIMEOUT:?}"))?
291 .map_err(|e| format!("connect: {}", err_chain(&e)))?;
292 Ok(connect.0)
293}
294
295pub(super) async fn send_request(ws: &mut Ws, req: &BacktestRequest) -> Result<(), String> {
297 let text = serde_json::to_string(req).map_err(|e| format!("serialize: {}", err_chain(&e)))?;
298 ws.send(Message::Text(text))
299 .await
300 .map_err(|e| format!("send: {}", err_chain(&e)))
301}
302
303pub(super) fn resolve_rpc_url(base: &str, endpoint: &str) -> String {
306 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
307 endpoint.to_string()
308 } else {
309 format!("{}/{}", base, endpoint.trim_start_matches('/'))
310 }
311}
312
313pub(super) enum HandshakeError {
315 Transient(String),
317 Fatal(String),
319}
320
321pub(super) fn handshake_error_for_response(
325 stage: &'static str,
326 err: BacktestError,
327) -> HandshakeError {
328 match err {
329 BacktestError::SessionOwnershipBusy { .. } => {
330 HandshakeError::Transient(format!("{stage} contended: {}", err_chain(&err)))
331 }
332 _ => HandshakeError::Fatal(format!("{stage} rejected: {}", err_chain(&err))),
333 }
334}
335
336pub(super) enum MessageLoopExit {
339 SessionEnded,
342 Cancelled,
345 ConnectionLost(String),
347 Terminal(String),
349}
350
351pub(super) fn publish_status(
353 status_tx: &watch::Sender<ConnectionStatus>,
354 status: ConnectionStatus,
355) {
356 status_tx.send_if_modified(|current| {
357 if *current == status {
358 false
359 } else {
360 *current = status;
361 true
362 }
363 });
364}
365
366pub(super) fn is_terminal_backtest_error(err: &BacktestError) -> bool {
368 matches!(
369 err,
370 BacktestError::NoMoreBlocks
371 | BacktestError::AdvanceSlotFailed { .. }
372 | BacktestError::FinalizeSlotFailed { .. }
373 | BacktestError::Internal { .. }
374 )
375}
376
377pub(super) async fn graceful_close(ws: &mut Ws) {
382 let _ = tokio::time::timeout(GRACEFUL_CLOSE_TIMEOUT, async {
383 let _ = send_request(ws, &BacktestRequest::CloseBacktestSession).await;
384 let _ = ws.close(None).await;
385 })
386 .await;
387}
388
389pub(super) enum InboundFrame {
391 Text(String),
393 Ignore,
395 Lost(String),
397}
398
399pub(super) fn classify_inbound(msg: Option<Result<Message, WsError>>) -> InboundFrame {
401 match msg {
402 Some(Ok(Message::Text(t))) => InboundFrame::Text(t),
403 Some(Ok(Message::Binary(b))) => match String::from_utf8(b) {
404 Ok(t) => InboundFrame::Text(t),
405 Err(_) => InboundFrame::Ignore,
406 },
407 Some(Ok(Message::Pong(_) | Message::Ping(_) | Message::Frame(_))) => InboundFrame::Ignore,
408 Some(Ok(Message::Close(frame))) => InboundFrame::Lost(format!("remote close: {frame:?}")),
409 Some(Err(e)) => InboundFrame::Lost(format!("ws read: {}", err_chain(&e))),
410 None => InboundFrame::Lost("ws stream ended".into()),
411 }
412}
413
414pub(super) async fn send_keepalive_ping(ws: &mut Ws, last_inbound: Instant) -> Option<String> {
418 if last_inbound.elapsed() > KEEPALIVE_MISS_DEADLINE {
419 return Some(format!("no traffic for {:?}", last_inbound.elapsed()));
420 }
421 if let Err(e) = ws.send(Message::Ping(vec![])).await {
422 return Some(format!("ping send: {}", err_chain(&e)));
423 }
424 None
425}
426
427pub(super) trait ControlConnection: Send + 'static {
431 fn url(&self) -> &str;
432 fn api_key(&self) -> &str;
433 fn cancel(&self) -> &CancellationToken;
434 fn label(&self) -> &'static str;
436 fn status_tx(&self) -> &watch::Sender<ConnectionStatus>;
437 fn fail_pending(&mut self, reason: String);
439 fn handshake(&mut self, ws: Ws) -> impl Future<Output = Result<Ws, HandshakeError>> + Send;
441 fn message_loop(&mut self, ws: Ws) -> impl Future<Output = MessageLoopExit> + Send;
443
444 fn publish(&self, status: ConnectionStatus) {
445 publish_status(self.status_tx(), status);
446 }
447
448 fn finish_failed(&mut self, reason: String) {
449 self.fail_pending(reason.clone());
450 self.publish(ConnectionStatus::Failed(reason));
451 }
452}
453
454pub(super) async fn run_control_loop<T: ControlConnection>(mut task: T) {
458 let mut budget = ReconnectBudget::new();
459
460 loop {
461 if task.cancel().is_cancelled() {
462 task.fail_pending("cancelled before session created".to_string());
463 return;
464 }
465 task.publish(ConnectionStatus::Down);
466
467 let ws = match connect_ws(task.url(), task.api_key()).await {
468 Ok(ws) => ws,
469 Err(why) => {
470 if let Some(delay) = budget.next_backoff() {
471 warn!(attempt = budget.attempt(), error = %why, ?delay, "{} connect failed, retrying", task.label());
472 if !cancellable_sleep(delay, task.cancel()).await {
473 return;
474 }
475 continue;
476 }
477 task.finish_failed(format!("connect: {why}"));
478 return;
479 }
480 };
481
482 let ws = match task.handshake(ws).await {
483 Ok(ws) => ws,
484 Err(HandshakeError::Fatal(why)) => {
485 task.finish_failed(format!("handshake: {why}"));
486 return;
487 }
488 Err(HandshakeError::Transient(why)) => {
489 if let Some(delay) = budget.next_backoff() {
490 warn!(attempt = budget.attempt(), error = %why, ?delay, "{} handshake failed, retrying", task.label());
491 if !cancellable_sleep(delay, task.cancel()).await {
492 return;
493 }
494 continue;
495 }
496 task.finish_failed(format!("handshake: {why}"));
497 return;
498 }
499 };
500
501 task.publish(ConnectionStatus::Up);
502 let connected_at = Instant::now();
503
504 match task.message_loop(ws).await {
505 MessageLoopExit::SessionEnded | MessageLoopExit::Cancelled => return,
506 MessageLoopExit::ConnectionLost(why) => {
507 if connected_at.elapsed() >= RECONNECT_UPTIME_RESET {
508 budget.reset();
509 }
510 if let Some(delay) = budget.next_backoff() {
511 warn!(attempt = budget.attempt(), reason = %why, ?delay, "{} connection lost, reconnecting", task.label());
512 if !cancellable_sleep(delay, task.cancel()).await {
513 return;
514 }
515 continue;
516 }
517 task.finish_failed(format!("connection lost: {why}"));
518 return;
519 }
520 MessageLoopExit::Terminal(why) => {
521 task.finish_failed(why);
522 return;
523 }
524 }
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 #[test]
533 fn budget_exhausts_after_max_attempts() {
534 let mut b = ReconnectBudget::new();
535 for _ in 0..RECONNECT_MAX_ATTEMPTS {
536 assert!(b.next_backoff().is_some());
537 }
538 assert!(b.next_backoff().is_none());
539 }
540
541 #[test]
542 fn budget_reset_restores_full_budget() {
543 let mut b = ReconnectBudget::new();
544 b.next_backoff();
545 b.next_backoff();
546 b.reset();
547 assert_eq!(b.attempt(), 0);
548 }
549
550 #[test]
551 fn streaming_guard_balances_the_count() {
552 let coord = Arc::new(ReconnectCoordinator::new());
553 assert_eq!(coord.streaming.load(Ordering::SeqCst), 0);
554 let g = coord.enter();
555 assert_eq!(coord.streaming.load(Ordering::SeqCst), 1);
556 drop(g);
557 assert_eq!(coord.streaming.load(Ordering::SeqCst), 0);
558 }
559
560 #[tokio::test]
561 async fn reconnect_slot_available_when_link_is_quiet() {
562 let coord = Arc::new(ReconnectCoordinator::new());
563 let cancel = CancellationToken::new();
564 assert!(coord.reconnect_slot(&cancel).await.is_some());
565 }
566
567 #[tokio::test]
568 async fn reconnect_slot_unparks_when_last_sibling_leaves() {
569 let coord = Arc::new(ReconnectCoordinator::new());
570 let cancel = CancellationToken::new();
571 let guard = coord.enter(); let waiter = tokio::spawn({
574 let coord = coord.clone();
575 let cancel = cancel.clone();
576 async move { coord.reconnect_slot(&cancel).await.is_some() }
577 });
578
579 tokio::task::yield_now().await;
581 assert!(!waiter.is_finished());
582
583 drop(guard); assert!(waiter.await.unwrap());
585 }
586
587 #[tokio::test]
588 async fn reconnect_slot_returns_none_on_cancel_while_parked() {
589 let coord = Arc::new(ReconnectCoordinator::new());
590 let _guard = coord.enter(); let cancel = CancellationToken::new();
592 cancel.cancel();
593 assert!(coord.reconnect_slot(&cancel).await.is_none());
594 }
595
596 #[test]
597 fn discount_parked_does_not_consume_the_budget() {
598 let mut b = ReconnectBudget::new();
599 b.discount_parked(2 * RECONNECT_MAX_TOTAL);
601 for _ in 0..RECONNECT_MAX_ATTEMPTS {
602 assert!(b.next_backoff().is_some());
603 }
604 }
605}