1use std::{
2 borrow::Cow,
3 collections::{BTreeMap, VecDeque},
4 future::Future,
5 time::Duration,
6};
7
8use futures::{SinkExt, StreamExt, stream};
9use simulator_api::{
10 AccountData, AccountModifications, AgentStatsReport, BacktestError, BacktestRequest,
11 BacktestResponse, BacktestStatus, ContinueParams, CreateBacktestSessionRequest,
12 CreateBacktestSessionRequestV1, SequencedResponse, SessionSummary,
13};
14use solana_address::Address;
15use solana_client::{
16 nonblocking::rpc_client::RpcClient,
17 rpc_response::{Response, RpcLogsResponse},
18};
19use solana_commitment_config::CommitmentConfig;
20use thiserror::Error;
21use tokio::net::TcpStream;
22use tokio_tungstenite::{
23 MaybeTlsStream, WebSocketStream,
24 tungstenite::{
25 Error as WsError, Message,
26 error::ProtocolError,
27 protocol::{CloseFrame, frame::coding::CloseCode},
28 },
29};
30
31use crate::{
32 BacktestClientError, BacktestClientResult, Continue,
33 injection::{ProgramModError, build_program_injection},
34 subscriptions::{
35 AccountDiffNotification, AccountDiffSubscriptionHandle, LogSubscriptionHandle,
36 SubscriptionError,
37 },
38};
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum ReadyOutcome {
43 Ready,
45 Completed,
47}
48
49#[derive(Debug, Default)]
51pub struct ContinueResult {
52 pub slot_notifications: u64,
54 pub last_slot: Option<u64>,
56 pub statuses: Vec<BacktestStatus>,
58 pub ready_for_continue: bool,
60 pub completed: bool,
62}
63
64#[derive(Debug)]
66pub struct AdvanceState {
67 pub expected_slots: u64,
69 pub slot_notifications: u64,
71 pub last_slot: Option<u64>,
73 pub statuses: Vec<BacktestStatus>,
75 pub ready_for_continue: bool,
77 pub completed: bool,
79 pub summary: Option<SessionSummary>,
81 pub agent_stats: Option<Vec<AgentStatsReport>>,
83}
84
85impl AdvanceState {
86 pub fn new(expected_slots: u64) -> Self {
88 Self {
89 expected_slots,
90 slot_notifications: 0,
91 last_slot: None,
92 statuses: Vec::new(),
93 ready_for_continue: false,
94 completed: false,
95 summary: None,
96 agent_stats: None,
97 }
98 }
99
100 pub fn is_done(&self, wait_for_slots: bool) -> bool {
102 if self.completed {
103 return true;
104 }
105
106 if !self.ready_for_continue {
107 return false;
108 }
109
110 !wait_for_slots || self.slot_notifications >= self.expected_slots
111 }
112}
113
114#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
116pub struct SessionCoverage {
117 completed: bool,
118 highest_slot_seen: Option<u64>,
119}
120
121impl SessionCoverage {
122 pub fn observe_slot(&mut self, slot: u64) {
124 self.highest_slot_seen = Some(
125 self.highest_slot_seen
126 .map_or(slot, |current| current.max(slot)),
127 );
128 }
129
130 pub fn mark_completed(&mut self) {
132 self.completed = true;
133 }
134
135 pub fn observe_response(&mut self, response: &BacktestResponse) {
137 match response {
138 BacktestResponse::SlotNotification(slot) => self.observe_slot(*slot),
139 BacktestResponse::Completed { .. } => self.mark_completed(),
140 _ => {}
141 }
142 }
143
144 pub fn is_completed(&self) -> bool {
146 self.completed
147 }
148
149 pub fn highest_slot_seen(&self) -> Option<u64> {
151 self.highest_slot_seen
152 }
153
154 pub fn validate_end_slot(&self, expected_end_slot: u64) -> Result<(), CoverageError> {
156 if !self.completed {
157 return Err(CoverageError::NotCompleted);
158 }
159
160 let Some(actual_end_slot) = self.highest_slot_seen else {
161 return Err(CoverageError::NoSlotsObserved);
162 };
163
164 if actual_end_slot < expected_end_slot {
165 return Err(CoverageError::RangeNotReached {
166 actual_end_slot,
167 expected_end_slot,
168 });
169 }
170
171 Ok(())
172 }
173}
174
175#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
177pub enum CoverageError {
178 #[error("ended before completion")]
179 NotCompleted,
180 #[error("completed without slot notifications")]
181 NoSlotsObserved,
182 #[error("completed at slot {actual_end_slot} but expected at least {expected_end_slot}")]
183 RangeNotReached {
184 actual_end_slot: u64,
185 expected_end_slot: u64,
186 },
187}
188
189pub struct BacktestSession {
194 ws: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
195 session_id: Option<String>,
196 rpc_endpoint: Option<String>,
197 rpc: Option<RpcClient>,
198 last_sequence: Option<u64>,
199 ready_for_continue: bool,
200 request_timeout: Option<Duration>,
201 log_raw: bool,
202 backlog: VecDeque<(Option<u64>, BacktestResponse)>,
203}
204
205impl BacktestSession {
206 pub(crate) fn new(
207 ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
208 request_timeout: Option<Duration>,
209 log_raw: bool,
210 ) -> Self {
211 Self {
212 ws: Some(ws),
213 session_id: None,
214 rpc_endpoint: None,
215 rpc: None,
216 last_sequence: None,
217 ready_for_continue: false,
218 request_timeout,
219 log_raw,
220 backlog: VecDeque::new(),
221 }
222 }
223
224 pub fn session_id(&self) -> Option<&str> {
226 self.session_id.as_deref()
227 }
228
229 pub fn rpc_endpoint(&self) -> Option<&str> {
231 self.rpc_endpoint.as_deref()
232 }
233
234 pub fn last_sequence(&self) -> Option<u64> {
236 self.last_sequence
237 }
238
239 pub fn rpc(&self) -> &RpcClient {
243 self.rpc
244 .as_ref()
245 .expect("rpc is set during session creation")
246 }
247
248 pub fn is_ready_for_continue(&self) -> bool {
250 self.ready_for_continue
251 }
252
253 pub fn apply_response(&mut self, response: &BacktestResponse) {
255 match response {
256 BacktestResponse::ReadyForContinue => {
257 self.ready_for_continue = true;
258 }
259 BacktestResponse::Completed { .. } => {
260 self.ready_for_continue = false;
261 }
262 _ => {}
263 }
264 }
265
266 fn ws_mut(&mut self) -> BacktestClientResult<&mut WebSocketStream<MaybeTlsStream<TcpStream>>> {
267 self.ws.as_mut().ok_or_else(|| BacktestClientError::Closed {
268 reason: "websocket closed".to_string(),
269 })
270 }
271
272 pub(crate) async fn create_with_request(
273 &mut self,
274 request: CreateBacktestSessionRequest,
275 rpc_base_url: String,
276 mut on_parallel_session_created: Option<&mut (dyn FnMut(String) + Send)>,
277 ) -> BacktestClientResult<CreateRequestResult> {
278 let expect_parallel = matches!(
279 &request,
280 CreateBacktestSessionRequest::V1(CreateBacktestSessionRequestV1 { parallel: true, .. })
281 );
282 self.send(&BacktestRequest::CreateBacktestSession(request), None)
283 .await?;
284 let mut streamed_parallel_session_ids = Vec::new();
285
286 loop {
287 let response =
288 self.next_response(None)
289 .await?
290 .ok_or_else(|| BacktestClientError::Closed {
291 reason: "websocket ended before SessionCreated".to_string(),
292 })?;
293
294 match response {
295 BacktestResponse::SessionCreated {
296 session_id,
297 rpc_endpoint,
298 } => {
299 if expect_parallel {
300 if let Some(callback) = on_parallel_session_created.as_mut() {
301 (**callback)(session_id.clone());
302 }
303 streamed_parallel_session_ids.push(session_id);
304 continue;
305 }
306 let created_session_id = session_id.clone();
307 self.session_id = Some(session_id);
308 let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
309 self.rpc = Some(RpcClient::new_with_commitment(
310 resolved.clone(),
311 CommitmentConfig::confirmed(),
312 ));
313 self.rpc_endpoint = Some(resolved);
314 return Ok(CreateRequestResult::Single {
315 session_id: created_session_id,
316 });
317 }
318 BacktestResponse::SessionsCreated { session_ids } => {
319 if expect_parallel && session_ids.is_empty() {
320 return Ok(CreateRequestResult::Parallel {
321 session_ids: streamed_parallel_session_ids,
322 });
323 }
324 return Ok(CreateRequestResult::Parallel { session_ids });
325 }
326 BacktestResponse::SessionsCreatedV2 { session_ids, .. } => {
327 if expect_parallel && session_ids.is_empty() {
328 return Ok(CreateRequestResult::Parallel {
329 session_ids: streamed_parallel_session_ids,
330 });
331 }
332 return Ok(CreateRequestResult::Parallel { session_ids });
333 }
334 BacktestResponse::ReadyForContinue => {
335 self.ready_for_continue = true;
336 }
337 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
338 other => {
339 self.backlog.push_back((self.last_sequence, other));
340 }
341 }
342 }
343 }
344
345 pub(crate) async fn attach(
346 &mut self,
347 session_id: String,
348 last_sequence: Option<u64>,
349 rpc_base_url: String,
350 ) -> BacktestClientResult<()> {
351 self.send(
352 &BacktestRequest::AttachBacktestSession {
353 session_id,
354 last_sequence,
355 },
356 None,
357 )
358 .await?;
359
360 self.wait_for_response(
361 || BacktestClientError::Closed {
362 reason: "websocket ended before SessionAttached".to_string(),
363 },
364 move |session, response| match response {
365 BacktestResponse::SessionAttached {
366 session_id,
367 rpc_endpoint,
368 } => {
369 session.session_id = Some(session_id);
370 let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
371 session.rpc = Some(RpcClient::new_with_commitment(
372 resolved.clone(),
373 CommitmentConfig::confirmed(),
374 ));
375 session.rpc_endpoint = Some(resolved);
376 Ok(Some(()))
377 }
378 BacktestResponse::ReadyForContinue => {
379 session.ready_for_continue = true;
380 Ok(None)
381 }
382 BacktestResponse::Error(err) => Err(BacktestClientError::Remote(err)),
383 other => {
384 session.backlog.push_back((session.last_sequence, other));
385 Ok(None)
386 }
387 },
388 )
389 .await
390 }
391
392 pub async fn resume_attached_session(&mut self) -> BacktestClientResult<()> {
395 self.send(&BacktestRequest::ResumeAttachedSession, None)
396 .await?;
397
398 self.wait_for_response(
399 || BacktestClientError::Closed {
400 reason: "websocket ended before ResumeAttachedSession acknowledgement".to_string(),
401 },
402 |session, response| match response {
403 BacktestResponse::Success => Ok(Some(())),
404 BacktestResponse::Error(err) => Err(BacktestClientError::Remote(err)),
405 other => {
406 session.backlog.push_back((session.last_sequence, other));
407 Ok(None)
408 }
409 },
410 )
411 .await
412 }
413
414 async fn wait_for_response<T, E, F>(
415 &mut self,
416 closed_error: E,
417 mut handle_response: F,
418 ) -> BacktestClientResult<T>
419 where
420 E: FnOnce() -> BacktestClientError,
421 F: FnMut(&mut Self, BacktestResponse) -> BacktestClientResult<Option<T>>,
422 {
423 let mut closed_error = Some(closed_error);
424
425 loop {
426 let response = self
427 .next_response(None)
428 .await?
429 .ok_or_else(|| closed_error.take().expect("closed error set")())?;
430
431 if let Some(result) = handle_response(self, response)? {
432 return Ok(result);
433 }
434 }
435 }
436
437 pub async fn send(
439 &mut self,
440 request: &BacktestRequest,
441 timeout: Option<Duration>,
442 ) -> BacktestClientResult<()> {
443 let text = serde_json::to_string(request)
444 .map_err(|source| BacktestClientError::SerializeRequest { source })?;
445
446 let request_timeout = self.request_timeout;
447 let timeout = timeout.or(request_timeout);
448
449 let send_fut = self.ws_mut()?.send(Message::Text(text));
450 let send_result = match timeout {
451 Some(duration) => tokio::time::timeout(duration, send_fut)
452 .await
453 .map_err(|_| BacktestClientError::Timeout {
454 action: "sending",
455 duration,
456 })?,
457 None => send_fut.await,
458 };
459
460 send_result.map_err(|source| BacktestClientError::WebSocket {
461 action: "sending",
462 source: Box::new(source),
463 })?;
464
465 Ok(())
466 }
467
468 pub async fn next_response(
470 &mut self,
471 timeout: Option<Duration>,
472 ) -> BacktestClientResult<Option<BacktestResponse>> {
473 if let Some((sequence, response)) = self.backlog.pop_front() {
474 self.last_sequence = sequence.or(self.last_sequence);
475 return Ok(Some(response));
476 }
477
478 let text = match self.next_text(timeout).await? {
479 Some(text) => text,
480 None => return Ok(None),
481 };
482
483 let (sequence, response) = match serde_json::from_str::<SequencedResponse>(&text) {
484 Ok(sequenced) => (Some(sequenced.seq_id), sequenced.response),
485 Err(_) => {
486 let response =
487 serde_json::from_str::<BacktestResponse>(&text).map_err(|source| {
488 BacktestClientError::DeserializeResponse {
489 raw: text.clone(),
490 source,
491 }
492 })?;
493 (None, response)
494 }
495 };
496 self.last_sequence = sequence.or(self.last_sequence);
497
498 Ok(Some(response))
499 }
500
501 pub async fn next_event(
503 &mut self,
504 timeout: Option<Duration>,
505 ) -> BacktestClientResult<Option<BacktestResponse>> {
506 let response = self.next_response(timeout).await?;
507 if let Some(ref response) = response {
508 self.apply_response(response);
509 }
510 Ok(response)
511 }
512
513 pub fn responses(
517 self,
518 timeout: Option<Duration>,
519 ) -> impl futures::Stream<Item = BacktestClientResult<BacktestResponse>> {
520 stream::unfold(Some(self), move |state| async move {
521 let mut session = match state {
522 Some(session) => session,
523 None => return None,
524 };
525
526 match session.next_response(timeout).await {
527 Ok(Some(response)) => {
528 session.apply_response(&response);
529 Some((Ok(response), Some(session)))
530 }
531 Ok(None) => None,
532 Err(err) => Some((Err(err), None)),
533 }
534 })
535 }
536
537 pub async fn ensure_ready(
539 &mut self,
540 timeout: Option<Duration>,
541 ) -> BacktestClientResult<ReadyOutcome> {
542 if self.ready_for_continue {
543 return Ok(ReadyOutcome::Ready);
544 }
545
546 loop {
547 let response =
548 self.next_response(timeout)
549 .await?
550 .ok_or_else(|| BacktestClientError::Closed {
551 reason: "websocket ended while waiting for ReadyForContinue".to_string(),
552 })?;
553
554 match response {
555 BacktestResponse::ReadyForContinue => {
556 self.ready_for_continue = true;
557 return Ok(ReadyOutcome::Ready);
558 }
559 BacktestResponse::Completed { .. } => return Ok(ReadyOutcome::Completed),
560 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
561 _ => {}
562 }
563 }
564 }
565
566 pub async fn wait_for_status(
568 &mut self,
569 desired: BacktestStatus,
570 timeout: Option<Duration>,
571 ) -> BacktestClientResult<()> {
572 let desired = std::mem::discriminant(&desired);
573
574 loop {
575 let response =
576 self.next_response(timeout)
577 .await?
578 .ok_or_else(|| BacktestClientError::Closed {
579 reason: "websocket ended while waiting for status".to_string(),
580 })?;
581
582 match response {
583 BacktestResponse::Status { status }
584 if std::mem::discriminant(&status) == desired =>
585 {
586 return Ok(());
587 }
588 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
589 BacktestResponse::Completed {
590 summary,
591 agent_stats,
592 } => {
593 return Err(BacktestClientError::UnexpectedResponse {
594 context: "waiting for status",
595 response: Box::new(BacktestResponse::Completed {
596 summary,
597 agent_stats,
598 }),
599 });
600 }
601 _ => {}
602 }
603 }
604 }
605
606 pub async fn send_continue(
608 &mut self,
609 params: ContinueParams,
610 timeout: Option<Duration>,
611 ) -> BacktestClientResult<()> {
612 self.ready_for_continue = false;
613 self.send(&BacktestRequest::Continue(params), timeout).await
614 }
615
616 pub async fn advance_step<F>(
618 &mut self,
619 state: &mut AdvanceState,
620 wait_for_slots: bool,
621 timeout: Option<Duration>,
622 on_event: &mut F,
623 ) -> BacktestClientResult<()>
624 where
625 F: FnMut(&BacktestResponse),
626 {
627 let Some(response) = self.next_response(timeout).await? else {
628 return Err(BacktestClientError::Closed {
629 reason: "websocket ended while awaiting continue responses".to_string(),
630 });
631 };
632
633 if self.log_raw {
634 tracing::debug!("<- {response:?}");
635 }
636
637 on_event(&response);
638
639 match response {
640 BacktestResponse::ReadyForContinue => {
641 self.ready_for_continue = true;
642 state.ready_for_continue = true;
643 }
644 BacktestResponse::SlotNotification(slot) => {
645 state.slot_notifications += 1;
646 state.last_slot = Some(slot);
647 }
648 BacktestResponse::Status { status } => {
649 state.statuses.push(status);
650 }
651 BacktestResponse::Success => {}
652 BacktestResponse::Completed {
653 summary,
654 agent_stats,
655 } => {
656 state.completed = true;
657 state.summary = summary;
658 state.agent_stats = agent_stats;
659 }
660 BacktestResponse::Error(err @ BacktestError::SimulationError { .. }) => {
661 tracing::warn!(error = %err, "simulation error");
662 }
663 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
664 BacktestResponse::SessionCreated { .. }
665 | BacktestResponse::SessionAttached { .. }
666 | BacktestResponse::SessionsCreated { .. }
667 | BacktestResponse::SessionsCreatedV2 { .. }
668 | BacktestResponse::ParallelSessionAttachedV2 { .. }
669 | BacktestResponse::SessionEventV1 { .. }
670 | BacktestResponse::SessionEventV2 { .. } => {
671 return Err(BacktestClientError::UnexpectedResponse {
672 context: "continuing",
673 response: Box::new(response),
674 });
675 }
676 }
677
678 if wait_for_slots && state.slot_notifications > state.expected_slots {
679 tracing::warn!(
680 "received {} slot notifications (expected {})",
681 state.slot_notifications,
682 state.expected_slots
683 );
684 }
685
686 Ok(())
687 }
688
689 pub async fn continue_until_ready<F>(
691 &mut self,
692 cont: Continue,
693 timeout: Option<Duration>,
694 mut on_event: F,
695 ) -> BacktestClientResult<ContinueResult>
696 where
697 F: FnMut(&BacktestResponse),
698 {
699 let expected_slots = cont.advance_count;
700 self.advance_internal(
701 cont.into_params(),
702 expected_slots,
703 false,
704 timeout,
705 &mut on_event,
706 )
707 .await
708 }
709
710 pub async fn advance<F>(
712 &mut self,
713 cont: Continue,
714 timeout: Option<Duration>,
715 mut on_event: F,
716 ) -> BacktestClientResult<ContinueResult>
717 where
718 F: FnMut(&BacktestResponse),
719 {
720 let expected_slots = cont.advance_count;
721 self.advance_internal(
722 cont.into_params(),
723 expected_slots,
724 true,
725 timeout,
726 &mut on_event,
727 )
728 .await
729 }
730
731 async fn advance_internal<F>(
732 &mut self,
733 params: ContinueParams,
734 expected_slots: u64,
735 wait_for_slots: bool,
736 timeout: Option<Duration>,
737 on_event: &mut F,
738 ) -> BacktestClientResult<ContinueResult>
739 where
740 F: FnMut(&BacktestResponse),
741 {
742 self.send_continue(params, timeout).await?;
743
744 let mut state = AdvanceState::new(expected_slots);
745 while !state.is_done(wait_for_slots) {
746 self.advance_step(&mut state, wait_for_slots, timeout, on_event)
747 .await?;
748 }
749
750 Ok(ContinueResult {
751 slot_notifications: state.slot_notifications,
752 last_slot: state.last_slot,
753 statuses: state.statuses,
754 ready_for_continue: state.ready_for_continue,
755 completed: state.completed,
756 })
757 }
758
759 pub async fn modify_program(
768 &self,
769 program_id: &str,
770 elf: &[u8],
771 ) -> Result<BTreeMap<Address, AccountData>, ProgramModError> {
772 let rpc = self.rpc.as_ref().ok_or(ProgramModError::NoRpcEndpoint)?;
773
774 let program_addr: Address =
775 program_id
776 .parse()
777 .map_err(|_| ProgramModError::InvalidProgramId {
778 id: program_id.to_string(),
779 })?;
780 let programdata_addr = solana_loader_v3_interface::get_program_data_address(&program_addr);
781
782 let slot = rpc.get_slot().await.map_err(|e| ProgramModError::Rpc {
783 source: Box::new(e),
784 })?;
785 let deploy_slot = slot.saturating_sub(1);
786
787 let existing =
790 rpc.get_account(&programdata_addr)
791 .await
792 .map_err(|e| ProgramModError::Rpc {
793 source: Box::new(e),
794 })?;
795
796 let upgrade_authority = if existing.data.get(12).copied() == Some(1) {
797 existing.data.get(13..45).and_then(|b| {
798 let bytes: [u8; 32] = b.try_into().ok()?;
799 Some(Address::from(bytes))
800 })
801 } else {
802 None
803 };
804
805 let data_len = upgrade_authority.map_or(13, |_| 45) + elf.len();
806 let lamports = rpc
807 .get_minimum_balance_for_rent_exemption(data_len)
808 .await
809 .map_err(|e| ProgramModError::Rpc {
810 source: Box::new(e),
811 })?;
812
813 Ok(build_program_injection(
814 programdata_addr,
815 elf,
816 deploy_slot,
817 upgrade_authority,
818 lamports,
819 ))
820 }
821
822 pub async fn modify_accounts(
826 &self,
827 modifications: &AccountModifications,
828 ) -> BacktestClientResult<usize> {
829 let rpc_endpoint =
830 self.rpc_endpoint
831 .as_deref()
832 .ok_or_else(|| BacktestClientError::Closed {
833 reason: "no RPC endpoint available".to_string(),
834 })?;
835
836 crate::rpc::modify_accounts(rpc_endpoint, modifications).await
837 }
838
839 pub async fn subscribe_program_logs<F, Fut>(
845 &self,
846 program_id: &str,
847 commitment: CommitmentConfig,
848 on_notification: F,
849 ) -> Result<LogSubscriptionHandle, SubscriptionError>
850 where
851 F: Fn(Response<RpcLogsResponse>) -> Fut + Send + Sync + 'static,
852 Fut: Future<Output = ()> + Send + 'static,
853 {
854 let rpc_endpoint = self
855 .rpc_endpoint
856 .as_deref()
857 .ok_or(SubscriptionError::NoRpcEndpoint)?;
858 crate::subscriptions::subscribe_program_logs(
859 rpc_endpoint,
860 program_id,
861 commitment,
862 on_notification,
863 )
864 .await
865 }
866
867 pub async fn subscribe_account_diffs<F, Fut>(
873 &self,
874 account: &str,
875 on_notification: F,
876 ) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
877 where
878 F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
879 Fut: Future<Output = ()> + Send + 'static,
880 {
881 let rpc_endpoint = self
882 .rpc_endpoint
883 .as_deref()
884 .ok_or(SubscriptionError::NoRpcEndpoint)?;
885 crate::subscriptions::subscribe_account_diffs(rpc_endpoint, account, on_notification).await
886 }
887
888 pub async fn close(&mut self, timeout: Option<Duration>) -> BacktestClientResult<()> {
892 self.close_with_frame(timeout, None).await
893 }
894
895 pub async fn close_with_frame(
897 &mut self,
898 timeout: Option<Duration>,
899 frame: Option<CloseFrame<'static>>,
900 ) -> BacktestClientResult<()> {
901 if self.ws.is_none() {
902 return Ok(());
903 }
904
905 let mut sent = false;
906 match self
907 .send(&BacktestRequest::CloseBacktestSession, timeout)
908 .await
909 {
910 Ok(()) => sent = true,
911 Err(err) if is_close_ok(&err) => {}
912 Err(err) => return Err(err),
913 }
914
915 if sent {
916 let response = match self.next_response(timeout).await {
917 Ok(Some(r)) => r,
918 Ok(None) => {
919 self.ws.take();
920 return Ok(());
921 }
922 Err(BacktestClientError::Closed { .. }) => {
923 self.ws.take();
924 return Ok(());
925 }
926 Err(BacktestClientError::WebSocket {
927 action: "receiving",
928 source,
929 }) if is_reset_without_close(&source) => {
930 self.ws.take();
931 return Ok(());
932 }
933 Err(err) => return Err(err),
934 };
935
936 match response {
937 BacktestResponse::Success | BacktestResponse::Completed { .. } => {}
938 BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
939 other => {
940 return Err(BacktestClientError::UnexpectedResponse {
941 context: "closing session",
942 response: Box::new(other),
943 });
944 }
945 }
946 }
947
948 if let Some(ws) = self.ws.as_mut() {
949 let close_result = ws.close(frame).await;
950 if let Err(source) = close_result
951 && !is_ws_closed_error(&source)
952 {
953 return Err(BacktestClientError::WebSocket {
954 action: "closing",
955 source: Box::new(source),
956 });
957 }
958 }
959
960 match self.next_response(timeout).await {
962 Ok(_) => {}
963 Err(BacktestClientError::Closed { .. }) => {}
964 Err(BacktestClientError::WebSocket {
965 action: "receiving",
966 source,
967 }) if is_reset_without_close(&source) => {}
968 Err(err) => return Err(err),
969 }
970
971 tokio::time::sleep(Duration::from_millis(100)).await;
973 self.ws.take();
974 Ok(())
975 }
976
977 pub async fn close_with_reason(
979 &mut self,
980 timeout: Option<Duration>,
981 code: CloseCode,
982 reason: impl Into<String>,
983 ) -> BacktestClientResult<()> {
984 let frame = CloseFrame {
985 code,
986 reason: Cow::Owned(reason.into()),
987 };
988 self.close_with_frame(timeout, Some(frame)).await
989 }
990
991 async fn next_text(
992 &mut self,
993 timeout: Option<Duration>,
994 ) -> BacktestClientResult<Option<String>> {
995 loop {
996 let request_timeout = self.request_timeout;
997 let timeout = timeout.or(request_timeout);
998
999 let next_fut = self.ws_mut()?.next();
1000 let msg = match timeout {
1001 Some(duration) => tokio::time::timeout(duration, next_fut)
1002 .await
1003 .map_err(|_| BacktestClientError::Timeout {
1004 action: "receiving",
1005 duration,
1006 })?,
1007 None => next_fut.await,
1008 };
1009
1010 let Some(msg) = msg else {
1011 return Ok(None);
1012 };
1013
1014 let msg = match msg {
1015 Ok(msg) => msg,
1016 Err(source) => {
1017 return Err(BacktestClientError::WebSocket {
1018 action: "receiving",
1019 source: Box::new(source),
1020 });
1021 }
1022 };
1023
1024 match msg {
1025 Message::Text(text) => {
1026 if self.log_raw {
1027 tracing::debug!("<- raw: {text}");
1028 }
1029 return Ok(Some(text));
1030 }
1031 Message::Binary(bin) => match String::from_utf8(bin) {
1032 Ok(text) => {
1033 if self.log_raw {
1034 tracing::debug!("<- raw(bin): {text}");
1035 }
1036 return Ok(Some(text));
1037 }
1038 Err(err) => {
1039 tracing::warn!("discarding non-utf8 binary message: {err}");
1040 continue;
1041 }
1042 },
1043 Message::Close(frame) => {
1044 let reason = close_reason(frame);
1045 return Err(BacktestClientError::Closed { reason });
1046 }
1047 Message::Ping(_) | Message::Pong(_) => continue,
1048 Message::Frame(_) => continue,
1049 }
1050 }
1051 }
1052}
1053
1054impl std::fmt::Debug for BacktestSession {
1055 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1056 f.debug_struct("BacktestSession")
1057 .field("session_id", &self.session_id)
1058 .field("rpc_endpoint", &self.rpc_endpoint)
1059 .field(
1060 "rpc",
1061 &self
1062 .rpc
1063 .as_ref()
1064 .map(|_| "<RpcClient>")
1065 .unwrap_or("<not set>"),
1066 )
1067 .field("ready_for_continue", &self.ready_for_continue)
1068 .field("request_timeout", &self.request_timeout)
1069 .finish_non_exhaustive()
1070 }
1071}
1072
1073#[derive(Debug)]
1074pub(crate) enum CreateRequestResult {
1075 Single { session_id: String },
1076 Parallel { session_ids: Vec<String> },
1077}
1078
1079impl Drop for BacktestSession {
1080 fn drop(&mut self) {
1081 let Some(ws) = self.ws.take() else {
1082 return;
1083 };
1084
1085 if let Ok(handle) = tokio::runtime::Handle::try_current() {
1086 handle.spawn(async move {
1087 let mut ws = ws;
1088 let _ = ws.close(None).await;
1089 });
1090 }
1091 }
1092}
1093
1094fn resolve_rpc_url(base: &str, endpoint: &str) -> String {
1095 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
1096 endpoint.to_string()
1097 } else {
1098 format!("{}/{}", base, endpoint.trim_start_matches('/'))
1099 }
1100}
1101
1102fn close_reason(frame: Option<CloseFrame<'static>>) -> String {
1103 match frame {
1104 Some(frame) => format!("{:?}: {}", frame.code, frame.reason),
1105 None => "no close frame".to_string(),
1106 }
1107}
1108
1109fn is_reset_without_close(err: &WsError) -> bool {
1110 matches!(
1111 err,
1112 WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1113 )
1114}
1115
1116fn is_ws_closed_error(err: &WsError) -> bool {
1117 matches!(
1118 err,
1119 WsError::ConnectionClosed
1120 | WsError::AlreadyClosed
1121 | WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1122 )
1123}
1124
1125fn is_close_ok(err: &BacktestClientError) -> bool {
1126 match err {
1127 BacktestClientError::Closed { .. } => true,
1128 BacktestClientError::WebSocket { source, .. } => is_ws_closed_error(source),
1129 _ => false,
1130 }
1131}
1132
1133#[cfg(test)]
1134mod tests {
1135 use super::*;
1136
1137 #[test]
1138 fn coverage_tracks_slot_and_completion_from_responses() {
1139 let mut coverage = SessionCoverage::default();
1140 coverage.observe_response(&BacktestResponse::SlotNotification(10));
1141 coverage.observe_response(&BacktestResponse::SlotNotification(12));
1142 coverage.observe_response(&BacktestResponse::Completed {
1143 summary: None,
1144 agent_stats: None,
1145 });
1146
1147 assert!(coverage.is_completed());
1148 assert_eq!(coverage.highest_slot_seen(), Some(12));
1149 }
1150
1151 #[test]
1152 fn coverage_validate_end_slot_checks_completion_and_range() {
1153 let mut coverage = SessionCoverage::default();
1154 assert_eq!(
1155 coverage.validate_end_slot(5),
1156 Err(CoverageError::NotCompleted)
1157 );
1158
1159 coverage.mark_completed();
1160 assert_eq!(
1161 coverage.validate_end_slot(5),
1162 Err(CoverageError::NoSlotsObserved)
1163 );
1164
1165 coverage.observe_slot(4);
1166 assert_eq!(
1167 coverage.validate_end_slot(5),
1168 Err(CoverageError::RangeNotReached {
1169 actual_end_slot: 4,
1170 expected_end_slot: 5,
1171 })
1172 );
1173
1174 coverage.observe_slot(6);
1175 assert_eq!(coverage.validate_end_slot(5), Ok(()));
1176 }
1177}