1use std::{
2 borrow::Cow,
3 collections::{BTreeMap, VecDeque},
4 future::Future,
5 time::Duration,
6};
7
8use futures::{SinkExt, StreamExt, stream};
9use simulator_api::{
10 AccountData, AccountModifications, AgentStatsReport, BacktestError, BacktestRequest,
11 BacktestResponse, BacktestStatus, ContinueParams, CreateBacktestSessionRequest,
12 CreateBacktestSessionRequestV1, SequencedResponse, SessionSummary,
13};
14use solana_address::Address;
15use solana_client::{
16 nonblocking::rpc_client::RpcClient,
17 rpc_response::{Response, RpcLogsResponse},
18};
19use solana_commitment_config::CommitmentConfig;
20use thiserror::Error;
21use tokio::net::TcpStream;
22use tokio_tungstenite::{
23 MaybeTlsStream, WebSocketStream,
24 tungstenite::{
25 Error as WsError, Message,
26 error::ProtocolError,
27 protocol::{CloseFrame, frame::coding::CloseCode},
28 },
29};
30
31use crate::{
32 BacktestClientError, BacktestClientResult, Continue,
33 injection::ProgramModError,
34 subscriptions::{
35 AccountDiffNotification, AccountDiffSubscriptionHandle, LogSubscriptionHandle,
36 SubscriptionError,
37 },
38};
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum ReadyOutcome {
43 Ready,
45 Completed,
47}
48
49#[derive(Debug, Default)]
51pub struct ContinueResult {
52 pub slot_notifications: u64,
54 pub last_slot: Option<u64>,
56 pub statuses: Vec<BacktestStatus>,
58 pub ready_for_continue: bool,
60 pub completed: bool,
62}
63
64#[derive(Debug)]
66pub struct AdvanceState {
67 pub expected_slots: u64,
69 pub slot_notifications: u64,
71 pub last_slot: Option<u64>,
73 pub statuses: Vec<BacktestStatus>,
75 pub ready_for_continue: bool,
77 pub completed: bool,
79 pub summary: Option<SessionSummary>,
81 pub agent_stats: Option<Vec<AgentStatsReport>>,
83}
84
85impl AdvanceState {
86 pub fn new(expected_slots: u64) -> Self {
88 Self {
89 expected_slots,
90 slot_notifications: 0,
91 last_slot: None,
92 statuses: Vec::new(),
93 ready_for_continue: false,
94 completed: false,
95 summary: None,
96 agent_stats: None,
97 }
98 }
99
100 pub fn is_done(&self, wait_for_slots: bool) -> bool {
102 if self.completed {
103 return true;
104 }
105
106 if !self.ready_for_continue {
107 return false;
108 }
109
110 !wait_for_slots || self.slot_notifications >= self.expected_slots
111 }
112}
113
114#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
116pub struct SessionCoverage {
117 completed: bool,
118 highest_slot_seen: Option<u64>,
119}
120
121impl SessionCoverage {
122 pub fn observe_slot(&mut self, slot: u64) {
124 self.highest_slot_seen = Some(
125 self.highest_slot_seen
126 .map_or(slot, |current| current.max(slot)),
127 );
128 }
129
130 pub fn mark_completed(&mut self) {
132 self.completed = true;
133 }
134
135 pub fn observe_response(&mut self, response: &BacktestResponse) {
137 match response {
138 BacktestResponse::SlotNotification(slot) => self.observe_slot(*slot),
139 BacktestResponse::Completed { .. } => self.mark_completed(),
140 _ => {}
141 }
142 }
143
144 pub fn is_completed(&self) -> bool {
146 self.completed
147 }
148
149 pub fn highest_slot_seen(&self) -> Option<u64> {
151 self.highest_slot_seen
152 }
153
154 pub fn validate_end_slot(&self, expected_end_slot: u64) -> Result<(), CoverageError> {
156 if !self.completed {
157 return Err(CoverageError::NotCompleted);
158 }
159
160 let Some(actual_end_slot) = self.highest_slot_seen else {
161 return Err(CoverageError::NoSlotsObserved);
162 };
163
164 if actual_end_slot < expected_end_slot {
165 return Err(CoverageError::RangeNotReached {
166 actual_end_slot,
167 expected_end_slot,
168 });
169 }
170
171 Ok(())
172 }
173}
174
175#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
177pub enum CoverageError {
178 #[error("ended before completion")]
179 NotCompleted,
180 #[error("completed without slot notifications")]
181 NoSlotsObserved,
182 #[error("completed at slot {actual_end_slot} but expected at least {expected_end_slot}")]
183 RangeNotReached {
184 actual_end_slot: u64,
185 expected_end_slot: u64,
186 },
187}
188
189pub struct BacktestSession {
194 ws: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
195 session_id: Option<String>,
196 rpc_endpoint: Option<String>,
197 task_id: Option<String>,
198 rpc: Option<RpcClient>,
199 last_sequence: Option<u64>,
200 pub(crate) ready_for_continue: bool,
201 request_timeout: Option<Duration>,
202 log_raw: bool,
203 backlog: VecDeque<(Option<u64>, BacktestResponse)>,
204}
205
206impl BacktestSession {
207 pub(crate) fn new(
208 ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
209 request_timeout: Option<Duration>,
210 log_raw: bool,
211 ) -> Self {
212 Self {
213 ws: Some(ws),
214 session_id: None,
215 rpc_endpoint: None,
216 task_id: None,
217 rpc: None,
218 last_sequence: None,
219 ready_for_continue: false,
220 request_timeout,
221 log_raw,
222 backlog: VecDeque::new(),
223 }
224 }
225
226 pub fn session_id(&self) -> Option<&str> {
228 self.session_id.as_deref()
229 }
230
231 pub fn rpc_endpoint(&self) -> Option<&str> {
233 self.rpc_endpoint.as_deref()
234 }
235
236 pub fn task_id(&self) -> Option<&str> {
239 self.task_id.as_deref()
240 }
241
242 pub fn last_sequence(&self) -> Option<u64> {
244 self.last_sequence
245 }
246
247 pub fn rpc(&self) -> &RpcClient {
251 self.rpc
252 .as_ref()
253 .expect("rpc is set during session creation")
254 }
255
256 pub fn is_ready_for_continue(&self) -> bool {
258 self.ready_for_continue
259 }
260
261 pub fn apply_response(&mut self, response: &BacktestResponse) {
263 match response {
264 BacktestResponse::ReadyForContinue | BacktestResponse::Paused(_) => {
265 self.ready_for_continue = true;
266 }
267 BacktestResponse::Completed { .. } => {
268 self.ready_for_continue = false;
269 }
270 _ => {}
271 }
272 }
273
274 fn ws_mut(&mut self) -> BacktestClientResult<&mut WebSocketStream<MaybeTlsStream<TcpStream>>> {
275 self.ws.as_mut().ok_or_else(|| BacktestClientError::Closed {
276 reason: "websocket closed".to_string(),
277 })
278 }
279
280 pub(crate) async fn create_with_request(
281 &mut self,
282 request: CreateBacktestSessionRequest,
283 rpc_base_url: String,
284 mut on_parallel_session_created: Option<&mut (dyn FnMut(String) + Send)>,
285 ) -> BacktestClientResult<CreateRequestResult> {
286 let expect_parallel = matches!(
287 &request,
288 CreateBacktestSessionRequest::V1(CreateBacktestSessionRequestV1 { parallel: true, .. })
289 );
290 self.send(&BacktestRequest::CreateBacktestSession(request), None)
291 .await?;
292 let mut streamed_parallel_session_ids = Vec::new();
293 let mut streamed_parallel_task_ids: Vec<Option<String>> = Vec::new();
294 let mut pending: Vec<(Option<u64>, BacktestResponse)> = Vec::new();
298
299 loop {
300 let response =
301 self.next_response(None)
302 .await?
303 .ok_or_else(|| BacktestClientError::Closed {
304 reason: "websocket ended before SessionCreated".to_string(),
305 })?;
306
307 match response {
308 BacktestResponse::SessionCreated {
309 session_id,
310 rpc_endpoint,
311 task_id,
312 } => {
313 if expect_parallel {
314 if let Some(callback) = on_parallel_session_created.as_mut() {
315 (**callback)(session_id.clone());
316 }
317 streamed_parallel_session_ids.push(session_id);
318 streamed_parallel_task_ids.push(task_id);
319 continue;
320 }
321 let created_session_id = session_id.clone();
322 let created_task_id = task_id.clone();
323 self.session_id = Some(session_id);
324 self.task_id = task_id;
325 let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
326 self.rpc = Some(RpcClient::new_with_commitment(
327 resolved.clone(),
328 CommitmentConfig::confirmed(),
329 ));
330 self.rpc_endpoint = Some(resolved);
331 self.backlog.extend(pending);
332 return Ok(CreateRequestResult::Single {
333 session_id: created_session_id,
334 task_id: created_task_id,
335 });
336 }
337 BacktestResponse::SessionsCreated { session_ids } => {
338 if expect_parallel && session_ids.is_empty() {
339 self.backlog.extend(pending);
340 return Ok(CreateRequestResult::Parallel {
341 session_ids: streamed_parallel_session_ids,
342 task_ids: streamed_parallel_task_ids,
343 });
344 }
345 let len = session_ids.len();
346 self.backlog.extend(pending);
347 return Ok(CreateRequestResult::Parallel {
348 session_ids,
349 task_ids: vec![None; len],
350 });
351 }
352 BacktestResponse::SessionsCreatedV2 {
353 session_ids,
354 task_ids,
355 ..
356 } => {
357 if expect_parallel && session_ids.is_empty() {
358 self.backlog.extend(pending);
359 return Ok(CreateRequestResult::Parallel {
360 session_ids: streamed_parallel_session_ids,
361 task_ids: streamed_parallel_task_ids,
362 });
363 }
364 let task_ids = align_task_ids(task_ids, session_ids.len());
365 self.backlog.extend(pending);
366 return Ok(CreateRequestResult::Parallel {
367 session_ids,
368 task_ids,
369 });
370 }
371 BacktestResponse::ReadyForContinue => {
372 self.ready_for_continue = true;
373 }
374 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
375 other => {
376 pending.push((self.last_sequence, other));
377 }
378 }
379 }
380 }
381
382 pub(crate) async fn attach(
383 &mut self,
384 session_id: String,
385 last_sequence: Option<u64>,
386 rpc_base_url: String,
387 ) -> BacktestClientResult<()> {
388 self.send(
389 &BacktestRequest::AttachBacktestSession {
390 session_id,
391 last_sequence,
392 },
393 None,
394 )
395 .await?;
396
397 self.wait_for_response(
398 || BacktestClientError::Closed {
399 reason: "websocket ended before SessionAttached".to_string(),
400 },
401 move |session, response| match response {
402 BacktestResponse::SessionAttached {
403 session_id,
404 rpc_endpoint,
405 task_id,
406 } => {
407 session.session_id = Some(session_id);
408 session.task_id = task_id;
409 let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
410 session.rpc = Some(RpcClient::new_with_commitment(
411 resolved.clone(),
412 CommitmentConfig::confirmed(),
413 ));
414 session.rpc_endpoint = Some(resolved);
415 Ok(Some(()))
416 }
417 BacktestResponse::ReadyForContinue => {
418 session.ready_for_continue = true;
419 Ok(None)
420 }
421 BacktestResponse::Error(err) => Err(BacktestClientError::Remote(err)),
422 other => {
423 session.backlog.push_back((session.last_sequence, other));
424 Ok(None)
425 }
426 },
427 )
428 .await
429 }
430
431 pub async fn resume_attached_session(&mut self) -> BacktestClientResult<()> {
434 self.send(&BacktestRequest::ResumeAttachedSession, None)
435 .await?;
436
437 self.wait_for_response(
438 || BacktestClientError::Closed {
439 reason: "websocket ended before ResumeAttachedSession acknowledgement".to_string(),
440 },
441 |session, response| match response {
442 BacktestResponse::Success => Ok(Some(())),
443 BacktestResponse::Error(err) => Err(BacktestClientError::Remote(err)),
444 other => {
445 session.backlog.push_back((session.last_sequence, other));
446 Ok(None)
447 }
448 },
449 )
450 .await
451 }
452
453 async fn wait_for_response<T, E, F>(
454 &mut self,
455 closed_error: E,
456 mut handle_response: F,
457 ) -> BacktestClientResult<T>
458 where
459 E: FnOnce() -> BacktestClientError,
460 F: FnMut(&mut Self, BacktestResponse) -> BacktestClientResult<Option<T>>,
461 {
462 let mut closed_error = Some(closed_error);
463
464 loop {
465 let response = self
466 .next_response(None)
467 .await?
468 .ok_or_else(|| closed_error.take().expect("closed error set")())?;
469
470 if let Some(result) = handle_response(self, response)? {
471 return Ok(result);
472 }
473 }
474 }
475
476 pub async fn send(
478 &mut self,
479 request: &BacktestRequest,
480 timeout: Option<Duration>,
481 ) -> BacktestClientResult<()> {
482 let text = serde_json::to_string(request)
483 .map_err(|source| BacktestClientError::SerializeRequest { source })?;
484
485 let request_timeout = self.request_timeout;
486 let timeout = timeout.or(request_timeout);
487
488 let send_fut = self.ws_mut()?.send(Message::Text(text));
489 let send_result = match timeout {
490 Some(duration) => tokio::time::timeout(duration, send_fut)
491 .await
492 .map_err(|_| BacktestClientError::Timeout {
493 action: "sending",
494 duration,
495 })?,
496 None => send_fut.await,
497 };
498
499 send_result.map_err(|source| BacktestClientError::WebSocket {
500 action: "sending",
501 source: Box::new(source),
502 })?;
503
504 Ok(())
505 }
506
507 pub async fn next_response(
509 &mut self,
510 timeout: Option<Duration>,
511 ) -> BacktestClientResult<Option<BacktestResponse>> {
512 if let Some((sequence, response)) = self.backlog.pop_front() {
513 self.last_sequence = sequence.or(self.last_sequence);
514 return Ok(Some(response));
515 }
516
517 let text = match self.next_text(timeout).await? {
518 Some(text) => text,
519 None => return Ok(None),
520 };
521
522 let (sequence, response) = match serde_json::from_str::<SequencedResponse>(&text) {
523 Ok(sequenced) => (Some(sequenced.seq_id), sequenced.response),
524 Err(_) => {
525 let response =
526 serde_json::from_str::<BacktestResponse>(&text).map_err(|source| {
527 BacktestClientError::DeserializeResponse {
528 raw: text.clone(),
529 source,
530 }
531 })?;
532 (None, response)
533 }
534 };
535 self.last_sequence = sequence.or(self.last_sequence);
536
537 Ok(Some(response))
538 }
539
540 pub async fn next_event(
542 &mut self,
543 timeout: Option<Duration>,
544 ) -> BacktestClientResult<Option<BacktestResponse>> {
545 let response = self.next_response(timeout).await?;
546 if let Some(ref response) = response {
547 self.apply_response(response);
548 }
549 Ok(response)
550 }
551
552 pub fn responses(
556 self,
557 timeout: Option<Duration>,
558 ) -> impl futures::Stream<Item = BacktestClientResult<BacktestResponse>> {
559 stream::unfold(Some(self), move |state| async move {
560 let mut session = match state {
561 Some(session) => session,
562 None => return None,
563 };
564
565 match session.next_response(timeout).await {
566 Ok(Some(response)) => {
567 session.apply_response(&response);
568 Some((Ok(response), Some(session)))
569 }
570 Ok(None) => None,
571 Err(err) => Some((Err(err), None)),
572 }
573 })
574 }
575
576 pub async fn ensure_ready(
578 &mut self,
579 timeout: Option<Duration>,
580 ) -> BacktestClientResult<ReadyOutcome> {
581 if self.ready_for_continue {
582 return Ok(ReadyOutcome::Ready);
583 }
584
585 loop {
586 let response =
587 self.next_response(timeout)
588 .await?
589 .ok_or_else(|| BacktestClientError::Closed {
590 reason: "websocket ended while waiting for ReadyForContinue".to_string(),
591 })?;
592
593 match response {
594 BacktestResponse::ReadyForContinue => {
595 self.ready_for_continue = true;
596 return Ok(ReadyOutcome::Ready);
597 }
598 BacktestResponse::Completed { .. } => return Ok(ReadyOutcome::Completed),
599 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
600 _ => {}
601 }
602 }
603 }
604
605 pub async fn wait_for_status(
607 &mut self,
608 desired: BacktestStatus,
609 timeout: Option<Duration>,
610 ) -> BacktestClientResult<()> {
611 let desired = std::mem::discriminant(&desired);
612
613 loop {
614 let response =
615 self.next_response(timeout)
616 .await?
617 .ok_or_else(|| BacktestClientError::Closed {
618 reason: "websocket ended while waiting for status".to_string(),
619 })?;
620
621 match response {
622 BacktestResponse::Status { status }
623 if std::mem::discriminant(&status) == desired =>
624 {
625 return Ok(());
626 }
627 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
628 BacktestResponse::Completed {
629 summary,
630 agent_stats,
631 } => {
632 return Err(BacktestClientError::UnexpectedResponse {
633 context: "waiting for status",
634 response: Box::new(BacktestResponse::Completed {
635 summary,
636 agent_stats,
637 }),
638 });
639 }
640 _ => {}
641 }
642 }
643 }
644
645 pub(crate) fn push_backlog(&mut self, response: BacktestResponse) {
647 self.backlog.push_back((None, response));
648 }
649
650 pub async fn send_continue(
652 &mut self,
653 params: ContinueParams,
654 timeout: Option<Duration>,
655 ) -> BacktestClientResult<()> {
656 self.ready_for_continue = false;
657 self.send(&BacktestRequest::Continue(params), timeout).await
658 }
659
660 pub async fn advance_step<F>(
662 &mut self,
663 state: &mut AdvanceState,
664 wait_for_slots: bool,
665 timeout: Option<Duration>,
666 on_event: &mut F,
667 ) -> BacktestClientResult<()>
668 where
669 F: FnMut(&BacktestResponse),
670 {
671 let Some(response) = self.next_response(timeout).await? else {
672 return Err(BacktestClientError::Closed {
673 reason: "websocket ended while awaiting continue responses".to_string(),
674 });
675 };
676
677 if self.log_raw {
678 tracing::debug!("<- {response:?}");
679 }
680
681 on_event(&response);
682
683 match response {
684 BacktestResponse::ReadyForContinue => {
685 self.ready_for_continue = true;
686 state.ready_for_continue = true;
687 }
688 BacktestResponse::SlotNotification(slot) => {
689 state.slot_notifications += 1;
690 state.last_slot = Some(slot);
691 }
692 BacktestResponse::Status { status } => {
693 state.statuses.push(status);
694 }
695 BacktestResponse::Success => {}
696 BacktestResponse::Completed {
697 summary,
698 agent_stats,
699 } => {
700 state.completed = true;
701 state.summary = summary;
702 state.agent_stats = agent_stats;
703 }
704 BacktestResponse::Error(err @ BacktestError::SimulationError { .. }) => {
705 tracing::warn!(error = %crate::error::err_chain(&err), "simulation error");
706 }
707 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
708 BacktestResponse::SessionCreated { .. }
709 | BacktestResponse::SessionAttached { .. }
710 | BacktestResponse::SessionsCreated { .. }
711 | BacktestResponse::SessionsCreatedV2 { .. }
712 | BacktestResponse::ParallelSessionAttachedV2 { .. }
713 | BacktestResponse::SessionEventV1 { .. }
714 | BacktestResponse::SessionEventV2 { .. }
715 | BacktestResponse::Paused(_)
716 | BacktestResponse::DiscoveryBatch(_) => {
717 return Err(BacktestClientError::UnexpectedResponse {
718 context: "continuing",
719 response: Box::new(response),
720 });
721 }
722 }
723
724 if wait_for_slots && state.slot_notifications > state.expected_slots {
725 tracing::warn!(
726 "received {} slot notifications (expected {})",
727 state.slot_notifications,
728 state.expected_slots
729 );
730 }
731
732 Ok(())
733 }
734
735 pub async fn continue_until_ready<F>(
737 &mut self,
738 cont: Continue,
739 timeout: Option<Duration>,
740 mut on_event: F,
741 ) -> BacktestClientResult<ContinueResult>
742 where
743 F: FnMut(&BacktestResponse),
744 {
745 let expected_slots = cont.advance_count;
746 self.advance_internal(
747 cont.into_params(),
748 expected_slots,
749 false,
750 timeout,
751 &mut on_event,
752 )
753 .await
754 }
755
756 pub async fn advance<F>(
758 &mut self,
759 cont: Continue,
760 timeout: Option<Duration>,
761 mut on_event: F,
762 ) -> BacktestClientResult<ContinueResult>
763 where
764 F: FnMut(&BacktestResponse),
765 {
766 let expected_slots = cont.advance_count;
767 self.advance_internal(
768 cont.into_params(),
769 expected_slots,
770 true,
771 timeout,
772 &mut on_event,
773 )
774 .await
775 }
776
777 async fn advance_internal<F>(
778 &mut self,
779 params: ContinueParams,
780 expected_slots: u64,
781 wait_for_slots: bool,
782 timeout: Option<Duration>,
783 on_event: &mut F,
784 ) -> BacktestClientResult<ContinueResult>
785 where
786 F: FnMut(&BacktestResponse),
787 {
788 self.send_continue(params, timeout).await?;
789
790 let mut state = AdvanceState::new(expected_slots);
791 while !state.is_done(wait_for_slots) {
792 self.advance_step(&mut state, wait_for_slots, timeout, on_event)
793 .await?;
794 }
795
796 Ok(ContinueResult {
797 slot_notifications: state.slot_notifications,
798 last_slot: state.last_slot,
799 statuses: state.statuses,
800 ready_for_continue: state.ready_for_continue,
801 completed: state.completed,
802 })
803 }
804
805 pub async fn modify_program(
814 &self,
815 program_id: &str,
816 elf: &[u8],
817 ) -> Result<BTreeMap<Address, AccountData>, ProgramModError> {
818 let rpc = self.rpc.as_ref().ok_or(ProgramModError::NoRpcEndpoint)?;
819 crate::injection::modify_program_via_rpc(rpc, program_id, elf).await
820 }
821
822 pub async fn modify_accounts(
826 &self,
827 modifications: &AccountModifications,
828 ) -> BacktestClientResult<usize> {
829 let rpc_endpoint =
830 self.rpc_endpoint
831 .as_deref()
832 .ok_or_else(|| BacktestClientError::Closed {
833 reason: "no RPC endpoint available".to_string(),
834 })?;
835
836 crate::rpc::modify_accounts(rpc_endpoint, modifications).await
837 }
838
839 pub async fn subscribe_program_logs<F, Fut>(
845 &self,
846 program_id: &str,
847 commitment: CommitmentConfig,
848 on_notification: F,
849 ) -> Result<LogSubscriptionHandle, SubscriptionError>
850 where
851 F: Fn(Response<RpcLogsResponse>) -> Fut + Send + Sync + 'static,
852 Fut: Future<Output = ()> + Send + 'static,
853 {
854 let rpc_endpoint = self
855 .rpc_endpoint
856 .as_deref()
857 .ok_or(SubscriptionError::NoRpcEndpoint)?;
858 crate::subscriptions::subscribe_program_logs(
859 rpc_endpoint,
860 program_id,
861 commitment,
862 on_notification,
863 )
864 .await
865 }
866
867 pub async fn subscribe_account_diffs<F, Fut>(
873 &self,
874 account: &str,
875 on_notification: F,
876 ) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
877 where
878 F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
879 Fut: Future<Output = ()> + Send + 'static,
880 {
881 let rpc_endpoint = self
882 .rpc_endpoint
883 .as_deref()
884 .ok_or(SubscriptionError::NoRpcEndpoint)?;
885 crate::subscriptions::subscribe_account_diffs(rpc_endpoint, account, on_notification).await
886 }
887
888 pub async fn close(&mut self, timeout: Option<Duration>) -> BacktestClientResult<()> {
892 self.close_with_frame(timeout, None).await
893 }
894
895 pub async fn close_with_frame(
897 &mut self,
898 timeout: Option<Duration>,
899 frame: Option<CloseFrame<'static>>,
900 ) -> BacktestClientResult<()> {
901 if self.ws.is_none() {
902 return Ok(());
903 }
904
905 let mut sent = false;
906 match self
907 .send(&BacktestRequest::CloseBacktestSession, timeout)
908 .await
909 {
910 Ok(()) => sent = true,
911 Err(err) if is_close_ok(&err) => {}
912 Err(err) => return Err(err),
913 }
914
915 if sent {
916 let response = match self.next_response(timeout).await {
917 Ok(Some(r)) => r,
918 Ok(None) => {
919 self.ws.take();
920 return Ok(());
921 }
922 Err(BacktestClientError::Closed { .. }) => {
923 self.ws.take();
924 return Ok(());
925 }
926 Err(BacktestClientError::WebSocket {
927 action: "receiving",
928 source,
929 }) if is_reset_without_close(&source) => {
930 self.ws.take();
931 return Ok(());
932 }
933 Err(err) => return Err(err),
934 };
935
936 match response {
937 BacktestResponse::Success | BacktestResponse::Completed { .. } => {}
938 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
939 other => {
940 return Err(BacktestClientError::UnexpectedResponse {
941 context: "closing session",
942 response: Box::new(other),
943 });
944 }
945 }
946 }
947
948 if let Some(ws) = self.ws.as_mut() {
949 let close_result = ws.close(frame).await;
950 if let Err(source) = close_result
951 && !is_ws_closed_error(&source)
952 {
953 return Err(BacktestClientError::WebSocket {
954 action: "closing",
955 source: Box::new(source),
956 });
957 }
958 }
959
960 match self.next_response(timeout).await {
962 Ok(_) => {}
963 Err(BacktestClientError::Closed { .. }) => {}
964 Err(BacktestClientError::WebSocket {
965 action: "receiving",
966 source,
967 }) if is_reset_without_close(&source) => {}
968 Err(err) => return Err(err),
969 }
970
971 tokio::time::sleep(Duration::from_millis(100)).await;
973 self.ws.take();
974 Ok(())
975 }
976
977 pub async fn close_with_reason(
979 &mut self,
980 timeout: Option<Duration>,
981 code: CloseCode,
982 reason: impl Into<String>,
983 ) -> BacktestClientResult<()> {
984 let frame = CloseFrame {
985 code,
986 reason: Cow::Owned(reason.into()),
987 };
988 self.close_with_frame(timeout, Some(frame)).await
989 }
990
991 async fn next_text(
992 &mut self,
993 timeout: Option<Duration>,
994 ) -> BacktestClientResult<Option<String>> {
995 loop {
996 let request_timeout = self.request_timeout;
997 let timeout = timeout.or(request_timeout);
998
999 let next_fut = self.ws_mut()?.next();
1000 let msg = match timeout {
1001 Some(duration) => tokio::time::timeout(duration, next_fut)
1002 .await
1003 .map_err(|_| BacktestClientError::Timeout {
1004 action: "receiving",
1005 duration,
1006 })?,
1007 None => next_fut.await,
1008 };
1009
1010 let Some(msg) = msg else {
1011 return Ok(None);
1012 };
1013
1014 let msg = match msg {
1015 Ok(msg) => msg,
1016 Err(source) => {
1017 return Err(BacktestClientError::WebSocket {
1018 action: "receiving",
1019 source: Box::new(source),
1020 });
1021 }
1022 };
1023
1024 match msg {
1025 Message::Text(text) => {
1026 if self.log_raw {
1027 tracing::debug!("<- raw: {text}");
1028 }
1029 return Ok(Some(text));
1030 }
1031 Message::Binary(bin) => match String::from_utf8(bin) {
1032 Ok(text) => {
1033 if self.log_raw {
1034 tracing::debug!("<- raw(bin): {text}");
1035 }
1036 return Ok(Some(text));
1037 }
1038 Err(err) => {
1039 tracing::warn!("discarding non-utf8 binary message: {err}");
1040 continue;
1041 }
1042 },
1043 Message::Close(frame) => {
1044 let reason = close_reason(frame);
1045 return Err(BacktestClientError::Closed { reason });
1046 }
1047 Message::Ping(_) | Message::Pong(_) => continue,
1048 Message::Frame(_) => continue,
1049 }
1050 }
1051 }
1052}
1053
1054impl std::fmt::Debug for BacktestSession {
1055 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1056 f.debug_struct("BacktestSession")
1057 .field("session_id", &self.session_id)
1058 .field("rpc_endpoint", &self.rpc_endpoint)
1059 .field(
1060 "rpc",
1061 &self
1062 .rpc
1063 .as_ref()
1064 .map(|_| "<RpcClient>")
1065 .unwrap_or("<not set>"),
1066 )
1067 .field("ready_for_continue", &self.ready_for_continue)
1068 .field("request_timeout", &self.request_timeout)
1069 .finish_non_exhaustive()
1070 }
1071}
1072
1073#[derive(Debug)]
1074pub(crate) enum CreateRequestResult {
1075 Single {
1076 session_id: String,
1077 task_id: Option<String>,
1078 },
1079 Parallel {
1080 session_ids: Vec<String>,
1081 task_ids: Vec<Option<String>>,
1082 },
1083}
1084
1085fn align_task_ids(mut task_ids: Vec<Option<String>>, expected_len: usize) -> Vec<Option<String>> {
1088 if task_ids.len() < expected_len {
1089 task_ids.resize(expected_len, None);
1090 } else if task_ids.len() > expected_len {
1091 task_ids.truncate(expected_len);
1092 }
1093 task_ids
1094}
1095
1096impl Drop for BacktestSession {
1097 fn drop(&mut self) {
1098 let Some(ws) = self.ws.take() else {
1099 return;
1100 };
1101
1102 if let Ok(handle) = tokio::runtime::Handle::try_current() {
1103 handle.spawn(async move {
1104 let mut ws = ws;
1105 let _ = ws.close(None).await;
1106 });
1107 }
1108 }
1109}
1110
1111fn resolve_rpc_url(base: &str, endpoint: &str) -> String {
1112 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
1113 endpoint.to_string()
1114 } else {
1115 format!("{}/{}", base, endpoint.trim_start_matches('/'))
1116 }
1117}
1118
1119fn close_reason(frame: Option<CloseFrame<'static>>) -> String {
1120 match frame {
1121 Some(frame) => format!("{:?}: {}", frame.code, frame.reason),
1122 None => "no close frame".to_string(),
1123 }
1124}
1125
1126fn is_reset_without_close(err: &WsError) -> bool {
1127 matches!(
1128 err,
1129 WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1130 )
1131}
1132
1133fn is_ws_closed_error(err: &WsError) -> bool {
1134 matches!(
1135 err,
1136 WsError::ConnectionClosed
1137 | WsError::AlreadyClosed
1138 | WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1139 )
1140}
1141
1142fn is_close_ok(err: &BacktestClientError) -> bool {
1143 match err {
1144 BacktestClientError::Closed { .. } => true,
1145 BacktestClientError::WebSocket { source, .. } => is_ws_closed_error(source),
1146 _ => false,
1147 }
1148}
1149
1150#[cfg(test)]
1151mod tests {
1152 use super::*;
1153
1154 #[test]
1155 fn coverage_tracks_slot_and_completion_from_responses() {
1156 let mut coverage = SessionCoverage::default();
1157 coverage.observe_response(&BacktestResponse::SlotNotification(10));
1158 coverage.observe_response(&BacktestResponse::SlotNotification(12));
1159 coverage.observe_response(&BacktestResponse::Completed {
1160 summary: None,
1161 agent_stats: None,
1162 });
1163
1164 assert!(coverage.is_completed());
1165 assert_eq!(coverage.highest_slot_seen(), Some(12));
1166 }
1167
1168 #[test]
1169 fn coverage_validate_end_slot_checks_completion_and_range() {
1170 let mut coverage = SessionCoverage::default();
1171 assert_eq!(
1172 coverage.validate_end_slot(5),
1173 Err(CoverageError::NotCompleted)
1174 );
1175
1176 coverage.mark_completed();
1177 assert_eq!(
1178 coverage.validate_end_slot(5),
1179 Err(CoverageError::NoSlotsObserved)
1180 );
1181
1182 coverage.observe_slot(4);
1183 assert_eq!(
1184 coverage.validate_end_slot(5),
1185 Err(CoverageError::RangeNotReached {
1186 actual_end_slot: 4,
1187 expected_end_slot: 5,
1188 })
1189 );
1190
1191 coverage.observe_slot(6);
1192 assert_eq!(coverage.validate_end_slot(5), Ok(()));
1193 }
1194}