1use std::{
2 borrow::Cow,
3 collections::{BTreeMap, VecDeque},
4 future::Future,
5 time::Duration,
6};
7
8use futures::{SinkExt, StreamExt, stream};
9use simulator_api::{
10 AccountData, AccountModifications, AgentStatsReport, BacktestError, BacktestRequest,
11 BacktestResponse, BacktestStatus, ContinueParams, CreateBacktestSessionRequest,
12 CreateBacktestSessionRequestV1, SessionSummary,
13};
14use solana_address::Address;
15use solana_client::{
16 nonblocking::rpc_client::RpcClient,
17 rpc_response::{Response, RpcLogsResponse},
18};
19use solana_commitment_config::CommitmentConfig;
20use thiserror::Error;
21use tokio::net::TcpStream;
22use tokio_tungstenite::{
23 MaybeTlsStream, WebSocketStream,
24 tungstenite::{
25 Error as WsError, Message,
26 error::ProtocolError,
27 protocol::{CloseFrame, frame::coding::CloseCode},
28 },
29};
30
31use crate::{
32 BacktestClientError, BacktestClientResult, Continue,
33 injection::{ProgramModError, build_program_injection},
34 subscriptions::{
35 AccountDiffNotification, AccountDiffSubscriptionHandle, LogSubscriptionHandle,
36 SubscriptionError,
37 },
38};
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum ReadyOutcome {
43 Ready,
45 Completed,
47}
48
49#[derive(Debug, Default)]
51pub struct ContinueResult {
52 pub slot_notifications: u64,
54 pub last_slot: Option<u64>,
56 pub statuses: Vec<BacktestStatus>,
58 pub ready_for_continue: bool,
60 pub completed: bool,
62}
63
64#[derive(Debug)]
66pub struct AdvanceState {
67 pub expected_slots: u64,
69 pub slot_notifications: u64,
71 pub last_slot: Option<u64>,
73 pub statuses: Vec<BacktestStatus>,
75 pub ready_for_continue: bool,
77 pub completed: bool,
79 pub summary: Option<SessionSummary>,
81 pub agent_stats: Option<Vec<AgentStatsReport>>,
83}
84
85impl AdvanceState {
86 pub fn new(expected_slots: u64) -> Self {
88 Self {
89 expected_slots,
90 slot_notifications: 0,
91 last_slot: None,
92 statuses: Vec::new(),
93 ready_for_continue: false,
94 completed: false,
95 summary: None,
96 agent_stats: None,
97 }
98 }
99
100 pub fn is_done(&self, wait_for_slots: bool) -> bool {
102 if self.completed {
103 return true;
104 }
105
106 if !self.ready_for_continue {
107 return false;
108 }
109
110 !wait_for_slots || self.slot_notifications >= self.expected_slots
111 }
112}
113
114#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
116pub struct SessionCoverage {
117 completed: bool,
118 highest_slot_seen: Option<u64>,
119}
120
121impl SessionCoverage {
122 pub fn observe_slot(&mut self, slot: u64) {
124 self.highest_slot_seen = Some(
125 self.highest_slot_seen
126 .map_or(slot, |current| current.max(slot)),
127 );
128 }
129
130 pub fn mark_completed(&mut self) {
132 self.completed = true;
133 }
134
135 pub fn observe_response(&mut self, response: &BacktestResponse) {
137 match response {
138 BacktestResponse::SlotNotification(slot) => self.observe_slot(*slot),
139 BacktestResponse::Completed { .. } => self.mark_completed(),
140 _ => {}
141 }
142 }
143
144 pub fn is_completed(&self) -> bool {
146 self.completed
147 }
148
149 pub fn highest_slot_seen(&self) -> Option<u64> {
151 self.highest_slot_seen
152 }
153
154 pub fn validate_end_slot(&self, expected_end_slot: u64) -> Result<(), CoverageError> {
156 if !self.completed {
157 return Err(CoverageError::NotCompleted);
158 }
159
160 let Some(actual_end_slot) = self.highest_slot_seen else {
161 return Err(CoverageError::NoSlotsObserved);
162 };
163
164 if actual_end_slot < expected_end_slot {
165 return Err(CoverageError::RangeNotReached {
166 actual_end_slot,
167 expected_end_slot,
168 });
169 }
170
171 Ok(())
172 }
173}
174
175#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
177pub enum CoverageError {
178 #[error("ended before completion")]
179 NotCompleted,
180 #[error("completed without slot notifications")]
181 NoSlotsObserved,
182 #[error("completed at slot {actual_end_slot} but expected at least {expected_end_slot}")]
183 RangeNotReached {
184 actual_end_slot: u64,
185 expected_end_slot: u64,
186 },
187}
188
189pub struct BacktestSession {
194 ws: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
195 session_id: Option<String>,
196 rpc_endpoint: Option<String>,
197 rpc: Option<RpcClient>,
198 ready_for_continue: bool,
199 request_timeout: Option<Duration>,
200 log_raw: bool,
201 backlog: VecDeque<BacktestResponse>,
202}
203
204impl BacktestSession {
205 pub(crate) fn new(
206 ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
207 request_timeout: Option<Duration>,
208 log_raw: bool,
209 ) -> Self {
210 Self {
211 ws: Some(ws),
212 session_id: None,
213 rpc_endpoint: None,
214 rpc: None,
215 ready_for_continue: false,
216 request_timeout,
217 log_raw,
218 backlog: VecDeque::new(),
219 }
220 }
221
222 pub fn session_id(&self) -> Option<&str> {
224 self.session_id.as_deref()
225 }
226
227 pub fn rpc_endpoint(&self) -> Option<&str> {
229 self.rpc_endpoint.as_deref()
230 }
231
232 pub fn rpc(&self) -> &RpcClient {
236 self.rpc
237 .as_ref()
238 .expect("rpc is set during session creation")
239 }
240
241 pub fn is_ready_for_continue(&self) -> bool {
243 self.ready_for_continue
244 }
245
246 pub fn apply_response(&mut self, response: &BacktestResponse) {
248 match response {
249 BacktestResponse::ReadyForContinue => {
250 self.ready_for_continue = true;
251 }
252 BacktestResponse::Completed { .. } => {
253 self.ready_for_continue = false;
254 }
255 _ => {}
256 }
257 }
258
259 fn ws_mut(&mut self) -> BacktestClientResult<&mut WebSocketStream<MaybeTlsStream<TcpStream>>> {
260 self.ws.as_mut().ok_or_else(|| BacktestClientError::Closed {
261 reason: "websocket closed".to_string(),
262 })
263 }
264
265 pub(crate) async fn create_with_request(
266 &mut self,
267 request: CreateBacktestSessionRequest,
268 rpc_base_url: String,
269 mut on_parallel_session_created: Option<&mut (dyn FnMut(String) + Send)>,
270 ) -> BacktestClientResult<CreateRequestResult> {
271 let expect_parallel = matches!(
272 &request,
273 CreateBacktestSessionRequest::V1(CreateBacktestSessionRequestV1 { parallel: true, .. })
274 );
275 self.send(&BacktestRequest::CreateBacktestSession(request), None)
276 .await?;
277 let mut streamed_parallel_session_ids = Vec::new();
278
279 loop {
280 let response =
281 self.next_response(None)
282 .await?
283 .ok_or_else(|| BacktestClientError::Closed {
284 reason: "websocket ended before SessionCreated".to_string(),
285 })?;
286
287 match response {
288 BacktestResponse::SessionCreated {
289 session_id,
290 rpc_endpoint,
291 } => {
292 if expect_parallel {
293 if let Some(callback) = on_parallel_session_created.as_mut() {
294 (**callback)(session_id.clone());
295 }
296 streamed_parallel_session_ids.push(session_id);
297 continue;
298 }
299 let created_session_id = session_id.clone();
300 self.session_id = Some(session_id);
301 let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
302 self.rpc = Some(RpcClient::new_with_commitment(
303 resolved.clone(),
304 CommitmentConfig::confirmed(),
305 ));
306 self.rpc_endpoint = Some(resolved);
307 return Ok(CreateRequestResult::Single {
308 session_id: created_session_id,
309 });
310 }
311 BacktestResponse::SessionsCreated { session_ids } => {
312 if expect_parallel && session_ids.is_empty() {
313 return Ok(CreateRequestResult::Parallel {
314 session_ids: streamed_parallel_session_ids,
315 });
316 }
317 return Ok(CreateRequestResult::Parallel { session_ids });
318 }
319 BacktestResponse::SessionsCreatedV2 { session_ids, .. } => {
320 if expect_parallel && session_ids.is_empty() {
321 return Ok(CreateRequestResult::Parallel {
322 session_ids: streamed_parallel_session_ids,
323 });
324 }
325 return Ok(CreateRequestResult::Parallel { session_ids });
326 }
327 BacktestResponse::ReadyForContinue => {
328 self.ready_for_continue = true;
329 }
330 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
331 other => {
332 self.backlog.push_back(other);
333 }
334 }
335 }
336 }
337
338 pub(crate) async fn attach(
339 &mut self,
340 session_id: String,
341 last_sequence: Option<u64>,
342 rpc_base_url: String,
343 ) -> BacktestClientResult<()> {
344 self.send(
345 &BacktestRequest::AttachBacktestSession {
346 session_id,
347 last_sequence,
348 },
349 None,
350 )
351 .await?;
352
353 loop {
354 let response =
355 self.next_response(None)
356 .await?
357 .ok_or_else(|| BacktestClientError::Closed {
358 reason: "websocket ended before SessionAttached".to_string(),
359 })?;
360
361 match response {
362 BacktestResponse::SessionAttached {
363 session_id,
364 rpc_endpoint,
365 } => {
366 self.session_id = Some(session_id);
367 let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
368 self.rpc = Some(RpcClient::new_with_commitment(
369 resolved.clone(),
370 CommitmentConfig::confirmed(),
371 ));
372 self.rpc_endpoint = Some(resolved);
373 return Ok(());
374 }
375 BacktestResponse::ReadyForContinue => {
376 self.ready_for_continue = true;
377 }
378 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
379 other => {
380 self.backlog.push_back(other);
381 }
382 }
383 }
384 }
385
386 pub async fn send(
388 &mut self,
389 request: &BacktestRequest,
390 timeout: Option<Duration>,
391 ) -> BacktestClientResult<()> {
392 let text = serde_json::to_string(request)
393 .map_err(|source| BacktestClientError::SerializeRequest { source })?;
394
395 let request_timeout = self.request_timeout;
396 let timeout = timeout.or(request_timeout);
397
398 let send_fut = self.ws_mut()?.send(Message::Text(text));
399 let send_result = match timeout {
400 Some(duration) => tokio::time::timeout(duration, send_fut)
401 .await
402 .map_err(|_| BacktestClientError::Timeout {
403 action: "sending",
404 duration,
405 })?,
406 None => send_fut.await,
407 };
408
409 send_result.map_err(|source| BacktestClientError::WebSocket {
410 action: "sending",
411 source: Box::new(source),
412 })?;
413
414 Ok(())
415 }
416
417 pub async fn next_response(
419 &mut self,
420 timeout: Option<Duration>,
421 ) -> BacktestClientResult<Option<BacktestResponse>> {
422 if let Some(response) = self.backlog.pop_front() {
423 return Ok(Some(response));
424 }
425
426 let text = match self.next_text(timeout).await? {
427 Some(text) => text,
428 None => return Ok(None),
429 };
430
431 let response = serde_json::from_str::<BacktestResponse>(&text).map_err(|source| {
432 BacktestClientError::DeserializeResponse {
433 raw: text.clone(),
434 source,
435 }
436 })?;
437
438 Ok(Some(response))
439 }
440
441 pub async fn next_event(
443 &mut self,
444 timeout: Option<Duration>,
445 ) -> BacktestClientResult<Option<BacktestResponse>> {
446 let response = self.next_response(timeout).await?;
447 if let Some(ref response) = response {
448 self.apply_response(response);
449 }
450 Ok(response)
451 }
452
453 pub fn responses(
457 self,
458 timeout: Option<Duration>,
459 ) -> impl futures::Stream<Item = BacktestClientResult<BacktestResponse>> {
460 stream::unfold(Some(self), move |state| async move {
461 let mut session = match state {
462 Some(session) => session,
463 None => return None,
464 };
465
466 match session.next_response(timeout).await {
467 Ok(Some(response)) => {
468 session.apply_response(&response);
469 Some((Ok(response), Some(session)))
470 }
471 Ok(None) => None,
472 Err(err) => Some((Err(err), None)),
473 }
474 })
475 }
476
477 pub async fn ensure_ready(
479 &mut self,
480 timeout: Option<Duration>,
481 ) -> BacktestClientResult<ReadyOutcome> {
482 if self.ready_for_continue {
483 return Ok(ReadyOutcome::Ready);
484 }
485
486 loop {
487 let response =
488 self.next_response(timeout)
489 .await?
490 .ok_or_else(|| BacktestClientError::Closed {
491 reason: "websocket ended while waiting for ReadyForContinue".to_string(),
492 })?;
493
494 match response {
495 BacktestResponse::ReadyForContinue => {
496 self.ready_for_continue = true;
497 return Ok(ReadyOutcome::Ready);
498 }
499 BacktestResponse::Completed { .. } => return Ok(ReadyOutcome::Completed),
500 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
501 _ => {}
502 }
503 }
504 }
505
506 pub async fn wait_for_status(
508 &mut self,
509 desired: BacktestStatus,
510 timeout: Option<Duration>,
511 ) -> BacktestClientResult<()> {
512 let desired = std::mem::discriminant(&desired);
513
514 loop {
515 let response =
516 self.next_response(timeout)
517 .await?
518 .ok_or_else(|| BacktestClientError::Closed {
519 reason: "websocket ended while waiting for status".to_string(),
520 })?;
521
522 match response {
523 BacktestResponse::Status { status }
524 if std::mem::discriminant(&status) == desired =>
525 {
526 return Ok(());
527 }
528 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
529 BacktestResponse::Completed {
530 summary,
531 agent_stats,
532 } => {
533 return Err(BacktestClientError::UnexpectedResponse {
534 context: "waiting for status",
535 response: Box::new(BacktestResponse::Completed {
536 summary,
537 agent_stats,
538 }),
539 });
540 }
541 _ => {}
542 }
543 }
544 }
545
546 pub async fn send_continue(
548 &mut self,
549 params: ContinueParams,
550 timeout: Option<Duration>,
551 ) -> BacktestClientResult<()> {
552 self.ready_for_continue = false;
553 self.send(&BacktestRequest::Continue(params), timeout).await
554 }
555
556 pub async fn advance_step<F>(
558 &mut self,
559 state: &mut AdvanceState,
560 wait_for_slots: bool,
561 timeout: Option<Duration>,
562 on_event: &mut F,
563 ) -> BacktestClientResult<()>
564 where
565 F: FnMut(&BacktestResponse),
566 {
567 let Some(response) = self.next_response(timeout).await? else {
568 return Err(BacktestClientError::Closed {
569 reason: "websocket ended while awaiting continue responses".to_string(),
570 });
571 };
572
573 if self.log_raw {
574 tracing::debug!("<- {response:?}");
575 }
576
577 on_event(&response);
578
579 match response {
580 BacktestResponse::ReadyForContinue => {
581 self.ready_for_continue = true;
582 state.ready_for_continue = true;
583 }
584 BacktestResponse::SlotNotification(slot) => {
585 state.slot_notifications += 1;
586 state.last_slot = Some(slot);
587 }
588 BacktestResponse::Status { status } => {
589 state.statuses.push(status);
590 }
591 BacktestResponse::Success => {}
592 BacktestResponse::Completed {
593 summary,
594 agent_stats,
595 } => {
596 state.completed = true;
597 state.summary = summary;
598 state.agent_stats = agent_stats;
599 }
600 BacktestResponse::Error(err @ BacktestError::SimulationError { .. }) => {
601 tracing::warn!(error = %err, "simulation error");
602 }
603 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
604 BacktestResponse::SessionCreated { .. }
605 | BacktestResponse::SessionAttached { .. }
606 | BacktestResponse::SessionsCreated { .. }
607 | BacktestResponse::SessionsCreatedV2 { .. }
608 | BacktestResponse::ParallelSessionAttachedV2 { .. }
609 | BacktestResponse::SessionEventV1 { .. }
610 | BacktestResponse::SessionEventV2 { .. } => {
611 return Err(BacktestClientError::UnexpectedResponse {
612 context: "continuing",
613 response: Box::new(response),
614 });
615 }
616 }
617
618 if wait_for_slots && state.slot_notifications > state.expected_slots {
619 tracing::warn!(
620 "received {} slot notifications (expected {})",
621 state.slot_notifications,
622 state.expected_slots
623 );
624 }
625
626 Ok(())
627 }
628
629 pub async fn continue_until_ready<F>(
631 &mut self,
632 cont: Continue,
633 timeout: Option<Duration>,
634 mut on_event: F,
635 ) -> BacktestClientResult<ContinueResult>
636 where
637 F: FnMut(&BacktestResponse),
638 {
639 let expected_slots = cont.advance_count;
640 self.advance_internal(
641 cont.into_params(),
642 expected_slots,
643 false,
644 timeout,
645 &mut on_event,
646 )
647 .await
648 }
649
650 pub async fn advance<F>(
652 &mut self,
653 cont: Continue,
654 timeout: Option<Duration>,
655 mut on_event: F,
656 ) -> BacktestClientResult<ContinueResult>
657 where
658 F: FnMut(&BacktestResponse),
659 {
660 let expected_slots = cont.advance_count;
661 self.advance_internal(
662 cont.into_params(),
663 expected_slots,
664 true,
665 timeout,
666 &mut on_event,
667 )
668 .await
669 }
670
671 async fn advance_internal<F>(
672 &mut self,
673 params: ContinueParams,
674 expected_slots: u64,
675 wait_for_slots: bool,
676 timeout: Option<Duration>,
677 on_event: &mut F,
678 ) -> BacktestClientResult<ContinueResult>
679 where
680 F: FnMut(&BacktestResponse),
681 {
682 self.send_continue(params, timeout).await?;
683
684 let mut state = AdvanceState::new(expected_slots);
685 while !state.is_done(wait_for_slots) {
686 self.advance_step(&mut state, wait_for_slots, timeout, on_event)
687 .await?;
688 }
689
690 Ok(ContinueResult {
691 slot_notifications: state.slot_notifications,
692 last_slot: state.last_slot,
693 statuses: state.statuses,
694 ready_for_continue: state.ready_for_continue,
695 completed: state.completed,
696 })
697 }
698
699 pub async fn modify_program(
708 &self,
709 program_id: &str,
710 elf: &[u8],
711 ) -> Result<BTreeMap<Address, AccountData>, ProgramModError> {
712 let rpc = self.rpc.as_ref().ok_or(ProgramModError::NoRpcEndpoint)?;
713
714 let program_addr: Address =
715 program_id
716 .parse()
717 .map_err(|_| ProgramModError::InvalidProgramId {
718 id: program_id.to_string(),
719 })?;
720 let programdata_addr = solana_loader_v3_interface::get_program_data_address(&program_addr);
721
722 let slot = rpc.get_slot().await.map_err(|e| ProgramModError::Rpc {
723 source: Box::new(e),
724 })?;
725 let deploy_slot = slot.saturating_sub(1);
726
727 let existing =
730 rpc.get_account(&programdata_addr)
731 .await
732 .map_err(|e| ProgramModError::Rpc {
733 source: Box::new(e),
734 })?;
735
736 let upgrade_authority = if existing.data.get(12).copied() == Some(1) {
737 existing.data.get(13..45).and_then(|b| {
738 let bytes: [u8; 32] = b.try_into().ok()?;
739 Some(Address::from(bytes))
740 })
741 } else {
742 None
743 };
744
745 let data_len = upgrade_authority.map_or(13, |_| 45) + elf.len();
746 let lamports = rpc
747 .get_minimum_balance_for_rent_exemption(data_len)
748 .await
749 .map_err(|e| ProgramModError::Rpc {
750 source: Box::new(e),
751 })?;
752
753 Ok(build_program_injection(
754 programdata_addr,
755 elf,
756 deploy_slot,
757 upgrade_authority,
758 lamports,
759 ))
760 }
761
762 pub async fn modify_accounts(
766 &self,
767 modifications: &AccountModifications,
768 ) -> BacktestClientResult<usize> {
769 let rpc_endpoint =
770 self.rpc_endpoint
771 .as_deref()
772 .ok_or_else(|| BacktestClientError::Closed {
773 reason: "no RPC endpoint available".to_string(),
774 })?;
775
776 crate::rpc::modify_accounts(rpc_endpoint, modifications).await
777 }
778
779 pub async fn subscribe_program_logs<F, Fut>(
785 &self,
786 program_id: &str,
787 commitment: CommitmentConfig,
788 on_notification: F,
789 ) -> Result<LogSubscriptionHandle, SubscriptionError>
790 where
791 F: Fn(Response<RpcLogsResponse>) -> Fut + Send + Sync + 'static,
792 Fut: Future<Output = ()> + Send + 'static,
793 {
794 let rpc_endpoint = self
795 .rpc_endpoint
796 .as_deref()
797 .ok_or(SubscriptionError::NoRpcEndpoint)?;
798 crate::subscriptions::subscribe_program_logs(
799 rpc_endpoint,
800 program_id,
801 commitment,
802 on_notification,
803 )
804 .await
805 }
806
807 pub async fn subscribe_account_diffs<F, Fut>(
813 &self,
814 account: &str,
815 on_notification: F,
816 ) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
817 where
818 F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
819 Fut: Future<Output = ()> + Send + 'static,
820 {
821 let rpc_endpoint = self
822 .rpc_endpoint
823 .as_deref()
824 .ok_or(SubscriptionError::NoRpcEndpoint)?;
825 crate::subscriptions::subscribe_account_diffs(rpc_endpoint, account, on_notification).await
826 }
827
828 pub async fn close(&mut self, timeout: Option<Duration>) -> BacktestClientResult<()> {
832 self.close_with_frame(timeout, None).await
833 }
834
835 pub async fn close_with_frame(
837 &mut self,
838 timeout: Option<Duration>,
839 frame: Option<CloseFrame<'static>>,
840 ) -> BacktestClientResult<()> {
841 if self.ws.is_none() {
842 return Ok(());
843 }
844
845 let mut sent = false;
846 match self
847 .send(&BacktestRequest::CloseBacktestSession, timeout)
848 .await
849 {
850 Ok(()) => sent = true,
851 Err(err) if is_close_ok(&err) => {}
852 Err(err) => return Err(err),
853 }
854
855 if sent {
856 let response = match self.next_response(timeout).await {
857 Ok(Some(r)) => r,
858 Ok(None) => {
859 self.ws.take();
860 return Ok(());
861 }
862 Err(BacktestClientError::Closed { .. }) => {
863 self.ws.take();
864 return Ok(());
865 }
866 Err(BacktestClientError::WebSocket {
867 action: "receiving",
868 source,
869 }) if is_reset_without_close(&source) => {
870 self.ws.take();
871 return Ok(());
872 }
873 Err(err) => return Err(err),
874 };
875
876 match response {
877 BacktestResponse::Success | BacktestResponse::Completed { .. } => {}
878 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
879 other => {
880 return Err(BacktestClientError::UnexpectedResponse {
881 context: "closing session",
882 response: Box::new(other),
883 });
884 }
885 }
886 }
887
888 if let Some(ws) = self.ws.as_mut() {
889 let close_result = ws.close(frame).await;
890 if let Err(source) = close_result
891 && !is_ws_closed_error(&source)
892 {
893 return Err(BacktestClientError::WebSocket {
894 action: "closing",
895 source: Box::new(source),
896 });
897 }
898 }
899
900 match self.next_response(timeout).await {
902 Ok(_) => {}
903 Err(BacktestClientError::Closed { .. }) => {}
904 Err(BacktestClientError::WebSocket {
905 action: "receiving",
906 source,
907 }) if is_reset_without_close(&source) => {}
908 Err(err) => return Err(err),
909 }
910
911 tokio::time::sleep(Duration::from_millis(100)).await;
913 self.ws.take();
914 Ok(())
915 }
916
917 pub async fn close_with_reason(
919 &mut self,
920 timeout: Option<Duration>,
921 code: CloseCode,
922 reason: impl Into<String>,
923 ) -> BacktestClientResult<()> {
924 let frame = CloseFrame {
925 code,
926 reason: Cow::Owned(reason.into()),
927 };
928 self.close_with_frame(timeout, Some(frame)).await
929 }
930
931 async fn next_text(
932 &mut self,
933 timeout: Option<Duration>,
934 ) -> BacktestClientResult<Option<String>> {
935 loop {
936 let request_timeout = self.request_timeout;
937 let timeout = timeout.or(request_timeout);
938
939 let next_fut = self.ws_mut()?.next();
940 let msg = match timeout {
941 Some(duration) => tokio::time::timeout(duration, next_fut)
942 .await
943 .map_err(|_| BacktestClientError::Timeout {
944 action: "receiving",
945 duration,
946 })?,
947 None => next_fut.await,
948 };
949
950 let Some(msg) = msg else {
951 return Ok(None);
952 };
953
954 let msg = match msg {
955 Ok(msg) => msg,
956 Err(source) => {
957 return Err(BacktestClientError::WebSocket {
958 action: "receiving",
959 source: Box::new(source),
960 });
961 }
962 };
963
964 match msg {
965 Message::Text(text) => {
966 if self.log_raw {
967 tracing::debug!("<- raw: {text}");
968 }
969 return Ok(Some(text));
970 }
971 Message::Binary(bin) => match String::from_utf8(bin) {
972 Ok(text) => {
973 if self.log_raw {
974 tracing::debug!("<- raw(bin): {text}");
975 }
976 return Ok(Some(text));
977 }
978 Err(err) => {
979 tracing::warn!("discarding non-utf8 binary message: {err}");
980 continue;
981 }
982 },
983 Message::Close(frame) => {
984 let reason = close_reason(frame);
985 return Err(BacktestClientError::Closed { reason });
986 }
987 Message::Ping(_) | Message::Pong(_) => continue,
988 Message::Frame(_) => continue,
989 }
990 }
991 }
992}
993
994impl std::fmt::Debug for BacktestSession {
995 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
996 f.debug_struct("BacktestSession")
997 .field("session_id", &self.session_id)
998 .field("rpc_endpoint", &self.rpc_endpoint)
999 .field(
1000 "rpc",
1001 &self
1002 .rpc
1003 .as_ref()
1004 .map(|_| "<RpcClient>")
1005 .unwrap_or("<not set>"),
1006 )
1007 .field("ready_for_continue", &self.ready_for_continue)
1008 .field("request_timeout", &self.request_timeout)
1009 .finish_non_exhaustive()
1010 }
1011}
1012
1013#[derive(Debug)]
1014pub(crate) enum CreateRequestResult {
1015 Single { session_id: String },
1016 Parallel { session_ids: Vec<String> },
1017}
1018
1019impl Drop for BacktestSession {
1020 fn drop(&mut self) {
1021 let Some(ws) = self.ws.take() else {
1022 return;
1023 };
1024
1025 if let Ok(handle) = tokio::runtime::Handle::try_current() {
1026 handle.spawn(async move {
1027 let mut ws = ws;
1028 let _ = ws.close(None).await;
1029 });
1030 }
1031 }
1032}
1033
1034fn resolve_rpc_url(base: &str, endpoint: &str) -> String {
1035 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
1036 endpoint.to_string()
1037 } else {
1038 format!("{}/{}", base, endpoint.trim_start_matches('/'))
1039 }
1040}
1041
1042fn close_reason(frame: Option<CloseFrame<'static>>) -> String {
1043 match frame {
1044 Some(frame) => format!("{:?}: {}", frame.code, frame.reason),
1045 None => "no close frame".to_string(),
1046 }
1047}
1048
1049fn is_reset_without_close(err: &WsError) -> bool {
1050 matches!(
1051 err,
1052 WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1053 )
1054}
1055
1056fn is_ws_closed_error(err: &WsError) -> bool {
1057 matches!(
1058 err,
1059 WsError::ConnectionClosed
1060 | WsError::AlreadyClosed
1061 | WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1062 )
1063}
1064
1065fn is_close_ok(err: &BacktestClientError) -> bool {
1066 match err {
1067 BacktestClientError::Closed { .. } => true,
1068 BacktestClientError::WebSocket { source, .. } => is_ws_closed_error(source),
1069 _ => false,
1070 }
1071}
1072
1073#[cfg(test)]
1074mod tests {
1075 use super::*;
1076
1077 #[test]
1078 fn coverage_tracks_slot_and_completion_from_responses() {
1079 let mut coverage = SessionCoverage::default();
1080 coverage.observe_response(&BacktestResponse::SlotNotification(10));
1081 coverage.observe_response(&BacktestResponse::SlotNotification(12));
1082 coverage.observe_response(&BacktestResponse::Completed {
1083 summary: None,
1084 agent_stats: None,
1085 });
1086
1087 assert!(coverage.is_completed());
1088 assert_eq!(coverage.highest_slot_seen(), Some(12));
1089 }
1090
1091 #[test]
1092 fn coverage_validate_end_slot_checks_completion_and_range() {
1093 let mut coverage = SessionCoverage::default();
1094 assert_eq!(
1095 coverage.validate_end_slot(5),
1096 Err(CoverageError::NotCompleted)
1097 );
1098
1099 coverage.mark_completed();
1100 assert_eq!(
1101 coverage.validate_end_slot(5),
1102 Err(CoverageError::NoSlotsObserved)
1103 );
1104
1105 coverage.observe_slot(4);
1106 assert_eq!(
1107 coverage.validate_end_slot(5),
1108 Err(CoverageError::RangeNotReached {
1109 actual_end_slot: 4,
1110 expected_end_slot: 5,
1111 })
1112 );
1113
1114 coverage.observe_slot(6);
1115 assert_eq!(coverage.validate_end_slot(5), Ok(()));
1116 }
1117}