1use std::{
17 collections::{HashMap, VecDeque},
18 sync::Arc,
19};
20
21use futures::StreamExt;
22use simulator_api::{
23 BacktestRequest, BacktestResponse, ContinueParams, ContinueSessionRequestV1,
24 CreateBacktestSessionRequest, SessionEventKind,
25};
26use tokio::{
27 sync::{mpsc, oneshot, watch},
28 task::JoinHandle,
29};
30use tokio_tungstenite::tungstenite::Message;
31use tokio_util::sync::CancellationToken;
32use tracing::{debug, info, warn};
33
34use super::{
35 ConnectionStatus, ControlConnection, ControlEvent, HANDSHAKE_RESPONSE_TIMEOUT, HandshakeError,
36 InboundFrame, KEEPALIVE_INTERVAL, ManagedEvent, ManagedSessionError, MessageLoopExit,
37 ReconnectCoordinator, SessionInfo, SubscriptionHandle, Ws, classify_inbound, graceful_close,
38 handshake_error_for_response, is_terminal_backtest_error, resolve_rpc_url, run_control_loop,
39 send_keepalive_ping, send_request,
40 session::{
41 DrainOutcome, drain_subscriptions_until_complete, try_next_subscription_event,
42 wait_any_subscription_event, wait_connections_up,
43 },
44 spawn_account_diff_subscription_manager, spawn_transaction_subscription_manager,
45};
46use crate::{error::err_chain, urls::http_base_from_ws_url};
47
48const CREATE_RESPONSE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(900);
51const CONTINUE_CHANNEL_CAPACITY: usize = 256;
54const COMPLETION_DRAIN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
57
58type TaggedContinue = (String, ContinueParams);
60
61struct ParallelCreated {
63 control_session_id: String,
64 sessions: Vec<CreatedSubSession>,
65}
66
67struct CreatedSubSession {
68 info: SessionInfo,
69 events: mpsc::UnboundedReceiver<ControlEvent>,
70 start_slot: u64,
74 end_slot: u64,
75}
76
77struct ParallelControlHandle {
79 continues: mpsc::Sender<TaggedContinue>,
80 status: watch::Receiver<ConnectionStatus>,
81 created: Option<oneshot::Receiver<Result<ParallelCreated, String>>>,
82 join: JoinHandle<()>,
83}
84
85impl ParallelControlHandle {
86 async fn wait_created(&mut self) -> Result<ParallelCreated, String> {
89 let rx = self
90 .created
91 .take()
92 .ok_or_else(|| "parallel create already consumed".to_string())?;
93 rx.await
94 .map_err(|_| "control manager exited before creating sessions".to_string())?
95 }
96
97 async fn join(self) {
100 drop(self.continues);
101 let _ = self.join.await;
102 }
103}
104
105fn spawn_parallel_control_manager(
106 url: String,
107 api_key: String,
108 create: CreateBacktestSessionRequest,
109 cancel: CancellationToken,
110) -> ParallelControlHandle {
111 let (continues_tx, continues_rx) = mpsc::channel::<TaggedContinue>(CONTINUE_CHANNEL_CAPACITY);
112 let (status_tx, status_rx) = watch::channel(ConnectionStatus::Down);
113 let (created_tx, created_rx) = oneshot::channel::<Result<ParallelCreated, String>>();
114
115 let task = ParallelControlTask {
116 url,
117 api_key,
118 create: Some(create),
119 control_session_id: None,
120 event_txs: HashMap::new(),
121 last_sequences: HashMap::new(),
122 completed: std::collections::HashSet::new(),
123 continues_rx,
124 status_tx,
125 created_tx: Some(created_tx),
126 cancel,
127 };
128
129 let join = tokio::spawn(run_control_loop(task));
130
131 ParallelControlHandle {
132 continues: continues_tx,
133 status: status_rx,
134 created: Some(created_rx),
135 join,
136 }
137}
138
139struct ParallelControlTask {
140 url: String,
141 api_key: String,
142 create: Option<CreateBacktestSessionRequest>,
144 control_session_id: Option<String>,
146 event_txs: HashMap<String, mpsc::UnboundedSender<ControlEvent>>,
149 last_sequences: HashMap<String, u64>,
152 completed: std::collections::HashSet<String>,
155 continues_rx: mpsc::Receiver<TaggedContinue>,
156 status_tx: watch::Sender<ConnectionStatus>,
157 created_tx: Option<oneshot::Sender<Result<ParallelCreated, String>>>,
158 cancel: CancellationToken,
159}
160
161impl ControlConnection for ParallelControlTask {
162 fn url(&self) -> &str {
163 &self.url
164 }
165 fn api_key(&self) -> &str {
166 &self.api_key
167 }
168 fn cancel(&self) -> &CancellationToken {
169 &self.cancel
170 }
171 fn label(&self) -> &'static str {
172 "parallel control"
173 }
174 fn status_tx(&self) -> &watch::Sender<ConnectionStatus> {
175 &self.status_tx
176 }
177
178 fn fail_pending(&mut self, reason: String) {
179 if let Some(tx) = self.created_tx.take() {
180 let _ = tx.send(Err(reason));
181 }
182 }
183
184 async fn handshake(&mut self, ws: Ws) -> Result<Ws, HandshakeError> {
185 if let Some(control_session_id) = self.control_session_id.clone() {
186 self.attach(ws, &control_session_id).await
187 } else if let Some(create) = self.create.clone() {
188 self.create_sessions(ws, create).await
193 } else {
194 Err(HandshakeError::Fatal(
195 "no create request and no control_session_id".into(),
196 ))
197 }
198 }
199
200 async fn message_loop(&mut self, mut ws: Ws) -> MessageLoopExit {
201 let mut ping_timer = tokio::time::interval(KEEPALIVE_INTERVAL);
202 ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
203 let mut last_inbound = std::time::Instant::now();
204
205 let exit = loop {
206 tokio::select! {
207 biased;
208 _ = self.cancel.cancelled() => break MessageLoopExit::Cancelled,
209
210 _ = ping_timer.tick() => {
211 if let Some(why) = send_keepalive_ping(&mut ws, last_inbound).await {
212 break MessageLoopExit::ConnectionLost(why);
213 }
214 }
215
216 msg = ws.next() => {
217 last_inbound = std::time::Instant::now();
218 match classify_inbound(msg) {
219 InboundFrame::Text(t) => {
220 if let Some(exit) = self.handle_text(&t) {
221 break exit;
222 }
223 }
224 InboundFrame::Ignore => {}
225 InboundFrame::Lost(why) => break MessageLoopExit::ConnectionLost(why),
226 }
227 }
228
229 req = self.continues_rx.recv() => {
230 match req {
231 Some((session_id, request)) => {
232 let msg = BacktestRequest::ContinueSessionV1(ContinueSessionRequestV1 { session_id, request });
233 if let Err(e) = send_request(&mut ws, &msg).await {
234 break MessageLoopExit::ConnectionLost(format!("continue send: {e}"));
235 }
236 }
237 None => break MessageLoopExit::SessionEnded,
238 }
239 }
240 }
241 };
242
243 if matches!(
244 exit,
245 MessageLoopExit::SessionEnded | MessageLoopExit::Cancelled
246 ) {
247 graceful_close(&mut ws).await;
248 }
249 exit
250 }
251}
252
253impl ParallelControlTask {
254 async fn create_sessions(
258 &mut self,
259 mut ws: Ws,
260 create: CreateBacktestSessionRequest,
261 ) -> Result<Ws, HandshakeError> {
262 send_request(&mut ws, &BacktestRequest::CreateBacktestSession(create))
263 .await
264 .map_err(HandshakeError::Transient)?;
265
266 let rpc_base = http_base_from_ws_url(&self.url);
267 let mut sessions: Vec<CreatedSubSession> = Vec::new();
268
269 loop {
270 let response = next_response(&mut ws, CREATE_RESPONSE_TIMEOUT)
271 .await
272 .map_err(HandshakeError::Transient)?;
273 match response {
274 BacktestResponse::SessionCreated {
275 session_id,
276 rpc_endpoint,
277 task_id,
278 } => {
279 let (event_tx, event_rx) = mpsc::unbounded_channel::<ControlEvent>();
280 self.event_txs.insert(session_id.clone(), event_tx);
281 sessions.push(CreatedSubSession {
282 info: SessionInfo {
283 rpc_endpoint: resolve_rpc_url(&rpc_base, &rpc_endpoint),
284 session_id,
285 task_id,
286 },
287 events: event_rx,
288 start_slot: 0,
289 end_slot: 0,
290 });
291 }
292 BacktestResponse::SessionEventV2 {
293 session_id,
294 seq_id,
295 event,
296 } => {
297 self.route_event(&session_id, seq_id, event);
298 }
299 BacktestResponse::SessionsCreatedV2 {
300 control_session_id,
301 session_ids,
302 start_slots,
303 end_slots,
304 ..
305 } => {
306 info!(
307 %control_session_id,
308 sessions = sessions.len(),
309 "parallel sessions created"
310 );
311 if session_ids.len() != start_slots.len()
318 || session_ids.len() != end_slots.len()
319 {
320 return Err(HandshakeError::Fatal(format!(
321 "server did not report per-sub-session ranges \
322 (session_ids={}, start_slots={}, end_slots={}); \
323 server is too old for the multiplexed parallel client",
324 session_ids.len(),
325 start_slots.len(),
326 end_slots.len(),
327 )));
328 }
329 for ((id, start), end) in session_ids.iter().zip(&start_slots).zip(&end_slots) {
330 if let Some(s) = sessions.iter_mut().find(|s| s.info.session_id == *id) {
331 s.start_slot = *start;
332 s.end_slot = *end;
333 }
334 }
335 self.control_session_id = Some(control_session_id.clone());
336 self.create = None;
339 if let Some(tx) = self.created_tx.take() {
340 let _ = tx.send(Ok(ParallelCreated {
341 control_session_id,
342 sessions,
343 }));
344 }
345 return Ok(ws);
346 }
347 BacktestResponse::Error(err) => {
348 return Err(HandshakeError::Fatal(format!(
349 "server error: {}",
350 err_chain(&err)
351 )));
352 }
353 _ => {}
354 }
355 }
356 }
357
358 async fn attach(&mut self, mut ws: Ws, control_session_id: &str) -> Result<Ws, HandshakeError> {
362 send_request(
363 &mut ws,
364 &BacktestRequest::AttachParallelControlSessionV2 {
365 control_session_id: control_session_id.to_string(),
366 last_sequences: self.last_sequences.clone().into_iter().collect(),
367 },
368 )
369 .await
370 .map_err(HandshakeError::Transient)?;
371
372 loop {
373 let response = next_response(&mut ws, HANDSHAKE_RESPONSE_TIMEOUT)
374 .await
375 .map_err(HandshakeError::Transient)?;
376 match response {
377 BacktestResponse::ParallelSessionAttachedV2 { .. } => {
378 debug!(%control_session_id, "parallel control reattached");
379 return Ok(ws);
380 }
381 BacktestResponse::SessionEventV2 {
382 session_id,
383 seq_id,
384 event,
385 } => {
386 self.route_event(&session_id, seq_id, event);
387 }
388 BacktestResponse::Error(err) => {
389 return Err(handshake_error_for_response("attach", err));
390 }
391 _ => {}
392 }
393 }
394 }
395
396 fn handle_text(&mut self, text: &str) -> Option<MessageLoopExit> {
399 let response = match serde_json::from_str::<BacktestResponse>(text) {
400 Ok(r) => r,
401 Err(e) => {
402 warn!(error = %err_chain(&e), "discarding undeserializable parallel control message");
403 return None;
404 }
405 };
406
407 match response {
408 BacktestResponse::SessionEventV2 {
409 session_id,
410 seq_id,
411 event,
412 } => {
413 self.route_event(&session_id, seq_id, event);
414 if self.completed.len() == self.event_txs.len() && !self.event_txs.is_empty() {
415 return Some(MessageLoopExit::SessionEnded);
416 }
417 }
418 BacktestResponse::Error(err) => {
419 if is_terminal_backtest_error(&err) {
420 return Some(MessageLoopExit::Terminal(format!(
421 "control session error: {}",
422 err_chain(&err)
423 )));
424 }
425 warn!(error = %err_chain(&err), "non-terminal parallel control error");
426 }
427 other => {
428 debug!(?other, "ignoring unexpected parallel control response");
429 }
430 }
431 None
432 }
433
434 fn route_event(&mut self, session_id: &str, seq_id: u64, event: SessionEventKind) {
437 match self.last_sequences.get_mut(session_id) {
441 Some(last) if seq_id <= *last => return,
442 Some(last) => *last = seq_id,
443 None => {
444 self.last_sequences.insert(session_id.to_string(), seq_id);
445 }
446 }
447
448 let is_completed = matches!(event, SessionEventKind::Completed { .. });
449 let Some(control_event) = session_event_to_control(event) else {
450 return;
451 };
452 if let Some(tx) = self.event_txs.get(session_id) {
453 let _ = tx.send(control_event);
455 }
456 if is_completed {
457 self.completed.insert(session_id.to_string());
458 }
459 }
460}
461
462fn session_event_to_control(event: SessionEventKind) -> Option<ControlEvent> {
465 Some(match event {
466 SessionEventKind::ReadyForContinue => ControlEvent::ReadyForContinue,
467 SessionEventKind::SlotNotification(slot) => ControlEvent::Slot(slot),
468 SessionEventKind::Paused(event) => ControlEvent::Paused(event),
469 SessionEventKind::DiscoveryBatch(event) => ControlEvent::DiscoveryBatch(event),
470 SessionEventKind::Error(error) => ControlEvent::Error(error),
471 SessionEventKind::Completed { summary } => ControlEvent::Completed {
472 summary,
473 agent_stats: None,
474 },
475 SessionEventKind::Status { status } => ControlEvent::Status(status),
476 SessionEventKind::Success => return None,
477 })
478}
479
480async fn next_response(
482 ws: &mut Ws,
483 timeout: std::time::Duration,
484) -> Result<BacktestResponse, String> {
485 let deadline = tokio::time::Instant::now() + timeout;
486 loop {
487 let msg = tokio::time::timeout_at(deadline, ws.next())
488 .await
489 .map_err(|_| format!("handshake timeout after {timeout:?}"))?;
490
491 let Some(msg) = msg else {
492 return Err("ws ended during handshake".into());
493 };
494 let msg = msg.map_err(|e| format!("ws read: {}", err_chain(&e)))?;
495
496 let text = match msg {
497 Message::Text(t) => t,
498 Message::Binary(b) => match String::from_utf8(b) {
499 Ok(t) => t,
500 Err(_) => continue,
501 },
502 Message::Close(frame) => {
503 return Err(format!("remote close during handshake: {frame:?}"));
504 }
505 _ => continue,
506 };
507
508 return serde_json::from_str::<BacktestResponse>(&text)
509 .map_err(|e| format!("deserialize: {}; raw={text}", err_chain(&e)));
510 }
511}
512
513pub struct ManagedParallelSession {
517 control_session_id: String,
518 control: Option<ParallelControlHandle>,
519 sub_sessions: Vec<ParallelSubSession>,
520 session_cancel: CancellationToken,
521}
522
523impl ManagedParallelSession {
524 pub async fn start_with_cancel(
526 url: String,
527 api_key: String,
528 create: CreateBacktestSessionRequest,
529 parent_cancel: CancellationToken,
530 ) -> Result<Self, ManagedSessionError> {
531 let session_cancel = parent_cancel.child_token();
532 let mut control =
533 spawn_parallel_control_manager(url, api_key, create, session_cancel.clone());
534
535 let created = tokio::select! {
536 biased;
537 _ = parent_cancel.cancelled() => {
538 session_cancel.cancel();
539 control.join().await;
540 return Err(ManagedSessionError::Cancelled);
541 }
542 result = control.wait_created() => {
543 result.map_err(ManagedSessionError::Create)?
544 }
545 };
546
547 let reconnect_coordinator = Arc::new(ReconnectCoordinator::new());
551 let sub_sessions = created
552 .sessions
553 .into_iter()
554 .map(|s| ParallelSubSession {
555 session_info: s.info,
556 events: s.events,
557 continues: control.continues.clone(),
558 status: control.status.clone(),
559 subscriptions: Vec::new(),
560 session_cancel: session_cancel.child_token(),
561 post_completion: None,
562 post_completion_error: None,
563 reconnect_coordinator: Some(reconnect_coordinator.clone()),
564 start_slot: s.start_slot,
565 end_slot: s.end_slot,
566 })
567 .collect();
568
569 Ok(Self {
570 control_session_id: created.control_session_id,
571 control: Some(control),
572 sub_sessions,
573 session_cancel,
574 })
575 }
576
577 pub fn control_session_id(&self) -> &str {
579 &self.control_session_id
580 }
581
582 pub fn take_sub_sessions(&mut self) -> Vec<ParallelSubSession> {
585 std::mem::take(&mut self.sub_sessions)
586 }
587
588 pub async fn shutdown(mut self) {
590 self.session_cancel.cancel();
591 if let Some(control) = self.control.take() {
592 control.join().await;
593 }
594 }
595}
596
597impl Drop for ManagedParallelSession {
598 fn drop(&mut self) {
599 self.session_cancel.cancel();
600 }
601}
602
603pub struct ParallelSubSession {
608 session_info: SessionInfo,
609 events: mpsc::UnboundedReceiver<ControlEvent>,
610 continues: mpsc::Sender<TaggedContinue>,
611 status: watch::Receiver<ConnectionStatus>,
612 subscriptions: Vec<SubscriptionHandle>,
613 session_cancel: CancellationToken,
614 post_completion: Option<VecDeque<ManagedEvent>>,
617 post_completion_error: Option<ManagedSessionError>,
620 reconnect_coordinator: Option<Arc<ReconnectCoordinator>>,
623 start_slot: u64,
626 end_slot: u64,
627}
628
629impl ParallelSubSession {
630 pub fn session_info(&self) -> &SessionInfo {
631 &self.session_info
632 }
633
634 pub fn range(&self) -> (u64, u64) {
636 (self.start_slot, self.end_slot)
637 }
638
639 pub fn subscribe_transactions(&mut self, program_ids: Vec<String>) {
640 self.subscriptions
641 .push(spawn_transaction_subscription_manager(
642 self.session_info.rpc_endpoint.clone(),
643 program_ids,
644 self.session_cancel.clone(),
645 self.reconnect_coordinator.clone(),
646 ));
647 }
648
649 pub fn subscribe_account_diffs(&mut self, program_ids: Vec<String>) {
650 self.subscriptions
651 .push(spawn_account_diff_subscription_manager(
652 self.session_info.rpc_endpoint.clone(),
653 program_ids,
654 self.session_cancel.clone(),
655 self.reconnect_coordinator.clone(),
656 ));
657 }
658
659 pub async fn next_event(&mut self) -> Result<ManagedEvent, ManagedSessionError> {
663 if let Some(buffered) = self.post_completion.as_mut() {
666 if let Some(event) = buffered.pop_front() {
667 return Ok(event);
668 }
669 return Err(self
672 .post_completion_error
673 .take()
674 .unwrap_or(ManagedSessionError::ControlClosed));
675 }
676
677 if let Some(event) = try_next_subscription_event(&mut self.subscriptions) {
678 return Ok(event);
679 }
680
681 let event = {
684 let cancel = &self.session_cancel;
685 let subscriptions = &mut self.subscriptions;
686 tokio::select! {
687 biased;
688 _ = cancel.cancelled() => return Err(ManagedSessionError::Cancelled),
689 event = self.events.recv() => {
690 event.map(ManagedEvent::from).ok_or(ManagedSessionError::ControlClosed)?
691 }
692 event = wait_any_subscription_event(subscriptions) => event,
693 }
694 };
695
696 let ManagedEvent::Completed {
698 summary,
699 agent_stats,
700 } = event
701 else {
702 return Ok(event);
703 };
704
705 let (mut buffered, terminal): (VecDeque<ManagedEvent>, _) = match self
709 .drain_until_subscriptions_complete(COMPLETION_DRAIN_TIMEOUT)
710 .await
711 {
712 DrainOutcome::Complete(events) => (
713 events.into(),
714 Ok(ManagedEvent::Completed {
715 summary,
716 agent_stats,
717 }),
718 ),
719 DrainOutcome::Stalled(events) => (
723 events.into(),
724 Err(ManagedSessionError::SubscriptionFailed(
725 "completion drain stalled: subscriptions did not deliver their \
726 end-of-stream terminals; the captured stream is incomplete"
727 .to_string(),
728 )),
729 ),
730 };
731 match terminal {
732 Ok(completed) => buffered.push_back(completed),
733 Err(err) => self.post_completion_error = Some(err),
734 }
735 let first = buffered.pop_front();
736 self.post_completion = Some(buffered);
737 match first {
738 Some(event) => Ok(event),
739 None => Err(self
740 .post_completion_error
741 .take()
742 .unwrap_or(ManagedSessionError::ControlClosed)),
743 }
744 }
745
746 pub async fn send_continue(
749 &mut self,
750 params: ContinueParams,
751 ) -> Result<(), ManagedSessionError> {
752 self.wait_all_up().await?;
753 self.continues
754 .send((self.session_info.session_id.clone(), params))
755 .await
756 .map_err(|e| ManagedSessionError::ContinueSend(e.to_string()))
757 }
758
759 async fn drain_until_subscriptions_complete(
761 &mut self,
762 idle_timeout: std::time::Duration,
763 ) -> DrainOutcome {
764 drain_subscriptions_until_complete(
765 &mut self.subscriptions,
766 &self.session_cancel,
767 idle_timeout,
768 )
769 .await
770 }
771
772 pub async fn shutdown(mut self) {
775 self.session_cancel.cancel();
776 for sub in std::mem::take(&mut self.subscriptions) {
777 let _ = sub.join.await;
778 }
779 }
780
781 async fn wait_all_up(&self) -> Result<(), ManagedSessionError> {
782 let subscriptions = self
783 .subscriptions
784 .iter()
785 .map(|s| s.status.clone())
786 .collect();
787 wait_connections_up(self.status.clone(), subscriptions, &self.session_cancel).await
788 }
789}
790
791impl Drop for ParallelSubSession {
792 fn drop(&mut self) {
793 self.session_cancel.cancel();
794 }
795}