1use std::{
2 borrow::Cow,
3 collections::{BTreeMap, VecDeque},
4 future::Future,
5 time::Duration,
6};
7
8use futures::{SinkExt, StreamExt, stream};
9use simulator_api::{
10 AccountData, AccountModifications, AgentStatsReport, BacktestError, BacktestRequest,
11 BacktestResponse, BacktestStatus, ContinueParams, CreateBacktestSessionRequest,
12 CreateBacktestSessionRequestV1, SequencedResponse, SessionSummary,
13};
14use solana_address::Address;
15use solana_client::{
16 nonblocking::rpc_client::RpcClient,
17 rpc_response::{Response, RpcLogsResponse},
18};
19use solana_commitment_config::CommitmentConfig;
20use thiserror::Error;
21use tokio::net::TcpStream;
22use tokio_tungstenite::{
23 MaybeTlsStream, WebSocketStream,
24 tungstenite::{
25 Error as WsError, Message,
26 error::ProtocolError,
27 protocol::{CloseFrame, frame::coding::CloseCode},
28 },
29};
30
31use crate::{
32 BacktestClientError, BacktestClientResult, Continue,
33 injection::ProgramModError,
34 subscriptions::{
35 AccountDiffNotification, AccountDiffSubscriptionHandle, LogSubscriptionHandle,
36 SubscriptionError,
37 },
38};
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum ReadyOutcome {
43 Ready,
45 Completed,
47}
48
49#[derive(Debug, Default)]
51pub struct ContinueResult {
52 pub slot_notifications: u64,
54 pub last_slot: Option<u64>,
56 pub statuses: Vec<BacktestStatus>,
58 pub ready_for_continue: bool,
60 pub completed: bool,
62}
63
64#[derive(Debug)]
66pub struct AdvanceState {
67 pub expected_slots: u64,
69 pub slot_notifications: u64,
71 pub last_slot: Option<u64>,
73 pub statuses: Vec<BacktestStatus>,
75 pub ready_for_continue: bool,
77 pub completed: bool,
79 pub summary: Option<SessionSummary>,
81 pub agent_stats: Option<Vec<AgentStatsReport>>,
83}
84
85impl AdvanceState {
86 pub fn new(expected_slots: u64) -> Self {
88 Self {
89 expected_slots,
90 slot_notifications: 0,
91 last_slot: None,
92 statuses: Vec::new(),
93 ready_for_continue: false,
94 completed: false,
95 summary: None,
96 agent_stats: None,
97 }
98 }
99
100 pub fn is_done(&self, wait_for_slots: bool) -> bool {
102 if self.completed {
103 return true;
104 }
105
106 if !self.ready_for_continue {
107 return false;
108 }
109
110 !wait_for_slots || self.slot_notifications >= self.expected_slots
111 }
112}
113
114#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
116pub struct SessionCoverage {
117 completed: bool,
118 highest_slot_seen: Option<u64>,
119}
120
121impl SessionCoverage {
122 pub fn observe_slot(&mut self, slot: u64) {
124 self.highest_slot_seen = Some(
125 self.highest_slot_seen
126 .map_or(slot, |current| current.max(slot)),
127 );
128 }
129
130 pub fn mark_completed(&mut self) {
132 self.completed = true;
133 }
134
135 pub fn observe_response(&mut self, response: &BacktestResponse) {
137 match response {
138 BacktestResponse::SlotNotification(slot) => self.observe_slot(*slot),
139 BacktestResponse::Completed { .. } => self.mark_completed(),
140 _ => {}
141 }
142 }
143
144 pub fn is_completed(&self) -> bool {
146 self.completed
147 }
148
149 pub fn highest_slot_seen(&self) -> Option<u64> {
151 self.highest_slot_seen
152 }
153
154 pub fn validate_end_slot(&self, expected_end_slot: u64) -> Result<(), CoverageError> {
156 if !self.completed {
157 return Err(CoverageError::NotCompleted);
158 }
159
160 let Some(actual_end_slot) = self.highest_slot_seen else {
161 return Err(CoverageError::NoSlotsObserved);
162 };
163
164 if actual_end_slot < expected_end_slot {
165 return Err(CoverageError::RangeNotReached {
166 actual_end_slot,
167 expected_end_slot,
168 });
169 }
170
171 Ok(())
172 }
173}
174
175#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
177pub enum CoverageError {
178 #[error("ended before completion")]
179 NotCompleted,
180 #[error("completed without slot notifications")]
181 NoSlotsObserved,
182 #[error("completed at slot {actual_end_slot} but expected at least {expected_end_slot}")]
183 RangeNotReached {
184 actual_end_slot: u64,
185 expected_end_slot: u64,
186 },
187}
188
189pub struct BacktestSession {
194 ws: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
195 session_id: Option<String>,
196 rpc_endpoint: Option<String>,
197 rpc: Option<RpcClient>,
198 last_sequence: Option<u64>,
199 pub(crate) ready_for_continue: bool,
200 request_timeout: Option<Duration>,
201 log_raw: bool,
202 backlog: VecDeque<(Option<u64>, BacktestResponse)>,
203}
204
205impl BacktestSession {
206 pub(crate) fn new(
207 ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
208 request_timeout: Option<Duration>,
209 log_raw: bool,
210 ) -> Self {
211 Self {
212 ws: Some(ws),
213 session_id: None,
214 rpc_endpoint: None,
215 rpc: None,
216 last_sequence: None,
217 ready_for_continue: false,
218 request_timeout,
219 log_raw,
220 backlog: VecDeque::new(),
221 }
222 }
223
224 pub fn session_id(&self) -> Option<&str> {
226 self.session_id.as_deref()
227 }
228
229 pub fn rpc_endpoint(&self) -> Option<&str> {
231 self.rpc_endpoint.as_deref()
232 }
233
234 pub fn last_sequence(&self) -> Option<u64> {
236 self.last_sequence
237 }
238
239 pub fn rpc(&self) -> &RpcClient {
243 self.rpc
244 .as_ref()
245 .expect("rpc is set during session creation")
246 }
247
248 pub fn is_ready_for_continue(&self) -> bool {
250 self.ready_for_continue
251 }
252
253 pub fn apply_response(&mut self, response: &BacktestResponse) {
255 match response {
256 BacktestResponse::ReadyForContinue | BacktestResponse::Paused(_) => {
257 self.ready_for_continue = true;
258 }
259 BacktestResponse::Completed { .. } => {
260 self.ready_for_continue = false;
261 }
262 _ => {}
263 }
264 }
265
266 fn ws_mut(&mut self) -> BacktestClientResult<&mut WebSocketStream<MaybeTlsStream<TcpStream>>> {
267 self.ws.as_mut().ok_or_else(|| BacktestClientError::Closed {
268 reason: "websocket closed".to_string(),
269 })
270 }
271
272 pub(crate) async fn create_with_request(
273 &mut self,
274 request: CreateBacktestSessionRequest,
275 rpc_base_url: String,
276 mut on_parallel_session_created: Option<&mut (dyn FnMut(String) + Send)>,
277 ) -> BacktestClientResult<CreateRequestResult> {
278 let expect_parallel = matches!(
279 &request,
280 CreateBacktestSessionRequest::V1(CreateBacktestSessionRequestV1 { parallel: true, .. })
281 );
282 self.send(&BacktestRequest::CreateBacktestSession(request), None)
283 .await?;
284 let mut streamed_parallel_session_ids = Vec::new();
285
286 loop {
287 let response =
288 self.next_response(None)
289 .await?
290 .ok_or_else(|| BacktestClientError::Closed {
291 reason: "websocket ended before SessionCreated".to_string(),
292 })?;
293
294 match response {
295 BacktestResponse::SessionCreated {
296 session_id,
297 rpc_endpoint,
298 } => {
299 if expect_parallel {
300 if let Some(callback) = on_parallel_session_created.as_mut() {
301 (**callback)(session_id.clone());
302 }
303 streamed_parallel_session_ids.push(session_id);
304 continue;
305 }
306 let created_session_id = session_id.clone();
307 self.session_id = Some(session_id);
308 let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
309 self.rpc = Some(RpcClient::new_with_commitment(
310 resolved.clone(),
311 CommitmentConfig::confirmed(),
312 ));
313 self.rpc_endpoint = Some(resolved);
314 return Ok(CreateRequestResult::Single {
315 session_id: created_session_id,
316 });
317 }
318 BacktestResponse::SessionsCreated { session_ids } => {
319 if expect_parallel && session_ids.is_empty() {
320 return Ok(CreateRequestResult::Parallel {
321 session_ids: streamed_parallel_session_ids,
322 });
323 }
324 return Ok(CreateRequestResult::Parallel { session_ids });
325 }
326 BacktestResponse::SessionsCreatedV2 { session_ids, .. } => {
327 if expect_parallel && session_ids.is_empty() {
328 return Ok(CreateRequestResult::Parallel {
329 session_ids: streamed_parallel_session_ids,
330 });
331 }
332 return Ok(CreateRequestResult::Parallel { session_ids });
333 }
334 BacktestResponse::ReadyForContinue => {
335 self.ready_for_continue = true;
336 }
337 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
338 other => {
339 self.backlog.push_back((self.last_sequence, other));
340 }
341 }
342 }
343 }
344
345 pub(crate) async fn attach(
346 &mut self,
347 session_id: String,
348 last_sequence: Option<u64>,
349 rpc_base_url: String,
350 ) -> BacktestClientResult<()> {
351 self.send(
352 &BacktestRequest::AttachBacktestSession {
353 session_id,
354 last_sequence,
355 },
356 None,
357 )
358 .await?;
359
360 self.wait_for_response(
361 || BacktestClientError::Closed {
362 reason: "websocket ended before SessionAttached".to_string(),
363 },
364 move |session, response| match response {
365 BacktestResponse::SessionAttached {
366 session_id,
367 rpc_endpoint,
368 } => {
369 session.session_id = Some(session_id);
370 let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
371 session.rpc = Some(RpcClient::new_with_commitment(
372 resolved.clone(),
373 CommitmentConfig::confirmed(),
374 ));
375 session.rpc_endpoint = Some(resolved);
376 Ok(Some(()))
377 }
378 BacktestResponse::ReadyForContinue => {
379 session.ready_for_continue = true;
380 Ok(None)
381 }
382 BacktestResponse::Error(err) => Err(BacktestClientError::Remote(err)),
383 other => {
384 session.backlog.push_back((session.last_sequence, other));
385 Ok(None)
386 }
387 },
388 )
389 .await
390 }
391
392 pub async fn resume_attached_session(&mut self) -> BacktestClientResult<()> {
395 self.send(&BacktestRequest::ResumeAttachedSession, None)
396 .await?;
397
398 self.wait_for_response(
399 || BacktestClientError::Closed {
400 reason: "websocket ended before ResumeAttachedSession acknowledgement".to_string(),
401 },
402 |session, response| match response {
403 BacktestResponse::Success => Ok(Some(())),
404 BacktestResponse::Error(err) => Err(BacktestClientError::Remote(err)),
405 other => {
406 session.backlog.push_back((session.last_sequence, other));
407 Ok(None)
408 }
409 },
410 )
411 .await
412 }
413
414 async fn wait_for_response<T, E, F>(
415 &mut self,
416 closed_error: E,
417 mut handle_response: F,
418 ) -> BacktestClientResult<T>
419 where
420 E: FnOnce() -> BacktestClientError,
421 F: FnMut(&mut Self, BacktestResponse) -> BacktestClientResult<Option<T>>,
422 {
423 let mut closed_error = Some(closed_error);
424
425 loop {
426 let response = self
427 .next_response(None)
428 .await?
429 .ok_or_else(|| closed_error.take().expect("closed error set")())?;
430
431 if let Some(result) = handle_response(self, response)? {
432 return Ok(result);
433 }
434 }
435 }
436
437 pub async fn send(
439 &mut self,
440 request: &BacktestRequest,
441 timeout: Option<Duration>,
442 ) -> BacktestClientResult<()> {
443 let text = serde_json::to_string(request)
444 .map_err(|source| BacktestClientError::SerializeRequest { source })?;
445
446 let request_timeout = self.request_timeout;
447 let timeout = timeout.or(request_timeout);
448
449 let send_fut = self.ws_mut()?.send(Message::Text(text));
450 let send_result = match timeout {
451 Some(duration) => tokio::time::timeout(duration, send_fut)
452 .await
453 .map_err(|_| BacktestClientError::Timeout {
454 action: "sending",
455 duration,
456 })?,
457 None => send_fut.await,
458 };
459
460 send_result.map_err(|source| BacktestClientError::WebSocket {
461 action: "sending",
462 source: Box::new(source),
463 })?;
464
465 Ok(())
466 }
467
468 pub async fn next_response(
470 &mut self,
471 timeout: Option<Duration>,
472 ) -> BacktestClientResult<Option<BacktestResponse>> {
473 if let Some((sequence, response)) = self.backlog.pop_front() {
474 self.last_sequence = sequence.or(self.last_sequence);
475 return Ok(Some(response));
476 }
477
478 let text = match self.next_text(timeout).await? {
479 Some(text) => text,
480 None => return Ok(None),
481 };
482
483 let (sequence, response) = match serde_json::from_str::<SequencedResponse>(&text) {
484 Ok(sequenced) => (Some(sequenced.seq_id), sequenced.response),
485 Err(_) => {
486 let response =
487 serde_json::from_str::<BacktestResponse>(&text).map_err(|source| {
488 BacktestClientError::DeserializeResponse {
489 raw: text.clone(),
490 source,
491 }
492 })?;
493 (None, response)
494 }
495 };
496 self.last_sequence = sequence.or(self.last_sequence);
497
498 Ok(Some(response))
499 }
500
501 pub async fn next_event(
503 &mut self,
504 timeout: Option<Duration>,
505 ) -> BacktestClientResult<Option<BacktestResponse>> {
506 let response = self.next_response(timeout).await?;
507 if let Some(ref response) = response {
508 self.apply_response(response);
509 }
510 Ok(response)
511 }
512
513 pub fn responses(
517 self,
518 timeout: Option<Duration>,
519 ) -> impl futures::Stream<Item = BacktestClientResult<BacktestResponse>> {
520 stream::unfold(Some(self), move |state| async move {
521 let mut session = match state {
522 Some(session) => session,
523 None => return None,
524 };
525
526 match session.next_response(timeout).await {
527 Ok(Some(response)) => {
528 session.apply_response(&response);
529 Some((Ok(response), Some(session)))
530 }
531 Ok(None) => None,
532 Err(err) => Some((Err(err), None)),
533 }
534 })
535 }
536
537 pub async fn ensure_ready(
539 &mut self,
540 timeout: Option<Duration>,
541 ) -> BacktestClientResult<ReadyOutcome> {
542 if self.ready_for_continue {
543 return Ok(ReadyOutcome::Ready);
544 }
545
546 loop {
547 let response =
548 self.next_response(timeout)
549 .await?
550 .ok_or_else(|| BacktestClientError::Closed {
551 reason: "websocket ended while waiting for ReadyForContinue".to_string(),
552 })?;
553
554 match response {
555 BacktestResponse::ReadyForContinue => {
556 self.ready_for_continue = true;
557 return Ok(ReadyOutcome::Ready);
558 }
559 BacktestResponse::Completed { .. } => return Ok(ReadyOutcome::Completed),
560 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
561 _ => {}
562 }
563 }
564 }
565
566 pub async fn wait_for_status(
568 &mut self,
569 desired: BacktestStatus,
570 timeout: Option<Duration>,
571 ) -> BacktestClientResult<()> {
572 let desired = std::mem::discriminant(&desired);
573
574 loop {
575 let response =
576 self.next_response(timeout)
577 .await?
578 .ok_or_else(|| BacktestClientError::Closed {
579 reason: "websocket ended while waiting for status".to_string(),
580 })?;
581
582 match response {
583 BacktestResponse::Status { status }
584 if std::mem::discriminant(&status) == desired =>
585 {
586 return Ok(());
587 }
588 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
589 BacktestResponse::Completed {
590 summary,
591 agent_stats,
592 } => {
593 return Err(BacktestClientError::UnexpectedResponse {
594 context: "waiting for status",
595 response: Box::new(BacktestResponse::Completed {
596 summary,
597 agent_stats,
598 }),
599 });
600 }
601 _ => {}
602 }
603 }
604 }
605
606 pub(crate) fn push_backlog(&mut self, response: BacktestResponse) {
608 self.backlog.push_back((None, response));
609 }
610
611 pub async fn send_continue(
613 &mut self,
614 params: ContinueParams,
615 timeout: Option<Duration>,
616 ) -> BacktestClientResult<()> {
617 self.ready_for_continue = false;
618 self.send(&BacktestRequest::Continue(params), timeout).await
619 }
620
621 pub async fn advance_step<F>(
623 &mut self,
624 state: &mut AdvanceState,
625 wait_for_slots: bool,
626 timeout: Option<Duration>,
627 on_event: &mut F,
628 ) -> BacktestClientResult<()>
629 where
630 F: FnMut(&BacktestResponse),
631 {
632 let Some(response) = self.next_response(timeout).await? else {
633 return Err(BacktestClientError::Closed {
634 reason: "websocket ended while awaiting continue responses".to_string(),
635 });
636 };
637
638 if self.log_raw {
639 tracing::debug!("<- {response:?}");
640 }
641
642 on_event(&response);
643
644 match response {
645 BacktestResponse::ReadyForContinue => {
646 self.ready_for_continue = true;
647 state.ready_for_continue = true;
648 }
649 BacktestResponse::SlotNotification(slot) => {
650 state.slot_notifications += 1;
651 state.last_slot = Some(slot);
652 }
653 BacktestResponse::Status { status } => {
654 state.statuses.push(status);
655 }
656 BacktestResponse::Success => {}
657 BacktestResponse::Completed {
658 summary,
659 agent_stats,
660 } => {
661 state.completed = true;
662 state.summary = summary;
663 state.agent_stats = agent_stats;
664 }
665 BacktestResponse::Error(err @ BacktestError::SimulationError { .. }) => {
666 tracing::warn!(error = %err, "simulation error");
667 }
668 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
669 BacktestResponse::SessionCreated { .. }
670 | BacktestResponse::SessionAttached { .. }
671 | BacktestResponse::SessionsCreated { .. }
672 | BacktestResponse::SessionsCreatedV2 { .. }
673 | BacktestResponse::ParallelSessionAttachedV2 { .. }
674 | BacktestResponse::SessionEventV1 { .. }
675 | BacktestResponse::SessionEventV2 { .. }
676 | BacktestResponse::Paused(_)
677 | BacktestResponse::DiscoveryBatch(_) => {
678 return Err(BacktestClientError::UnexpectedResponse {
679 context: "continuing",
680 response: Box::new(response),
681 });
682 }
683 }
684
685 if wait_for_slots && state.slot_notifications > state.expected_slots {
686 tracing::warn!(
687 "received {} slot notifications (expected {})",
688 state.slot_notifications,
689 state.expected_slots
690 );
691 }
692
693 Ok(())
694 }
695
696 pub async fn continue_until_ready<F>(
698 &mut self,
699 cont: Continue,
700 timeout: Option<Duration>,
701 mut on_event: F,
702 ) -> BacktestClientResult<ContinueResult>
703 where
704 F: FnMut(&BacktestResponse),
705 {
706 let expected_slots = cont.advance_count;
707 self.advance_internal(
708 cont.into_params(),
709 expected_slots,
710 false,
711 timeout,
712 &mut on_event,
713 )
714 .await
715 }
716
717 pub async fn advance<F>(
719 &mut self,
720 cont: Continue,
721 timeout: Option<Duration>,
722 mut on_event: F,
723 ) -> BacktestClientResult<ContinueResult>
724 where
725 F: FnMut(&BacktestResponse),
726 {
727 let expected_slots = cont.advance_count;
728 self.advance_internal(
729 cont.into_params(),
730 expected_slots,
731 true,
732 timeout,
733 &mut on_event,
734 )
735 .await
736 }
737
738 async fn advance_internal<F>(
739 &mut self,
740 params: ContinueParams,
741 expected_slots: u64,
742 wait_for_slots: bool,
743 timeout: Option<Duration>,
744 on_event: &mut F,
745 ) -> BacktestClientResult<ContinueResult>
746 where
747 F: FnMut(&BacktestResponse),
748 {
749 self.send_continue(params, timeout).await?;
750
751 let mut state = AdvanceState::new(expected_slots);
752 while !state.is_done(wait_for_slots) {
753 self.advance_step(&mut state, wait_for_slots, timeout, on_event)
754 .await?;
755 }
756
757 Ok(ContinueResult {
758 slot_notifications: state.slot_notifications,
759 last_slot: state.last_slot,
760 statuses: state.statuses,
761 ready_for_continue: state.ready_for_continue,
762 completed: state.completed,
763 })
764 }
765
766 pub async fn modify_program(
775 &self,
776 program_id: &str,
777 elf: &[u8],
778 ) -> Result<BTreeMap<Address, AccountData>, ProgramModError> {
779 let rpc = self.rpc.as_ref().ok_or(ProgramModError::NoRpcEndpoint)?;
780 crate::injection::modify_program_via_rpc(rpc, program_id, elf).await
781 }
782
783 pub async fn modify_accounts(
787 &self,
788 modifications: &AccountModifications,
789 ) -> BacktestClientResult<usize> {
790 let rpc_endpoint =
791 self.rpc_endpoint
792 .as_deref()
793 .ok_or_else(|| BacktestClientError::Closed {
794 reason: "no RPC endpoint available".to_string(),
795 })?;
796
797 crate::rpc::modify_accounts(rpc_endpoint, modifications).await
798 }
799
800 pub async fn subscribe_program_logs<F, Fut>(
806 &self,
807 program_id: &str,
808 commitment: CommitmentConfig,
809 on_notification: F,
810 ) -> Result<LogSubscriptionHandle, SubscriptionError>
811 where
812 F: Fn(Response<RpcLogsResponse>) -> Fut + Send + Sync + 'static,
813 Fut: Future<Output = ()> + Send + 'static,
814 {
815 let rpc_endpoint = self
816 .rpc_endpoint
817 .as_deref()
818 .ok_or(SubscriptionError::NoRpcEndpoint)?;
819 crate::subscriptions::subscribe_program_logs(
820 rpc_endpoint,
821 program_id,
822 commitment,
823 on_notification,
824 )
825 .await
826 }
827
828 pub async fn subscribe_account_diffs<F, Fut>(
834 &self,
835 account: &str,
836 on_notification: F,
837 ) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
838 where
839 F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
840 Fut: Future<Output = ()> + Send + 'static,
841 {
842 let rpc_endpoint = self
843 .rpc_endpoint
844 .as_deref()
845 .ok_or(SubscriptionError::NoRpcEndpoint)?;
846 crate::subscriptions::subscribe_account_diffs(rpc_endpoint, account, on_notification).await
847 }
848
849 pub async fn close(&mut self, timeout: Option<Duration>) -> BacktestClientResult<()> {
853 self.close_with_frame(timeout, None).await
854 }
855
856 pub async fn close_with_frame(
858 &mut self,
859 timeout: Option<Duration>,
860 frame: Option<CloseFrame<'static>>,
861 ) -> BacktestClientResult<()> {
862 if self.ws.is_none() {
863 return Ok(());
864 }
865
866 let mut sent = false;
867 match self
868 .send(&BacktestRequest::CloseBacktestSession, timeout)
869 .await
870 {
871 Ok(()) => sent = true,
872 Err(err) if is_close_ok(&err) => {}
873 Err(err) => return Err(err),
874 }
875
876 if sent {
877 let response = match self.next_response(timeout).await {
878 Ok(Some(r)) => r,
879 Ok(None) => {
880 self.ws.take();
881 return Ok(());
882 }
883 Err(BacktestClientError::Closed { .. }) => {
884 self.ws.take();
885 return Ok(());
886 }
887 Err(BacktestClientError::WebSocket {
888 action: "receiving",
889 source,
890 }) if is_reset_without_close(&source) => {
891 self.ws.take();
892 return Ok(());
893 }
894 Err(err) => return Err(err),
895 };
896
897 match response {
898 BacktestResponse::Success | BacktestResponse::Completed { .. } => {}
899 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
900 other => {
901 return Err(BacktestClientError::UnexpectedResponse {
902 context: "closing session",
903 response: Box::new(other),
904 });
905 }
906 }
907 }
908
909 if let Some(ws) = self.ws.as_mut() {
910 let close_result = ws.close(frame).await;
911 if let Err(source) = close_result
912 && !is_ws_closed_error(&source)
913 {
914 return Err(BacktestClientError::WebSocket {
915 action: "closing",
916 source: Box::new(source),
917 });
918 }
919 }
920
921 match self.next_response(timeout).await {
923 Ok(_) => {}
924 Err(BacktestClientError::Closed { .. }) => {}
925 Err(BacktestClientError::WebSocket {
926 action: "receiving",
927 source,
928 }) if is_reset_without_close(&source) => {}
929 Err(err) => return Err(err),
930 }
931
932 tokio::time::sleep(Duration::from_millis(100)).await;
934 self.ws.take();
935 Ok(())
936 }
937
938 pub async fn close_with_reason(
940 &mut self,
941 timeout: Option<Duration>,
942 code: CloseCode,
943 reason: impl Into<String>,
944 ) -> BacktestClientResult<()> {
945 let frame = CloseFrame {
946 code,
947 reason: Cow::Owned(reason.into()),
948 };
949 self.close_with_frame(timeout, Some(frame)).await
950 }
951
952 async fn next_text(
953 &mut self,
954 timeout: Option<Duration>,
955 ) -> BacktestClientResult<Option<String>> {
956 loop {
957 let request_timeout = self.request_timeout;
958 let timeout = timeout.or(request_timeout);
959
960 let next_fut = self.ws_mut()?.next();
961 let msg = match timeout {
962 Some(duration) => tokio::time::timeout(duration, next_fut)
963 .await
964 .map_err(|_| BacktestClientError::Timeout {
965 action: "receiving",
966 duration,
967 })?,
968 None => next_fut.await,
969 };
970
971 let Some(msg) = msg else {
972 return Ok(None);
973 };
974
975 let msg = match msg {
976 Ok(msg) => msg,
977 Err(source) => {
978 return Err(BacktestClientError::WebSocket {
979 action: "receiving",
980 source: Box::new(source),
981 });
982 }
983 };
984
985 match msg {
986 Message::Text(text) => {
987 if self.log_raw {
988 tracing::debug!("<- raw: {text}");
989 }
990 return Ok(Some(text));
991 }
992 Message::Binary(bin) => match String::from_utf8(bin) {
993 Ok(text) => {
994 if self.log_raw {
995 tracing::debug!("<- raw(bin): {text}");
996 }
997 return Ok(Some(text));
998 }
999 Err(err) => {
1000 tracing::warn!("discarding non-utf8 binary message: {err}");
1001 continue;
1002 }
1003 },
1004 Message::Close(frame) => {
1005 let reason = close_reason(frame);
1006 return Err(BacktestClientError::Closed { reason });
1007 }
1008 Message::Ping(_) | Message::Pong(_) => continue,
1009 Message::Frame(_) => continue,
1010 }
1011 }
1012 }
1013}
1014
1015impl std::fmt::Debug for BacktestSession {
1016 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1017 f.debug_struct("BacktestSession")
1018 .field("session_id", &self.session_id)
1019 .field("rpc_endpoint", &self.rpc_endpoint)
1020 .field(
1021 "rpc",
1022 &self
1023 .rpc
1024 .as_ref()
1025 .map(|_| "<RpcClient>")
1026 .unwrap_or("<not set>"),
1027 )
1028 .field("ready_for_continue", &self.ready_for_continue)
1029 .field("request_timeout", &self.request_timeout)
1030 .finish_non_exhaustive()
1031 }
1032}
1033
1034#[derive(Debug)]
1035pub(crate) enum CreateRequestResult {
1036 Single { session_id: String },
1037 Parallel { session_ids: Vec<String> },
1038}
1039
1040impl Drop for BacktestSession {
1041 fn drop(&mut self) {
1042 let Some(ws) = self.ws.take() else {
1043 return;
1044 };
1045
1046 if let Ok(handle) = tokio::runtime::Handle::try_current() {
1047 handle.spawn(async move {
1048 let mut ws = ws;
1049 let _ = ws.close(None).await;
1050 });
1051 }
1052 }
1053}
1054
1055fn resolve_rpc_url(base: &str, endpoint: &str) -> String {
1056 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
1057 endpoint.to_string()
1058 } else {
1059 format!("{}/{}", base, endpoint.trim_start_matches('/'))
1060 }
1061}
1062
1063fn close_reason(frame: Option<CloseFrame<'static>>) -> String {
1064 match frame {
1065 Some(frame) => format!("{:?}: {}", frame.code, frame.reason),
1066 None => "no close frame".to_string(),
1067 }
1068}
1069
1070fn is_reset_without_close(err: &WsError) -> bool {
1071 matches!(
1072 err,
1073 WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1074 )
1075}
1076
1077fn is_ws_closed_error(err: &WsError) -> bool {
1078 matches!(
1079 err,
1080 WsError::ConnectionClosed
1081 | WsError::AlreadyClosed
1082 | WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1083 )
1084}
1085
1086fn is_close_ok(err: &BacktestClientError) -> bool {
1087 match err {
1088 BacktestClientError::Closed { .. } => true,
1089 BacktestClientError::WebSocket { source, .. } => is_ws_closed_error(source),
1090 _ => false,
1091 }
1092}
1093
1094#[cfg(test)]
1095mod tests {
1096 use super::*;
1097
1098 #[test]
1099 fn coverage_tracks_slot_and_completion_from_responses() {
1100 let mut coverage = SessionCoverage::default();
1101 coverage.observe_response(&BacktestResponse::SlotNotification(10));
1102 coverage.observe_response(&BacktestResponse::SlotNotification(12));
1103 coverage.observe_response(&BacktestResponse::Completed {
1104 summary: None,
1105 agent_stats: None,
1106 });
1107
1108 assert!(coverage.is_completed());
1109 assert_eq!(coverage.highest_slot_seen(), Some(12));
1110 }
1111
1112 #[test]
1113 fn coverage_validate_end_slot_checks_completion_and_range() {
1114 let mut coverage = SessionCoverage::default();
1115 assert_eq!(
1116 coverage.validate_end_slot(5),
1117 Err(CoverageError::NotCompleted)
1118 );
1119
1120 coverage.mark_completed();
1121 assert_eq!(
1122 coverage.validate_end_slot(5),
1123 Err(CoverageError::NoSlotsObserved)
1124 );
1125
1126 coverage.observe_slot(4);
1127 assert_eq!(
1128 coverage.validate_end_slot(5),
1129 Err(CoverageError::RangeNotReached {
1130 actual_end_slot: 4,
1131 expected_end_slot: 5,
1132 })
1133 );
1134
1135 coverage.observe_slot(6);
1136 assert_eq!(coverage.validate_end_slot(5), Ok(()));
1137 }
1138}