1use std::{
2 borrow::Cow,
3 collections::{BTreeMap, VecDeque},
4 future::Future,
5 time::Duration,
6};
7
8use futures::{SinkExt, StreamExt, stream};
9use simulator_api::{
10 AccountData, AccountModifications, AgentStatsReport, BacktestError, BacktestRequest,
11 BacktestResponse, BacktestStatus, ContinueParams, CreateBacktestSessionRequest,
12 CreateBacktestSessionRequestV1, SequencedResponse, SessionSummary,
13};
14use solana_address::Address;
15use solana_client::{
16 nonblocking::rpc_client::RpcClient,
17 rpc_response::{Response, RpcLogsResponse},
18};
19use solana_commitment_config::CommitmentConfig;
20use thiserror::Error;
21use tokio::net::TcpStream;
22use tokio_tungstenite::{
23 MaybeTlsStream, WebSocketStream,
24 tungstenite::{
25 Error as WsError, Message,
26 error::ProtocolError,
27 protocol::{CloseFrame, frame::coding::CloseCode},
28 },
29};
30
31use crate::{
32 BacktestClientError, BacktestClientResult, Continue,
33 injection::ProgramModError,
34 subscriptions::{
35 AccountDiffNotification, AccountDiffSubscriptionHandle, LogSubscriptionHandle,
36 SubscriptionError,
37 },
38};
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum ReadyOutcome {
43 Ready,
45 Completed,
47}
48
49#[derive(Debug, Default)]
51pub struct ContinueResult {
52 pub slot_notifications: u64,
54 pub last_slot: Option<u64>,
56 pub statuses: Vec<BacktestStatus>,
58 pub ready_for_continue: bool,
60 pub completed: bool,
62}
63
64#[derive(Debug)]
66pub struct AdvanceState {
67 pub expected_slots: u64,
69 pub slot_notifications: u64,
71 pub last_slot: Option<u64>,
73 pub statuses: Vec<BacktestStatus>,
75 pub ready_for_continue: bool,
77 pub completed: bool,
79 pub summary: Option<SessionSummary>,
81 pub agent_stats: Option<Vec<AgentStatsReport>>,
83}
84
85impl AdvanceState {
86 pub fn new(expected_slots: u64) -> Self {
88 Self {
89 expected_slots,
90 slot_notifications: 0,
91 last_slot: None,
92 statuses: Vec::new(),
93 ready_for_continue: false,
94 completed: false,
95 summary: None,
96 agent_stats: None,
97 }
98 }
99
100 pub fn is_done(&self, wait_for_slots: bool) -> bool {
102 if self.completed {
103 return true;
104 }
105
106 if !self.ready_for_continue {
107 return false;
108 }
109
110 !wait_for_slots || self.slot_notifications >= self.expected_slots
111 }
112}
113
114#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
116pub struct SessionCoverage {
117 completed: bool,
118 highest_slot_seen: Option<u64>,
119}
120
121impl SessionCoverage {
122 pub fn observe_slot(&mut self, slot: u64) {
124 self.highest_slot_seen = Some(
125 self.highest_slot_seen
126 .map_or(slot, |current| current.max(slot)),
127 );
128 }
129
130 pub fn mark_completed(&mut self) {
132 self.completed = true;
133 }
134
135 pub fn observe_response(&mut self, response: &BacktestResponse) {
137 match response {
138 BacktestResponse::SlotNotification(slot) => self.observe_slot(*slot),
139 BacktestResponse::Completed { .. } => self.mark_completed(),
140 _ => {}
141 }
142 }
143
144 pub fn is_completed(&self) -> bool {
146 self.completed
147 }
148
149 pub fn highest_slot_seen(&self) -> Option<u64> {
151 self.highest_slot_seen
152 }
153
154 pub fn validate_end_slot(&self, expected_end_slot: u64) -> Result<(), CoverageError> {
156 if !self.completed {
157 return Err(CoverageError::NotCompleted);
158 }
159
160 let Some(actual_end_slot) = self.highest_slot_seen else {
161 return Err(CoverageError::NoSlotsObserved);
162 };
163
164 if actual_end_slot < expected_end_slot {
165 return Err(CoverageError::RangeNotReached {
166 actual_end_slot,
167 expected_end_slot,
168 });
169 }
170
171 Ok(())
172 }
173}
174
175#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
177pub enum CoverageError {
178 #[error("ended before completion")]
179 NotCompleted,
180 #[error("completed without slot notifications")]
181 NoSlotsObserved,
182 #[error("completed at slot {actual_end_slot} but expected at least {expected_end_slot}")]
183 RangeNotReached {
184 actual_end_slot: u64,
185 expected_end_slot: u64,
186 },
187}
188
189pub struct BacktestSession {
194 ws: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
195 session_id: Option<String>,
196 rpc_endpoint: Option<String>,
197 task_id: Option<String>,
198 rpc: Option<RpcClient>,
199 last_sequence: Option<u64>,
200 pub(crate) ready_for_continue: bool,
201 request_timeout: Option<Duration>,
202 log_raw: bool,
203 backlog: VecDeque<(Option<u64>, BacktestResponse)>,
204}
205
206impl BacktestSession {
207 pub(crate) fn new(
208 ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
209 request_timeout: Option<Duration>,
210 log_raw: bool,
211 ) -> Self {
212 Self {
213 ws: Some(ws),
214 session_id: None,
215 rpc_endpoint: None,
216 task_id: None,
217 rpc: None,
218 last_sequence: None,
219 ready_for_continue: false,
220 request_timeout,
221 log_raw,
222 backlog: VecDeque::new(),
223 }
224 }
225
226 pub fn session_id(&self) -> Option<&str> {
228 self.session_id.as_deref()
229 }
230
231 pub fn rpc_endpoint(&self) -> Option<&str> {
233 self.rpc_endpoint.as_deref()
234 }
235
236 pub fn task_id(&self) -> Option<&str> {
239 self.task_id.as_deref()
240 }
241
242 pub fn last_sequence(&self) -> Option<u64> {
244 self.last_sequence
245 }
246
247 pub fn rpc(&self) -> &RpcClient {
251 self.rpc
252 .as_ref()
253 .expect("rpc is set during session creation")
254 }
255
256 pub fn is_ready_for_continue(&self) -> bool {
258 self.ready_for_continue
259 }
260
261 pub fn apply_response(&mut self, response: &BacktestResponse) {
263 match response {
264 BacktestResponse::ReadyForContinue | BacktestResponse::Paused(_) => {
265 self.ready_for_continue = true;
266 }
267 BacktestResponse::Completed { .. } => {
268 self.ready_for_continue = false;
269 }
270 _ => {}
271 }
272 }
273
274 fn ws_mut(&mut self) -> BacktestClientResult<&mut WebSocketStream<MaybeTlsStream<TcpStream>>> {
275 self.ws.as_mut().ok_or_else(|| BacktestClientError::Closed {
276 reason: "websocket closed".to_string(),
277 })
278 }
279
280 pub(crate) async fn create_with_request(
281 &mut self,
282 request: CreateBacktestSessionRequest,
283 rpc_base_url: String,
284 mut on_parallel_session_created: Option<&mut (dyn FnMut(String) + Send)>,
285 ) -> BacktestClientResult<CreateRequestResult> {
286 let expect_parallel = matches!(
287 &request,
288 CreateBacktestSessionRequest::V1(CreateBacktestSessionRequestV1 { parallel: true, .. })
289 );
290 self.send(&BacktestRequest::CreateBacktestSession(request), None)
291 .await?;
292 let mut streamed_parallel_session_ids = Vec::new();
293 let mut streamed_parallel_task_ids: Vec<Option<String>> = Vec::new();
294 let mut pending: Vec<(Option<u64>, BacktestResponse)> = Vec::new();
298
299 loop {
300 let response =
301 self.next_response(None)
302 .await?
303 .ok_or_else(|| BacktestClientError::Closed {
304 reason: "websocket ended before SessionCreated".to_string(),
305 })?;
306
307 match response {
308 BacktestResponse::SessionCreated {
309 session_id,
310 rpc_endpoint,
311 task_id,
312 } => {
313 if expect_parallel {
314 if let Some(callback) = on_parallel_session_created.as_mut() {
315 (**callback)(session_id.clone());
316 }
317 streamed_parallel_session_ids.push(session_id);
318 streamed_parallel_task_ids.push(task_id);
319 continue;
320 }
321 let created_session_id = session_id.clone();
322 let created_task_id = task_id.clone();
323 self.session_id = Some(session_id);
324 self.task_id = task_id;
325 let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
326 self.rpc = Some(RpcClient::new_with_commitment(
327 resolved.clone(),
328 CommitmentConfig::confirmed(),
329 ));
330 self.rpc_endpoint = Some(resolved);
331 self.backlog.extend(pending);
332 return Ok(CreateRequestResult::Single {
333 session_id: created_session_id,
334 task_id: created_task_id,
335 });
336 }
337 BacktestResponse::SessionsCreated { session_ids } => {
338 if expect_parallel && session_ids.is_empty() {
339 self.backlog.extend(pending);
340 return Ok(CreateRequestResult::Parallel {
341 session_ids: streamed_parallel_session_ids,
342 task_ids: streamed_parallel_task_ids,
343 });
344 }
345 let len = session_ids.len();
346 self.backlog.extend(pending);
347 return Ok(CreateRequestResult::Parallel {
348 session_ids,
349 task_ids: vec![None; len],
350 });
351 }
352 BacktestResponse::SessionsCreatedV2 {
353 session_ids,
354 task_ids,
355 ..
356 } => {
357 if expect_parallel && session_ids.is_empty() {
358 self.backlog.extend(pending);
359 return Ok(CreateRequestResult::Parallel {
360 session_ids: streamed_parallel_session_ids,
361 task_ids: streamed_parallel_task_ids,
362 });
363 }
364 let task_ids = align_task_ids(task_ids, session_ids.len());
365 self.backlog.extend(pending);
366 return Ok(CreateRequestResult::Parallel {
367 session_ids,
368 task_ids,
369 });
370 }
371 BacktestResponse::ReadyForContinue => {
372 self.ready_for_continue = true;
373 }
374 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
375 other => {
376 pending.push((self.last_sequence, other));
377 }
378 }
379 }
380 }
381
382 pub(crate) async fn attach(
383 &mut self,
384 session_id: String,
385 last_sequence: Option<u64>,
386 rpc_base_url: String,
387 ) -> BacktestClientResult<()> {
388 self.send(
389 &BacktestRequest::AttachBacktestSession {
390 session_id,
391 last_sequence,
392 },
393 None,
394 )
395 .await?;
396
397 self.wait_for_response(
398 || BacktestClientError::Closed {
399 reason: "websocket ended before SessionAttached".to_string(),
400 },
401 move |session, response| match response {
402 BacktestResponse::SessionAttached {
403 session_id,
404 rpc_endpoint,
405 task_id,
406 } => {
407 session.session_id = Some(session_id);
408 session.task_id = task_id;
409 let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
410 session.rpc = Some(RpcClient::new_with_commitment(
411 resolved.clone(),
412 CommitmentConfig::confirmed(),
413 ));
414 session.rpc_endpoint = Some(resolved);
415 Ok(Some(()))
416 }
417 BacktestResponse::ReadyForContinue => {
418 session.ready_for_continue = true;
419 Ok(None)
420 }
421 BacktestResponse::Error(err) => Err(BacktestClientError::Remote(err)),
422 other => {
423 session.backlog.push_back((session.last_sequence, other));
424 Ok(None)
425 }
426 },
427 )
428 .await
429 }
430
431 pub async fn resume_attached_session(&mut self) -> BacktestClientResult<()> {
434 self.send(&BacktestRequest::ResumeAttachedSession, None)
435 .await?;
436
437 self.wait_for_response(
438 || BacktestClientError::Closed {
439 reason: "websocket ended before ResumeAttachedSession acknowledgement".to_string(),
440 },
441 |session, response| match response {
442 BacktestResponse::Success => Ok(Some(())),
443 BacktestResponse::Error(err) => Err(BacktestClientError::Remote(err)),
444 other => {
445 session.backlog.push_back((session.last_sequence, other));
446 Ok(None)
447 }
448 },
449 )
450 .await
451 }
452
453 async fn wait_for_response<T, E, F>(
454 &mut self,
455 closed_error: E,
456 mut handle_response: F,
457 ) -> BacktestClientResult<T>
458 where
459 E: FnOnce() -> BacktestClientError,
460 F: FnMut(&mut Self, BacktestResponse) -> BacktestClientResult<Option<T>>,
461 {
462 let mut closed_error = Some(closed_error);
463
464 loop {
465 let response = self
466 .next_response(None)
467 .await?
468 .ok_or_else(|| closed_error.take().expect("closed error set")())?;
469
470 if let Some(result) = handle_response(self, response)? {
471 return Ok(result);
472 }
473 }
474 }
475
476 pub async fn send(
478 &mut self,
479 request: &BacktestRequest,
480 timeout: Option<Duration>,
481 ) -> BacktestClientResult<()> {
482 let text = serde_json::to_string(request)
483 .map_err(|source| BacktestClientError::SerializeRequest { source })?;
484
485 let request_timeout = self.request_timeout;
486 let timeout = timeout.or(request_timeout);
487
488 let send_fut = self.ws_mut()?.send(Message::Text(text));
489 let send_result = match timeout {
490 Some(duration) => tokio::time::timeout(duration, send_fut)
491 .await
492 .map_err(|_| BacktestClientError::Timeout {
493 action: "sending",
494 duration,
495 })?,
496 None => send_fut.await,
497 };
498
499 send_result.map_err(|source| BacktestClientError::WebSocket {
500 action: "sending",
501 source: Box::new(source),
502 })?;
503
504 Ok(())
505 }
506
507 pub async fn next_response(
509 &mut self,
510 timeout: Option<Duration>,
511 ) -> BacktestClientResult<Option<BacktestResponse>> {
512 if let Some((sequence, response)) = self.backlog.pop_front() {
513 self.last_sequence = sequence.or(self.last_sequence);
514 return Ok(Some(response));
515 }
516
517 let text = match self.next_text(timeout).await? {
518 Some(text) => text,
519 None => return Ok(None),
520 };
521
522 let (sequence, response) = match serde_json::from_str::<SequencedResponse>(&text) {
523 Ok(sequenced) => (Some(sequenced.seq_id), sequenced.response),
524 Err(_) => {
525 let response =
526 serde_json::from_str::<BacktestResponse>(&text).map_err(|source| {
527 BacktestClientError::DeserializeResponse {
528 raw: text.clone(),
529 source,
530 }
531 })?;
532 (None, response)
533 }
534 };
535 self.last_sequence = sequence.or(self.last_sequence);
536
537 Ok(Some(response))
538 }
539
540 pub async fn next_event(
542 &mut self,
543 timeout: Option<Duration>,
544 ) -> BacktestClientResult<Option<BacktestResponse>> {
545 let response = self.next_response(timeout).await?;
546 if let Some(ref response) = response {
547 self.apply_response(response);
548 }
549 Ok(response)
550 }
551
552 pub fn responses(
556 self,
557 timeout: Option<Duration>,
558 ) -> impl futures::Stream<Item = BacktestClientResult<BacktestResponse>> {
559 stream::unfold(Some(self), move |state| async move {
560 let mut session = match state {
561 Some(session) => session,
562 None => return None,
563 };
564
565 match session.next_response(timeout).await {
566 Ok(Some(response)) => {
567 session.apply_response(&response);
568 Some((Ok(response), Some(session)))
569 }
570 Ok(None) => None,
571 Err(err) => Some((Err(err), None)),
572 }
573 })
574 }
575
576 pub async fn ensure_ready(
578 &mut self,
579 timeout: Option<Duration>,
580 ) -> BacktestClientResult<ReadyOutcome> {
581 if self.ready_for_continue {
582 return Ok(ReadyOutcome::Ready);
583 }
584
585 loop {
586 let response =
587 self.next_response(timeout)
588 .await?
589 .ok_or_else(|| BacktestClientError::Closed {
590 reason: "websocket ended while waiting for ReadyForContinue".to_string(),
591 })?;
592
593 match response {
594 BacktestResponse::ReadyForContinue => {
595 self.ready_for_continue = true;
596 return Ok(ReadyOutcome::Ready);
597 }
598 BacktestResponse::Completed { .. } => return Ok(ReadyOutcome::Completed),
599 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
600 _ => {}
601 }
602 }
603 }
604
605 pub async fn wait_for_status(
607 &mut self,
608 desired: BacktestStatus,
609 timeout: Option<Duration>,
610 ) -> BacktestClientResult<()> {
611 let desired = std::mem::discriminant(&desired);
612
613 loop {
614 let response =
615 self.next_response(timeout)
616 .await?
617 .ok_or_else(|| BacktestClientError::Closed {
618 reason: "websocket ended while waiting for status".to_string(),
619 })?;
620
621 match response {
622 BacktestResponse::Status { status }
623 if std::mem::discriminant(&status) == desired =>
624 {
625 return Ok(());
626 }
627 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
628 BacktestResponse::Completed {
629 summary,
630 agent_stats,
631 } => {
632 return Err(BacktestClientError::UnexpectedResponse {
633 context: "waiting for status",
634 response: Box::new(BacktestResponse::Completed {
635 summary,
636 agent_stats,
637 }),
638 });
639 }
640 _ => {}
641 }
642 }
643 }
644
645 pub(crate) fn push_backlog(&mut self, response: BacktestResponse) {
647 self.backlog.push_back((None, response));
648 }
649
650 pub async fn send_continue(
652 &mut self,
653 params: ContinueParams,
654 timeout: Option<Duration>,
655 ) -> BacktestClientResult<()> {
656 self.ready_for_continue = false;
657 self.send(&BacktestRequest::Continue(params), timeout).await
658 }
659
660 pub async fn advance_step<F>(
662 &mut self,
663 state: &mut AdvanceState,
664 wait_for_slots: bool,
665 timeout: Option<Duration>,
666 on_event: &mut F,
667 ) -> BacktestClientResult<()>
668 where
669 F: FnMut(&BacktestResponse),
670 {
671 let Some(response) = self.next_response(timeout).await? else {
672 return Err(BacktestClientError::Closed {
673 reason: "websocket ended while awaiting continue responses".to_string(),
674 });
675 };
676
677 if self.log_raw {
678 tracing::debug!("<- {response:?}");
679 }
680
681 on_event(&response);
682
683 match response {
684 BacktestResponse::ReadyForContinue => {
685 self.ready_for_continue = true;
686 state.ready_for_continue = true;
687 }
688 BacktestResponse::SlotNotification(slot) => {
689 state.slot_notifications += 1;
690 state.last_slot = Some(slot);
691 }
692 BacktestResponse::Status { status } => {
693 state.statuses.push(status);
694 }
695 BacktestResponse::Success => {}
696 BacktestResponse::Completed {
697 summary,
698 agent_stats,
699 } => {
700 state.completed = true;
701 state.summary = summary;
702 state.agent_stats = agent_stats;
703 }
704 BacktestResponse::Error(err @ BacktestError::SimulationError { .. }) => {
705 tracing::warn!(error = %crate::error::err_chain(&err), "simulation error");
706 }
707 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
708 BacktestResponse::SessionCreated { .. }
709 | BacktestResponse::SessionAttached { .. }
710 | BacktestResponse::SessionsCreated { .. }
711 | BacktestResponse::SessionsCreatedV2 { .. }
712 | BacktestResponse::ParallelSessionAttachedV2 { .. }
713 | BacktestResponse::SessionEventV1 { .. }
714 | BacktestResponse::SessionEventV2 { .. }
715 | BacktestResponse::Paused(_)
716 | BacktestResponse::DiscoveryBatch(_) => {
717 return Err(BacktestClientError::UnexpectedResponse {
718 context: "continuing",
719 response: Box::new(response),
720 });
721 }
722 }
723
724 if wait_for_slots && state.slot_notifications > state.expected_slots {
725 tracing::warn!(
726 "received {} slot notifications (expected {})",
727 state.slot_notifications,
728 state.expected_slots
729 );
730 }
731
732 Ok(())
733 }
734
735 pub async fn continue_until_ready<F>(
737 &mut self,
738 cont: Continue,
739 timeout: Option<Duration>,
740 mut on_event: F,
741 ) -> BacktestClientResult<ContinueResult>
742 where
743 F: FnMut(&BacktestResponse),
744 {
745 let expected_slots = cont.advance_count;
746 self.advance_internal(
747 cont.into_params(),
748 expected_slots,
749 false,
750 timeout,
751 &mut on_event,
752 )
753 .await
754 }
755
756 pub async fn advance<F>(
758 &mut self,
759 cont: Continue,
760 timeout: Option<Duration>,
761 mut on_event: F,
762 ) -> BacktestClientResult<ContinueResult>
763 where
764 F: FnMut(&BacktestResponse),
765 {
766 let expected_slots = cont.advance_count;
767 self.advance_internal(
768 cont.into_params(),
769 expected_slots,
770 true,
771 timeout,
772 &mut on_event,
773 )
774 .await
775 }
776
777 async fn advance_internal<F>(
778 &mut self,
779 params: ContinueParams,
780 expected_slots: u64,
781 wait_for_slots: bool,
782 timeout: Option<Duration>,
783 on_event: &mut F,
784 ) -> BacktestClientResult<ContinueResult>
785 where
786 F: FnMut(&BacktestResponse),
787 {
788 self.send_continue(params, timeout).await?;
789
790 let mut state = AdvanceState::new(expected_slots);
791 while !state.is_done(wait_for_slots) {
792 self.advance_step(&mut state, wait_for_slots, timeout, on_event)
793 .await?;
794 }
795
796 Ok(ContinueResult {
797 slot_notifications: state.slot_notifications,
798 last_slot: state.last_slot,
799 statuses: state.statuses,
800 ready_for_continue: state.ready_for_continue,
801 completed: state.completed,
802 })
803 }
804
805 pub async fn modify_program(
814 &self,
815 program_id: &str,
816 elf: &[u8],
817 ) -> Result<BTreeMap<Address, AccountData>, ProgramModError> {
818 let rpc = self.rpc.as_ref().ok_or(ProgramModError::NoRpcEndpoint)?;
819 crate::injection::modify_program_via_rpc(rpc, program_id, elf).await
820 }
821
822 pub async fn modify_accounts(
826 &self,
827 modifications: &AccountModifications,
828 ) -> BacktestClientResult<usize> {
829 let rpc_endpoint =
830 self.rpc_endpoint
831 .as_deref()
832 .ok_or_else(|| BacktestClientError::Closed {
833 reason: "no RPC endpoint available".to_string(),
834 })?;
835
836 crate::rpc::modify_accounts(rpc_endpoint, modifications).await
837 }
838
839 pub async fn subscribe_program_logs<F, Fut>(
845 &self,
846 program_id: &str,
847 commitment: CommitmentConfig,
848 on_notification: F,
849 ) -> Result<LogSubscriptionHandle, SubscriptionError>
850 where
851 F: Fn(Response<RpcLogsResponse>) -> Fut + Send + Sync + 'static,
852 Fut: Future<Output = ()> + Send + 'static,
853 {
854 let rpc_endpoint = self
855 .rpc_endpoint
856 .as_deref()
857 .ok_or(SubscriptionError::NoRpcEndpoint)?;
858 crate::subscriptions::subscribe_program_logs(
859 rpc_endpoint,
860 program_id,
861 commitment,
862 on_notification,
863 )
864 .await
865 }
866
867 pub async fn subscribe_account_diffs<F, Fut>(
873 &self,
874 account: &str,
875 on_notification: F,
876 ) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
877 where
878 F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
879 Fut: Future<Output = ()> + Send + 'static,
880 {
881 let rpc_endpoint = self
882 .rpc_endpoint
883 .as_deref()
884 .ok_or(SubscriptionError::NoRpcEndpoint)?;
885 crate::subscriptions::subscribe_account_diffs(rpc_endpoint, account, on_notification).await
886 }
887
888 pub async fn close(&mut self, timeout: Option<Duration>) -> BacktestClientResult<()> {
892 self.close_with_frame(timeout, None).await
893 }
894
895 pub async fn close_with_frame(
897 &mut self,
898 timeout: Option<Duration>,
899 frame: Option<CloseFrame<'static>>,
900 ) -> BacktestClientResult<()> {
901 if self.ws.is_none() {
902 return Ok(());
903 }
904
905 let mut sent = false;
906 match self
907 .send(&BacktestRequest::CloseBacktestSession, timeout)
908 .await
909 {
910 Ok(()) => sent = true,
911 Err(err) if is_close_ok(&err) => {}
912 Err(err) => return Err(err),
913 }
914
915 if sent {
916 let response = match self.next_response(timeout).await {
917 Ok(Some(r)) => r,
918 Ok(None) => {
919 self.ws.take();
920 return Ok(());
921 }
922 Err(BacktestClientError::Closed { .. }) => {
923 self.ws.take();
924 return Ok(());
925 }
926 Err(BacktestClientError::WebSocket {
927 action: "receiving",
928 source,
929 }) if is_reset_without_close(&source) => {
930 self.ws.take();
931 return Ok(());
932 }
933 Err(err) => return Err(err),
934 };
935
936 match response {
937 BacktestResponse::Success | BacktestResponse::Completed { .. } => {}
938 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
939 other => {
940 return Err(BacktestClientError::UnexpectedResponse {
941 context: "closing session",
942 response: Box::new(other),
943 });
944 }
945 }
946 }
947
948 if let Some(mut ws) = self.ws.take()
951 && let Err(source) = ws.close(frame).await
952 && !is_ws_closed_error(&source)
953 {
954 return Err(BacktestClientError::WebSocket {
955 action: "closing",
956 source: Box::new(source),
957 });
958 }
959 Ok(())
960 }
961
962 pub async fn close_with_reason(
964 &mut self,
965 timeout: Option<Duration>,
966 code: CloseCode,
967 reason: impl Into<String>,
968 ) -> BacktestClientResult<()> {
969 let frame = CloseFrame {
970 code,
971 reason: Cow::Owned(reason.into()),
972 };
973 self.close_with_frame(timeout, Some(frame)).await
974 }
975
976 pub fn abort(mut self) {
983 let _ws = self.ws.take();
984 }
985
986 async fn next_text(
987 &mut self,
988 timeout: Option<Duration>,
989 ) -> BacktestClientResult<Option<String>> {
990 loop {
991 let request_timeout = self.request_timeout;
992 let timeout = timeout.or(request_timeout);
993
994 let next_fut = self.ws_mut()?.next();
995 let msg = match timeout {
996 Some(duration) => tokio::time::timeout(duration, next_fut)
997 .await
998 .map_err(|_| BacktestClientError::Timeout {
999 action: "receiving",
1000 duration,
1001 })?,
1002 None => next_fut.await,
1003 };
1004
1005 let Some(msg) = msg else {
1006 return Ok(None);
1007 };
1008
1009 let msg = match msg {
1010 Ok(msg) => msg,
1011 Err(source) => {
1012 return Err(BacktestClientError::WebSocket {
1013 action: "receiving",
1014 source: Box::new(source),
1015 });
1016 }
1017 };
1018
1019 match msg {
1020 Message::Text(text) => {
1021 if self.log_raw {
1022 tracing::debug!("<- raw: {text}");
1023 }
1024 return Ok(Some(text));
1025 }
1026 Message::Binary(bin) => match String::from_utf8(bin) {
1027 Ok(text) => {
1028 if self.log_raw {
1029 tracing::debug!("<- raw(bin): {text}");
1030 }
1031 return Ok(Some(text));
1032 }
1033 Err(err) => {
1034 tracing::warn!("discarding non-utf8 binary message: {err}");
1035 continue;
1036 }
1037 },
1038 Message::Close(frame) => {
1039 let reason = close_reason(frame);
1040 return Err(BacktestClientError::Closed { reason });
1041 }
1042 Message::Ping(_) | Message::Pong(_) => continue,
1043 Message::Frame(_) => continue,
1044 }
1045 }
1046 }
1047}
1048
1049impl std::fmt::Debug for BacktestSession {
1050 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1051 f.debug_struct("BacktestSession")
1052 .field("session_id", &self.session_id)
1053 .field("rpc_endpoint", &self.rpc_endpoint)
1054 .field(
1055 "rpc",
1056 &self
1057 .rpc
1058 .as_ref()
1059 .map(|_| "<RpcClient>")
1060 .unwrap_or("<not set>"),
1061 )
1062 .field("ready_for_continue", &self.ready_for_continue)
1063 .field("request_timeout", &self.request_timeout)
1064 .finish_non_exhaustive()
1065 }
1066}
1067
1068#[derive(Debug)]
1069pub(crate) enum CreateRequestResult {
1070 Single {
1071 session_id: String,
1072 task_id: Option<String>,
1073 },
1074 Parallel {
1075 session_ids: Vec<String>,
1076 task_ids: Vec<Option<String>>,
1077 },
1078}
1079
1080fn align_task_ids(mut task_ids: Vec<Option<String>>, expected_len: usize) -> Vec<Option<String>> {
1083 if task_ids.len() < expected_len {
1084 task_ids.resize(expected_len, None);
1085 } else if task_ids.len() > expected_len {
1086 task_ids.truncate(expected_len);
1087 }
1088 task_ids
1089}
1090
1091impl Drop for BacktestSession {
1092 fn drop(&mut self) {
1093 let Some(ws) = self.ws.take() else {
1094 return;
1095 };
1096
1097 if let Ok(handle) = tokio::runtime::Handle::try_current() {
1098 handle.spawn(async move {
1099 let mut ws = ws;
1100 let _ = ws.close(None).await;
1101 });
1102 }
1103 }
1104}
1105
1106fn resolve_rpc_url(base: &str, endpoint: &str) -> String {
1107 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
1108 endpoint.to_string()
1109 } else {
1110 format!("{}/{}", base, endpoint.trim_start_matches('/'))
1111 }
1112}
1113
1114fn close_reason(frame: Option<CloseFrame<'static>>) -> String {
1115 match frame {
1116 Some(frame) => format!("{:?}: {}", frame.code, frame.reason),
1117 None => "no close frame".to_string(),
1118 }
1119}
1120
1121fn is_reset_without_close(err: &WsError) -> bool {
1122 matches!(
1123 err,
1124 WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1125 )
1126}
1127
1128fn is_ws_closed_error(err: &WsError) -> bool {
1129 matches!(
1130 err,
1131 WsError::ConnectionClosed
1132 | WsError::AlreadyClosed
1133 | WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1134 )
1135}
1136
1137fn is_close_ok(err: &BacktestClientError) -> bool {
1138 match err {
1139 BacktestClientError::Closed { .. } => true,
1140 BacktestClientError::WebSocket { source, .. } => is_ws_closed_error(source),
1141 _ => false,
1142 }
1143}
1144
1145#[cfg(test)]
1146mod tests {
1147 use super::*;
1148
1149 #[test]
1150 fn coverage_tracks_slot_and_completion_from_responses() {
1151 let mut coverage = SessionCoverage::default();
1152 coverage.observe_response(&BacktestResponse::SlotNotification(10));
1153 coverage.observe_response(&BacktestResponse::SlotNotification(12));
1154 coverage.observe_response(&BacktestResponse::Completed {
1155 summary: None,
1156 agent_stats: None,
1157 });
1158
1159 assert!(coverage.is_completed());
1160 assert_eq!(coverage.highest_slot_seen(), Some(12));
1161 }
1162
1163 #[test]
1164 fn coverage_validate_end_slot_checks_completion_and_range() {
1165 let mut coverage = SessionCoverage::default();
1166 assert_eq!(
1167 coverage.validate_end_slot(5),
1168 Err(CoverageError::NotCompleted)
1169 );
1170
1171 coverage.mark_completed();
1172 assert_eq!(
1173 coverage.validate_end_slot(5),
1174 Err(CoverageError::NoSlotsObserved)
1175 );
1176
1177 coverage.observe_slot(4);
1178 assert_eq!(
1179 coverage.validate_end_slot(5),
1180 Err(CoverageError::RangeNotReached {
1181 actual_end_slot: 4,
1182 expected_end_slot: 5,
1183 })
1184 );
1185
1186 coverage.observe_slot(6);
1187 assert_eq!(coverage.validate_end_slot(5), Ok(()));
1188 }
1189}