Skip to main content

simulator_client/managed/
control.rs

1//! ControlManager — owns the backtest control WebSocket.
2//!
3//! Responsibilities:
4//! - Establish the WS with a bounded connect timeout
5//! - Perform the correct handshake (`Create` the first time, `Attach`+`Resume`
6//!   on reconnect) with a bounded per-response timeout
7//! - Bridge inbound `BacktestResponse` ↔ outbound `Continue` via channels
8//! - Keep the connection alive with WebSocket ping/pong
9//! - Reconnect with bounded backoff; publish `ConnectionStatus` transitions
10//! - On total-budget exhaustion, publish `Failed(reason)` and exit
11
12use std::time::Instant;
13
14use futures::StreamExt;
15use simulator_api::{
16    AgentStatsReport, BacktestError, BacktestRequest, BacktestResponse, BacktestStatus,
17    ContinueParams, ContinueToParams, CreateBacktestSessionRequest, DiscoveryBatchEvent,
18    PausedEvent, SequencedResponse, SessionSummary,
19};
20use tokio::{
21    sync::{mpsc, oneshot, watch},
22    task::JoinHandle,
23};
24use tokio_tungstenite::tungstenite::Message;
25use tokio_util::sync::CancellationToken;
26use tracing::{debug, info, warn};
27
28use super::{
29    ConnectionStatus, ControlConnection, HANDSHAKE_RESPONSE_TIMEOUT, HandshakeError, InboundFrame,
30    KEEPALIVE_INTERVAL, MessageLoopExit, SessionInfo, Ws, classify_inbound, graceful_close,
31    handshake_error_for_response, is_terminal_backtest_error, resolve_rpc_url, run_control_loop,
32    send_keepalive_ping, send_request,
33};
34use crate::{error::err_chain, urls::http_base_from_ws_url};
35
36/// Events the driver observes from the control connection.
37///
38/// Session-lifecycle responses (`SessionCreated`, `SessionAttached`,
39/// `ResumeSuccess`) are handled internally and not forwarded.
40#[derive(Debug)]
41pub enum ControlEvent {
42    ReadyForContinue,
43    /// Server paused at a `ContinueTo` target. The session is ready for
44    /// another `Continue` or `ContinueTo` from this point.
45    Paused(PausedEvent),
46    /// Server discovered an upcoming batch matching a registered
47    /// `DiscoveryFilter`. Send `ContinueTo(slot, batch_index)` to pause
48    /// immediately before it executes.
49    DiscoveryBatch(DiscoveryBatchEvent),
50    Slot(u64),
51    /// High-level progress phase during session startup (e.g. `StartingRuntime`).
52    /// Useful for showing what the server is doing while waiting for the first
53    /// `ReadyForContinue`.
54    Status(BacktestStatus),
55    Completed {
56        summary: Option<SessionSummary>,
57        agent_stats: Option<Vec<AgentStatsReport>>,
58    },
59    Error(BacktestError),
60}
61
62/// Handle to a running `ControlManager` task.
63pub struct ControlHandle {
64    continues: mpsc::Sender<ContinueParams>,
65    continue_tos: mpsc::Sender<ContinueToParams>,
66    pub events: mpsc::Receiver<ControlEvent>,
67    pub status: watch::Receiver<ConnectionStatus>,
68    session_info: Option<oneshot::Receiver<Result<SessionInfo, String>>>,
69    join: JoinHandle<()>,
70}
71
72impl ControlHandle {
73    /// Resolve once the session has been created (or the manager has failed
74    /// before reaching that point). Consumes the one-shot; callable only once.
75    pub async fn wait_for_session(&mut self) -> Result<SessionInfo, String> {
76        let rx = self
77            .session_info
78            .take()
79            .ok_or_else(|| "session_info already consumed".to_string())?;
80        rx.await
81            .map_err(|_| "control manager exited before creating session".to_string())?
82    }
83
84    /// Send a `Continue` request to the control task. Errors if the manager has
85    /// exited.
86    pub async fn send_continue(
87        &self,
88        params: ContinueParams,
89    ) -> Result<(), mpsc::error::SendError<ContinueParams>> {
90        self.continues.send(params).await
91    }
92
93    /// Send a `ContinueTo` request to step to a specific slot/batch boundary.
94    /// Pair with `ControlEvent::DiscoveryBatch` to pause before each
95    /// discovered batch.
96    pub async fn send_continue_to(
97        &self,
98        params: ContinueToParams,
99    ) -> Result<(), mpsc::error::SendError<ContinueToParams>> {
100        self.continue_tos.send(params).await
101    }
102
103    /// Await the control task's exit. The task exits on its own when the
104    /// server reports `Completed`, the cancel token fires, or it hits a
105    /// terminal error; dropping the request channels here nudges it in the
106    /// case where the driver is giving up without having seen `Completed`.
107    pub async fn join(self) {
108        drop(self.continues);
109        drop(self.continue_tos);
110        let _ = self.join.await;
111    }
112}
113
114/// Spawn a `ControlManager` task and return a handle.
115///
116/// The `continues` channel has a bounded capacity of 1: we only ever have one
117/// Continue in flight, and backpressuring the driver is the correct behavior
118/// if the connection is temporarily down.
119pub fn spawn_control_manager(
120    url: String,
121    api_key: String,
122    create: CreateBacktestSessionRequest,
123    cancel: CancellationToken,
124) -> ControlHandle {
125    let (continues_tx, continues_rx) = mpsc::channel::<ContinueParams>(1);
126    let (continue_tos_tx, continue_tos_rx) = mpsc::channel::<ContinueToParams>(1);
127    let (events_tx, events_rx) = mpsc::channel::<ControlEvent>(256);
128    let (status_tx, status_rx) = watch::channel(ConnectionStatus::Down);
129    let (session_tx, session_rx) = oneshot::channel::<Result<SessionInfo, String>>();
130
131    let manager = ControlTask {
132        url,
133        api_key,
134        create: Some(create),
135        session_info: None,
136        session_tx: Some(session_tx),
137        last_sequence: None,
138        continues_rx,
139        continue_tos_rx,
140        events_tx,
141        status_tx,
142        cancel,
143    };
144
145    let join = tokio::spawn(run_control_loop(manager));
146
147    ControlHandle {
148        continues: continues_tx,
149        continue_tos: continue_tos_tx,
150        events: events_rx,
151        status: status_rx,
152        session_info: Some(session_rx),
153        join,
154    }
155}
156
157struct ControlTask {
158    url: String,
159    api_key: String,
160    /// Set on first connect; consumed and cleared after `Create` succeeds.
161    create: Option<CreateBacktestSessionRequest>,
162    /// Populated after `Create` succeeds. On reconnect, used to build `Attach`.
163    session_info: Option<SessionInfo>,
164    /// One-shot result to the handle; fired exactly once.
165    session_tx: Option<oneshot::Sender<Result<SessionInfo, String>>>,
166    /// Highest sequence number observed from the server.
167    last_sequence: Option<u64>,
168    continues_rx: mpsc::Receiver<ContinueParams>,
169    continue_tos_rx: mpsc::Receiver<ContinueToParams>,
170    events_tx: mpsc::Sender<ControlEvent>,
171    status_tx: watch::Sender<ConnectionStatus>,
172    cancel: CancellationToken,
173}
174
175impl ControlConnection for ControlTask {
176    fn url(&self) -> &str {
177        &self.url
178    }
179    fn api_key(&self) -> &str {
180        &self.api_key
181    }
182    fn cancel(&self) -> &CancellationToken {
183        &self.cancel
184    }
185    fn label(&self) -> &'static str {
186        "control"
187    }
188    fn status_tx(&self) -> &watch::Sender<ConnectionStatus> {
189        &self.status_tx
190    }
191
192    fn fail_pending(&mut self, reason: String) {
193        if let Some(tx) = self.session_tx.take() {
194            let _ = tx.send(Err(reason));
195        }
196    }
197
198    async fn handshake(&mut self, mut ws: Ws) -> Result<Ws, HandshakeError> {
199        if let Some(info) = &self.session_info {
200            let info = info.clone();
201            attach(
202                &mut ws,
203                &info.session_id,
204                self.last_sequence,
205                &mut self.events_tx,
206                &mut self.last_sequence,
207            )
208            .await?;
209            resume(&mut ws, &mut self.events_tx, &mut self.last_sequence).await?;
210            debug!(session_id = info.session_id, "control reattached");
211        } else if let Some(create) = self.create.take() {
212            let info = create_session(
213                &mut ws,
214                create,
215                &self.url,
216                &mut self.events_tx,
217                &mut self.last_sequence,
218            )
219            .await?;
220            info!(session_id = info.session_id, "control session created");
221            self.session_info = Some(info.clone());
222            if let Some(tx) = self.session_tx.take() {
223                let _ = tx.send(Ok(info));
224            }
225        } else {
226            return Err(HandshakeError::Fatal(
227                "no create request and no session_id".into(),
228            ));
229        }
230
231        Ok(ws)
232    }
233
234    async fn message_loop(&mut self, mut ws: Ws) -> MessageLoopExit {
235        let mut ping_timer = tokio::time::interval(KEEPALIVE_INTERVAL);
236        ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
237        let mut last_inbound = Instant::now();
238
239        let exit = loop {
240            tokio::select! {
241                biased;
242                _ = self.cancel.cancelled() => break MessageLoopExit::Cancelled,
243
244                _ = ping_timer.tick() => {
245                    if let Some(why) = send_keepalive_ping(&mut ws, last_inbound).await {
246                        break MessageLoopExit::ConnectionLost(why);
247                    }
248                }
249
250                msg = ws.next() => {
251                    last_inbound = Instant::now();
252                    match classify_inbound(msg) {
253                        InboundFrame::Text(t) => {
254                            if let Err(exit) = self.handle_text(&t).await {
255                                break exit;
256                            }
257                        }
258                        InboundFrame::Ignore => {}
259                        InboundFrame::Lost(why) => break MessageLoopExit::ConnectionLost(why),
260                    }
261                }
262
263                req = self.continues_rx.recv() => {
264                    match req {
265                        Some(params) => {
266                            if let Err(e) = send_request(&mut ws, &BacktestRequest::Continue(params)).await {
267                                break MessageLoopExit::ConnectionLost(format!("continue send: {e}"));
268                            }
269                        }
270                        None => break MessageLoopExit::SessionEnded,
271                    }
272                }
273
274                req = self.continue_tos_rx.recv() => {
275                    match req {
276                        Some(params) => {
277                            if let Err(e) = send_request(&mut ws, &BacktestRequest::ContinueTo(params)).await {
278                                break MessageLoopExit::ConnectionLost(format!("continue_to send: {e}"));
279                            }
280                        }
281                        None => break MessageLoopExit::SessionEnded,
282                    }
283                }
284            }
285        };
286
287        if matches!(
288            exit,
289            MessageLoopExit::SessionEnded | MessageLoopExit::Cancelled
290        ) {
291            graceful_close(&mut ws).await;
292        }
293        exit
294    }
295}
296
297impl ControlTask {
298    /// Returns `Err(MessageLoopExit)` if the message signals we should exit the loop.
299    async fn handle_text(&mut self, text: &str) -> Result<(), MessageLoopExit> {
300        let (seq, response) = match serde_json::from_str::<SequencedResponse>(text) {
301            Ok(s) => (Some(s.seq_id), s.response),
302            Err(_) => match serde_json::from_str::<BacktestResponse>(text) {
303                Ok(r) => (None, r),
304                Err(e) => {
305                    warn!(error = %err_chain(&e), "discarding undeserializable control message");
306                    return Ok(());
307                }
308            },
309        };
310
311        if let Some(s) = seq {
312            self.last_sequence = Some(s);
313        }
314
315        match response {
316            BacktestResponse::ReadyForContinue => {
317                let _ = self.events_tx.send(ControlEvent::ReadyForContinue).await;
318            }
319            BacktestResponse::Paused(event) => {
320                let _ = self.events_tx.send(ControlEvent::Paused(event)).await;
321            }
322            BacktestResponse::DiscoveryBatch(event) => {
323                let _ = self
324                    .events_tx
325                    .send(ControlEvent::DiscoveryBatch(event))
326                    .await;
327            }
328            BacktestResponse::SlotNotification(slot) => {
329                let _ = self.events_tx.send(ControlEvent::Slot(slot)).await;
330            }
331            BacktestResponse::Completed {
332                summary,
333                agent_stats,
334            } => {
335                let _ = self
336                    .events_tx
337                    .send(ControlEvent::Completed {
338                        summary,
339                        agent_stats,
340                    })
341                    .await;
342                return Err(MessageLoopExit::SessionEnded);
343            }
344            BacktestResponse::Error(err) => {
345                // Per-slot simulation errors are non-fatal: log and keep going.
346                if matches!(&err, BacktestError::SimulationError { .. }) {
347                    warn!(error = %err_chain(&err), "simulation error");
348                    return Ok(());
349                }
350                let terminal = is_terminal_backtest_error(&err);
351                let _ = self.events_tx.send(ControlEvent::Error(err)).await;
352                if terminal {
353                    return Err(MessageLoopExit::Terminal(
354                        "server reported terminal error".into(),
355                    ));
356                }
357            }
358            BacktestResponse::Status { status } => {
359                let _ = self.events_tx.send(ControlEvent::Status(status)).await;
360            }
361            BacktestResponse::Success => {
362                // Ack for Close or similar; nothing to forward.
363            }
364            other => {
365                // SessionCreated/Attached/etc. during the message loop are unexpected.
366                debug!(?other, "ignoring unexpected control response");
367            }
368        }
369
370        Ok(())
371    }
372}
373
374async fn create_session(
375    ws: &mut Ws,
376    request: CreateBacktestSessionRequest,
377    url: &str,
378    events: &mut mpsc::Sender<ControlEvent>,
379    last_sequence: &mut Option<u64>,
380) -> Result<SessionInfo, HandshakeError> {
381    send_request(ws, &BacktestRequest::CreateBacktestSession(request))
382        .await
383        .map_err(HandshakeError::Transient)?;
384
385    let rpc_base = http_base_from_ws_url(url);
386
387    loop {
388        let response = next_response_with_timeout(ws, events, last_sequence)
389            .await
390            .map_err(HandshakeError::Transient)?;
391        match response {
392            BacktestResponse::SessionCreated {
393                session_id,
394                rpc_endpoint,
395                task_id,
396            } => {
397                let rpc_endpoint = resolve_rpc_url(&rpc_base, &rpc_endpoint);
398                return Ok(SessionInfo {
399                    session_id,
400                    rpc_endpoint,
401                    task_id,
402                });
403            }
404            BacktestResponse::Error(err) => {
405                return Err(HandshakeError::Fatal(format!(
406                    "server error: {}",
407                    err_chain(&err)
408                )));
409            }
410            _ => {
411                // Any unexpected response before SessionCreated — ignore and
412                // keep waiting. (e.g. statuses, early events.)
413            }
414        }
415    }
416}
417
418async fn attach(
419    ws: &mut Ws,
420    session_id: &str,
421    last_sequence: Option<u64>,
422    events: &mut mpsc::Sender<ControlEvent>,
423    last_seq_state: &mut Option<u64>,
424) -> Result<(), HandshakeError> {
425    send_request(
426        ws,
427        &BacktestRequest::AttachBacktestSession {
428            session_id: session_id.to_string(),
429            last_sequence,
430        },
431    )
432    .await
433    .map_err(HandshakeError::Transient)?;
434
435    loop {
436        let response = next_response_with_timeout(ws, events, last_seq_state)
437            .await
438            .map_err(HandshakeError::Transient)?;
439        match response {
440            BacktestResponse::SessionAttached { .. } => return Ok(()),
441            BacktestResponse::Error(err) => {
442                return Err(handshake_error_for_response("attach", err));
443            }
444            _ => {}
445        }
446    }
447}
448
449async fn resume(
450    ws: &mut Ws,
451    events: &mut mpsc::Sender<ControlEvent>,
452    last_seq_state: &mut Option<u64>,
453) -> Result<(), HandshakeError> {
454    send_request(ws, &BacktestRequest::ResumeAttachedSession)
455        .await
456        .map_err(HandshakeError::Transient)?;
457
458    loop {
459        let response = next_response_with_timeout(ws, events, last_seq_state)
460            .await
461            .map_err(HandshakeError::Transient)?;
462        match response {
463            BacktestResponse::Success => return Ok(()),
464            BacktestResponse::Error(err) => {
465                return Err(handshake_error_for_response("resume", err));
466            }
467            _ => {}
468        }
469    }
470}
471
472/// Read the next response during a handshake, with a bounded timeout.
473///
474/// Any non-handshake responses observed during the wait are forwarded to the
475/// driver (slot notifications, errors) so we don't lose them.
476async fn next_response_with_timeout(
477    ws: &mut Ws,
478    events: &mut mpsc::Sender<ControlEvent>,
479    last_sequence: &mut Option<u64>,
480) -> Result<BacktestResponse, String> {
481    let deadline = tokio::time::Instant::now() + HANDSHAKE_RESPONSE_TIMEOUT;
482    loop {
483        let msg = tokio::time::timeout_at(deadline, ws.next())
484            .await
485            .map_err(|_| format!("handshake timeout after {HANDSHAKE_RESPONSE_TIMEOUT:?}"))?;
486
487        let Some(msg) = msg else {
488            return Err("ws ended during handshake".into());
489        };
490        let msg = msg.map_err(|e| format!("ws read: {}", err_chain(&e)))?;
491
492        let text = match msg {
493            Message::Text(t) => t,
494            Message::Binary(b) => match std::str::from_utf8(&b) {
495                Ok(t) => t.to_string(),
496                Err(_) => continue,
497            },
498            Message::Close(frame) => {
499                return Err(format!("remote close during handshake: {frame:?}"));
500            }
501            _ => continue,
502        };
503
504        let (seq, response) = match serde_json::from_str::<SequencedResponse>(&text) {
505            Ok(s) => (Some(s.seq_id), s.response),
506            Err(_) => (
507                None,
508                serde_json::from_str::<BacktestResponse>(&text)
509                    .map_err(|e| format!("deserialize: {}; raw={text}", err_chain(&e)))?,
510            ),
511        };
512        if let Some(s) = seq {
513            *last_sequence = Some(s);
514        }
515
516        // Forward noisy event kinds to the driver so nothing is lost while we
517        // wait for the handshake response.
518        match response {
519            BacktestResponse::SlotNotification(slot) => {
520                let _ = events.send(ControlEvent::Slot(slot)).await;
521            }
522            BacktestResponse::ReadyForContinue => {
523                let _ = events.send(ControlEvent::ReadyForContinue).await;
524            }
525            BacktestResponse::Paused(event) => {
526                let _ = events.send(ControlEvent::Paused(event)).await;
527            }
528            BacktestResponse::DiscoveryBatch(event) => {
529                let _ = events.send(ControlEvent::DiscoveryBatch(event)).await;
530            }
531            BacktestResponse::Completed {
532                summary,
533                agent_stats,
534            } => {
535                let _ = events
536                    .send(ControlEvent::Completed {
537                        summary,
538                        agent_stats,
539                    })
540                    .await;
541            }
542            other => return Ok(other),
543        }
544    }
545}