1use std::time::Instant;
13
14use futures::{SinkExt, StreamExt};
15use simulator_api::{
16 BacktestError, BacktestRequest, BacktestResponse, BacktestStatus, ContinueParams,
17 ContinueToParams, CreateBacktestSessionRequest, DiscoveryBatchEvent, PausedEvent,
18 SequencedResponse,
19};
20use tokio::{
21 net::TcpStream,
22 sync::{mpsc, oneshot, watch},
23 task::JoinHandle,
24};
25use tokio_tungstenite::{
26 MaybeTlsStream, WebSocketStream, connect_async,
27 tungstenite::{Message, client::IntoClientRequest, http::HeaderValue},
28};
29use tokio_util::sync::CancellationToken;
30use tracing::{debug, info, warn};
31
32use super::{
33 CONNECT_TIMEOUT, ConnectionStatus, GRACEFUL_CLOSE_TIMEOUT, HANDSHAKE_RESPONSE_TIMEOUT,
34 KEEPALIVE_INTERVAL, KEEPALIVE_MISS_DEADLINE, RECONNECT_UPTIME_RESET, ReconnectBudget,
35 SessionInfo, cancellable_sleep,
36};
37use crate::{error::err_chain, urls::http_base_from_ws_url};
38
39#[derive(Debug)]
44pub enum ControlEvent {
45 ReadyForContinue,
46 Paused(PausedEvent),
49 DiscoveryBatch(DiscoveryBatchEvent),
53 Slot(u64),
54 Status(BacktestStatus),
58 Completed,
59 Error(BacktestError),
60}
61
62pub struct ControlHandle {
64 continues: mpsc::Sender<ContinueParams>,
65 continue_tos: mpsc::Sender<ContinueToParams>,
66 pub events: mpsc::Receiver<ControlEvent>,
67 pub status: watch::Receiver<ConnectionStatus>,
68 session_info: Option<oneshot::Receiver<Result<SessionInfo, String>>>,
69 join: JoinHandle<()>,
70}
71
72impl ControlHandle {
73 pub async fn wait_for_session(&mut self) -> Result<SessionInfo, String> {
76 let rx = self
77 .session_info
78 .take()
79 .ok_or_else(|| "session_info already consumed".to_string())?;
80 rx.await
81 .map_err(|_| "control manager exited before creating session".to_string())?
82 }
83
84 pub async fn send_continue(
87 &self,
88 params: ContinueParams,
89 ) -> Result<(), mpsc::error::SendError<ContinueParams>> {
90 self.continues.send(params).await
91 }
92
93 pub async fn send_continue_to(
97 &self,
98 params: ContinueToParams,
99 ) -> Result<(), mpsc::error::SendError<ContinueToParams>> {
100 self.continue_tos.send(params).await
101 }
102
103 pub async fn join(self) {
108 drop(self.continues);
109 drop(self.continue_tos);
110 let _ = self.join.await;
111 }
112}
113
114pub fn spawn_control_manager(
120 url: String,
121 api_key: String,
122 create: CreateBacktestSessionRequest,
123 cancel: CancellationToken,
124) -> ControlHandle {
125 let (continues_tx, continues_rx) = mpsc::channel::<ContinueParams>(1);
126 let (continue_tos_tx, continue_tos_rx) = mpsc::channel::<ContinueToParams>(1);
127 let (events_tx, events_rx) = mpsc::channel::<ControlEvent>(256);
128 let (status_tx, status_rx) = watch::channel(ConnectionStatus::Down);
129 let (session_tx, session_rx) = oneshot::channel::<Result<SessionInfo, String>>();
130
131 let manager = ControlTask {
132 url,
133 api_key,
134 create: Some(create),
135 session_info: None,
136 session_tx: Some(session_tx),
137 last_sequence: None,
138 continues_rx,
139 continue_tos_rx,
140 events_tx,
141 status_tx,
142 cancel,
143 };
144
145 let join = tokio::spawn(manager.run());
146
147 ControlHandle {
148 continues: continues_tx,
149 continue_tos: continue_tos_tx,
150 events: events_rx,
151 status: status_rx,
152 session_info: Some(session_rx),
153 join,
154 }
155}
156
157type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
158
159struct ControlTask {
160 url: String,
161 api_key: String,
162 create: Option<CreateBacktestSessionRequest>,
164 session_info: Option<SessionInfo>,
166 session_tx: Option<oneshot::Sender<Result<SessionInfo, String>>>,
168 last_sequence: Option<u64>,
170 continues_rx: mpsc::Receiver<ContinueParams>,
171 continue_tos_rx: mpsc::Receiver<ContinueToParams>,
172 events_tx: mpsc::Sender<ControlEvent>,
173 status_tx: watch::Sender<ConnectionStatus>,
174 cancel: CancellationToken,
175}
176
177enum MessageLoopExit {
178 SessionEnded,
181 Cancelled,
183 ConnectionLost(String),
185 Terminal(String),
187}
188
189impl ControlTask {
190 async fn run(mut self) {
191 let mut budget = ReconnectBudget::new();
192
193 loop {
194 if self.cancel.is_cancelled() {
195 self.fail_session_info_if_pending("cancelled before session created");
196 return;
197 }
198 self.publish(ConnectionStatus::Down);
199
200 let ws = match self.connect().await {
202 Ok(ws) => ws,
203 Err(why) => {
204 if let Some(delay) = budget.next_backoff() {
205 warn!(attempt = budget.attempt(), error = %why, ?delay, "control connect failed, retrying");
206 if !cancellable_sleep(delay, &self.cancel).await {
207 return;
208 }
209 continue;
210 }
211 self.finish_failed(format!("connect: {why}"));
212 return;
213 }
214 };
215
216 let ws = match self.handshake(ws).await {
218 Ok(ws) => ws,
219 Err(HandshakeError::Fatal(why)) => {
220 self.finish_failed(format!("handshake: {why}"));
221 return;
222 }
223 Err(HandshakeError::Transient(why)) => {
224 if let Some(delay) = budget.next_backoff() {
225 warn!(attempt = budget.attempt(), error = %why, ?delay, "control handshake failed, retrying");
226 if !cancellable_sleep(delay, &self.cancel).await {
227 return;
228 }
229 continue;
230 }
231 self.finish_failed(format!("handshake: {why}"));
232 return;
233 }
234 };
235
236 self.publish(ConnectionStatus::Up);
237 let connected_at = Instant::now();
238
239 let exit = self.message_loop(ws).await;
240
241 match exit {
242 MessageLoopExit::SessionEnded => return,
243 MessageLoopExit::Cancelled => return,
244 MessageLoopExit::ConnectionLost(why) => {
245 if connected_at.elapsed() >= RECONNECT_UPTIME_RESET {
246 budget.reset();
247 }
248 if let Some(delay) = budget.next_backoff() {
249 warn!(attempt = budget.attempt(), reason = %why, ?delay, "control connection lost, reconnecting");
250 if !cancellable_sleep(delay, &self.cancel).await {
251 return;
252 }
253 continue;
254 }
255 self.finish_failed(format!("connection lost: {why}"));
256 return;
257 }
258 MessageLoopExit::Terminal(why) => {
259 self.finish_failed(why);
260 return;
261 }
262 }
263 }
264 }
265
266 fn publish(&self, status: ConnectionStatus) {
267 self.status_tx.send_if_modified(|current| {
268 if *current == status {
269 false
270 } else {
271 *current = status;
272 true
273 }
274 });
275 }
276
277 fn fail_session_info_if_pending(&mut self, reason: &str) {
278 if let Some(tx) = self.session_tx.take() {
279 let _ = tx.send(Err(reason.to_string()));
280 }
281 }
282
283 fn finish_failed(&mut self, reason: String) {
284 self.fail_session_info_if_pending(&reason);
285 self.publish(ConnectionStatus::Failed(reason));
286 }
287
288 async fn connect(&self) -> Result<Ws, String> {
289 let mut request = self
290 .url
291 .clone()
292 .into_client_request()
293 .map_err(|e| format!("build request: {}", err_chain(&e)))?;
294
295 request.headers_mut().insert(
296 "X-API-Key",
297 HeaderValue::from_str(&self.api_key)
298 .map_err(|e| format!("api key header: {}", err_chain(&e)))?,
299 );
300
301 let connect = tokio::time::timeout(CONNECT_TIMEOUT, connect_async(request))
302 .await
303 .map_err(|_| format!("connect timeout after {CONNECT_TIMEOUT:?}"))?
304 .map_err(|e| format!("connect: {}", err_chain(&e)))?;
305
306 Ok(connect.0)
307 }
308
309 async fn handshake(&mut self, mut ws: Ws) -> Result<Ws, HandshakeError> {
310 if let Some(info) = &self.session_info {
311 let info = info.clone();
312 attach(
313 &mut ws,
314 &info.session_id,
315 self.last_sequence,
316 &mut self.events_tx,
317 &mut self.last_sequence,
318 )
319 .await?;
320 resume(&mut ws, &mut self.events_tx, &mut self.last_sequence).await?;
321 debug!(session_id = info.session_id, "control reattached");
322 } else if let Some(create) = self.create.take() {
323 let info = create_session(
324 &mut ws,
325 create,
326 &self.url,
327 &mut self.events_tx,
328 &mut self.last_sequence,
329 )
330 .await?;
331 info!(session_id = info.session_id, "control session created");
332 self.session_info = Some(info.clone());
333 if let Some(tx) = self.session_tx.take() {
334 let _ = tx.send(Ok(info));
335 }
336 } else {
337 return Err(HandshakeError::Fatal(
338 "no create request and no session_id".into(),
339 ));
340 }
341
342 Ok(ws)
343 }
344
345 async fn message_loop(&mut self, mut ws: Ws) -> MessageLoopExit {
346 let mut ping_timer = tokio::time::interval(KEEPALIVE_INTERVAL);
347 ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
348 let mut last_inbound = Instant::now();
349
350 let exit = loop {
351 tokio::select! {
352 biased;
353 _ = self.cancel.cancelled() => break MessageLoopExit::Cancelled,
354
355 _ = ping_timer.tick() => {
356 if last_inbound.elapsed() > KEEPALIVE_MISS_DEADLINE {
357 break MessageLoopExit::ConnectionLost(format!(
358 "no traffic for {:?}", last_inbound.elapsed()
359 ));
360 }
361 if let Err(e) = ws.send(Message::Ping(vec![])).await {
362 break MessageLoopExit::ConnectionLost(format!("ping send: {}", err_chain(&e)));
363 }
364 }
365
366 msg = ws.next() => {
367 last_inbound = Instant::now();
368 match msg {
369 Some(Ok(Message::Text(t))) => {
370 if let Err(exit) = self.handle_text(&t).await {
371 break exit;
372 }
373 }
374 Some(Ok(Message::Binary(b))) => {
375 if let Ok(t) = std::str::from_utf8(&b)
376 && let Err(exit) = self.handle_text(t).await {
377 break exit;
378 }
379 }
380 Some(Ok(Message::Pong(_))) | Some(Ok(Message::Ping(_))) => {}
381 Some(Ok(Message::Close(frame))) => {
382 break MessageLoopExit::ConnectionLost(format!("remote close: {frame:?}"));
383 }
384 Some(Ok(Message::Frame(_))) => {}
385 Some(Err(e)) => {
386 break MessageLoopExit::ConnectionLost(format!("ws read: {}", err_chain(&e)));
387 }
388 None => break MessageLoopExit::ConnectionLost("ws stream ended".into()),
389 }
390 }
391
392 req = self.continues_rx.recv() => {
393 match req {
394 Some(params) => {
395 if let Err(e) = send_request(&mut ws, &BacktestRequest::Continue(params)).await {
396 break MessageLoopExit::ConnectionLost(format!("continue send: {e}"));
397 }
398 }
399 None => {
400 break MessageLoopExit::SessionEnded;
402 }
403 }
404 }
405
406 req = self.continue_tos_rx.recv() => {
407 match req {
408 Some(params) => {
409 if let Err(e) = send_request(&mut ws, &BacktestRequest::ContinueTo(params)).await {
410 break MessageLoopExit::ConnectionLost(format!("continue_to send: {e}"));
411 }
412 }
413 None => break MessageLoopExit::SessionEnded,
414 }
415 }
416 }
417 };
418
419 if matches!(exit, MessageLoopExit::SessionEnded) {
420 graceful_close(&mut ws).await;
421 }
422 exit
423 }
424
425 async fn handle_text(&mut self, text: &str) -> Result<(), MessageLoopExit> {
427 let (seq, response) = match serde_json::from_str::<SequencedResponse>(text) {
428 Ok(s) => (Some(s.seq_id), s.response),
429 Err(_) => match serde_json::from_str::<BacktestResponse>(text) {
430 Ok(r) => (None, r),
431 Err(e) => {
432 warn!(error = %err_chain(&e), "discarding undeserializable control message");
433 return Ok(());
434 }
435 },
436 };
437
438 if let Some(s) = seq {
439 self.last_sequence = Some(s);
440 }
441
442 match response {
443 BacktestResponse::ReadyForContinue => {
444 let _ = self.events_tx.send(ControlEvent::ReadyForContinue).await;
445 }
446 BacktestResponse::Paused(event) => {
447 let _ = self.events_tx.send(ControlEvent::Paused(event)).await;
448 }
449 BacktestResponse::DiscoveryBatch(event) => {
450 let _ = self
451 .events_tx
452 .send(ControlEvent::DiscoveryBatch(event))
453 .await;
454 }
455 BacktestResponse::SlotNotification(slot) => {
456 let _ = self.events_tx.send(ControlEvent::Slot(slot)).await;
457 }
458 BacktestResponse::Completed { .. } => {
459 let _ = self.events_tx.send(ControlEvent::Completed).await;
460 return Err(MessageLoopExit::SessionEnded);
461 }
462 BacktestResponse::Error(err) => {
463 if matches!(&err, BacktestError::SimulationError { .. }) {
465 warn!(error = %err_chain(&err), "simulation error");
466 return Ok(());
467 }
468 let terminal = matches!(
469 &err,
470 BacktestError::NoMoreBlocks
471 | BacktestError::AdvanceSlotFailed { .. }
472 | BacktestError::FinalizeSlotFailed { .. }
473 | BacktestError::Internal { .. }
474 );
475 let _ = self.events_tx.send(ControlEvent::Error(err)).await;
476 if terminal {
477 return Err(MessageLoopExit::Terminal(
478 "server reported terminal error".into(),
479 ));
480 }
481 }
482 BacktestResponse::Status { status } => {
483 let _ = self.events_tx.send(ControlEvent::Status(status)).await;
484 }
485 BacktestResponse::Success => {
486 }
488 other => {
489 debug!(?other, "ignoring unexpected control response");
491 }
492 }
493
494 Ok(())
495 }
496}
497
498enum HandshakeError {
499 Transient(String),
501 Fatal(String),
503}
504
505async fn create_session(
506 ws: &mut Ws,
507 request: CreateBacktestSessionRequest,
508 url: &str,
509 events: &mut mpsc::Sender<ControlEvent>,
510 last_sequence: &mut Option<u64>,
511) -> Result<SessionInfo, HandshakeError> {
512 send_request(ws, &BacktestRequest::CreateBacktestSession(request))
513 .await
514 .map_err(HandshakeError::Transient)?;
515
516 let rpc_base = http_base_from_ws_url(url);
517
518 loop {
519 let response = next_response_with_timeout(ws, events, last_sequence)
520 .await
521 .map_err(HandshakeError::Transient)?;
522 match response {
523 BacktestResponse::SessionCreated {
524 session_id,
525 rpc_endpoint,
526 task_id,
527 } => {
528 let rpc_endpoint = resolve_rpc_url(&rpc_base, &rpc_endpoint);
529 return Ok(SessionInfo {
530 session_id,
531 rpc_endpoint,
532 task_id,
533 });
534 }
535 BacktestResponse::Error(err) => {
536 return Err(HandshakeError::Fatal(format!(
537 "server error: {}",
538 err_chain(&err)
539 )));
540 }
541 _ => {
542 }
545 }
546 }
547}
548
549async fn attach(
550 ws: &mut Ws,
551 session_id: &str,
552 last_sequence: Option<u64>,
553 events: &mut mpsc::Sender<ControlEvent>,
554 last_seq_state: &mut Option<u64>,
555) -> Result<(), HandshakeError> {
556 send_request(
557 ws,
558 &BacktestRequest::AttachBacktestSession {
559 session_id: session_id.to_string(),
560 last_sequence,
561 },
562 )
563 .await
564 .map_err(HandshakeError::Transient)?;
565
566 loop {
567 let response = next_response_with_timeout(ws, events, last_seq_state)
568 .await
569 .map_err(HandshakeError::Transient)?;
570 match response {
571 BacktestResponse::SessionAttached { .. } => return Ok(()),
572 BacktestResponse::Error(err) => {
573 return Err(handshake_error_for_response("attach", err));
574 }
575 _ => {}
576 }
577 }
578}
579
580async fn resume(
581 ws: &mut Ws,
582 events: &mut mpsc::Sender<ControlEvent>,
583 last_seq_state: &mut Option<u64>,
584) -> Result<(), HandshakeError> {
585 send_request(ws, &BacktestRequest::ResumeAttachedSession)
586 .await
587 .map_err(HandshakeError::Transient)?;
588
589 loop {
590 let response = next_response_with_timeout(ws, events, last_seq_state)
591 .await
592 .map_err(HandshakeError::Transient)?;
593 match response {
594 BacktestResponse::Success => return Ok(()),
595 BacktestResponse::Error(err) => {
596 return Err(handshake_error_for_response("resume", err));
597 }
598 _ => {}
599 }
600 }
601}
602
603fn handshake_error_for_response(stage: &'static str, err: BacktestError) -> HandshakeError {
608 match err {
609 BacktestError::SessionOwnershipBusy { .. } => {
610 HandshakeError::Transient(format!("{stage} contended: {}", err_chain(&err)))
611 }
612 _ => HandshakeError::Fatal(format!("{stage} rejected: {}", err_chain(&err))),
613 }
614}
615
616async fn send_request(ws: &mut Ws, req: &BacktestRequest) -> Result<(), String> {
617 let text = serde_json::to_string(req).map_err(|e| format!("serialize: {}", err_chain(&e)))?;
618 ws.send(Message::Text(text))
619 .await
620 .map_err(|e| format!("send: {}", err_chain(&e)))
621}
622
623async fn next_response_with_timeout(
628 ws: &mut Ws,
629 events: &mut mpsc::Sender<ControlEvent>,
630 last_sequence: &mut Option<u64>,
631) -> Result<BacktestResponse, String> {
632 let deadline = tokio::time::Instant::now() + HANDSHAKE_RESPONSE_TIMEOUT;
633 loop {
634 let msg = tokio::time::timeout_at(deadline, ws.next())
635 .await
636 .map_err(|_| format!("handshake timeout after {HANDSHAKE_RESPONSE_TIMEOUT:?}"))?;
637
638 let Some(msg) = msg else {
639 return Err("ws ended during handshake".into());
640 };
641 let msg = msg.map_err(|e| format!("ws read: {}", err_chain(&e)))?;
642
643 let text = match msg {
644 Message::Text(t) => t,
645 Message::Binary(b) => match std::str::from_utf8(&b) {
646 Ok(t) => t.to_string(),
647 Err(_) => continue,
648 },
649 Message::Close(frame) => {
650 return Err(format!("remote close during handshake: {frame:?}"));
651 }
652 _ => continue,
653 };
654
655 let (seq, response) = match serde_json::from_str::<SequencedResponse>(&text) {
656 Ok(s) => (Some(s.seq_id), s.response),
657 Err(_) => (
658 None,
659 serde_json::from_str::<BacktestResponse>(&text)
660 .map_err(|e| format!("deserialize: {}; raw={text}", err_chain(&e)))?,
661 ),
662 };
663 if let Some(s) = seq {
664 *last_sequence = Some(s);
665 }
666
667 match response {
670 BacktestResponse::SlotNotification(slot) => {
671 let _ = events.send(ControlEvent::Slot(slot)).await;
672 }
673 BacktestResponse::ReadyForContinue => {
674 let _ = events.send(ControlEvent::ReadyForContinue).await;
675 }
676 BacktestResponse::Paused(event) => {
677 let _ = events.send(ControlEvent::Paused(event)).await;
678 }
679 BacktestResponse::DiscoveryBatch(event) => {
680 let _ = events.send(ControlEvent::DiscoveryBatch(event)).await;
681 }
682 BacktestResponse::Completed { .. } => {
683 let _ = events.send(ControlEvent::Completed).await;
684 }
685 other => return Ok(other),
686 }
687 }
688}
689
690async fn graceful_close(ws: &mut Ws) {
691 let _ = tokio::time::timeout(
692 GRACEFUL_CLOSE_TIMEOUT,
693 send_request(ws, &BacktestRequest::CloseBacktestSession),
694 )
695 .await;
696 let _ = tokio::time::timeout(GRACEFUL_CLOSE_TIMEOUT, ws.close(None)).await;
697}
698
699fn resolve_rpc_url(base: &str, endpoint: &str) -> String {
700 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
701 endpoint.to_string()
702 } else {
703 format!("{}/{}", base, endpoint.trim_start_matches('/'))
704 }
705}