1use std::time::Instant;
13
14use futures::StreamExt;
15use simulator_api::{
16 AgentStatsReport, BacktestError, BacktestRequest, BacktestResponse, BacktestStatus,
17 ContinueParams, ContinueToParams, CreateBacktestSessionRequest, DiscoveryBatchEvent,
18 PausedEvent, SequencedResponse, SessionSummary,
19};
20use tokio::{
21 sync::{mpsc, oneshot, watch},
22 task::JoinHandle,
23};
24use tokio_tungstenite::tungstenite::Message;
25use tokio_util::sync::CancellationToken;
26use tracing::{debug, info, warn};
27
28use super::{
29 ConnectionStatus, ControlConnection, HANDSHAKE_RESPONSE_TIMEOUT, HandshakeError, InboundFrame,
30 KEEPALIVE_INTERVAL, MessageLoopExit, SessionInfo, Ws, classify_inbound, graceful_close,
31 handshake_error_for_response, is_terminal_backtest_error, resolve_rpc_url, run_control_loop,
32 send_keepalive_ping, send_request,
33};
34use crate::{error::err_chain, urls::http_base_from_ws_url};
35
36#[derive(Debug)]
41pub enum ControlEvent {
42 ReadyForContinue,
43 Paused(PausedEvent),
46 DiscoveryBatch(DiscoveryBatchEvent),
50 Slot(u64),
51 Status(BacktestStatus),
55 Completed {
56 summary: Option<SessionSummary>,
57 agent_stats: Option<Vec<AgentStatsReport>>,
58 },
59 Error(BacktestError),
60}
61
62pub struct ControlHandle {
64 continues: mpsc::Sender<ContinueParams>,
65 continue_tos: mpsc::Sender<ContinueToParams>,
66 pub events: mpsc::Receiver<ControlEvent>,
67 pub status: watch::Receiver<ConnectionStatus>,
68 session_info: Option<oneshot::Receiver<Result<SessionInfo, String>>>,
69 join: JoinHandle<()>,
70}
71
72impl ControlHandle {
73 pub async fn wait_for_session(&mut self) -> Result<SessionInfo, String> {
76 let rx = self
77 .session_info
78 .take()
79 .ok_or_else(|| "session_info already consumed".to_string())?;
80 rx.await
81 .map_err(|_| "control manager exited before creating session".to_string())?
82 }
83
84 pub async fn send_continue(
87 &self,
88 params: ContinueParams,
89 ) -> Result<(), mpsc::error::SendError<ContinueParams>> {
90 self.continues.send(params).await
91 }
92
93 pub async fn send_continue_to(
97 &self,
98 params: ContinueToParams,
99 ) -> Result<(), mpsc::error::SendError<ContinueToParams>> {
100 self.continue_tos.send(params).await
101 }
102
103 pub async fn join(self) {
108 drop(self.continues);
109 drop(self.continue_tos);
110 let _ = self.join.await;
111 }
112}
113
114pub fn spawn_control_manager(
120 url: String,
121 api_key: String,
122 create: CreateBacktestSessionRequest,
123 cancel: CancellationToken,
124) -> ControlHandle {
125 let (continues_tx, continues_rx) = mpsc::channel::<ContinueParams>(1);
126 let (continue_tos_tx, continue_tos_rx) = mpsc::channel::<ContinueToParams>(1);
127 let (events_tx, events_rx) = mpsc::channel::<ControlEvent>(256);
128 let (status_tx, status_rx) = watch::channel(ConnectionStatus::Down);
129 let (session_tx, session_rx) = oneshot::channel::<Result<SessionInfo, String>>();
130
131 let manager = ControlTask {
132 url,
133 api_key,
134 create: Some(create),
135 session_info: None,
136 session_tx: Some(session_tx),
137 last_sequence: None,
138 continues_rx,
139 continue_tos_rx,
140 events_tx,
141 status_tx,
142 cancel,
143 };
144
145 let join = tokio::spawn(run_control_loop(manager));
146
147 ControlHandle {
148 continues: continues_tx,
149 continue_tos: continue_tos_tx,
150 events: events_rx,
151 status: status_rx,
152 session_info: Some(session_rx),
153 join,
154 }
155}
156
157struct ControlTask {
158 url: String,
159 api_key: String,
160 create: Option<CreateBacktestSessionRequest>,
162 session_info: Option<SessionInfo>,
164 session_tx: Option<oneshot::Sender<Result<SessionInfo, String>>>,
166 last_sequence: Option<u64>,
168 continues_rx: mpsc::Receiver<ContinueParams>,
169 continue_tos_rx: mpsc::Receiver<ContinueToParams>,
170 events_tx: mpsc::Sender<ControlEvent>,
171 status_tx: watch::Sender<ConnectionStatus>,
172 cancel: CancellationToken,
173}
174
175impl ControlConnection for ControlTask {
176 fn url(&self) -> &str {
177 &self.url
178 }
179 fn api_key(&self) -> &str {
180 &self.api_key
181 }
182 fn cancel(&self) -> &CancellationToken {
183 &self.cancel
184 }
185 fn label(&self) -> &'static str {
186 "control"
187 }
188 fn status_tx(&self) -> &watch::Sender<ConnectionStatus> {
189 &self.status_tx
190 }
191
192 fn fail_pending(&mut self, reason: String) {
193 if let Some(tx) = self.session_tx.take() {
194 let _ = tx.send(Err(reason));
195 }
196 }
197
198 async fn handshake(&mut self, mut ws: Ws) -> Result<Ws, HandshakeError> {
199 if let Some(info) = &self.session_info {
200 let info = info.clone();
201 attach(
202 &mut ws,
203 &info.session_id,
204 self.last_sequence,
205 &mut self.events_tx,
206 &mut self.last_sequence,
207 )
208 .await?;
209 resume(&mut ws, &mut self.events_tx, &mut self.last_sequence).await?;
210 debug!(session_id = info.session_id, "control reattached");
211 } else if let Some(create) = self.create.take() {
212 let info = create_session(
213 &mut ws,
214 create,
215 &self.url,
216 &mut self.events_tx,
217 &mut self.last_sequence,
218 )
219 .await?;
220 info!(session_id = info.session_id, "control session created");
221 self.session_info = Some(info.clone());
222 if let Some(tx) = self.session_tx.take() {
223 let _ = tx.send(Ok(info));
224 }
225 } else {
226 return Err(HandshakeError::Fatal(
227 "no create request and no session_id".into(),
228 ));
229 }
230
231 Ok(ws)
232 }
233
234 async fn message_loop(&mut self, mut ws: Ws) -> MessageLoopExit {
235 let mut ping_timer = tokio::time::interval(KEEPALIVE_INTERVAL);
236 ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
237 let mut last_inbound = Instant::now();
238
239 let exit = loop {
240 tokio::select! {
241 biased;
242 _ = self.cancel.cancelled() => break MessageLoopExit::Cancelled,
243
244 _ = ping_timer.tick() => {
245 if let Some(why) = send_keepalive_ping(&mut ws, last_inbound).await {
246 break MessageLoopExit::ConnectionLost(why);
247 }
248 }
249
250 msg = ws.next() => {
251 last_inbound = Instant::now();
252 match classify_inbound(msg) {
253 InboundFrame::Text(t) => {
254 if let Err(exit) = self.handle_text(&t).await {
255 break exit;
256 }
257 }
258 InboundFrame::Ignore => {}
259 InboundFrame::Lost(why) => break MessageLoopExit::ConnectionLost(why),
260 }
261 }
262
263 req = self.continues_rx.recv() => {
264 match req {
265 Some(params) => {
266 if let Err(e) = send_request(&mut ws, &BacktestRequest::Continue(params)).await {
267 break MessageLoopExit::ConnectionLost(format!("continue send: {e}"));
268 }
269 }
270 None => break MessageLoopExit::SessionEnded,
271 }
272 }
273
274 req = self.continue_tos_rx.recv() => {
275 match req {
276 Some(params) => {
277 if let Err(e) = send_request(&mut ws, &BacktestRequest::ContinueTo(params)).await {
278 break MessageLoopExit::ConnectionLost(format!("continue_to send: {e}"));
279 }
280 }
281 None => break MessageLoopExit::SessionEnded,
282 }
283 }
284 }
285 };
286
287 if matches!(
288 exit,
289 MessageLoopExit::SessionEnded | MessageLoopExit::Cancelled
290 ) {
291 graceful_close(&mut ws).await;
292 }
293 exit
294 }
295}
296
297impl ControlTask {
298 async fn handle_text(&mut self, text: &str) -> Result<(), MessageLoopExit> {
300 let (seq, response) = match serde_json::from_str::<SequencedResponse>(text) {
301 Ok(s) => (Some(s.seq_id), s.response),
302 Err(_) => match serde_json::from_str::<BacktestResponse>(text) {
303 Ok(r) => (None, r),
304 Err(e) => {
305 warn!(error = %err_chain(&e), "discarding undeserializable control message");
306 return Ok(());
307 }
308 },
309 };
310
311 if let Some(s) = seq {
312 self.last_sequence = Some(s);
313 }
314
315 match response {
316 BacktestResponse::ReadyForContinue => {
317 let _ = self.events_tx.send(ControlEvent::ReadyForContinue).await;
318 }
319 BacktestResponse::Paused(event) => {
320 let _ = self.events_tx.send(ControlEvent::Paused(event)).await;
321 }
322 BacktestResponse::DiscoveryBatch(event) => {
323 let _ = self
324 .events_tx
325 .send(ControlEvent::DiscoveryBatch(event))
326 .await;
327 }
328 BacktestResponse::SlotNotification(slot) => {
329 let _ = self.events_tx.send(ControlEvent::Slot(slot)).await;
330 }
331 BacktestResponse::Completed {
332 summary,
333 agent_stats,
334 } => {
335 let _ = self
336 .events_tx
337 .send(ControlEvent::Completed {
338 summary,
339 agent_stats,
340 })
341 .await;
342 return Err(MessageLoopExit::SessionEnded);
343 }
344 BacktestResponse::Error(err) => {
345 if matches!(&err, BacktestError::SimulationError { .. }) {
347 warn!(error = %err_chain(&err), "simulation error");
348 return Ok(());
349 }
350 let terminal = is_terminal_backtest_error(&err);
351 let _ = self.events_tx.send(ControlEvent::Error(err)).await;
352 if terminal {
353 return Err(MessageLoopExit::Terminal(
354 "server reported terminal error".into(),
355 ));
356 }
357 }
358 BacktestResponse::Status { status } => {
359 let _ = self.events_tx.send(ControlEvent::Status(status)).await;
360 }
361 BacktestResponse::Success => {
362 }
364 other => {
365 debug!(?other, "ignoring unexpected control response");
367 }
368 }
369
370 Ok(())
371 }
372}
373
374async fn create_session(
375 ws: &mut Ws,
376 request: CreateBacktestSessionRequest,
377 url: &str,
378 events: &mut mpsc::Sender<ControlEvent>,
379 last_sequence: &mut Option<u64>,
380) -> Result<SessionInfo, HandshakeError> {
381 send_request(ws, &BacktestRequest::CreateBacktestSession(request))
382 .await
383 .map_err(HandshakeError::Transient)?;
384
385 let rpc_base = http_base_from_ws_url(url);
386
387 loop {
388 let response = next_response_with_timeout(ws, events, last_sequence)
389 .await
390 .map_err(HandshakeError::Transient)?;
391 match response {
392 BacktestResponse::SessionCreated {
393 session_id,
394 rpc_endpoint,
395 task_id,
396 } => {
397 let rpc_endpoint = resolve_rpc_url(&rpc_base, &rpc_endpoint);
398 return Ok(SessionInfo {
399 session_id,
400 rpc_endpoint,
401 task_id,
402 });
403 }
404 BacktestResponse::Error(err) => {
405 return Err(HandshakeError::Fatal(format!(
406 "server error: {}",
407 err_chain(&err)
408 )));
409 }
410 _ => {
411 }
414 }
415 }
416}
417
418async fn attach(
419 ws: &mut Ws,
420 session_id: &str,
421 last_sequence: Option<u64>,
422 events: &mut mpsc::Sender<ControlEvent>,
423 last_seq_state: &mut Option<u64>,
424) -> Result<(), HandshakeError> {
425 send_request(
426 ws,
427 &BacktestRequest::AttachBacktestSession {
428 session_id: session_id.to_string(),
429 last_sequence,
430 },
431 )
432 .await
433 .map_err(HandshakeError::Transient)?;
434
435 loop {
436 let response = next_response_with_timeout(ws, events, last_seq_state)
437 .await
438 .map_err(HandshakeError::Transient)?;
439 match response {
440 BacktestResponse::SessionAttached { .. } => return Ok(()),
441 BacktestResponse::Error(err) => {
442 return Err(handshake_error_for_response("attach", err));
443 }
444 _ => {}
445 }
446 }
447}
448
449async fn resume(
450 ws: &mut Ws,
451 events: &mut mpsc::Sender<ControlEvent>,
452 last_seq_state: &mut Option<u64>,
453) -> Result<(), HandshakeError> {
454 send_request(ws, &BacktestRequest::ResumeAttachedSession)
455 .await
456 .map_err(HandshakeError::Transient)?;
457
458 loop {
459 let response = next_response_with_timeout(ws, events, last_seq_state)
460 .await
461 .map_err(HandshakeError::Transient)?;
462 match response {
463 BacktestResponse::Success => return Ok(()),
464 BacktestResponse::Error(err) => {
465 return Err(handshake_error_for_response("resume", err));
466 }
467 _ => {}
468 }
469 }
470}
471
472async fn next_response_with_timeout(
477 ws: &mut Ws,
478 events: &mut mpsc::Sender<ControlEvent>,
479 last_sequence: &mut Option<u64>,
480) -> Result<BacktestResponse, String> {
481 let deadline = tokio::time::Instant::now() + HANDSHAKE_RESPONSE_TIMEOUT;
482 loop {
483 let msg = tokio::time::timeout_at(deadline, ws.next())
484 .await
485 .map_err(|_| format!("handshake timeout after {HANDSHAKE_RESPONSE_TIMEOUT:?}"))?;
486
487 let Some(msg) = msg else {
488 return Err("ws ended during handshake".into());
489 };
490 let msg = msg.map_err(|e| format!("ws read: {}", err_chain(&e)))?;
491
492 let text = match msg {
493 Message::Text(t) => t,
494 Message::Binary(b) => match std::str::from_utf8(&b) {
495 Ok(t) => t.to_string(),
496 Err(_) => continue,
497 },
498 Message::Close(frame) => {
499 return Err(format!("remote close during handshake: {frame:?}"));
500 }
501 _ => continue,
502 };
503
504 let (seq, response) = match serde_json::from_str::<SequencedResponse>(&text) {
505 Ok(s) => (Some(s.seq_id), s.response),
506 Err(_) => (
507 None,
508 serde_json::from_str::<BacktestResponse>(&text)
509 .map_err(|e| format!("deserialize: {}; raw={text}", err_chain(&e)))?,
510 ),
511 };
512 if let Some(s) = seq {
513 *last_sequence = Some(s);
514 }
515
516 match response {
519 BacktestResponse::SlotNotification(slot) => {
520 let _ = events.send(ControlEvent::Slot(slot)).await;
521 }
522 BacktestResponse::ReadyForContinue => {
523 let _ = events.send(ControlEvent::ReadyForContinue).await;
524 }
525 BacktestResponse::Paused(event) => {
526 let _ = events.send(ControlEvent::Paused(event)).await;
527 }
528 BacktestResponse::DiscoveryBatch(event) => {
529 let _ = events.send(ControlEvent::DiscoveryBatch(event)).await;
530 }
531 BacktestResponse::Completed {
532 summary,
533 agent_stats,
534 } => {
535 let _ = events
536 .send(ControlEvent::Completed {
537 summary,
538 agent_stats,
539 })
540 .await;
541 }
542 other => return Ok(other),
543 }
544 }
545}