Skip to main content

simulator_client/
session.rs

1use std::{
2    borrow::Cow,
3    collections::{BTreeMap, VecDeque},
4    future::Future,
5    time::Duration,
6};
7
8use futures::{SinkExt, StreamExt, stream};
9use simulator_api::{
10    AccountData, AccountModifications, AgentStatsReport, BacktestError, BacktestRequest,
11    BacktestResponse, BacktestStatus, ContinueParams, CreateBacktestSessionRequest,
12    CreateBacktestSessionRequestV1, SequencedResponse, SessionSummary,
13};
14use solana_address::Address;
15use solana_client::{
16    nonblocking::rpc_client::RpcClient,
17    rpc_response::{Response, RpcLogsResponse},
18};
19use solana_commitment_config::CommitmentConfig;
20use thiserror::Error;
21use tokio::net::TcpStream;
22use tokio_tungstenite::{
23    MaybeTlsStream, WebSocketStream,
24    tungstenite::{
25        Error as WsError, Message,
26        error::ProtocolError,
27        protocol::{CloseFrame, frame::coding::CloseCode},
28    },
29};
30
31use crate::{
32    BacktestClientError, BacktestClientResult, Continue,
33    injection::{ProgramModError, build_program_injection},
34    subscriptions::{
35        AccountDiffNotification, AccountDiffSubscriptionHandle, LogSubscriptionHandle,
36        SubscriptionError,
37    },
38};
39
40/// Outcome of waiting for readiness on a backtest session.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum ReadyOutcome {
43    /// The session is ready to accept a `Continue` request.
44    Ready,
45    /// The session completed before becoming ready.
46    Completed,
47}
48
49/// Summary of responses collected while advancing a backtest.
50#[derive(Debug, Default)]
51pub struct ContinueResult {
52    /// Number of slot notifications observed.
53    pub slot_notifications: u64,
54    /// Last slot received via notification.
55    pub last_slot: Option<u64>,
56    /// Status messages received.
57    pub statuses: Vec<BacktestStatus>,
58    /// Whether the session became ready for the next `Continue`.
59    pub ready_for_continue: bool,
60    /// Whether the session completed while advancing.
61    pub completed: bool,
62}
63
64/// Mutable state for driving a session with `advance_step`.
65#[derive(Debug)]
66pub struct AdvanceState {
67    /// Expected number of slot notifications for this step.
68    pub expected_slots: u64,
69    /// Count of slot notifications received so far.
70    pub slot_notifications: u64,
71    /// Most recent slot notification.
72    pub last_slot: Option<u64>,
73    /// Status messages received so far.
74    pub statuses: Vec<BacktestStatus>,
75    /// Whether the session is ready for another `Continue`.
76    pub ready_for_continue: bool,
77    /// Whether the session completed while advancing.
78    pub completed: bool,
79    /// Session summary received on completion (if send_summary was enabled).
80    pub summary: Option<SessionSummary>,
81    /// Agent stats received on completion.
82    pub agent_stats: Option<Vec<AgentStatsReport>>,
83}
84
85impl AdvanceState {
86    /// Create new tracking state for a step that expects `expected_slots` notifications.
87    pub fn new(expected_slots: u64) -> Self {
88        Self {
89            expected_slots,
90            slot_notifications: 0,
91            last_slot: None,
92            statuses: Vec::new(),
93            ready_for_continue: false,
94            completed: false,
95            summary: None,
96            agent_stats: None,
97        }
98    }
99
100    /// Return true when the step is complete based on readiness and slot count.
101    pub fn is_done(&self, wait_for_slots: bool) -> bool {
102        if self.completed {
103            return true;
104        }
105
106        if !self.ready_for_continue {
107            return false;
108        }
109
110        !wait_for_slots || self.slot_notifications >= self.expected_slots
111    }
112}
113
114/// Coverage of a session's observed slot notifications and completion state.
115#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
116pub struct SessionCoverage {
117    completed: bool,
118    highest_slot_seen: Option<u64>,
119}
120
121impl SessionCoverage {
122    /// Record a slot notification.
123    pub fn observe_slot(&mut self, slot: u64) {
124        self.highest_slot_seen = Some(
125            self.highest_slot_seen
126                .map_or(slot, |current| current.max(slot)),
127        );
128    }
129
130    /// Mark the session as completed.
131    pub fn mark_completed(&mut self) {
132        self.completed = true;
133    }
134
135    /// Update coverage state from a backtest response.
136    pub fn observe_response(&mut self, response: &BacktestResponse) {
137        match response {
138            BacktestResponse::SlotNotification(slot) => self.observe_slot(*slot),
139            BacktestResponse::Completed { .. } => self.mark_completed(),
140            _ => {}
141        }
142    }
143
144    /// Return whether completion has been observed.
145    pub fn is_completed(&self) -> bool {
146        self.completed
147    }
148
149    /// Return the highest slot observed via slot notifications.
150    pub fn highest_slot_seen(&self) -> Option<u64> {
151        self.highest_slot_seen
152    }
153
154    /// Validate that the session completed and reached `expected_end_slot`.
155    pub fn validate_end_slot(&self, expected_end_slot: u64) -> Result<(), CoverageError> {
156        if !self.completed {
157            return Err(CoverageError::NotCompleted);
158        }
159
160        let Some(actual_end_slot) = self.highest_slot_seen else {
161            return Err(CoverageError::NoSlotsObserved);
162        };
163
164        if actual_end_slot < expected_end_slot {
165            return Err(CoverageError::RangeNotReached {
166                actual_end_slot,
167                expected_end_slot,
168            });
169        }
170
171        Ok(())
172    }
173}
174
175/// Coverage validation failures for a backtest session.
176#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
177pub enum CoverageError {
178    #[error("ended before completion")]
179    NotCompleted,
180    #[error("completed without slot notifications")]
181    NoSlotsObserved,
182    #[error("completed at slot {actual_end_slot} but expected at least {expected_end_slot}")]
183    RangeNotReached {
184        actual_end_slot: u64,
185        expected_end_slot: u64,
186    },
187}
188
189/// Active backtest session over a WebSocket connection.
190///
191/// Dropping the session sends a best-effort close frame if a Tokio runtime is
192/// available. Call [`BacktestSession::close`] to request server-side cleanup.
193pub struct BacktestSession {
194    ws: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
195    session_id: Option<String>,
196    rpc_endpoint: Option<String>,
197    rpc: Option<RpcClient>,
198    last_sequence: Option<u64>,
199    ready_for_continue: bool,
200    request_timeout: Option<Duration>,
201    log_raw: bool,
202    backlog: VecDeque<(Option<u64>, BacktestResponse)>,
203}
204
205impl BacktestSession {
206    pub(crate) fn new(
207        ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
208        request_timeout: Option<Duration>,
209        log_raw: bool,
210    ) -> Self {
211        Self {
212            ws: Some(ws),
213            session_id: None,
214            rpc_endpoint: None,
215            rpc: None,
216            last_sequence: None,
217            ready_for_continue: false,
218            request_timeout,
219            log_raw,
220            backlog: VecDeque::new(),
221        }
222    }
223
224    /// Return the server-assigned session id, if known.
225    pub fn session_id(&self) -> Option<&str> {
226        self.session_id.as_deref()
227    }
228
229    /// Return the session-scoped RPC endpoint if provided.
230    pub fn rpc_endpoint(&self) -> Option<&str> {
231        self.rpc_endpoint.as_deref()
232    }
233
234    /// Return the highest sequenced control-websocket response observed so far.
235    pub fn last_sequence(&self) -> Option<u64> {
236        self.last_sequence
237    }
238
239    /// Return the RPC client for this session's endpoint.
240    ///
241    /// Always available after [`BacktestClient::create_session`](crate::BacktestClient::create_session) completes.
242    pub fn rpc(&self) -> &RpcClient {
243        self.rpc
244            .as_ref()
245            .expect("rpc is set during session creation")
246    }
247
248    /// Return whether the session is currently ready to accept `Continue`.
249    pub fn is_ready_for_continue(&self) -> bool {
250        self.ready_for_continue
251    }
252
253    /// Update internal readiness state based on a response.
254    pub fn apply_response(&mut self, response: &BacktestResponse) {
255        match response {
256            BacktestResponse::ReadyForContinue => {
257                self.ready_for_continue = true;
258            }
259            BacktestResponse::Completed { .. } => {
260                self.ready_for_continue = false;
261            }
262            _ => {}
263        }
264    }
265
266    fn ws_mut(&mut self) -> BacktestClientResult<&mut WebSocketStream<MaybeTlsStream<TcpStream>>> {
267        self.ws.as_mut().ok_or_else(|| BacktestClientError::Closed {
268            reason: "websocket closed".to_string(),
269        })
270    }
271
272    pub(crate) async fn create_with_request(
273        &mut self,
274        request: CreateBacktestSessionRequest,
275        rpc_base_url: String,
276        mut on_parallel_session_created: Option<&mut (dyn FnMut(String) + Send)>,
277    ) -> BacktestClientResult<CreateRequestResult> {
278        let expect_parallel = matches!(
279            &request,
280            CreateBacktestSessionRequest::V1(CreateBacktestSessionRequestV1 { parallel: true, .. })
281        );
282        self.send(&BacktestRequest::CreateBacktestSession(request), None)
283            .await?;
284        let mut streamed_parallel_session_ids = Vec::new();
285
286        loop {
287            let response =
288                self.next_response(None)
289                    .await?
290                    .ok_or_else(|| BacktestClientError::Closed {
291                        reason: "websocket ended before SessionCreated".to_string(),
292                    })?;
293
294            match response {
295                BacktestResponse::SessionCreated {
296                    session_id,
297                    rpc_endpoint,
298                } => {
299                    if expect_parallel {
300                        if let Some(callback) = on_parallel_session_created.as_mut() {
301                            (**callback)(session_id.clone());
302                        }
303                        streamed_parallel_session_ids.push(session_id);
304                        continue;
305                    }
306                    let created_session_id = session_id.clone();
307                    self.session_id = Some(session_id);
308                    let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
309                    self.rpc = Some(RpcClient::new_with_commitment(
310                        resolved.clone(),
311                        CommitmentConfig::confirmed(),
312                    ));
313                    self.rpc_endpoint = Some(resolved);
314                    return Ok(CreateRequestResult::Single {
315                        session_id: created_session_id,
316                    });
317                }
318                BacktestResponse::SessionsCreated { session_ids } => {
319                    if expect_parallel && session_ids.is_empty() {
320                        return Ok(CreateRequestResult::Parallel {
321                            session_ids: streamed_parallel_session_ids,
322                        });
323                    }
324                    return Ok(CreateRequestResult::Parallel { session_ids });
325                }
326                BacktestResponse::SessionsCreatedV2 { session_ids, .. } => {
327                    if expect_parallel && session_ids.is_empty() {
328                        return Ok(CreateRequestResult::Parallel {
329                            session_ids: streamed_parallel_session_ids,
330                        });
331                    }
332                    return Ok(CreateRequestResult::Parallel { session_ids });
333                }
334                BacktestResponse::ReadyForContinue => {
335                    self.ready_for_continue = true;
336                }
337                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
338                other => {
339                    self.backlog.push_back((self.last_sequence, other));
340                }
341            }
342        }
343    }
344
345    pub(crate) async fn attach(
346        &mut self,
347        session_id: String,
348        last_sequence: Option<u64>,
349        rpc_base_url: String,
350    ) -> BacktestClientResult<()> {
351        self.send(
352            &BacktestRequest::AttachBacktestSession {
353                session_id,
354                last_sequence,
355            },
356            None,
357        )
358        .await?;
359
360        self.wait_for_response(
361            || BacktestClientError::Closed {
362                reason: "websocket ended before SessionAttached".to_string(),
363            },
364            move |session, response| match response {
365                BacktestResponse::SessionAttached {
366                    session_id,
367                    rpc_endpoint,
368                } => {
369                    session.session_id = Some(session_id);
370                    let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
371                    session.rpc = Some(RpcClient::new_with_commitment(
372                        resolved.clone(),
373                        CommitmentConfig::confirmed(),
374                    ));
375                    session.rpc_endpoint = Some(resolved);
376                    Ok(Some(()))
377                }
378                BacktestResponse::ReadyForContinue => {
379                    session.ready_for_continue = true;
380                    Ok(None)
381                }
382                BacktestResponse::Error(err) => Err(BacktestClientError::Remote(err)),
383                other => {
384                    session.backlog.push_back((session.last_sequence, other));
385                    Ok(None)
386                }
387            },
388        )
389        .await
390    }
391
392    /// Sent after reattaching and rebuilding any dependent subscriptions so the
393    /// manager can resume a session that was paused for handoff.
394    pub async fn resume_attached_session(&mut self) -> BacktestClientResult<()> {
395        self.send(&BacktestRequest::ResumeAttachedSession, None)
396            .await?;
397
398        self.wait_for_response(
399            || BacktestClientError::Closed {
400                reason: "websocket ended before ResumeAttachedSession acknowledgement".to_string(),
401            },
402            |session, response| match response {
403                BacktestResponse::Success => Ok(Some(())),
404                BacktestResponse::Error(err) => Err(BacktestClientError::Remote(err)),
405                other => {
406                    session.backlog.push_back((session.last_sequence, other));
407                    Ok(None)
408                }
409            },
410        )
411        .await
412    }
413
414    async fn wait_for_response<T, E, F>(
415        &mut self,
416        closed_error: E,
417        mut handle_response: F,
418    ) -> BacktestClientResult<T>
419    where
420        E: FnOnce() -> BacktestClientError,
421        F: FnMut(&mut Self, BacktestResponse) -> BacktestClientResult<Option<T>>,
422    {
423        let mut closed_error = Some(closed_error);
424
425        loop {
426            let response = self
427                .next_response(None)
428                .await?
429                .ok_or_else(|| closed_error.take().expect("closed error set")())?;
430
431            if let Some(result) = handle_response(self, response)? {
432                return Ok(result);
433            }
434        }
435    }
436
437    /// Send a raw backtest request over the WebSocket.
438    pub async fn send(
439        &mut self,
440        request: &BacktestRequest,
441        timeout: Option<Duration>,
442    ) -> BacktestClientResult<()> {
443        let text = serde_json::to_string(request)
444            .map_err(|source| BacktestClientError::SerializeRequest { source })?;
445
446        let request_timeout = self.request_timeout;
447        let timeout = timeout.or(request_timeout);
448
449        let send_fut = self.ws_mut()?.send(Message::Text(text));
450        let send_result = match timeout {
451            Some(duration) => tokio::time::timeout(duration, send_fut)
452                .await
453                .map_err(|_| BacktestClientError::Timeout {
454                    action: "sending",
455                    duration,
456                })?,
457            None => send_fut.await,
458        };
459
460        send_result.map_err(|source| BacktestClientError::WebSocket {
461            action: "sending",
462            source: Box::new(source),
463        })?;
464
465        Ok(())
466    }
467
468    /// Receive the next response, using the backlog first.
469    pub async fn next_response(
470        &mut self,
471        timeout: Option<Duration>,
472    ) -> BacktestClientResult<Option<BacktestResponse>> {
473        if let Some((sequence, response)) = self.backlog.pop_front() {
474            self.last_sequence = sequence.or(self.last_sequence);
475            return Ok(Some(response));
476        }
477
478        let text = match self.next_text(timeout).await? {
479            Some(text) => text,
480            None => return Ok(None),
481        };
482
483        let (sequence, response) = match serde_json::from_str::<SequencedResponse>(&text) {
484            Ok(sequenced) => (Some(sequenced.seq_id), sequenced.response),
485            Err(_) => {
486                let response =
487                    serde_json::from_str::<BacktestResponse>(&text).map_err(|source| {
488                        BacktestClientError::DeserializeResponse {
489                            raw: text.clone(),
490                            source,
491                        }
492                    })?;
493                (None, response)
494            }
495        };
496        self.last_sequence = sequence.or(self.last_sequence);
497
498        Ok(Some(response))
499    }
500
501    /// Receive the next response and update readiness state.
502    pub async fn next_event(
503        &mut self,
504        timeout: Option<Duration>,
505    ) -> BacktestClientResult<Option<BacktestResponse>> {
506        let response = self.next_response(timeout).await?;
507        if let Some(ref response) = response {
508            self.apply_response(response);
509        }
510        Ok(response)
511    }
512
513    /// Stream responses, updating readiness state as items arrive.
514    ///
515    /// This consumes the session and yields responses until the connection ends.
516    pub fn responses(
517        self,
518        timeout: Option<Duration>,
519    ) -> impl futures::Stream<Item = BacktestClientResult<BacktestResponse>> {
520        stream::unfold(Some(self), move |state| async move {
521            let mut session = match state {
522                Some(session) => session,
523                None => return None,
524            };
525
526            match session.next_response(timeout).await {
527                Ok(Some(response)) => {
528                    session.apply_response(&response);
529                    Some((Ok(response), Some(session)))
530                }
531                Ok(None) => None,
532                Err(err) => Some((Err(err), None)),
533            }
534        })
535    }
536
537    /// Wait for the session to become ready or completed.
538    pub async fn ensure_ready(
539        &mut self,
540        timeout: Option<Duration>,
541    ) -> BacktestClientResult<ReadyOutcome> {
542        if self.ready_for_continue {
543            return Ok(ReadyOutcome::Ready);
544        }
545
546        loop {
547            let response =
548                self.next_response(timeout)
549                    .await?
550                    .ok_or_else(|| BacktestClientError::Closed {
551                        reason: "websocket ended while waiting for ReadyForContinue".to_string(),
552                    })?;
553
554            match response {
555                BacktestResponse::ReadyForContinue => {
556                    self.ready_for_continue = true;
557                    return Ok(ReadyOutcome::Ready);
558                }
559                BacktestResponse::Completed { .. } => return Ok(ReadyOutcome::Completed),
560                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
561                _ => {}
562            }
563        }
564    }
565
566    /// Wait for a specific status to be emitted.
567    pub async fn wait_for_status(
568        &mut self,
569        desired: BacktestStatus,
570        timeout: Option<Duration>,
571    ) -> BacktestClientResult<()> {
572        let desired = std::mem::discriminant(&desired);
573
574        loop {
575            let response =
576                self.next_response(timeout)
577                    .await?
578                    .ok_or_else(|| BacktestClientError::Closed {
579                        reason: "websocket ended while waiting for status".to_string(),
580                    })?;
581
582            match response {
583                BacktestResponse::Status { status }
584                    if std::mem::discriminant(&status) == desired =>
585                {
586                    return Ok(());
587                }
588                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
589                BacktestResponse::Completed {
590                    summary,
591                    agent_stats,
592                } => {
593                    return Err(BacktestClientError::UnexpectedResponse {
594                        context: "waiting for status",
595                        response: Box::new(BacktestResponse::Completed {
596                            summary,
597                            agent_stats,
598                        }),
599                    });
600                }
601                _ => {}
602            }
603        }
604    }
605
606    /// Send a `Continue` request and reset readiness.
607    pub async fn send_continue(
608        &mut self,
609        params: ContinueParams,
610        timeout: Option<Duration>,
611    ) -> BacktestClientResult<()> {
612        self.ready_for_continue = false;
613        self.send(&BacktestRequest::Continue(params), timeout).await
614    }
615
616    /// Read and apply a single response while advancing.
617    pub async fn advance_step<F>(
618        &mut self,
619        state: &mut AdvanceState,
620        wait_for_slots: bool,
621        timeout: Option<Duration>,
622        on_event: &mut F,
623    ) -> BacktestClientResult<()>
624    where
625        F: FnMut(&BacktestResponse),
626    {
627        let Some(response) = self.next_response(timeout).await? else {
628            return Err(BacktestClientError::Closed {
629                reason: "websocket ended while awaiting continue responses".to_string(),
630            });
631        };
632
633        if self.log_raw {
634            tracing::debug!("<- {response:?}");
635        }
636
637        on_event(&response);
638
639        match response {
640            BacktestResponse::ReadyForContinue => {
641                self.ready_for_continue = true;
642                state.ready_for_continue = true;
643            }
644            BacktestResponse::SlotNotification(slot) => {
645                state.slot_notifications += 1;
646                state.last_slot = Some(slot);
647            }
648            BacktestResponse::Status { status } => {
649                state.statuses.push(status);
650            }
651            BacktestResponse::Success => {}
652            BacktestResponse::Completed {
653                summary,
654                agent_stats,
655            } => {
656                state.completed = true;
657                state.summary = summary;
658                state.agent_stats = agent_stats;
659            }
660            BacktestResponse::Error(err @ BacktestError::SimulationError { .. }) => {
661                tracing::warn!(error = %err, "simulation error");
662            }
663            BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
664            BacktestResponse::SessionCreated { .. }
665            | BacktestResponse::SessionAttached { .. }
666            | BacktestResponse::SessionsCreated { .. }
667            | BacktestResponse::SessionsCreatedV2 { .. }
668            | BacktestResponse::ParallelSessionAttachedV2 { .. }
669            | BacktestResponse::SessionEventV1 { .. }
670            | BacktestResponse::SessionEventV2 { .. } => {
671                return Err(BacktestClientError::UnexpectedResponse {
672                    context: "continuing",
673                    response: Box::new(response),
674                });
675            }
676        }
677
678        if wait_for_slots && state.slot_notifications > state.expected_slots {
679            tracing::warn!(
680                "received {} slot notifications (expected {})",
681                state.slot_notifications,
682                state.expected_slots
683            );
684        }
685
686        Ok(())
687    }
688
689    /// Advance until the session becomes ready for another `Continue`.
690    pub async fn continue_until_ready<F>(
691        &mut self,
692        cont: Continue,
693        timeout: Option<Duration>,
694        mut on_event: F,
695    ) -> BacktestClientResult<ContinueResult>
696    where
697        F: FnMut(&BacktestResponse),
698    {
699        let expected_slots = cont.advance_count;
700        self.advance_internal(
701            cont.into_params(),
702            expected_slots,
703            false,
704            timeout,
705            &mut on_event,
706        )
707        .await
708    }
709
710    /// Advance and wait for both readiness and slot notifications.
711    pub async fn advance<F>(
712        &mut self,
713        cont: Continue,
714        timeout: Option<Duration>,
715        mut on_event: F,
716    ) -> BacktestClientResult<ContinueResult>
717    where
718        F: FnMut(&BacktestResponse),
719    {
720        let expected_slots = cont.advance_count;
721        self.advance_internal(
722            cont.into_params(),
723            expected_slots,
724            true,
725            timeout,
726            &mut on_event,
727        )
728        .await
729    }
730
731    async fn advance_internal<F>(
732        &mut self,
733        params: ContinueParams,
734        expected_slots: u64,
735        wait_for_slots: bool,
736        timeout: Option<Duration>,
737        on_event: &mut F,
738    ) -> BacktestClientResult<ContinueResult>
739    where
740        F: FnMut(&BacktestResponse),
741    {
742        self.send_continue(params, timeout).await?;
743
744        let mut state = AdvanceState::new(expected_slots);
745        while !state.is_done(wait_for_slots) {
746            self.advance_step(&mut state, wait_for_slots, timeout, on_event)
747                .await?;
748        }
749
750        Ok(ContinueResult {
751            slot_notifications: state.slot_notifications,
752            last_slot: state.last_slot,
753            statuses: state.statuses,
754            ready_for_continue: state.ready_for_continue,
755            completed: state.completed,
756        })
757    }
758
759    /// Build a program-data account modification from raw ELF bytes.
760    ///
761    /// Derives the `ProgramData` address from `program_id`, queries the session's RPC endpoint
762    /// for the current slot and rent-exempt minimum, then returns a modification map ready to
763    /// pass to [`Continue::builder().modify_accounts(...)`](crate::Continue).
764    ///
765    /// The deploy slot is set to `current_slot - 1` so the program appears deployed before the
766    /// next executed slot.
767    pub async fn modify_program(
768        &self,
769        program_id: &str,
770        elf: &[u8],
771    ) -> Result<BTreeMap<Address, AccountData>, ProgramModError> {
772        let rpc = self.rpc.as_ref().ok_or(ProgramModError::NoRpcEndpoint)?;
773
774        let program_addr: Address =
775            program_id
776                .parse()
777                .map_err(|_| ProgramModError::InvalidProgramId {
778                    id: program_id.to_string(),
779                })?;
780        let programdata_addr = solana_loader_v3_interface::get_program_data_address(&program_addr);
781
782        let slot = rpc.get_slot().await.map_err(|e| ProgramModError::Rpc {
783            source: Box::new(e),
784        })?;
785        let deploy_slot = slot.saturating_sub(1);
786
787        // Fetch the existing programdata account to read the upgrade authority.
788        // The discriminant at byte 12 tells us whether an authority is present.
789        let existing =
790            rpc.get_account(&programdata_addr)
791                .await
792                .map_err(|e| ProgramModError::Rpc {
793                    source: Box::new(e),
794                })?;
795
796        let upgrade_authority = if existing.data.get(12).copied() == Some(1) {
797            existing.data.get(13..45).and_then(|b| {
798                let bytes: [u8; 32] = b.try_into().ok()?;
799                Some(Address::from(bytes))
800            })
801        } else {
802            None
803        };
804
805        let data_len = upgrade_authority.map_or(13, |_| 45) + elf.len();
806        let lamports = rpc
807            .get_minimum_balance_for_rent_exemption(data_len)
808            .await
809            .map_err(|e| ProgramModError::Rpc {
810                source: Box::new(e),
811            })?;
812
813        Ok(build_program_injection(
814            programdata_addr,
815            elf,
816            deploy_slot,
817            upgrade_authority,
818            lamports,
819        ))
820    }
821
822    /// Modify accounts on the session's RPC endpoint via the custom `modifyAccounts` method.
823    ///
824    /// Returns the number of accounts modified on success.
825    pub async fn modify_accounts(
826        &self,
827        modifications: &AccountModifications,
828    ) -> BacktestClientResult<usize> {
829        let rpc_endpoint =
830            self.rpc_endpoint
831                .as_deref()
832                .ok_or_else(|| BacktestClientError::Closed {
833                    reason: "no RPC endpoint available".to_string(),
834                })?;
835
836        crate::rpc::modify_accounts(rpc_endpoint, modifications).await
837    }
838
839    /// Subscribe to program log notifications using the session's RPC endpoint.
840    ///
841    /// Equivalent to calling [`subscribe_program_logs`](crate::subscribe_program_logs) with
842    /// the endpoint from [`rpc_endpoint`](Self::rpc_endpoint), which is set after
843    /// [`BacktestClient::create_session`](crate::BacktestClient::create_session) completes.
844    pub async fn subscribe_program_logs<F, Fut>(
845        &self,
846        program_id: &str,
847        commitment: CommitmentConfig,
848        on_notification: F,
849    ) -> Result<LogSubscriptionHandle, SubscriptionError>
850    where
851        F: Fn(Response<RpcLogsResponse>) -> Fut + Send + Sync + 'static,
852        Fut: Future<Output = ()> + Send + 'static,
853    {
854        let rpc_endpoint = self
855            .rpc_endpoint
856            .as_deref()
857            .ok_or(SubscriptionError::NoRpcEndpoint)?;
858        crate::subscriptions::subscribe_program_logs(
859            rpc_endpoint,
860            program_id,
861            commitment,
862            on_notification,
863        )
864        .await
865    }
866
867    /// Subscribe to account diff notifications using the session's RPC endpoint.
868    ///
869    /// Equivalent to calling [`subscribe_account_diffs`](crate::subscribe_account_diffs) with
870    /// the endpoint from [`rpc_endpoint`](Self::rpc_endpoint), which is set after
871    /// [`BacktestClient::create_session`](crate::BacktestClient::create_session) completes.
872    pub async fn subscribe_account_diffs<F, Fut>(
873        &self,
874        account: &str,
875        on_notification: F,
876    ) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
877    where
878        F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
879        Fut: Future<Output = ()> + Send + 'static,
880    {
881        let rpc_endpoint = self
882            .rpc_endpoint
883            .as_deref()
884            .ok_or(SubscriptionError::NoRpcEndpoint)?;
885        crate::subscriptions::subscribe_account_diffs(rpc_endpoint, account, on_notification).await
886    }
887
888    /// Request server cleanup and close the underlying WebSocket.
889    ///
890    /// This is idempotent and will return `Ok(())` if the connection is already closed.
891    pub async fn close(&mut self, timeout: Option<Duration>) -> BacktestClientResult<()> {
892        self.close_with_frame(timeout, None).await
893    }
894
895    /// Close the session with a specific WebSocket close frame.
896    pub async fn close_with_frame(
897        &mut self,
898        timeout: Option<Duration>,
899        frame: Option<CloseFrame<'static>>,
900    ) -> BacktestClientResult<()> {
901        if self.ws.is_none() {
902            return Ok(());
903        }
904
905        let mut sent = false;
906        match self
907            .send(&BacktestRequest::CloseBacktestSession, timeout)
908            .await
909        {
910            Ok(()) => sent = true,
911            Err(err) if is_close_ok(&err) => {}
912            Err(err) => return Err(err),
913        }
914
915        if sent {
916            let response = match self.next_response(timeout).await {
917                Ok(Some(r)) => r,
918                Ok(None) => {
919                    self.ws.take();
920                    return Ok(());
921                }
922                Err(BacktestClientError::Closed { .. }) => {
923                    self.ws.take();
924                    return Ok(());
925                }
926                Err(BacktestClientError::WebSocket {
927                    action: "receiving",
928                    source,
929                }) if is_reset_without_close(&source) => {
930                    self.ws.take();
931                    return Ok(());
932                }
933                Err(err) => return Err(err),
934            };
935
936            match response {
937                BacktestResponse::Success | BacktestResponse::Completed { .. } => {}
938                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
939                other => {
940                    return Err(BacktestClientError::UnexpectedResponse {
941                        context: "closing session",
942                        response: Box::new(other),
943                    });
944                }
945            }
946        }
947
948        if let Some(ws) = self.ws.as_mut() {
949            let close_result = ws.close(frame).await;
950            if let Err(source) = close_result
951                && !is_ws_closed_error(&source)
952            {
953                return Err(BacktestClientError::WebSocket {
954                    action: "closing",
955                    source: Box::new(source),
956                });
957            }
958        }
959
960        // Await the close confirmation
961        match self.next_response(timeout).await {
962            Ok(_) => {}
963            Err(BacktestClientError::Closed { .. }) => {}
964            Err(BacktestClientError::WebSocket {
965                action: "receiving",
966                source,
967            }) if is_reset_without_close(&source) => {}
968            Err(err) => return Err(err),
969        }
970
971        // Give some time for the close to propagate
972        tokio::time::sleep(Duration::from_millis(100)).await;
973        self.ws.take();
974        Ok(())
975    }
976
977    /// Close the session with a close code and reason.
978    pub async fn close_with_reason(
979        &mut self,
980        timeout: Option<Duration>,
981        code: CloseCode,
982        reason: impl Into<String>,
983    ) -> BacktestClientResult<()> {
984        let frame = CloseFrame {
985            code,
986            reason: Cow::Owned(reason.into()),
987        };
988        self.close_with_frame(timeout, Some(frame)).await
989    }
990
991    async fn next_text(
992        &mut self,
993        timeout: Option<Duration>,
994    ) -> BacktestClientResult<Option<String>> {
995        loop {
996            let request_timeout = self.request_timeout;
997            let timeout = timeout.or(request_timeout);
998
999            let next_fut = self.ws_mut()?.next();
1000            let msg = match timeout {
1001                Some(duration) => tokio::time::timeout(duration, next_fut)
1002                    .await
1003                    .map_err(|_| BacktestClientError::Timeout {
1004                        action: "receiving",
1005                        duration,
1006                    })?,
1007                None => next_fut.await,
1008            };
1009
1010            let Some(msg) = msg else {
1011                return Ok(None);
1012            };
1013
1014            let msg = match msg {
1015                Ok(msg) => msg,
1016                Err(source) => {
1017                    return Err(BacktestClientError::WebSocket {
1018                        action: "receiving",
1019                        source: Box::new(source),
1020                    });
1021                }
1022            };
1023
1024            match msg {
1025                Message::Text(text) => {
1026                    if self.log_raw {
1027                        tracing::debug!("<- raw: {text}");
1028                    }
1029                    return Ok(Some(text));
1030                }
1031                Message::Binary(bin) => match String::from_utf8(bin) {
1032                    Ok(text) => {
1033                        if self.log_raw {
1034                            tracing::debug!("<- raw(bin): {text}");
1035                        }
1036                        return Ok(Some(text));
1037                    }
1038                    Err(err) => {
1039                        tracing::warn!("discarding non-utf8 binary message: {err}");
1040                        continue;
1041                    }
1042                },
1043                Message::Close(frame) => {
1044                    let reason = close_reason(frame);
1045                    return Err(BacktestClientError::Closed { reason });
1046                }
1047                Message::Ping(_) | Message::Pong(_) => continue,
1048                Message::Frame(_) => continue,
1049            }
1050        }
1051    }
1052}
1053
1054impl std::fmt::Debug for BacktestSession {
1055    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1056        f.debug_struct("BacktestSession")
1057            .field("session_id", &self.session_id)
1058            .field("rpc_endpoint", &self.rpc_endpoint)
1059            .field(
1060                "rpc",
1061                &self
1062                    .rpc
1063                    .as_ref()
1064                    .map(|_| "<RpcClient>")
1065                    .unwrap_or("<not set>"),
1066            )
1067            .field("ready_for_continue", &self.ready_for_continue)
1068            .field("request_timeout", &self.request_timeout)
1069            .finish_non_exhaustive()
1070    }
1071}
1072
1073#[derive(Debug)]
1074pub(crate) enum CreateRequestResult {
1075    Single { session_id: String },
1076    Parallel { session_ids: Vec<String> },
1077}
1078
1079impl Drop for BacktestSession {
1080    fn drop(&mut self) {
1081        let Some(ws) = self.ws.take() else {
1082            return;
1083        };
1084
1085        if let Ok(handle) = tokio::runtime::Handle::try_current() {
1086            handle.spawn(async move {
1087                let mut ws = ws;
1088                let _ = ws.close(None).await;
1089            });
1090        }
1091    }
1092}
1093
1094fn resolve_rpc_url(base: &str, endpoint: &str) -> String {
1095    if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
1096        endpoint.to_string()
1097    } else {
1098        format!("{}/{}", base, endpoint.trim_start_matches('/'))
1099    }
1100}
1101
1102fn close_reason(frame: Option<CloseFrame<'static>>) -> String {
1103    match frame {
1104        Some(frame) => format!("{:?}: {}", frame.code, frame.reason),
1105        None => "no close frame".to_string(),
1106    }
1107}
1108
1109fn is_reset_without_close(err: &WsError) -> bool {
1110    matches!(
1111        err,
1112        WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1113    )
1114}
1115
1116fn is_ws_closed_error(err: &WsError) -> bool {
1117    matches!(
1118        err,
1119        WsError::ConnectionClosed
1120            | WsError::AlreadyClosed
1121            | WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1122    )
1123}
1124
1125fn is_close_ok(err: &BacktestClientError) -> bool {
1126    match err {
1127        BacktestClientError::Closed { .. } => true,
1128        BacktestClientError::WebSocket { source, .. } => is_ws_closed_error(source),
1129        _ => false,
1130    }
1131}
1132
1133#[cfg(test)]
1134mod tests {
1135    use super::*;
1136
1137    #[test]
1138    fn coverage_tracks_slot_and_completion_from_responses() {
1139        let mut coverage = SessionCoverage::default();
1140        coverage.observe_response(&BacktestResponse::SlotNotification(10));
1141        coverage.observe_response(&BacktestResponse::SlotNotification(12));
1142        coverage.observe_response(&BacktestResponse::Completed {
1143            summary: None,
1144            agent_stats: None,
1145        });
1146
1147        assert!(coverage.is_completed());
1148        assert_eq!(coverage.highest_slot_seen(), Some(12));
1149    }
1150
1151    #[test]
1152    fn coverage_validate_end_slot_checks_completion_and_range() {
1153        let mut coverage = SessionCoverage::default();
1154        assert_eq!(
1155            coverage.validate_end_slot(5),
1156            Err(CoverageError::NotCompleted)
1157        );
1158
1159        coverage.mark_completed();
1160        assert_eq!(
1161            coverage.validate_end_slot(5),
1162            Err(CoverageError::NoSlotsObserved)
1163        );
1164
1165        coverage.observe_slot(4);
1166        assert_eq!(
1167            coverage.validate_end_slot(5),
1168            Err(CoverageError::RangeNotReached {
1169                actual_end_slot: 4,
1170                expected_end_slot: 5,
1171            })
1172        );
1173
1174        coverage.observe_slot(6);
1175        assert_eq!(coverage.validate_end_slot(5), Ok(()));
1176    }
1177}