Skip to main content

simulator_client/
session.rs

1use std::{
2    borrow::Cow,
3    collections::{BTreeMap, VecDeque},
4    future::Future,
5    time::Duration,
6};
7
8use futures::{SinkExt, StreamExt, stream};
9use simulator_api::{
10    AccountData, AccountModifications, AgentStatsReport, BacktestError, BacktestRequest,
11    BacktestResponse, BacktestStatus, ContinueParams, CreateBacktestSessionRequest,
12    CreateBacktestSessionRequestV1, SequencedResponse, SessionSummary,
13};
14use solana_address::Address;
15use solana_client::{
16    nonblocking::rpc_client::RpcClient,
17    rpc_response::{Response, RpcLogsResponse},
18};
19use solana_commitment_config::CommitmentConfig;
20use thiserror::Error;
21use tokio::net::TcpStream;
22use tokio_tungstenite::{
23    MaybeTlsStream, WebSocketStream,
24    tungstenite::{
25        Error as WsError, Message,
26        error::ProtocolError,
27        protocol::{CloseFrame, frame::coding::CloseCode},
28    },
29};
30
31use crate::{
32    BacktestClientError, BacktestClientResult, Continue,
33    injection::ProgramModError,
34    subscriptions::{
35        AccountDiffNotification, AccountDiffSubscriptionHandle, LogSubscriptionHandle,
36        SubscriptionError,
37    },
38};
39
40/// Outcome of waiting for readiness on a backtest session.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum ReadyOutcome {
43    /// The session is ready to accept a `Continue` request.
44    Ready,
45    /// The session completed before becoming ready.
46    Completed,
47}
48
49/// Summary of responses collected while advancing a backtest.
50#[derive(Debug, Default)]
51pub struct ContinueResult {
52    /// Number of slot notifications observed.
53    pub slot_notifications: u64,
54    /// Last slot received via notification.
55    pub last_slot: Option<u64>,
56    /// Status messages received.
57    pub statuses: Vec<BacktestStatus>,
58    /// Whether the session became ready for the next `Continue`.
59    pub ready_for_continue: bool,
60    /// Whether the session completed while advancing.
61    pub completed: bool,
62}
63
64/// Mutable state for driving a session with `advance_step`.
65#[derive(Debug)]
66pub struct AdvanceState {
67    /// Expected number of slot notifications for this step.
68    pub expected_slots: u64,
69    /// Count of slot notifications received so far.
70    pub slot_notifications: u64,
71    /// Most recent slot notification.
72    pub last_slot: Option<u64>,
73    /// Status messages received so far.
74    pub statuses: Vec<BacktestStatus>,
75    /// Whether the session is ready for another `Continue`.
76    pub ready_for_continue: bool,
77    /// Whether the session completed while advancing.
78    pub completed: bool,
79    /// Session summary received on completion (if send_summary was enabled).
80    pub summary: Option<SessionSummary>,
81    /// Agent stats received on completion.
82    pub agent_stats: Option<Vec<AgentStatsReport>>,
83}
84
85impl AdvanceState {
86    /// Create new tracking state for a step that expects `expected_slots` notifications.
87    pub fn new(expected_slots: u64) -> Self {
88        Self {
89            expected_slots,
90            slot_notifications: 0,
91            last_slot: None,
92            statuses: Vec::new(),
93            ready_for_continue: false,
94            completed: false,
95            summary: None,
96            agent_stats: None,
97        }
98    }
99
100    /// Return true when the step is complete based on readiness and slot count.
101    pub fn is_done(&self, wait_for_slots: bool) -> bool {
102        if self.completed {
103            return true;
104        }
105
106        if !self.ready_for_continue {
107            return false;
108        }
109
110        !wait_for_slots || self.slot_notifications >= self.expected_slots
111    }
112}
113
114/// Coverage of a session's observed slot notifications and completion state.
115#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
116pub struct SessionCoverage {
117    completed: bool,
118    highest_slot_seen: Option<u64>,
119}
120
121impl SessionCoverage {
122    /// Record a slot notification.
123    pub fn observe_slot(&mut self, slot: u64) {
124        self.highest_slot_seen = Some(
125            self.highest_slot_seen
126                .map_or(slot, |current| current.max(slot)),
127        );
128    }
129
130    /// Mark the session as completed.
131    pub fn mark_completed(&mut self) {
132        self.completed = true;
133    }
134
135    /// Update coverage state from a backtest response.
136    pub fn observe_response(&mut self, response: &BacktestResponse) {
137        match response {
138            BacktestResponse::SlotNotification(slot) => self.observe_slot(*slot),
139            BacktestResponse::Completed { .. } => self.mark_completed(),
140            _ => {}
141        }
142    }
143
144    /// Return whether completion has been observed.
145    pub fn is_completed(&self) -> bool {
146        self.completed
147    }
148
149    /// Return the highest slot observed via slot notifications.
150    pub fn highest_slot_seen(&self) -> Option<u64> {
151        self.highest_slot_seen
152    }
153
154    /// Validate that the session completed and reached `expected_end_slot`.
155    pub fn validate_end_slot(&self, expected_end_slot: u64) -> Result<(), CoverageError> {
156        if !self.completed {
157            return Err(CoverageError::NotCompleted);
158        }
159
160        let Some(actual_end_slot) = self.highest_slot_seen else {
161            return Err(CoverageError::NoSlotsObserved);
162        };
163
164        if actual_end_slot < expected_end_slot {
165            return Err(CoverageError::RangeNotReached {
166                actual_end_slot,
167                expected_end_slot,
168            });
169        }
170
171        Ok(())
172    }
173}
174
175/// Coverage validation failures for a backtest session.
176#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
177pub enum CoverageError {
178    #[error("ended before completion")]
179    NotCompleted,
180    #[error("completed without slot notifications")]
181    NoSlotsObserved,
182    #[error("completed at slot {actual_end_slot} but expected at least {expected_end_slot}")]
183    RangeNotReached {
184        actual_end_slot: u64,
185        expected_end_slot: u64,
186    },
187}
188
189/// Active backtest session over a WebSocket connection.
190///
191/// Dropping the session sends a best-effort close frame if a Tokio runtime is
192/// available. Call [`BacktestSession::close`] to request server-side cleanup.
193pub struct BacktestSession {
194    ws: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
195    session_id: Option<String>,
196    rpc_endpoint: Option<String>,
197    rpc: Option<RpcClient>,
198    last_sequence: Option<u64>,
199    pub(crate) ready_for_continue: bool,
200    request_timeout: Option<Duration>,
201    log_raw: bool,
202    backlog: VecDeque<(Option<u64>, BacktestResponse)>,
203}
204
205impl BacktestSession {
206    pub(crate) fn new(
207        ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
208        request_timeout: Option<Duration>,
209        log_raw: bool,
210    ) -> Self {
211        Self {
212            ws: Some(ws),
213            session_id: None,
214            rpc_endpoint: None,
215            rpc: None,
216            last_sequence: None,
217            ready_for_continue: false,
218            request_timeout,
219            log_raw,
220            backlog: VecDeque::new(),
221        }
222    }
223
224    /// Return the server-assigned session id, if known.
225    pub fn session_id(&self) -> Option<&str> {
226        self.session_id.as_deref()
227    }
228
229    /// Return the session-scoped RPC endpoint if provided.
230    pub fn rpc_endpoint(&self) -> Option<&str> {
231        self.rpc_endpoint.as_deref()
232    }
233
234    /// Return the highest sequenced control-websocket response observed so far.
235    pub fn last_sequence(&self) -> Option<u64> {
236        self.last_sequence
237    }
238
239    /// Return the RPC client for this session's endpoint.
240    ///
241    /// Always available after [`BacktestClient::create_session`](crate::BacktestClient::create_session) completes.
242    pub fn rpc(&self) -> &RpcClient {
243        self.rpc
244            .as_ref()
245            .expect("rpc is set during session creation")
246    }
247
248    /// Return whether the session is currently ready to accept `Continue`.
249    pub fn is_ready_for_continue(&self) -> bool {
250        self.ready_for_continue
251    }
252
253    /// Update internal readiness state based on a response.
254    pub fn apply_response(&mut self, response: &BacktestResponse) {
255        match response {
256            BacktestResponse::ReadyForContinue | BacktestResponse::Paused(_) => {
257                self.ready_for_continue = true;
258            }
259            BacktestResponse::Completed { .. } => {
260                self.ready_for_continue = false;
261            }
262            _ => {}
263        }
264    }
265
266    fn ws_mut(&mut self) -> BacktestClientResult<&mut WebSocketStream<MaybeTlsStream<TcpStream>>> {
267        self.ws.as_mut().ok_or_else(|| BacktestClientError::Closed {
268            reason: "websocket closed".to_string(),
269        })
270    }
271
272    pub(crate) async fn create_with_request(
273        &mut self,
274        request: CreateBacktestSessionRequest,
275        rpc_base_url: String,
276        mut on_parallel_session_created: Option<&mut (dyn FnMut(String) + Send)>,
277    ) -> BacktestClientResult<CreateRequestResult> {
278        let expect_parallel = matches!(
279            &request,
280            CreateBacktestSessionRequest::V1(CreateBacktestSessionRequestV1 { parallel: true, .. })
281        );
282        self.send(&BacktestRequest::CreateBacktestSession(request), None)
283            .await?;
284        let mut streamed_parallel_session_ids = Vec::new();
285
286        loop {
287            let response =
288                self.next_response(None)
289                    .await?
290                    .ok_or_else(|| BacktestClientError::Closed {
291                        reason: "websocket ended before SessionCreated".to_string(),
292                    })?;
293
294            match response {
295                BacktestResponse::SessionCreated {
296                    session_id,
297                    rpc_endpoint,
298                } => {
299                    if expect_parallel {
300                        if let Some(callback) = on_parallel_session_created.as_mut() {
301                            (**callback)(session_id.clone());
302                        }
303                        streamed_parallel_session_ids.push(session_id);
304                        continue;
305                    }
306                    let created_session_id = session_id.clone();
307                    self.session_id = Some(session_id);
308                    let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
309                    self.rpc = Some(RpcClient::new_with_commitment(
310                        resolved.clone(),
311                        CommitmentConfig::confirmed(),
312                    ));
313                    self.rpc_endpoint = Some(resolved);
314                    return Ok(CreateRequestResult::Single {
315                        session_id: created_session_id,
316                    });
317                }
318                BacktestResponse::SessionsCreated { session_ids } => {
319                    if expect_parallel && session_ids.is_empty() {
320                        return Ok(CreateRequestResult::Parallel {
321                            session_ids: streamed_parallel_session_ids,
322                        });
323                    }
324                    return Ok(CreateRequestResult::Parallel { session_ids });
325                }
326                BacktestResponse::SessionsCreatedV2 { session_ids, .. } => {
327                    if expect_parallel && session_ids.is_empty() {
328                        return Ok(CreateRequestResult::Parallel {
329                            session_ids: streamed_parallel_session_ids,
330                        });
331                    }
332                    return Ok(CreateRequestResult::Parallel { session_ids });
333                }
334                BacktestResponse::ReadyForContinue => {
335                    self.ready_for_continue = true;
336                }
337                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
338                other => {
339                    self.backlog.push_back((self.last_sequence, other));
340                }
341            }
342        }
343    }
344
345    pub(crate) async fn attach(
346        &mut self,
347        session_id: String,
348        last_sequence: Option<u64>,
349        rpc_base_url: String,
350    ) -> BacktestClientResult<()> {
351        self.send(
352            &BacktestRequest::AttachBacktestSession {
353                session_id,
354                last_sequence,
355            },
356            None,
357        )
358        .await?;
359
360        self.wait_for_response(
361            || BacktestClientError::Closed {
362                reason: "websocket ended before SessionAttached".to_string(),
363            },
364            move |session, response| match response {
365                BacktestResponse::SessionAttached {
366                    session_id,
367                    rpc_endpoint,
368                } => {
369                    session.session_id = Some(session_id);
370                    let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
371                    session.rpc = Some(RpcClient::new_with_commitment(
372                        resolved.clone(),
373                        CommitmentConfig::confirmed(),
374                    ));
375                    session.rpc_endpoint = Some(resolved);
376                    Ok(Some(()))
377                }
378                BacktestResponse::ReadyForContinue => {
379                    session.ready_for_continue = true;
380                    Ok(None)
381                }
382                BacktestResponse::Error(err) => Err(BacktestClientError::Remote(err)),
383                other => {
384                    session.backlog.push_back((session.last_sequence, other));
385                    Ok(None)
386                }
387            },
388        )
389        .await
390    }
391
392    /// Sent after reattaching and rebuilding any dependent subscriptions so the
393    /// manager can resume a session that was paused for handoff.
394    pub async fn resume_attached_session(&mut self) -> BacktestClientResult<()> {
395        self.send(&BacktestRequest::ResumeAttachedSession, None)
396            .await?;
397
398        self.wait_for_response(
399            || BacktestClientError::Closed {
400                reason: "websocket ended before ResumeAttachedSession acknowledgement".to_string(),
401            },
402            |session, response| match response {
403                BacktestResponse::Success => Ok(Some(())),
404                BacktestResponse::Error(err) => Err(BacktestClientError::Remote(err)),
405                other => {
406                    session.backlog.push_back((session.last_sequence, other));
407                    Ok(None)
408                }
409            },
410        )
411        .await
412    }
413
414    async fn wait_for_response<T, E, F>(
415        &mut self,
416        closed_error: E,
417        mut handle_response: F,
418    ) -> BacktestClientResult<T>
419    where
420        E: FnOnce() -> BacktestClientError,
421        F: FnMut(&mut Self, BacktestResponse) -> BacktestClientResult<Option<T>>,
422    {
423        let mut closed_error = Some(closed_error);
424
425        loop {
426            let response = self
427                .next_response(None)
428                .await?
429                .ok_or_else(|| closed_error.take().expect("closed error set")())?;
430
431            if let Some(result) = handle_response(self, response)? {
432                return Ok(result);
433            }
434        }
435    }
436
437    /// Send a raw backtest request over the WebSocket.
438    pub async fn send(
439        &mut self,
440        request: &BacktestRequest,
441        timeout: Option<Duration>,
442    ) -> BacktestClientResult<()> {
443        let text = serde_json::to_string(request)
444            .map_err(|source| BacktestClientError::SerializeRequest { source })?;
445
446        let request_timeout = self.request_timeout;
447        let timeout = timeout.or(request_timeout);
448
449        let send_fut = self.ws_mut()?.send(Message::Text(text));
450        let send_result = match timeout {
451            Some(duration) => tokio::time::timeout(duration, send_fut)
452                .await
453                .map_err(|_| BacktestClientError::Timeout {
454                    action: "sending",
455                    duration,
456                })?,
457            None => send_fut.await,
458        };
459
460        send_result.map_err(|source| BacktestClientError::WebSocket {
461            action: "sending",
462            source: Box::new(source),
463        })?;
464
465        Ok(())
466    }
467
468    /// Receive the next response, using the backlog first.
469    pub async fn next_response(
470        &mut self,
471        timeout: Option<Duration>,
472    ) -> BacktestClientResult<Option<BacktestResponse>> {
473        if let Some((sequence, response)) = self.backlog.pop_front() {
474            self.last_sequence = sequence.or(self.last_sequence);
475            return Ok(Some(response));
476        }
477
478        let text = match self.next_text(timeout).await? {
479            Some(text) => text,
480            None => return Ok(None),
481        };
482
483        let (sequence, response) = match serde_json::from_str::<SequencedResponse>(&text) {
484            Ok(sequenced) => (Some(sequenced.seq_id), sequenced.response),
485            Err(_) => {
486                let response =
487                    serde_json::from_str::<BacktestResponse>(&text).map_err(|source| {
488                        BacktestClientError::DeserializeResponse {
489                            raw: text.clone(),
490                            source,
491                        }
492                    })?;
493                (None, response)
494            }
495        };
496        self.last_sequence = sequence.or(self.last_sequence);
497
498        Ok(Some(response))
499    }
500
501    /// Receive the next response and update readiness state.
502    pub async fn next_event(
503        &mut self,
504        timeout: Option<Duration>,
505    ) -> BacktestClientResult<Option<BacktestResponse>> {
506        let response = self.next_response(timeout).await?;
507        if let Some(ref response) = response {
508            self.apply_response(response);
509        }
510        Ok(response)
511    }
512
513    /// Stream responses, updating readiness state as items arrive.
514    ///
515    /// This consumes the session and yields responses until the connection ends.
516    pub fn responses(
517        self,
518        timeout: Option<Duration>,
519    ) -> impl futures::Stream<Item = BacktestClientResult<BacktestResponse>> {
520        stream::unfold(Some(self), move |state| async move {
521            let mut session = match state {
522                Some(session) => session,
523                None => return None,
524            };
525
526            match session.next_response(timeout).await {
527                Ok(Some(response)) => {
528                    session.apply_response(&response);
529                    Some((Ok(response), Some(session)))
530                }
531                Ok(None) => None,
532                Err(err) => Some((Err(err), None)),
533            }
534        })
535    }
536
537    /// Wait for the session to become ready or completed.
538    pub async fn ensure_ready(
539        &mut self,
540        timeout: Option<Duration>,
541    ) -> BacktestClientResult<ReadyOutcome> {
542        if self.ready_for_continue {
543            return Ok(ReadyOutcome::Ready);
544        }
545
546        loop {
547            let response =
548                self.next_response(timeout)
549                    .await?
550                    .ok_or_else(|| BacktestClientError::Closed {
551                        reason: "websocket ended while waiting for ReadyForContinue".to_string(),
552                    })?;
553
554            match response {
555                BacktestResponse::ReadyForContinue => {
556                    self.ready_for_continue = true;
557                    return Ok(ReadyOutcome::Ready);
558                }
559                BacktestResponse::Completed { .. } => return Ok(ReadyOutcome::Completed),
560                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
561                _ => {}
562            }
563        }
564    }
565
566    /// Wait for a specific status to be emitted.
567    pub async fn wait_for_status(
568        &mut self,
569        desired: BacktestStatus,
570        timeout: Option<Duration>,
571    ) -> BacktestClientResult<()> {
572        let desired = std::mem::discriminant(&desired);
573
574        loop {
575            let response =
576                self.next_response(timeout)
577                    .await?
578                    .ok_or_else(|| BacktestClientError::Closed {
579                        reason: "websocket ended while waiting for status".to_string(),
580                    })?;
581
582            match response {
583                BacktestResponse::Status { status }
584                    if std::mem::discriminant(&status) == desired =>
585                {
586                    return Ok(());
587                }
588                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
589                BacktestResponse::Completed {
590                    summary,
591                    agent_stats,
592                } => {
593                    return Err(BacktestClientError::UnexpectedResponse {
594                        context: "waiting for status",
595                        response: Box::new(BacktestResponse::Completed {
596                            summary,
597                            agent_stats,
598                        }),
599                    });
600                }
601                _ => {}
602            }
603        }
604    }
605
606    /// Push a response onto the end of the receive backlog.
607    pub(crate) fn push_backlog(&mut self, response: BacktestResponse) {
608        self.backlog.push_back((None, response));
609    }
610
611    /// Send a `Continue` request and reset readiness.
612    pub async fn send_continue(
613        &mut self,
614        params: ContinueParams,
615        timeout: Option<Duration>,
616    ) -> BacktestClientResult<()> {
617        self.ready_for_continue = false;
618        self.send(&BacktestRequest::Continue(params), timeout).await
619    }
620
621    /// Read and apply a single response while advancing.
622    pub async fn advance_step<F>(
623        &mut self,
624        state: &mut AdvanceState,
625        wait_for_slots: bool,
626        timeout: Option<Duration>,
627        on_event: &mut F,
628    ) -> BacktestClientResult<()>
629    where
630        F: FnMut(&BacktestResponse),
631    {
632        let Some(response) = self.next_response(timeout).await? else {
633            return Err(BacktestClientError::Closed {
634                reason: "websocket ended while awaiting continue responses".to_string(),
635            });
636        };
637
638        if self.log_raw {
639            tracing::debug!("<- {response:?}");
640        }
641
642        on_event(&response);
643
644        match response {
645            BacktestResponse::ReadyForContinue => {
646                self.ready_for_continue = true;
647                state.ready_for_continue = true;
648            }
649            BacktestResponse::SlotNotification(slot) => {
650                state.slot_notifications += 1;
651                state.last_slot = Some(slot);
652            }
653            BacktestResponse::Status { status } => {
654                state.statuses.push(status);
655            }
656            BacktestResponse::Success => {}
657            BacktestResponse::Completed {
658                summary,
659                agent_stats,
660            } => {
661                state.completed = true;
662                state.summary = summary;
663                state.agent_stats = agent_stats;
664            }
665            BacktestResponse::Error(err @ BacktestError::SimulationError { .. }) => {
666                tracing::warn!(error = %err, "simulation error");
667            }
668            BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
669            BacktestResponse::SessionCreated { .. }
670            | BacktestResponse::SessionAttached { .. }
671            | BacktestResponse::SessionsCreated { .. }
672            | BacktestResponse::SessionsCreatedV2 { .. }
673            | BacktestResponse::ParallelSessionAttachedV2 { .. }
674            | BacktestResponse::SessionEventV1 { .. }
675            | BacktestResponse::SessionEventV2 { .. }
676            | BacktestResponse::Paused(_)
677            | BacktestResponse::DiscoveryBatch(_) => {
678                return Err(BacktestClientError::UnexpectedResponse {
679                    context: "continuing",
680                    response: Box::new(response),
681                });
682            }
683        }
684
685        if wait_for_slots && state.slot_notifications > state.expected_slots {
686            tracing::warn!(
687                "received {} slot notifications (expected {})",
688                state.slot_notifications,
689                state.expected_slots
690            );
691        }
692
693        Ok(())
694    }
695
696    /// Advance until the session becomes ready for another `Continue`.
697    pub async fn continue_until_ready<F>(
698        &mut self,
699        cont: Continue,
700        timeout: Option<Duration>,
701        mut on_event: F,
702    ) -> BacktestClientResult<ContinueResult>
703    where
704        F: FnMut(&BacktestResponse),
705    {
706        let expected_slots = cont.advance_count;
707        self.advance_internal(
708            cont.into_params(),
709            expected_slots,
710            false,
711            timeout,
712            &mut on_event,
713        )
714        .await
715    }
716
717    /// Advance and wait for both readiness and slot notifications.
718    pub async fn advance<F>(
719        &mut self,
720        cont: Continue,
721        timeout: Option<Duration>,
722        mut on_event: F,
723    ) -> BacktestClientResult<ContinueResult>
724    where
725        F: FnMut(&BacktestResponse),
726    {
727        let expected_slots = cont.advance_count;
728        self.advance_internal(
729            cont.into_params(),
730            expected_slots,
731            true,
732            timeout,
733            &mut on_event,
734        )
735        .await
736    }
737
738    async fn advance_internal<F>(
739        &mut self,
740        params: ContinueParams,
741        expected_slots: u64,
742        wait_for_slots: bool,
743        timeout: Option<Duration>,
744        on_event: &mut F,
745    ) -> BacktestClientResult<ContinueResult>
746    where
747        F: FnMut(&BacktestResponse),
748    {
749        self.send_continue(params, timeout).await?;
750
751        let mut state = AdvanceState::new(expected_slots);
752        while !state.is_done(wait_for_slots) {
753            self.advance_step(&mut state, wait_for_slots, timeout, on_event)
754                .await?;
755        }
756
757        Ok(ContinueResult {
758            slot_notifications: state.slot_notifications,
759            last_slot: state.last_slot,
760            statuses: state.statuses,
761            ready_for_continue: state.ready_for_continue,
762            completed: state.completed,
763        })
764    }
765
766    /// Build a program-data account modification from raw ELF bytes.
767    ///
768    /// Derives the `ProgramData` address from `program_id`, queries the session's RPC endpoint
769    /// for the current slot and rent-exempt minimum, then returns a modification map ready to
770    /// pass to [`Continue::builder().modify_accounts(...)`](crate::Continue).
771    ///
772    /// The deploy slot is set to `current_slot - 1` so the program appears deployed before the
773    /// next executed slot.
774    pub async fn modify_program(
775        &self,
776        program_id: &str,
777        elf: &[u8],
778    ) -> Result<BTreeMap<Address, AccountData>, ProgramModError> {
779        let rpc = self.rpc.as_ref().ok_or(ProgramModError::NoRpcEndpoint)?;
780        crate::injection::modify_program_via_rpc(rpc, program_id, elf).await
781    }
782
783    /// Modify accounts on the session's RPC endpoint via the custom `modifyAccounts` method.
784    ///
785    /// Returns the number of accounts modified on success.
786    pub async fn modify_accounts(
787        &self,
788        modifications: &AccountModifications,
789    ) -> BacktestClientResult<usize> {
790        let rpc_endpoint =
791            self.rpc_endpoint
792                .as_deref()
793                .ok_or_else(|| BacktestClientError::Closed {
794                    reason: "no RPC endpoint available".to_string(),
795                })?;
796
797        crate::rpc::modify_accounts(rpc_endpoint, modifications).await
798    }
799
800    /// Subscribe to program log notifications using the session's RPC endpoint.
801    ///
802    /// Equivalent to calling [`subscribe_program_logs`](crate::subscribe_program_logs) with
803    /// the endpoint from [`rpc_endpoint`](Self::rpc_endpoint), which is set after
804    /// [`BacktestClient::create_session`](crate::BacktestClient::create_session) completes.
805    pub async fn subscribe_program_logs<F, Fut>(
806        &self,
807        program_id: &str,
808        commitment: CommitmentConfig,
809        on_notification: F,
810    ) -> Result<LogSubscriptionHandle, SubscriptionError>
811    where
812        F: Fn(Response<RpcLogsResponse>) -> Fut + Send + Sync + 'static,
813        Fut: Future<Output = ()> + Send + 'static,
814    {
815        let rpc_endpoint = self
816            .rpc_endpoint
817            .as_deref()
818            .ok_or(SubscriptionError::NoRpcEndpoint)?;
819        crate::subscriptions::subscribe_program_logs(
820            rpc_endpoint,
821            program_id,
822            commitment,
823            on_notification,
824        )
825        .await
826    }
827
828    /// Subscribe to account diff notifications using the session's RPC endpoint.
829    ///
830    /// Equivalent to calling [`subscribe_account_diffs`](crate::subscribe_account_diffs) with
831    /// the endpoint from [`rpc_endpoint`](Self::rpc_endpoint), which is set after
832    /// [`BacktestClient::create_session`](crate::BacktestClient::create_session) completes.
833    pub async fn subscribe_account_diffs<F, Fut>(
834        &self,
835        account: &str,
836        on_notification: F,
837    ) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
838    where
839        F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
840        Fut: Future<Output = ()> + Send + 'static,
841    {
842        let rpc_endpoint = self
843            .rpc_endpoint
844            .as_deref()
845            .ok_or(SubscriptionError::NoRpcEndpoint)?;
846        crate::subscriptions::subscribe_account_diffs(rpc_endpoint, account, on_notification).await
847    }
848
849    /// Request server cleanup and close the underlying WebSocket.
850    ///
851    /// This is idempotent and will return `Ok(())` if the connection is already closed.
852    pub async fn close(&mut self, timeout: Option<Duration>) -> BacktestClientResult<()> {
853        self.close_with_frame(timeout, None).await
854    }
855
856    /// Close the session with a specific WebSocket close frame.
857    pub async fn close_with_frame(
858        &mut self,
859        timeout: Option<Duration>,
860        frame: Option<CloseFrame<'static>>,
861    ) -> BacktestClientResult<()> {
862        if self.ws.is_none() {
863            return Ok(());
864        }
865
866        let mut sent = false;
867        match self
868            .send(&BacktestRequest::CloseBacktestSession, timeout)
869            .await
870        {
871            Ok(()) => sent = true,
872            Err(err) if is_close_ok(&err) => {}
873            Err(err) => return Err(err),
874        }
875
876        if sent {
877            let response = match self.next_response(timeout).await {
878                Ok(Some(r)) => r,
879                Ok(None) => {
880                    self.ws.take();
881                    return Ok(());
882                }
883                Err(BacktestClientError::Closed { .. }) => {
884                    self.ws.take();
885                    return Ok(());
886                }
887                Err(BacktestClientError::WebSocket {
888                    action: "receiving",
889                    source,
890                }) if is_reset_without_close(&source) => {
891                    self.ws.take();
892                    return Ok(());
893                }
894                Err(err) => return Err(err),
895            };
896
897            match response {
898                BacktestResponse::Success | BacktestResponse::Completed { .. } => {}
899                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
900                other => {
901                    return Err(BacktestClientError::UnexpectedResponse {
902                        context: "closing session",
903                        response: Box::new(other),
904                    });
905                }
906            }
907        }
908
909        if let Some(ws) = self.ws.as_mut() {
910            let close_result = ws.close(frame).await;
911            if let Err(source) = close_result
912                && !is_ws_closed_error(&source)
913            {
914                return Err(BacktestClientError::WebSocket {
915                    action: "closing",
916                    source: Box::new(source),
917                });
918            }
919        }
920
921        // Await the close confirmation
922        match self.next_response(timeout).await {
923            Ok(_) => {}
924            Err(BacktestClientError::Closed { .. }) => {}
925            Err(BacktestClientError::WebSocket {
926                action: "receiving",
927                source,
928            }) if is_reset_without_close(&source) => {}
929            Err(err) => return Err(err),
930        }
931
932        // Give some time for the close to propagate
933        tokio::time::sleep(Duration::from_millis(100)).await;
934        self.ws.take();
935        Ok(())
936    }
937
938    /// Close the session with a close code and reason.
939    pub async fn close_with_reason(
940        &mut self,
941        timeout: Option<Duration>,
942        code: CloseCode,
943        reason: impl Into<String>,
944    ) -> BacktestClientResult<()> {
945        let frame = CloseFrame {
946            code,
947            reason: Cow::Owned(reason.into()),
948        };
949        self.close_with_frame(timeout, Some(frame)).await
950    }
951
952    async fn next_text(
953        &mut self,
954        timeout: Option<Duration>,
955    ) -> BacktestClientResult<Option<String>> {
956        loop {
957            let request_timeout = self.request_timeout;
958            let timeout = timeout.or(request_timeout);
959
960            let next_fut = self.ws_mut()?.next();
961            let msg = match timeout {
962                Some(duration) => tokio::time::timeout(duration, next_fut)
963                    .await
964                    .map_err(|_| BacktestClientError::Timeout {
965                        action: "receiving",
966                        duration,
967                    })?,
968                None => next_fut.await,
969            };
970
971            let Some(msg) = msg else {
972                return Ok(None);
973            };
974
975            let msg = match msg {
976                Ok(msg) => msg,
977                Err(source) => {
978                    return Err(BacktestClientError::WebSocket {
979                        action: "receiving",
980                        source: Box::new(source),
981                    });
982                }
983            };
984
985            match msg {
986                Message::Text(text) => {
987                    if self.log_raw {
988                        tracing::debug!("<- raw: {text}");
989                    }
990                    return Ok(Some(text));
991                }
992                Message::Binary(bin) => match String::from_utf8(bin) {
993                    Ok(text) => {
994                        if self.log_raw {
995                            tracing::debug!("<- raw(bin): {text}");
996                        }
997                        return Ok(Some(text));
998                    }
999                    Err(err) => {
1000                        tracing::warn!("discarding non-utf8 binary message: {err}");
1001                        continue;
1002                    }
1003                },
1004                Message::Close(frame) => {
1005                    let reason = close_reason(frame);
1006                    return Err(BacktestClientError::Closed { reason });
1007                }
1008                Message::Ping(_) | Message::Pong(_) => continue,
1009                Message::Frame(_) => continue,
1010            }
1011        }
1012    }
1013}
1014
1015impl std::fmt::Debug for BacktestSession {
1016    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1017        f.debug_struct("BacktestSession")
1018            .field("session_id", &self.session_id)
1019            .field("rpc_endpoint", &self.rpc_endpoint)
1020            .field(
1021                "rpc",
1022                &self
1023                    .rpc
1024                    .as_ref()
1025                    .map(|_| "<RpcClient>")
1026                    .unwrap_or("<not set>"),
1027            )
1028            .field("ready_for_continue", &self.ready_for_continue)
1029            .field("request_timeout", &self.request_timeout)
1030            .finish_non_exhaustive()
1031    }
1032}
1033
1034#[derive(Debug)]
1035pub(crate) enum CreateRequestResult {
1036    Single { session_id: String },
1037    Parallel { session_ids: Vec<String> },
1038}
1039
1040impl Drop for BacktestSession {
1041    fn drop(&mut self) {
1042        let Some(ws) = self.ws.take() else {
1043            return;
1044        };
1045
1046        if let Ok(handle) = tokio::runtime::Handle::try_current() {
1047            handle.spawn(async move {
1048                let mut ws = ws;
1049                let _ = ws.close(None).await;
1050            });
1051        }
1052    }
1053}
1054
1055fn resolve_rpc_url(base: &str, endpoint: &str) -> String {
1056    if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
1057        endpoint.to_string()
1058    } else {
1059        format!("{}/{}", base, endpoint.trim_start_matches('/'))
1060    }
1061}
1062
1063fn close_reason(frame: Option<CloseFrame<'static>>) -> String {
1064    match frame {
1065        Some(frame) => format!("{:?}: {}", frame.code, frame.reason),
1066        None => "no close frame".to_string(),
1067    }
1068}
1069
1070fn is_reset_without_close(err: &WsError) -> bool {
1071    matches!(
1072        err,
1073        WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1074    )
1075}
1076
1077fn is_ws_closed_error(err: &WsError) -> bool {
1078    matches!(
1079        err,
1080        WsError::ConnectionClosed
1081            | WsError::AlreadyClosed
1082            | WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1083    )
1084}
1085
1086fn is_close_ok(err: &BacktestClientError) -> bool {
1087    match err {
1088        BacktestClientError::Closed { .. } => true,
1089        BacktestClientError::WebSocket { source, .. } => is_ws_closed_error(source),
1090        _ => false,
1091    }
1092}
1093
1094#[cfg(test)]
1095mod tests {
1096    use super::*;
1097
1098    #[test]
1099    fn coverage_tracks_slot_and_completion_from_responses() {
1100        let mut coverage = SessionCoverage::default();
1101        coverage.observe_response(&BacktestResponse::SlotNotification(10));
1102        coverage.observe_response(&BacktestResponse::SlotNotification(12));
1103        coverage.observe_response(&BacktestResponse::Completed {
1104            summary: None,
1105            agent_stats: None,
1106        });
1107
1108        assert!(coverage.is_completed());
1109        assert_eq!(coverage.highest_slot_seen(), Some(12));
1110    }
1111
1112    #[test]
1113    fn coverage_validate_end_slot_checks_completion_and_range() {
1114        let mut coverage = SessionCoverage::default();
1115        assert_eq!(
1116            coverage.validate_end_slot(5),
1117            Err(CoverageError::NotCompleted)
1118        );
1119
1120        coverage.mark_completed();
1121        assert_eq!(
1122            coverage.validate_end_slot(5),
1123            Err(CoverageError::NoSlotsObserved)
1124        );
1125
1126        coverage.observe_slot(4);
1127        assert_eq!(
1128            coverage.validate_end_slot(5),
1129            Err(CoverageError::RangeNotReached {
1130                actual_end_slot: 4,
1131                expected_end_slot: 5,
1132            })
1133        );
1134
1135        coverage.observe_slot(6);
1136        assert_eq!(coverage.validate_end_slot(5), Ok(()));
1137    }
1138}