Skip to main content

simulator_client/
session.rs

1use std::{
2    borrow::Cow,
3    collections::{BTreeMap, VecDeque},
4    future::Future,
5    time::Duration,
6};
7
8use futures::{SinkExt, StreamExt, stream};
9use simulator_api::{
10    AccountData, AccountModifications, AgentStatsReport, BacktestError, BacktestRequest,
11    BacktestResponse, BacktestStatus, ContinueParams, CreateBacktestSessionRequest,
12    CreateBacktestSessionRequestV1, SessionSummary,
13};
14use solana_address::Address;
15use solana_client::{
16    nonblocking::rpc_client::RpcClient,
17    rpc_response::{Response, RpcLogsResponse},
18};
19use solana_commitment_config::CommitmentConfig;
20use thiserror::Error;
21use tokio::net::TcpStream;
22use tokio_tungstenite::{
23    MaybeTlsStream, WebSocketStream,
24    tungstenite::{
25        Error as WsError, Message,
26        error::ProtocolError,
27        protocol::{CloseFrame, frame::coding::CloseCode},
28    },
29};
30
31use crate::{
32    BacktestClientError, BacktestClientResult, Continue,
33    injection::{ProgramModError, build_program_injection},
34    subscriptions::{
35        AccountDiffNotification, AccountDiffSubscriptionHandle, LogSubscriptionHandle,
36        SubscriptionError,
37    },
38};
39
40/// Outcome of waiting for readiness on a backtest session.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum ReadyOutcome {
43    /// The session is ready to accept a `Continue` request.
44    Ready,
45    /// The session completed before becoming ready.
46    Completed,
47}
48
49/// Summary of responses collected while advancing a backtest.
50#[derive(Debug, Default)]
51pub struct ContinueResult {
52    /// Number of slot notifications observed.
53    pub slot_notifications: u64,
54    /// Last slot received via notification.
55    pub last_slot: Option<u64>,
56    /// Status messages received.
57    pub statuses: Vec<BacktestStatus>,
58    /// Whether the session became ready for the next `Continue`.
59    pub ready_for_continue: bool,
60    /// Whether the session completed while advancing.
61    pub completed: bool,
62}
63
64/// Mutable state for driving a session with `advance_step`.
65#[derive(Debug)]
66pub struct AdvanceState {
67    /// Expected number of slot notifications for this step.
68    pub expected_slots: u64,
69    /// Count of slot notifications received so far.
70    pub slot_notifications: u64,
71    /// Most recent slot notification.
72    pub last_slot: Option<u64>,
73    /// Status messages received so far.
74    pub statuses: Vec<BacktestStatus>,
75    /// Whether the session is ready for another `Continue`.
76    pub ready_for_continue: bool,
77    /// Whether the session completed while advancing.
78    pub completed: bool,
79    /// Session summary received on completion (if send_summary was enabled).
80    pub summary: Option<SessionSummary>,
81    /// Agent stats received on completion.
82    pub agent_stats: Option<Vec<AgentStatsReport>>,
83}
84
85impl AdvanceState {
86    /// Create new tracking state for a step that expects `expected_slots` notifications.
87    pub fn new(expected_slots: u64) -> Self {
88        Self {
89            expected_slots,
90            slot_notifications: 0,
91            last_slot: None,
92            statuses: Vec::new(),
93            ready_for_continue: false,
94            completed: false,
95            summary: None,
96            agent_stats: None,
97        }
98    }
99
100    /// Return true when the step is complete based on readiness and slot count.
101    pub fn is_done(&self, wait_for_slots: bool) -> bool {
102        if self.completed {
103            return true;
104        }
105
106        if !self.ready_for_continue {
107            return false;
108        }
109
110        !wait_for_slots || self.slot_notifications >= self.expected_slots
111    }
112}
113
114/// Coverage of a session's observed slot notifications and completion state.
115#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
116pub struct SessionCoverage {
117    completed: bool,
118    highest_slot_seen: Option<u64>,
119}
120
121impl SessionCoverage {
122    /// Record a slot notification.
123    pub fn observe_slot(&mut self, slot: u64) {
124        self.highest_slot_seen = Some(
125            self.highest_slot_seen
126                .map_or(slot, |current| current.max(slot)),
127        );
128    }
129
130    /// Mark the session as completed.
131    pub fn mark_completed(&mut self) {
132        self.completed = true;
133    }
134
135    /// Update coverage state from a backtest response.
136    pub fn observe_response(&mut self, response: &BacktestResponse) {
137        match response {
138            BacktestResponse::SlotNotification(slot) => self.observe_slot(*slot),
139            BacktestResponse::Completed { .. } => self.mark_completed(),
140            _ => {}
141        }
142    }
143
144    /// Return whether completion has been observed.
145    pub fn is_completed(&self) -> bool {
146        self.completed
147    }
148
149    /// Return the highest slot observed via slot notifications.
150    pub fn highest_slot_seen(&self) -> Option<u64> {
151        self.highest_slot_seen
152    }
153
154    /// Validate that the session completed and reached `expected_end_slot`.
155    pub fn validate_end_slot(&self, expected_end_slot: u64) -> Result<(), CoverageError> {
156        if !self.completed {
157            return Err(CoverageError::NotCompleted);
158        }
159
160        let Some(actual_end_slot) = self.highest_slot_seen else {
161            return Err(CoverageError::NoSlotsObserved);
162        };
163
164        if actual_end_slot < expected_end_slot {
165            return Err(CoverageError::RangeNotReached {
166                actual_end_slot,
167                expected_end_slot,
168            });
169        }
170
171        Ok(())
172    }
173}
174
175/// Coverage validation failures for a backtest session.
176#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
177pub enum CoverageError {
178    #[error("ended before completion")]
179    NotCompleted,
180    #[error("completed without slot notifications")]
181    NoSlotsObserved,
182    #[error("completed at slot {actual_end_slot} but expected at least {expected_end_slot}")]
183    RangeNotReached {
184        actual_end_slot: u64,
185        expected_end_slot: u64,
186    },
187}
188
189/// Active backtest session over a WebSocket connection.
190///
191/// Dropping the session sends a best-effort close frame if a Tokio runtime is
192/// available. Call [`BacktestSession::close`] to request server-side cleanup.
193pub struct BacktestSession {
194    ws: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
195    session_id: Option<String>,
196    rpc_endpoint: Option<String>,
197    rpc: Option<RpcClient>,
198    ready_for_continue: bool,
199    request_timeout: Option<Duration>,
200    log_raw: bool,
201    backlog: VecDeque<BacktestResponse>,
202}
203
204impl BacktestSession {
205    pub(crate) fn new(
206        ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
207        request_timeout: Option<Duration>,
208        log_raw: bool,
209    ) -> Self {
210        Self {
211            ws: Some(ws),
212            session_id: None,
213            rpc_endpoint: None,
214            rpc: None,
215            ready_for_continue: false,
216            request_timeout,
217            log_raw,
218            backlog: VecDeque::new(),
219        }
220    }
221
222    /// Return the server-assigned session id, if known.
223    pub fn session_id(&self) -> Option<&str> {
224        self.session_id.as_deref()
225    }
226
227    /// Return the session-scoped RPC endpoint if provided.
228    pub fn rpc_endpoint(&self) -> Option<&str> {
229        self.rpc_endpoint.as_deref()
230    }
231
232    /// Return the RPC client for this session's endpoint.
233    ///
234    /// Always available after [`BacktestClient::create_session`](crate::BacktestClient::create_session) completes.
235    pub fn rpc(&self) -> &RpcClient {
236        self.rpc
237            .as_ref()
238            .expect("rpc is set during session creation")
239    }
240
241    /// Return whether the session is currently ready to accept `Continue`.
242    pub fn is_ready_for_continue(&self) -> bool {
243        self.ready_for_continue
244    }
245
246    /// Update internal readiness state based on a response.
247    pub fn apply_response(&mut self, response: &BacktestResponse) {
248        match response {
249            BacktestResponse::ReadyForContinue => {
250                self.ready_for_continue = true;
251            }
252            BacktestResponse::Completed { .. } => {
253                self.ready_for_continue = false;
254            }
255            _ => {}
256        }
257    }
258
259    fn ws_mut(&mut self) -> BacktestClientResult<&mut WebSocketStream<MaybeTlsStream<TcpStream>>> {
260        self.ws.as_mut().ok_or_else(|| BacktestClientError::Closed {
261            reason: "websocket closed".to_string(),
262        })
263    }
264
265    pub(crate) async fn create_with_request(
266        &mut self,
267        request: CreateBacktestSessionRequest,
268        rpc_base_url: String,
269        mut on_parallel_session_created: Option<&mut (dyn FnMut(String) + Send)>,
270    ) -> BacktestClientResult<CreateRequestResult> {
271        let expect_parallel = matches!(
272            &request,
273            CreateBacktestSessionRequest::V1(CreateBacktestSessionRequestV1 { parallel: true, .. })
274        );
275        self.send(&BacktestRequest::CreateBacktestSession(request), None)
276            .await?;
277        let mut streamed_parallel_session_ids = Vec::new();
278
279        loop {
280            let response =
281                self.next_response(None)
282                    .await?
283                    .ok_or_else(|| BacktestClientError::Closed {
284                        reason: "websocket ended before SessionCreated".to_string(),
285                    })?;
286
287            match response {
288                BacktestResponse::SessionCreated {
289                    session_id,
290                    rpc_endpoint,
291                } => {
292                    if expect_parallel {
293                        if let Some(callback) = on_parallel_session_created.as_mut() {
294                            (**callback)(session_id.clone());
295                        }
296                        streamed_parallel_session_ids.push(session_id);
297                        continue;
298                    }
299                    let created_session_id = session_id.clone();
300                    self.session_id = Some(session_id);
301                    let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
302                    self.rpc = Some(RpcClient::new_with_commitment(
303                        resolved.clone(),
304                        CommitmentConfig::confirmed(),
305                    ));
306                    self.rpc_endpoint = Some(resolved);
307                    return Ok(CreateRequestResult::Single {
308                        session_id: created_session_id,
309                    });
310                }
311                BacktestResponse::SessionsCreated { session_ids } => {
312                    if expect_parallel && session_ids.is_empty() {
313                        return Ok(CreateRequestResult::Parallel {
314                            session_ids: streamed_parallel_session_ids,
315                        });
316                    }
317                    return Ok(CreateRequestResult::Parallel { session_ids });
318                }
319                BacktestResponse::SessionsCreatedV2 { session_ids, .. } => {
320                    if expect_parallel && session_ids.is_empty() {
321                        return Ok(CreateRequestResult::Parallel {
322                            session_ids: streamed_parallel_session_ids,
323                        });
324                    }
325                    return Ok(CreateRequestResult::Parallel { session_ids });
326                }
327                BacktestResponse::ReadyForContinue => {
328                    self.ready_for_continue = true;
329                }
330                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
331                other => {
332                    self.backlog.push_back(other);
333                }
334            }
335        }
336    }
337
338    pub(crate) async fn attach(
339        &mut self,
340        session_id: String,
341        last_sequence: Option<u64>,
342        rpc_base_url: String,
343    ) -> BacktestClientResult<()> {
344        self.send(
345            &BacktestRequest::AttachBacktestSession {
346                session_id,
347                last_sequence,
348            },
349            None,
350        )
351        .await?;
352
353        loop {
354            let response =
355                self.next_response(None)
356                    .await?
357                    .ok_or_else(|| BacktestClientError::Closed {
358                        reason: "websocket ended before SessionAttached".to_string(),
359                    })?;
360
361            match response {
362                BacktestResponse::SessionAttached {
363                    session_id,
364                    rpc_endpoint,
365                } => {
366                    self.session_id = Some(session_id);
367                    let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
368                    self.rpc = Some(RpcClient::new_with_commitment(
369                        resolved.clone(),
370                        CommitmentConfig::confirmed(),
371                    ));
372                    self.rpc_endpoint = Some(resolved);
373                    return Ok(());
374                }
375                BacktestResponse::ReadyForContinue => {
376                    self.ready_for_continue = true;
377                }
378                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
379                other => {
380                    self.backlog.push_back(other);
381                }
382            }
383        }
384    }
385
386    /// Send a raw backtest request over the WebSocket.
387    pub async fn send(
388        &mut self,
389        request: &BacktestRequest,
390        timeout: Option<Duration>,
391    ) -> BacktestClientResult<()> {
392        let text = serde_json::to_string(request)
393            .map_err(|source| BacktestClientError::SerializeRequest { source })?;
394
395        let request_timeout = self.request_timeout;
396        let timeout = timeout.or(request_timeout);
397
398        let send_fut = self.ws_mut()?.send(Message::Text(text));
399        let send_result = match timeout {
400            Some(duration) => tokio::time::timeout(duration, send_fut)
401                .await
402                .map_err(|_| BacktestClientError::Timeout {
403                    action: "sending",
404                    duration,
405                })?,
406            None => send_fut.await,
407        };
408
409        send_result.map_err(|source| BacktestClientError::WebSocket {
410            action: "sending",
411            source: Box::new(source),
412        })?;
413
414        Ok(())
415    }
416
417    /// Receive the next response, using the backlog first.
418    pub async fn next_response(
419        &mut self,
420        timeout: Option<Duration>,
421    ) -> BacktestClientResult<Option<BacktestResponse>> {
422        if let Some(response) = self.backlog.pop_front() {
423            return Ok(Some(response));
424        }
425
426        let text = match self.next_text(timeout).await? {
427            Some(text) => text,
428            None => return Ok(None),
429        };
430
431        let response = serde_json::from_str::<BacktestResponse>(&text).map_err(|source| {
432            BacktestClientError::DeserializeResponse {
433                raw: text.clone(),
434                source,
435            }
436        })?;
437
438        Ok(Some(response))
439    }
440
441    /// Receive the next response and update readiness state.
442    pub async fn next_event(
443        &mut self,
444        timeout: Option<Duration>,
445    ) -> BacktestClientResult<Option<BacktestResponse>> {
446        let response = self.next_response(timeout).await?;
447        if let Some(ref response) = response {
448            self.apply_response(response);
449        }
450        Ok(response)
451    }
452
453    /// Stream responses, updating readiness state as items arrive.
454    ///
455    /// This consumes the session and yields responses until the connection ends.
456    pub fn responses(
457        self,
458        timeout: Option<Duration>,
459    ) -> impl futures::Stream<Item = BacktestClientResult<BacktestResponse>> {
460        stream::unfold(Some(self), move |state| async move {
461            let mut session = match state {
462                Some(session) => session,
463                None => return None,
464            };
465
466            match session.next_response(timeout).await {
467                Ok(Some(response)) => {
468                    session.apply_response(&response);
469                    Some((Ok(response), Some(session)))
470                }
471                Ok(None) => None,
472                Err(err) => Some((Err(err), None)),
473            }
474        })
475    }
476
477    /// Wait for the session to become ready or completed.
478    pub async fn ensure_ready(
479        &mut self,
480        timeout: Option<Duration>,
481    ) -> BacktestClientResult<ReadyOutcome> {
482        if self.ready_for_continue {
483            return Ok(ReadyOutcome::Ready);
484        }
485
486        loop {
487            let response =
488                self.next_response(timeout)
489                    .await?
490                    .ok_or_else(|| BacktestClientError::Closed {
491                        reason: "websocket ended while waiting for ReadyForContinue".to_string(),
492                    })?;
493
494            match response {
495                BacktestResponse::ReadyForContinue => {
496                    self.ready_for_continue = true;
497                    return Ok(ReadyOutcome::Ready);
498                }
499                BacktestResponse::Completed { .. } => return Ok(ReadyOutcome::Completed),
500                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
501                _ => {}
502            }
503        }
504    }
505
506    /// Wait for a specific status to be emitted.
507    pub async fn wait_for_status(
508        &mut self,
509        desired: BacktestStatus,
510        timeout: Option<Duration>,
511    ) -> BacktestClientResult<()> {
512        let desired = std::mem::discriminant(&desired);
513
514        loop {
515            let response =
516                self.next_response(timeout)
517                    .await?
518                    .ok_or_else(|| BacktestClientError::Closed {
519                        reason: "websocket ended while waiting for status".to_string(),
520                    })?;
521
522            match response {
523                BacktestResponse::Status { status }
524                    if std::mem::discriminant(&status) == desired =>
525                {
526                    return Ok(());
527                }
528                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
529                BacktestResponse::Completed {
530                    summary,
531                    agent_stats,
532                } => {
533                    return Err(BacktestClientError::UnexpectedResponse {
534                        context: "waiting for status",
535                        response: Box::new(BacktestResponse::Completed {
536                            summary,
537                            agent_stats,
538                        }),
539                    });
540                }
541                _ => {}
542            }
543        }
544    }
545
546    /// Send a `Continue` request and reset readiness.
547    pub async fn send_continue(
548        &mut self,
549        params: ContinueParams,
550        timeout: Option<Duration>,
551    ) -> BacktestClientResult<()> {
552        self.ready_for_continue = false;
553        self.send(&BacktestRequest::Continue(params), timeout).await
554    }
555
556    /// Read and apply a single response while advancing.
557    pub async fn advance_step<F>(
558        &mut self,
559        state: &mut AdvanceState,
560        wait_for_slots: bool,
561        timeout: Option<Duration>,
562        on_event: &mut F,
563    ) -> BacktestClientResult<()>
564    where
565        F: FnMut(&BacktestResponse),
566    {
567        let Some(response) = self.next_response(timeout).await? else {
568            return Err(BacktestClientError::Closed {
569                reason: "websocket ended while awaiting continue responses".to_string(),
570            });
571        };
572
573        if self.log_raw {
574            tracing::debug!("<- {response:?}");
575        }
576
577        on_event(&response);
578
579        match response {
580            BacktestResponse::ReadyForContinue => {
581                self.ready_for_continue = true;
582                state.ready_for_continue = true;
583            }
584            BacktestResponse::SlotNotification(slot) => {
585                state.slot_notifications += 1;
586                state.last_slot = Some(slot);
587            }
588            BacktestResponse::Status { status } => {
589                state.statuses.push(status);
590            }
591            BacktestResponse::Success => {}
592            BacktestResponse::Completed {
593                summary,
594                agent_stats,
595            } => {
596                state.completed = true;
597                state.summary = summary;
598                state.agent_stats = agent_stats;
599            }
600            BacktestResponse::Error(err @ BacktestError::SimulationError { .. }) => {
601                tracing::warn!(error = %err, "simulation error");
602            }
603            BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
604            BacktestResponse::SessionCreated { .. }
605            | BacktestResponse::SessionAttached { .. }
606            | BacktestResponse::SessionsCreated { .. }
607            | BacktestResponse::SessionsCreatedV2 { .. }
608            | BacktestResponse::ParallelSessionAttachedV2 { .. }
609            | BacktestResponse::SessionEventV1 { .. }
610            | BacktestResponse::SessionEventV2 { .. } => {
611                return Err(BacktestClientError::UnexpectedResponse {
612                    context: "continuing",
613                    response: Box::new(response),
614                });
615            }
616        }
617
618        if wait_for_slots && state.slot_notifications > state.expected_slots {
619            tracing::warn!(
620                "received {} slot notifications (expected {})",
621                state.slot_notifications,
622                state.expected_slots
623            );
624        }
625
626        Ok(())
627    }
628
629    /// Advance until the session becomes ready for another `Continue`.
630    pub async fn continue_until_ready<F>(
631        &mut self,
632        cont: Continue,
633        timeout: Option<Duration>,
634        mut on_event: F,
635    ) -> BacktestClientResult<ContinueResult>
636    where
637        F: FnMut(&BacktestResponse),
638    {
639        let expected_slots = cont.advance_count;
640        self.advance_internal(
641            cont.into_params(),
642            expected_slots,
643            false,
644            timeout,
645            &mut on_event,
646        )
647        .await
648    }
649
650    /// Advance and wait for both readiness and slot notifications.
651    pub async fn advance<F>(
652        &mut self,
653        cont: Continue,
654        timeout: Option<Duration>,
655        mut on_event: F,
656    ) -> BacktestClientResult<ContinueResult>
657    where
658        F: FnMut(&BacktestResponse),
659    {
660        let expected_slots = cont.advance_count;
661        self.advance_internal(
662            cont.into_params(),
663            expected_slots,
664            true,
665            timeout,
666            &mut on_event,
667        )
668        .await
669    }
670
671    async fn advance_internal<F>(
672        &mut self,
673        params: ContinueParams,
674        expected_slots: u64,
675        wait_for_slots: bool,
676        timeout: Option<Duration>,
677        on_event: &mut F,
678    ) -> BacktestClientResult<ContinueResult>
679    where
680        F: FnMut(&BacktestResponse),
681    {
682        self.send_continue(params, timeout).await?;
683
684        let mut state = AdvanceState::new(expected_slots);
685        while !state.is_done(wait_for_slots) {
686            self.advance_step(&mut state, wait_for_slots, timeout, on_event)
687                .await?;
688        }
689
690        Ok(ContinueResult {
691            slot_notifications: state.slot_notifications,
692            last_slot: state.last_slot,
693            statuses: state.statuses,
694            ready_for_continue: state.ready_for_continue,
695            completed: state.completed,
696        })
697    }
698
699    /// Build a program-data account modification from raw ELF bytes.
700    ///
701    /// Derives the `ProgramData` address from `program_id`, queries the session's RPC endpoint
702    /// for the current slot and rent-exempt minimum, then returns a modification map ready to
703    /// pass to [`Continue::builder().modify_accounts(...)`](crate::Continue).
704    ///
705    /// The deploy slot is set to `current_slot - 1` so the program appears deployed before the
706    /// next executed slot.
707    pub async fn modify_program(
708        &self,
709        program_id: &str,
710        elf: &[u8],
711    ) -> Result<BTreeMap<Address, AccountData>, ProgramModError> {
712        let rpc = self.rpc.as_ref().ok_or(ProgramModError::NoRpcEndpoint)?;
713
714        let program_addr: Address =
715            program_id
716                .parse()
717                .map_err(|_| ProgramModError::InvalidProgramId {
718                    id: program_id.to_string(),
719                })?;
720        let programdata_addr = solana_loader_v3_interface::get_program_data_address(&program_addr);
721
722        let slot = rpc.get_slot().await.map_err(|e| ProgramModError::Rpc {
723            source: Box::new(e),
724        })?;
725        let deploy_slot = slot.saturating_sub(1);
726
727        // Fetch the existing programdata account to read the upgrade authority.
728        // The discriminant at byte 12 tells us whether an authority is present.
729        let existing =
730            rpc.get_account(&programdata_addr)
731                .await
732                .map_err(|e| ProgramModError::Rpc {
733                    source: Box::new(e),
734                })?;
735
736        let upgrade_authority = if existing.data.get(12).copied() == Some(1) {
737            existing.data.get(13..45).and_then(|b| {
738                let bytes: [u8; 32] = b.try_into().ok()?;
739                Some(Address::from(bytes))
740            })
741        } else {
742            None
743        };
744
745        let data_len = upgrade_authority.map_or(13, |_| 45) + elf.len();
746        let lamports = rpc
747            .get_minimum_balance_for_rent_exemption(data_len)
748            .await
749            .map_err(|e| ProgramModError::Rpc {
750                source: Box::new(e),
751            })?;
752
753        Ok(build_program_injection(
754            programdata_addr,
755            elf,
756            deploy_slot,
757            upgrade_authority,
758            lamports,
759        ))
760    }
761
762    /// Modify accounts on the session's RPC endpoint via the custom `modifyAccounts` method.
763    ///
764    /// Returns the number of accounts modified on success.
765    pub async fn modify_accounts(
766        &self,
767        modifications: &AccountModifications,
768    ) -> BacktestClientResult<usize> {
769        let rpc_endpoint =
770            self.rpc_endpoint
771                .as_deref()
772                .ok_or_else(|| BacktestClientError::Closed {
773                    reason: "no RPC endpoint available".to_string(),
774                })?;
775
776        crate::rpc::modify_accounts(rpc_endpoint, modifications).await
777    }
778
779    /// Subscribe to program log notifications using the session's RPC endpoint.
780    ///
781    /// Equivalent to calling [`subscribe_program_logs`](crate::subscribe_program_logs) with
782    /// the endpoint from [`rpc_endpoint`](Self::rpc_endpoint), which is set after
783    /// [`BacktestClient::create_session`](crate::BacktestClient::create_session) completes.
784    pub async fn subscribe_program_logs<F, Fut>(
785        &self,
786        program_id: &str,
787        commitment: CommitmentConfig,
788        on_notification: F,
789    ) -> Result<LogSubscriptionHandle, SubscriptionError>
790    where
791        F: Fn(Response<RpcLogsResponse>) -> Fut + Send + Sync + 'static,
792        Fut: Future<Output = ()> + Send + 'static,
793    {
794        let rpc_endpoint = self
795            .rpc_endpoint
796            .as_deref()
797            .ok_or(SubscriptionError::NoRpcEndpoint)?;
798        crate::subscriptions::subscribe_program_logs(
799            rpc_endpoint,
800            program_id,
801            commitment,
802            on_notification,
803        )
804        .await
805    }
806
807    /// Subscribe to account diff notifications using the session's RPC endpoint.
808    ///
809    /// Equivalent to calling [`subscribe_account_diffs`](crate::subscribe_account_diffs) with
810    /// the endpoint from [`rpc_endpoint`](Self::rpc_endpoint), which is set after
811    /// [`BacktestClient::create_session`](crate::BacktestClient::create_session) completes.
812    pub async fn subscribe_account_diffs<F, Fut>(
813        &self,
814        account: &str,
815        on_notification: F,
816    ) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
817    where
818        F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
819        Fut: Future<Output = ()> + Send + 'static,
820    {
821        let rpc_endpoint = self
822            .rpc_endpoint
823            .as_deref()
824            .ok_or(SubscriptionError::NoRpcEndpoint)?;
825        crate::subscriptions::subscribe_account_diffs(rpc_endpoint, account, on_notification).await
826    }
827
828    /// Request server cleanup and close the underlying WebSocket.
829    ///
830    /// This is idempotent and will return `Ok(())` if the connection is already closed.
831    pub async fn close(&mut self, timeout: Option<Duration>) -> BacktestClientResult<()> {
832        self.close_with_frame(timeout, None).await
833    }
834
835    /// Close the session with a specific WebSocket close frame.
836    pub async fn close_with_frame(
837        &mut self,
838        timeout: Option<Duration>,
839        frame: Option<CloseFrame<'static>>,
840    ) -> BacktestClientResult<()> {
841        if self.ws.is_none() {
842            return Ok(());
843        }
844
845        let mut sent = false;
846        match self
847            .send(&BacktestRequest::CloseBacktestSession, timeout)
848            .await
849        {
850            Ok(()) => sent = true,
851            Err(err) if is_close_ok(&err) => {}
852            Err(err) => return Err(err),
853        }
854
855        if sent {
856            let response = match self.next_response(timeout).await {
857                Ok(Some(r)) => r,
858                Ok(None) => {
859                    self.ws.take();
860                    return Ok(());
861                }
862                Err(BacktestClientError::Closed { .. }) => {
863                    self.ws.take();
864                    return Ok(());
865                }
866                Err(BacktestClientError::WebSocket {
867                    action: "receiving",
868                    source,
869                }) if is_reset_without_close(&source) => {
870                    self.ws.take();
871                    return Ok(());
872                }
873                Err(err) => return Err(err),
874            };
875
876            match response {
877                BacktestResponse::Success | BacktestResponse::Completed { .. } => {}
878                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
879                other => {
880                    return Err(BacktestClientError::UnexpectedResponse {
881                        context: "closing session",
882                        response: Box::new(other),
883                    });
884                }
885            }
886        }
887
888        if let Some(ws) = self.ws.as_mut() {
889            let close_result = ws.close(frame).await;
890            if let Err(source) = close_result
891                && !is_ws_closed_error(&source)
892            {
893                return Err(BacktestClientError::WebSocket {
894                    action: "closing",
895                    source: Box::new(source),
896                });
897            }
898        }
899
900        // Await the close confirmation
901        match self.next_response(timeout).await {
902            Ok(_) => {}
903            Err(BacktestClientError::Closed { .. }) => {}
904            Err(BacktestClientError::WebSocket {
905                action: "receiving",
906                source,
907            }) if is_reset_without_close(&source) => {}
908            Err(err) => return Err(err),
909        }
910
911        // Give some time for the close to propagate
912        tokio::time::sleep(Duration::from_millis(100)).await;
913        self.ws.take();
914        Ok(())
915    }
916
917    /// Close the session with a close code and reason.
918    pub async fn close_with_reason(
919        &mut self,
920        timeout: Option<Duration>,
921        code: CloseCode,
922        reason: impl Into<String>,
923    ) -> BacktestClientResult<()> {
924        let frame = CloseFrame {
925            code,
926            reason: Cow::Owned(reason.into()),
927        };
928        self.close_with_frame(timeout, Some(frame)).await
929    }
930
931    async fn next_text(
932        &mut self,
933        timeout: Option<Duration>,
934    ) -> BacktestClientResult<Option<String>> {
935        loop {
936            let request_timeout = self.request_timeout;
937            let timeout = timeout.or(request_timeout);
938
939            let next_fut = self.ws_mut()?.next();
940            let msg = match timeout {
941                Some(duration) => tokio::time::timeout(duration, next_fut)
942                    .await
943                    .map_err(|_| BacktestClientError::Timeout {
944                        action: "receiving",
945                        duration,
946                    })?,
947                None => next_fut.await,
948            };
949
950            let Some(msg) = msg else {
951                return Ok(None);
952            };
953
954            let msg = match msg {
955                Ok(msg) => msg,
956                Err(source) => {
957                    return Err(BacktestClientError::WebSocket {
958                        action: "receiving",
959                        source: Box::new(source),
960                    });
961                }
962            };
963
964            match msg {
965                Message::Text(text) => {
966                    if self.log_raw {
967                        tracing::debug!("<- raw: {text}");
968                    }
969                    return Ok(Some(text));
970                }
971                Message::Binary(bin) => match String::from_utf8(bin) {
972                    Ok(text) => {
973                        if self.log_raw {
974                            tracing::debug!("<- raw(bin): {text}");
975                        }
976                        return Ok(Some(text));
977                    }
978                    Err(err) => {
979                        tracing::warn!("discarding non-utf8 binary message: {err}");
980                        continue;
981                    }
982                },
983                Message::Close(frame) => {
984                    let reason = close_reason(frame);
985                    return Err(BacktestClientError::Closed { reason });
986                }
987                Message::Ping(_) | Message::Pong(_) => continue,
988                Message::Frame(_) => continue,
989            }
990        }
991    }
992}
993
994impl std::fmt::Debug for BacktestSession {
995    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
996        f.debug_struct("BacktestSession")
997            .field("session_id", &self.session_id)
998            .field("rpc_endpoint", &self.rpc_endpoint)
999            .field(
1000                "rpc",
1001                &self
1002                    .rpc
1003                    .as_ref()
1004                    .map(|_| "<RpcClient>")
1005                    .unwrap_or("<not set>"),
1006            )
1007            .field("ready_for_continue", &self.ready_for_continue)
1008            .field("request_timeout", &self.request_timeout)
1009            .finish_non_exhaustive()
1010    }
1011}
1012
1013#[derive(Debug)]
1014pub(crate) enum CreateRequestResult {
1015    Single { session_id: String },
1016    Parallel { session_ids: Vec<String> },
1017}
1018
1019impl Drop for BacktestSession {
1020    fn drop(&mut self) {
1021        let Some(ws) = self.ws.take() else {
1022            return;
1023        };
1024
1025        if let Ok(handle) = tokio::runtime::Handle::try_current() {
1026            handle.spawn(async move {
1027                let mut ws = ws;
1028                let _ = ws.close(None).await;
1029            });
1030        }
1031    }
1032}
1033
1034fn resolve_rpc_url(base: &str, endpoint: &str) -> String {
1035    if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
1036        endpoint.to_string()
1037    } else {
1038        format!("{}/{}", base, endpoint.trim_start_matches('/'))
1039    }
1040}
1041
1042fn close_reason(frame: Option<CloseFrame<'static>>) -> String {
1043    match frame {
1044        Some(frame) => format!("{:?}: {}", frame.code, frame.reason),
1045        None => "no close frame".to_string(),
1046    }
1047}
1048
1049fn is_reset_without_close(err: &WsError) -> bool {
1050    matches!(
1051        err,
1052        WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1053    )
1054}
1055
1056fn is_ws_closed_error(err: &WsError) -> bool {
1057    matches!(
1058        err,
1059        WsError::ConnectionClosed
1060            | WsError::AlreadyClosed
1061            | WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1062    )
1063}
1064
1065fn is_close_ok(err: &BacktestClientError) -> bool {
1066    match err {
1067        BacktestClientError::Closed { .. } => true,
1068        BacktestClientError::WebSocket { source, .. } => is_ws_closed_error(source),
1069        _ => false,
1070    }
1071}
1072
1073#[cfg(test)]
1074mod tests {
1075    use super::*;
1076
1077    #[test]
1078    fn coverage_tracks_slot_and_completion_from_responses() {
1079        let mut coverage = SessionCoverage::default();
1080        coverage.observe_response(&BacktestResponse::SlotNotification(10));
1081        coverage.observe_response(&BacktestResponse::SlotNotification(12));
1082        coverage.observe_response(&BacktestResponse::Completed {
1083            summary: None,
1084            agent_stats: None,
1085        });
1086
1087        assert!(coverage.is_completed());
1088        assert_eq!(coverage.highest_slot_seen(), Some(12));
1089    }
1090
1091    #[test]
1092    fn coverage_validate_end_slot_checks_completion_and_range() {
1093        let mut coverage = SessionCoverage::default();
1094        assert_eq!(
1095            coverage.validate_end_slot(5),
1096            Err(CoverageError::NotCompleted)
1097        );
1098
1099        coverage.mark_completed();
1100        assert_eq!(
1101            coverage.validate_end_slot(5),
1102            Err(CoverageError::NoSlotsObserved)
1103        );
1104
1105        coverage.observe_slot(4);
1106        assert_eq!(
1107            coverage.validate_end_slot(5),
1108            Err(CoverageError::RangeNotReached {
1109                actual_end_slot: 4,
1110                expected_end_slot: 5,
1111            })
1112        );
1113
1114        coverage.observe_slot(6);
1115        assert_eq!(coverage.validate_end_slot(5), Ok(()));
1116    }
1117}