1use std::{
17 collections::{HashMap, VecDeque},
18 sync::Arc,
19};
20
21use futures::StreamExt;
22use simulator_api::{
23 BacktestError, BacktestRequest, BacktestResponse, ContinueParams, ContinueSessionRequestV1,
24 CreateBacktestSessionRequest, SessionEventKind,
25};
26use tokio::{
27 sync::{mpsc, oneshot, watch},
28 task::JoinHandle,
29};
30use tokio_tungstenite::tungstenite::Message;
31use tokio_util::sync::CancellationToken;
32use tracing::{debug, info, warn};
33
34use super::{
35 ConnectionStatus, ControlConnection, ControlEvent, HANDSHAKE_RESPONSE_TIMEOUT, HandshakeError,
36 InboundFrame, KEEPALIVE_INTERVAL, ManagedEvent, ManagedSessionError, MessageLoopExit,
37 ReconnectCoordinator, SessionInfo, SubscriptionHandle, Ws, classify_inbound, graceful_close,
38 handshake_error_for_response, is_terminal_backtest_error, resolve_rpc_url, run_control_loop,
39 send_keepalive_ping, send_request,
40 session::{
41 DrainOutcome, drain_subscriptions_until_complete, try_next_subscription_event,
42 wait_any_subscription_event, wait_connections_up,
43 },
44 spawn_account_diff_subscription_manager, spawn_action_subscription_manager,
45 spawn_transaction_subscription_manager,
46};
47use crate::{error::err_chain, urls::http_base_from_ws_url};
48
49const CREATE_RESPONSE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(900);
52const CONTINUE_CHANNEL_CAPACITY: usize = 256;
55const COMPLETION_DRAIN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
58
59type TaggedContinue = (String, ContinueParams);
61
62struct ParallelCreated {
64 control_session_id: String,
65 sessions: Vec<CreatedSubSession>,
66}
67
68struct CreatedSubSession {
69 info: SessionInfo,
70 events: mpsc::UnboundedReceiver<ControlEvent>,
71 start_slot: u64,
75 end_slot: u64,
76}
77
78struct ParallelControlHandle {
80 continues: mpsc::Sender<TaggedContinue>,
81 status: watch::Receiver<ConnectionStatus>,
82 created: Option<oneshot::Receiver<Result<ParallelCreated, String>>>,
83 join: JoinHandle<()>,
84}
85
86impl ParallelControlHandle {
87 async fn wait_created(&mut self) -> Result<ParallelCreated, String> {
90 let rx = self
91 .created
92 .take()
93 .ok_or_else(|| "parallel create already consumed".to_string())?;
94 rx.await
95 .map_err(|_| "control manager exited before creating sessions".to_string())?
96 }
97
98 async fn join(self) {
101 drop(self.continues);
102 let _ = self.join.await;
103 }
104}
105
106fn spawn_parallel_control_manager(
107 url: String,
108 api_key: String,
109 create: CreateBacktestSessionRequest,
110 cancel: CancellationToken,
111) -> ParallelControlHandle {
112 let (continues_tx, continues_rx) = mpsc::channel::<TaggedContinue>(CONTINUE_CHANNEL_CAPACITY);
113 let (status_tx, status_rx) = watch::channel(ConnectionStatus::Down);
114 let (created_tx, created_rx) = oneshot::channel::<Result<ParallelCreated, String>>();
115
116 let task = ParallelControlTask {
117 url,
118 api_key,
119 create: Some(create),
120 control_session_id: None,
121 event_txs: HashMap::new(),
122 last_sequences: HashMap::new(),
123 completed: std::collections::HashSet::new(),
124 continues_rx,
125 status_tx,
126 created_tx: Some(created_tx),
127 cancel,
128 };
129
130 let join = tokio::spawn(run_control_loop(task));
131
132 ParallelControlHandle {
133 continues: continues_tx,
134 status: status_rx,
135 created: Some(created_rx),
136 join,
137 }
138}
139
140struct ParallelControlTask {
141 url: String,
142 api_key: String,
143 create: Option<CreateBacktestSessionRequest>,
145 control_session_id: Option<String>,
147 event_txs: HashMap<String, mpsc::UnboundedSender<ControlEvent>>,
150 last_sequences: HashMap<String, u64>,
153 completed: std::collections::HashSet<String>,
156 continues_rx: mpsc::Receiver<TaggedContinue>,
157 status_tx: watch::Sender<ConnectionStatus>,
158 created_tx: Option<oneshot::Sender<Result<ParallelCreated, String>>>,
159 cancel: CancellationToken,
160}
161
162impl ControlConnection for ParallelControlTask {
163 fn url(&self) -> &str {
164 &self.url
165 }
166 fn api_key(&self) -> &str {
167 &self.api_key
168 }
169 fn cancel(&self) -> &CancellationToken {
170 &self.cancel
171 }
172 fn label(&self) -> &'static str {
173 "parallel control"
174 }
175 fn status_tx(&self) -> &watch::Sender<ConnectionStatus> {
176 &self.status_tx
177 }
178
179 fn fail_pending(&mut self, reason: String) {
180 if let Some(tx) = self.created_tx.take() {
181 let _ = tx.send(Err(reason));
182 }
183 }
184
185 async fn handshake(&mut self, ws: Ws) -> Result<Ws, HandshakeError> {
186 if let Some(control_session_id) = self.control_session_id.clone() {
187 self.attach(ws, &control_session_id).await
188 } else if let Some(create) = self.create.clone() {
189 self.create_sessions(ws, create).await
194 } else {
195 Err(HandshakeError::Fatal(
196 "no create request and no control_session_id".into(),
197 ))
198 }
199 }
200
201 async fn message_loop(&mut self, mut ws: Ws) -> MessageLoopExit {
202 let mut ping_timer = tokio::time::interval(KEEPALIVE_INTERVAL);
203 ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
204 let mut last_inbound = std::time::Instant::now();
205
206 let exit = loop {
207 tokio::select! {
208 biased;
209 _ = self.cancel.cancelled() => break MessageLoopExit::Cancelled,
210
211 _ = ping_timer.tick() => {
212 if let Some(why) = send_keepalive_ping(&mut ws, last_inbound).await {
213 break MessageLoopExit::ConnectionLost(why);
214 }
215 }
216
217 msg = ws.next() => {
218 last_inbound = std::time::Instant::now();
219 match classify_inbound(msg) {
220 InboundFrame::Text(t) => {
221 if let Some(exit) = self.handle_text(&t) {
222 break exit;
223 }
224 }
225 InboundFrame::Ignore => {}
226 InboundFrame::Lost(why) => break MessageLoopExit::ConnectionLost(why),
227 }
228 }
229
230 req = self.continues_rx.recv() => {
231 match req {
232 Some((session_id, request)) => {
233 let msg = BacktestRequest::ContinueSessionV1(ContinueSessionRequestV1 { session_id, request });
234 if let Err(e) = send_request(&mut ws, &msg).await {
235 break MessageLoopExit::ConnectionLost(format!("continue send: {e}"));
236 }
237 }
238 None => break MessageLoopExit::SessionEnded,
239 }
240 }
241 }
242 };
243
244 if matches!(
245 exit,
246 MessageLoopExit::SessionEnded | MessageLoopExit::Cancelled
247 ) {
248 graceful_close(&mut ws).await;
249 }
250 exit
251 }
252}
253
254impl ParallelControlTask {
255 async fn create_sessions(
259 &mut self,
260 mut ws: Ws,
261 create: CreateBacktestSessionRequest,
262 ) -> Result<Ws, HandshakeError> {
263 send_request(&mut ws, &BacktestRequest::CreateBacktestSession(create))
264 .await
265 .map_err(HandshakeError::Transient)?;
266
267 let rpc_base = http_base_from_ws_url(&self.url);
268 let mut sessions: Vec<CreatedSubSession> = Vec::new();
269
270 loop {
271 let response = next_response(&mut ws, CREATE_RESPONSE_TIMEOUT)
272 .await
273 .map_err(HandshakeError::Transient)?;
274 match response {
275 BacktestResponse::SessionCreated {
276 session_id,
277 rpc_endpoint,
278 task_id,
279 } => {
280 let (event_tx, event_rx) = mpsc::unbounded_channel::<ControlEvent>();
281 self.event_txs.insert(session_id.clone(), event_tx);
282 sessions.push(CreatedSubSession {
283 info: SessionInfo {
284 rpc_endpoint: resolve_rpc_url(&rpc_base, &rpc_endpoint),
285 session_id,
286 task_id,
287 },
288 events: event_rx,
289 start_slot: 0,
290 end_slot: 0,
291 });
292 }
293 BacktestResponse::SessionEventV2 {
294 session_id,
295 seq_id,
296 event,
297 } => {
298 self.route_event(&session_id, seq_id, event);
299 }
300 BacktestResponse::SessionsCreatedV2 {
301 control_session_id,
302 session_ids,
303 start_slots,
304 end_slots,
305 ..
306 } => {
307 info!(
308 %control_session_id,
309 sessions = sessions.len(),
310 "parallel sessions created"
311 );
312 if session_ids.len() != start_slots.len()
319 || session_ids.len() != end_slots.len()
320 {
321 return Err(HandshakeError::Fatal(format!(
322 "server did not report per-sub-session ranges \
323 (session_ids={}, start_slots={}, end_slots={}); \
324 server is too old for the multiplexed parallel client",
325 session_ids.len(),
326 start_slots.len(),
327 end_slots.len(),
328 )));
329 }
330 for ((id, start), end) in session_ids.iter().zip(&start_slots).zip(&end_slots) {
331 if let Some(s) = sessions.iter_mut().find(|s| s.info.session_id == *id) {
332 s.start_slot = *start;
333 s.end_slot = *end;
334 }
335 }
336 self.control_session_id = Some(control_session_id.clone());
337 self.create = None;
340 if let Some(tx) = self.created_tx.take() {
341 let _ = tx.send(Ok(ParallelCreated {
342 control_session_id,
343 sessions,
344 }));
345 }
346 return Ok(ws);
347 }
348 BacktestResponse::Error(err) => {
349 return Err(HandshakeError::Fatal(format!(
350 "server error: {}",
351 err_chain(&err)
352 )));
353 }
354 _ => {}
355 }
356 }
357 }
358
359 async fn attach(&mut self, mut ws: Ws, control_session_id: &str) -> Result<Ws, HandshakeError> {
363 send_request(
364 &mut ws,
365 &BacktestRequest::AttachParallelControlSessionV2 {
366 control_session_id: control_session_id.to_string(),
367 last_sequences: self.last_sequences.clone().into_iter().collect(),
368 },
369 )
370 .await
371 .map_err(HandshakeError::Transient)?;
372
373 loop {
374 let response = next_response(&mut ws, HANDSHAKE_RESPONSE_TIMEOUT)
375 .await
376 .map_err(HandshakeError::Transient)?;
377 match response {
378 BacktestResponse::ParallelSessionAttachedV2 { .. } => {
379 debug!(%control_session_id, "parallel control reattached");
380 return Ok(ws);
381 }
382 BacktestResponse::SessionEventV2 {
383 session_id,
384 seq_id,
385 event,
386 } => {
387 self.route_event(&session_id, seq_id, event);
388 }
389 BacktestResponse::Error(err) => {
390 return Err(handshake_error_for_response("attach", err));
391 }
392 _ => {}
393 }
394 }
395 }
396
397 fn handle_text(&mut self, text: &str) -> Option<MessageLoopExit> {
400 let response = match serde_json::from_str::<BacktestResponse>(text) {
401 Ok(r) => r,
402 Err(e) => {
403 warn!(error = %err_chain(&e), "discarding undeserializable parallel control message");
404 return None;
405 }
406 };
407
408 match response {
409 BacktestResponse::SessionEventV2 {
410 session_id,
411 seq_id,
412 event,
413 } => {
414 self.route_event(&session_id, seq_id, event);
415 if self.completed.len() == self.event_txs.len() && !self.event_txs.is_empty() {
416 return Some(MessageLoopExit::SessionEnded);
417 }
418 }
419 BacktestResponse::Error(err) => {
420 if is_terminal_backtest_error(&err) {
421 return Some(MessageLoopExit::Terminal(format!(
422 "control session error: {}",
423 err_chain(&err)
424 )));
425 }
426 warn!(error = %err_chain(&err), "non-terminal parallel control error");
427 }
428 other => {
429 debug!(?other, "ignoring unexpected parallel control response");
430 }
431 }
432 None
433 }
434
435 fn route_event(&mut self, session_id: &str, seq_id: u64, event: SessionEventKind) {
438 match self.last_sequences.get_mut(session_id) {
442 Some(last) if seq_id <= *last => return,
443 Some(last) => *last = seq_id,
444 None => {
445 self.last_sequences.insert(session_id.to_string(), seq_id);
446 }
447 }
448
449 let is_completed = matches!(event, SessionEventKind::Completed { .. });
450 let Some(control_event) = session_event_to_control(event) else {
451 return;
452 };
453 if let Some(tx) = self.event_txs.get(session_id) {
454 let _ = tx.send(control_event);
456 }
457 if is_completed {
458 self.completed.insert(session_id.to_string());
459 }
460 }
461}
462
463fn session_event_to_control(event: SessionEventKind) -> Option<ControlEvent> {
466 Some(match event {
467 SessionEventKind::ReadyForContinue => ControlEvent::ReadyForContinue,
468 SessionEventKind::SlotNotification(slot) => ControlEvent::Slot(slot),
469 SessionEventKind::Paused(event) => ControlEvent::Paused(event),
470 SessionEventKind::DiscoveryBatch(event) => ControlEvent::DiscoveryBatch(event),
471 SessionEventKind::Error(error)
474 if matches!(error, BacktestError::SimulationError { .. }) =>
475 {
476 warn!(error = %err_chain(&error), "simulation error");
477 return None;
478 }
479 SessionEventKind::Error(error) => ControlEvent::Error(error),
480 SessionEventKind::Completed { summary } => ControlEvent::Completed {
481 summary,
482 agent_stats: None,
483 },
484 SessionEventKind::Status { status } => ControlEvent::Status(status),
485 SessionEventKind::Success => return None,
486 })
487}
488
489async fn next_response(
491 ws: &mut Ws,
492 timeout: std::time::Duration,
493) -> Result<BacktestResponse, String> {
494 let deadline = tokio::time::Instant::now() + timeout;
495 loop {
496 let msg = tokio::time::timeout_at(deadline, ws.next())
497 .await
498 .map_err(|_| format!("handshake timeout after {timeout:?}"))?;
499
500 let Some(msg) = msg else {
501 return Err("ws ended during handshake".into());
502 };
503 let msg = msg.map_err(|e| format!("ws read: {}", err_chain(&e)))?;
504
505 let text = match msg {
506 Message::Text(t) => t,
507 Message::Binary(b) => match String::from_utf8(b) {
508 Ok(t) => t,
509 Err(_) => continue,
510 },
511 Message::Close(frame) => {
512 return Err(format!("remote close during handshake: {frame:?}"));
513 }
514 _ => continue,
515 };
516
517 return serde_json::from_str::<BacktestResponse>(&text)
518 .map_err(|e| format!("deserialize: {}; raw={text}", err_chain(&e)));
519 }
520}
521
522pub struct ManagedParallelSession {
526 control_session_id: String,
527 control: Option<ParallelControlHandle>,
528 sub_sessions: Vec<ParallelSubSession>,
529 session_cancel: CancellationToken,
530}
531
532impl ManagedParallelSession {
533 pub async fn start_with_cancel(
535 url: String,
536 api_key: String,
537 create: CreateBacktestSessionRequest,
538 parent_cancel: CancellationToken,
539 ) -> Result<Self, ManagedSessionError> {
540 let session_cancel = parent_cancel.child_token();
541 let mut control =
542 spawn_parallel_control_manager(url, api_key, create, session_cancel.clone());
543
544 let created = tokio::select! {
545 biased;
546 _ = parent_cancel.cancelled() => {
547 session_cancel.cancel();
548 control.join().await;
549 return Err(ManagedSessionError::Cancelled);
550 }
551 result = control.wait_created() => {
552 result.map_err(ManagedSessionError::Create)?
553 }
554 };
555
556 let reconnect_coordinator = Arc::new(ReconnectCoordinator::new());
560 let sub_sessions = created
561 .sessions
562 .into_iter()
563 .map(|s| ParallelSubSession {
564 session_info: s.info,
565 events: s.events,
566 continues: control.continues.clone(),
567 status: control.status.clone(),
568 subscriptions: Vec::new(),
569 session_cancel: session_cancel.child_token(),
570 post_completion: None,
571 post_completion_error: None,
572 reconnect_coordinator: Some(reconnect_coordinator.clone()),
573 start_slot: s.start_slot,
574 end_slot: s.end_slot,
575 })
576 .collect();
577
578 Ok(Self {
579 control_session_id: created.control_session_id,
580 control: Some(control),
581 sub_sessions,
582 session_cancel,
583 })
584 }
585
586 pub fn control_session_id(&self) -> &str {
588 &self.control_session_id
589 }
590
591 pub fn take_sub_sessions(&mut self) -> Vec<ParallelSubSession> {
594 std::mem::take(&mut self.sub_sessions)
595 }
596
597 pub async fn shutdown(mut self) {
599 self.session_cancel.cancel();
600 if let Some(control) = self.control.take() {
601 control.join().await;
602 }
603 }
604}
605
606impl Drop for ManagedParallelSession {
607 fn drop(&mut self) {
608 self.session_cancel.cancel();
609 }
610}
611
612pub struct ParallelSubSession {
617 session_info: SessionInfo,
618 events: mpsc::UnboundedReceiver<ControlEvent>,
619 continues: mpsc::Sender<TaggedContinue>,
620 status: watch::Receiver<ConnectionStatus>,
621 subscriptions: Vec<SubscriptionHandle>,
622 session_cancel: CancellationToken,
623 post_completion: Option<VecDeque<ManagedEvent>>,
626 post_completion_error: Option<ManagedSessionError>,
629 reconnect_coordinator: Option<Arc<ReconnectCoordinator>>,
632 start_slot: u64,
635 end_slot: u64,
636}
637
638impl ParallelSubSession {
639 pub fn session_info(&self) -> &SessionInfo {
640 &self.session_info
641 }
642
643 pub fn range(&self) -> (u64, u64) {
645 (self.start_slot, self.end_slot)
646 }
647
648 pub fn subscribe_transactions(&mut self, program_ids: Vec<String>) {
649 self.subscriptions
650 .push(spawn_transaction_subscription_manager(
651 self.session_info.rpc_endpoint.clone(),
652 program_ids,
653 self.session_cancel.clone(),
654 self.reconnect_coordinator.clone(),
655 ));
656 }
657
658 pub fn subscribe_account_diffs(&mut self, program_ids: Vec<String>) {
659 self.subscriptions
660 .push(spawn_account_diff_subscription_manager(
661 self.session_info.rpc_endpoint.clone(),
662 program_ids,
663 self.session_cancel.clone(),
664 self.reconnect_coordinator.clone(),
665 ));
666 }
667
668 pub fn subscribe_actions(&mut self) {
670 self.subscriptions.push(spawn_action_subscription_manager(
671 self.session_info.rpc_endpoint.clone(),
672 self.session_cancel.clone(),
673 self.reconnect_coordinator.clone(),
674 ));
675 }
676
677 pub async fn next_event(&mut self) -> Result<ManagedEvent, ManagedSessionError> {
681 if let Some(buffered) = self.post_completion.as_mut() {
684 if let Some(event) = buffered.pop_front() {
685 return Ok(event);
686 }
687 return Err(self
690 .post_completion_error
691 .take()
692 .unwrap_or(ManagedSessionError::ControlClosed));
693 }
694
695 if let Some(event) = try_next_subscription_event(&mut self.subscriptions) {
696 return Ok(event);
697 }
698
699 let event = {
702 let cancel = &self.session_cancel;
703 let subscriptions = &mut self.subscriptions;
704 tokio::select! {
705 biased;
706 _ = cancel.cancelled() => return Err(ManagedSessionError::Cancelled),
707 event = self.events.recv() => {
708 event.map(ManagedEvent::from).ok_or(ManagedSessionError::ControlClosed)?
709 }
710 event = wait_any_subscription_event(subscriptions) => event,
711 }
712 };
713
714 let ManagedEvent::Completed {
716 summary,
717 agent_stats,
718 } = event
719 else {
720 return Ok(event);
721 };
722
723 let (mut buffered, terminal): (VecDeque<ManagedEvent>, _) = match self
727 .drain_until_subscriptions_complete(COMPLETION_DRAIN_TIMEOUT)
728 .await
729 {
730 DrainOutcome::Complete(events) => (
731 events.into(),
732 Ok(ManagedEvent::Completed {
733 summary,
734 agent_stats,
735 }),
736 ),
737 DrainOutcome::Stalled(events) => (
741 events.into(),
742 Err(ManagedSessionError::SubscriptionFailed(
743 "completion drain stalled: subscriptions did not deliver their \
744 end-of-stream terminals; the captured stream is incomplete"
745 .to_string(),
746 )),
747 ),
748 };
749 match terminal {
750 Ok(completed) => buffered.push_back(completed),
751 Err(err) => self.post_completion_error = Some(err),
752 }
753 let first = buffered.pop_front();
754 self.post_completion = Some(buffered);
755 match first {
756 Some(event) => Ok(event),
757 None => Err(self
758 .post_completion_error
759 .take()
760 .unwrap_or(ManagedSessionError::ControlClosed)),
761 }
762 }
763
764 pub async fn send_continue(
767 &mut self,
768 params: ContinueParams,
769 ) -> Result<(), ManagedSessionError> {
770 self.wait_all_up().await?;
771 self.continues
772 .send((self.session_info.session_id.clone(), params))
773 .await
774 .map_err(|e| ManagedSessionError::ContinueSend(e.to_string()))
775 }
776
777 async fn drain_until_subscriptions_complete(
779 &mut self,
780 idle_timeout: std::time::Duration,
781 ) -> DrainOutcome {
782 drain_subscriptions_until_complete(
783 &mut self.subscriptions,
784 &self.session_cancel,
785 idle_timeout,
786 )
787 .await
788 }
789
790 pub async fn shutdown(mut self) {
793 self.session_cancel.cancel();
794 for sub in std::mem::take(&mut self.subscriptions) {
795 let _ = sub.join.await;
796 }
797 }
798
799 async fn wait_all_up(&self) -> Result<(), ManagedSessionError> {
800 let subscriptions = self
801 .subscriptions
802 .iter()
803 .map(|s| s.status.clone())
804 .collect();
805 wait_connections_up(self.status.clone(), subscriptions, &self.session_cancel).await
806 }
807}
808
809impl Drop for ParallelSubSession {
810 fn drop(&mut self) {
811 self.session_cancel.cancel();
812 }
813}