Skip to main content

simulator_client/
session.rs

1use std::{
2    borrow::Cow,
3    collections::{BTreeMap, VecDeque},
4    future::Future,
5    time::Duration,
6};
7
8use futures::{SinkExt, StreamExt, stream};
9use simulator_api::{
10    AccountData, AccountModifications, AgentStatsReport, BacktestError, BacktestRequest,
11    BacktestResponse, BacktestStatus, ContinueParams, CreateBacktestSessionRequest,
12    CreateBacktestSessionRequestV1, SequencedResponse, SessionSummary,
13};
14use solana_address::Address;
15use solana_client::{
16    nonblocking::rpc_client::RpcClient,
17    rpc_response::{Response, RpcLogsResponse},
18};
19use solana_commitment_config::CommitmentConfig;
20use thiserror::Error;
21use tokio::net::TcpStream;
22use tokio_tungstenite::{
23    MaybeTlsStream, WebSocketStream,
24    tungstenite::{
25        Error as WsError, Message,
26        error::ProtocolError,
27        protocol::{CloseFrame, frame::coding::CloseCode},
28    },
29};
30
31use crate::{
32    BacktestClientError, BacktestClientResult, Continue,
33    injection::ProgramModError,
34    subscriptions::{
35        AccountDiffNotification, AccountDiffSubscriptionHandle, LogSubscriptionHandle,
36        SubscriptionError,
37    },
38};
39
40/// Outcome of waiting for readiness on a backtest session.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum ReadyOutcome {
43    /// The session is ready to accept a `Continue` request.
44    Ready,
45    /// The session completed before becoming ready.
46    Completed,
47}
48
49/// Summary of responses collected while advancing a backtest.
50#[derive(Debug, Default)]
51pub struct ContinueResult {
52    /// Number of slot notifications observed.
53    pub slot_notifications: u64,
54    /// Last slot received via notification.
55    pub last_slot: Option<u64>,
56    /// Status messages received.
57    pub statuses: Vec<BacktestStatus>,
58    /// Whether the session became ready for the next `Continue`.
59    pub ready_for_continue: bool,
60    /// Whether the session completed while advancing.
61    pub completed: bool,
62}
63
64/// Mutable state for driving a session with `advance_step`.
65#[derive(Debug)]
66pub struct AdvanceState {
67    /// Expected number of slot notifications for this step.
68    pub expected_slots: u64,
69    /// Count of slot notifications received so far.
70    pub slot_notifications: u64,
71    /// Most recent slot notification.
72    pub last_slot: Option<u64>,
73    /// Status messages received so far.
74    pub statuses: Vec<BacktestStatus>,
75    /// Whether the session is ready for another `Continue`.
76    pub ready_for_continue: bool,
77    /// Whether the session completed while advancing.
78    pub completed: bool,
79    /// Session summary received on completion (if send_summary was enabled).
80    pub summary: Option<SessionSummary>,
81    /// Agent stats received on completion.
82    pub agent_stats: Option<Vec<AgentStatsReport>>,
83}
84
85impl AdvanceState {
86    /// Create new tracking state for a step that expects `expected_slots` notifications.
87    pub fn new(expected_slots: u64) -> Self {
88        Self {
89            expected_slots,
90            slot_notifications: 0,
91            last_slot: None,
92            statuses: Vec::new(),
93            ready_for_continue: false,
94            completed: false,
95            summary: None,
96            agent_stats: None,
97        }
98    }
99
100    /// Return true when the step is complete based on readiness and slot count.
101    pub fn is_done(&self, wait_for_slots: bool) -> bool {
102        if self.completed {
103            return true;
104        }
105
106        if !self.ready_for_continue {
107            return false;
108        }
109
110        !wait_for_slots || self.slot_notifications >= self.expected_slots
111    }
112}
113
114/// Coverage of a session's observed slot notifications and completion state.
115#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
116pub struct SessionCoverage {
117    completed: bool,
118    highest_slot_seen: Option<u64>,
119}
120
121impl SessionCoverage {
122    /// Record a slot notification.
123    pub fn observe_slot(&mut self, slot: u64) {
124        self.highest_slot_seen = Some(
125            self.highest_slot_seen
126                .map_or(slot, |current| current.max(slot)),
127        );
128    }
129
130    /// Mark the session as completed.
131    pub fn mark_completed(&mut self) {
132        self.completed = true;
133    }
134
135    /// Update coverage state from a backtest response.
136    pub fn observe_response(&mut self, response: &BacktestResponse) {
137        match response {
138            BacktestResponse::SlotNotification(slot) => self.observe_slot(*slot),
139            BacktestResponse::Completed { .. } => self.mark_completed(),
140            _ => {}
141        }
142    }
143
144    /// Return whether completion has been observed.
145    pub fn is_completed(&self) -> bool {
146        self.completed
147    }
148
149    /// Return the highest slot observed via slot notifications.
150    pub fn highest_slot_seen(&self) -> Option<u64> {
151        self.highest_slot_seen
152    }
153
154    /// Validate that the session completed and reached `expected_end_slot`.
155    pub fn validate_end_slot(&self, expected_end_slot: u64) -> Result<(), CoverageError> {
156        if !self.completed {
157            return Err(CoverageError::NotCompleted);
158        }
159
160        let Some(actual_end_slot) = self.highest_slot_seen else {
161            return Err(CoverageError::NoSlotsObserved);
162        };
163
164        if actual_end_slot < expected_end_slot {
165            return Err(CoverageError::RangeNotReached {
166                actual_end_slot,
167                expected_end_slot,
168            });
169        }
170
171        Ok(())
172    }
173}
174
175/// Coverage validation failures for a backtest session.
176#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
177pub enum CoverageError {
178    #[error("ended before completion")]
179    NotCompleted,
180    #[error("completed without slot notifications")]
181    NoSlotsObserved,
182    #[error("completed at slot {actual_end_slot} but expected at least {expected_end_slot}")]
183    RangeNotReached {
184        actual_end_slot: u64,
185        expected_end_slot: u64,
186    },
187}
188
189/// Active backtest session over a WebSocket connection.
190///
191/// Dropping the session sends a best-effort close frame if a Tokio runtime is
192/// available. Call [`BacktestSession::close`] to request server-side cleanup.
193pub struct BacktestSession {
194    ws: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
195    session_id: Option<String>,
196    rpc_endpoint: Option<String>,
197    task_id: Option<String>,
198    rpc: Option<RpcClient>,
199    last_sequence: Option<u64>,
200    pub(crate) ready_for_continue: bool,
201    request_timeout: Option<Duration>,
202    log_raw: bool,
203    backlog: VecDeque<(Option<u64>, BacktestResponse)>,
204}
205
206impl BacktestSession {
207    pub(crate) fn new(
208        ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
209        request_timeout: Option<Duration>,
210        log_raw: bool,
211    ) -> Self {
212        Self {
213            ws: Some(ws),
214            session_id: None,
215            rpc_endpoint: None,
216            task_id: None,
217            rpc: None,
218            last_sequence: None,
219            ready_for_continue: false,
220            request_timeout,
221            log_raw,
222            backlog: VecDeque::new(),
223        }
224    }
225
226    /// Return the server-assigned session id, if known.
227    pub fn session_id(&self) -> Option<&str> {
228        self.session_id.as_deref()
229    }
230
231    /// Return the session-scoped RPC endpoint if provided.
232    pub fn rpc_endpoint(&self) -> Option<&str> {
233        self.rpc_endpoint.as_deref()
234    }
235
236    /// Return the opaque `task_id` reported by the server for this session,
237    /// if any.
238    pub fn task_id(&self) -> Option<&str> {
239        self.task_id.as_deref()
240    }
241
242    /// Return the highest sequenced control-websocket response observed so far.
243    pub fn last_sequence(&self) -> Option<u64> {
244        self.last_sequence
245    }
246
247    /// Return the RPC client for this session's endpoint.
248    ///
249    /// Always available after [`BacktestClient::create_session`](crate::BacktestClient::create_session) completes.
250    pub fn rpc(&self) -> &RpcClient {
251        self.rpc
252            .as_ref()
253            .expect("rpc is set during session creation")
254    }
255
256    /// Return whether the session is currently ready to accept `Continue`.
257    pub fn is_ready_for_continue(&self) -> bool {
258        self.ready_for_continue
259    }
260
261    /// Update internal readiness state based on a response.
262    pub fn apply_response(&mut self, response: &BacktestResponse) {
263        match response {
264            BacktestResponse::ReadyForContinue | BacktestResponse::Paused(_) => {
265                self.ready_for_continue = true;
266            }
267            BacktestResponse::Completed { .. } => {
268                self.ready_for_continue = false;
269            }
270            _ => {}
271        }
272    }
273
274    fn ws_mut(&mut self) -> BacktestClientResult<&mut WebSocketStream<MaybeTlsStream<TcpStream>>> {
275        self.ws.as_mut().ok_or_else(|| BacktestClientError::Closed {
276            reason: "websocket closed".to_string(),
277        })
278    }
279
280    pub(crate) async fn create_with_request(
281        &mut self,
282        request: CreateBacktestSessionRequest,
283        rpc_base_url: String,
284        mut on_parallel_session_created: Option<&mut (dyn FnMut(String) + Send)>,
285    ) -> BacktestClientResult<CreateRequestResult> {
286        let expect_parallel = matches!(
287            &request,
288            CreateBacktestSessionRequest::V1(CreateBacktestSessionRequestV1 { parallel: true, .. })
289        );
290        self.send(&BacktestRequest::CreateBacktestSession(request), None)
291            .await?;
292        let mut streamed_parallel_session_ids = Vec::new();
293        let mut streamed_parallel_task_ids: Vec<Option<String>> = Vec::new();
294        // Collect intermediate messages (e.g. Status, SlotNotification) in a local buffer.
295        // Pushing them back on self.backlog can cause re-reading of the same message,
296        // triggering an infinite loop.
297        let mut pending: Vec<(Option<u64>, BacktestResponse)> = Vec::new();
298
299        loop {
300            let response =
301                self.next_response(None)
302                    .await?
303                    .ok_or_else(|| BacktestClientError::Closed {
304                        reason: "websocket ended before SessionCreated".to_string(),
305                    })?;
306
307            match response {
308                BacktestResponse::SessionCreated {
309                    session_id,
310                    rpc_endpoint,
311                    task_id,
312                } => {
313                    if expect_parallel {
314                        if let Some(callback) = on_parallel_session_created.as_mut() {
315                            (**callback)(session_id.clone());
316                        }
317                        streamed_parallel_session_ids.push(session_id);
318                        streamed_parallel_task_ids.push(task_id);
319                        continue;
320                    }
321                    let created_session_id = session_id.clone();
322                    let created_task_id = task_id.clone();
323                    self.session_id = Some(session_id);
324                    self.task_id = task_id;
325                    let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
326                    self.rpc = Some(RpcClient::new_with_commitment(
327                        resolved.clone(),
328                        CommitmentConfig::confirmed(),
329                    ));
330                    self.rpc_endpoint = Some(resolved);
331                    self.backlog.extend(pending);
332                    return Ok(CreateRequestResult::Single {
333                        session_id: created_session_id,
334                        task_id: created_task_id,
335                    });
336                }
337                BacktestResponse::SessionsCreated { session_ids } => {
338                    if expect_parallel && session_ids.is_empty() {
339                        self.backlog.extend(pending);
340                        return Ok(CreateRequestResult::Parallel {
341                            session_ids: streamed_parallel_session_ids,
342                            task_ids: streamed_parallel_task_ids,
343                        });
344                    }
345                    let len = session_ids.len();
346                    self.backlog.extend(pending);
347                    return Ok(CreateRequestResult::Parallel {
348                        session_ids,
349                        task_ids: vec![None; len],
350                    });
351                }
352                BacktestResponse::SessionsCreatedV2 {
353                    session_ids,
354                    task_ids,
355                    ..
356                } => {
357                    if expect_parallel && session_ids.is_empty() {
358                        self.backlog.extend(pending);
359                        return Ok(CreateRequestResult::Parallel {
360                            session_ids: streamed_parallel_session_ids,
361                            task_ids: streamed_parallel_task_ids,
362                        });
363                    }
364                    let task_ids = align_task_ids(task_ids, session_ids.len());
365                    self.backlog.extend(pending);
366                    return Ok(CreateRequestResult::Parallel {
367                        session_ids,
368                        task_ids,
369                    });
370                }
371                BacktestResponse::ReadyForContinue => {
372                    self.ready_for_continue = true;
373                }
374                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
375                other => {
376                    pending.push((self.last_sequence, other));
377                }
378            }
379        }
380    }
381
382    pub(crate) async fn attach(
383        &mut self,
384        session_id: String,
385        last_sequence: Option<u64>,
386        rpc_base_url: String,
387    ) -> BacktestClientResult<()> {
388        self.send(
389            &BacktestRequest::AttachBacktestSession {
390                session_id,
391                last_sequence,
392            },
393            None,
394        )
395        .await?;
396
397        self.wait_for_response(
398            || BacktestClientError::Closed {
399                reason: "websocket ended before SessionAttached".to_string(),
400            },
401            move |session, response| match response {
402                BacktestResponse::SessionAttached {
403                    session_id,
404                    rpc_endpoint,
405                    task_id,
406                } => {
407                    session.session_id = Some(session_id);
408                    session.task_id = task_id;
409                    let resolved = resolve_rpc_url(&rpc_base_url, &rpc_endpoint);
410                    session.rpc = Some(RpcClient::new_with_commitment(
411                        resolved.clone(),
412                        CommitmentConfig::confirmed(),
413                    ));
414                    session.rpc_endpoint = Some(resolved);
415                    Ok(Some(()))
416                }
417                BacktestResponse::ReadyForContinue => {
418                    session.ready_for_continue = true;
419                    Ok(None)
420                }
421                BacktestResponse::Error(err) => Err(BacktestClientError::Remote(err)),
422                other => {
423                    session.backlog.push_back((session.last_sequence, other));
424                    Ok(None)
425                }
426            },
427        )
428        .await
429    }
430
431    /// Sent after reattaching and rebuilding any dependent subscriptions so the
432    /// manager can resume a session that was paused for handoff.
433    pub async fn resume_attached_session(&mut self) -> BacktestClientResult<()> {
434        self.send(&BacktestRequest::ResumeAttachedSession, None)
435            .await?;
436
437        self.wait_for_response(
438            || BacktestClientError::Closed {
439                reason: "websocket ended before ResumeAttachedSession acknowledgement".to_string(),
440            },
441            |session, response| match response {
442                BacktestResponse::Success => Ok(Some(())),
443                BacktestResponse::Error(err) => Err(BacktestClientError::Remote(err)),
444                other => {
445                    session.backlog.push_back((session.last_sequence, other));
446                    Ok(None)
447                }
448            },
449        )
450        .await
451    }
452
453    async fn wait_for_response<T, E, F>(
454        &mut self,
455        closed_error: E,
456        mut handle_response: F,
457    ) -> BacktestClientResult<T>
458    where
459        E: FnOnce() -> BacktestClientError,
460        F: FnMut(&mut Self, BacktestResponse) -> BacktestClientResult<Option<T>>,
461    {
462        let mut closed_error = Some(closed_error);
463
464        loop {
465            let response = self
466                .next_response(None)
467                .await?
468                .ok_or_else(|| closed_error.take().expect("closed error set")())?;
469
470            if let Some(result) = handle_response(self, response)? {
471                return Ok(result);
472            }
473        }
474    }
475
476    /// Send a raw backtest request over the WebSocket.
477    pub async fn send(
478        &mut self,
479        request: &BacktestRequest,
480        timeout: Option<Duration>,
481    ) -> BacktestClientResult<()> {
482        let text = serde_json::to_string(request)
483            .map_err(|source| BacktestClientError::SerializeRequest { source })?;
484
485        let request_timeout = self.request_timeout;
486        let timeout = timeout.or(request_timeout);
487
488        let send_fut = self.ws_mut()?.send(Message::Text(text));
489        let send_result = match timeout {
490            Some(duration) => tokio::time::timeout(duration, send_fut)
491                .await
492                .map_err(|_| BacktestClientError::Timeout {
493                    action: "sending",
494                    duration,
495                })?,
496            None => send_fut.await,
497        };
498
499        send_result.map_err(|source| BacktestClientError::WebSocket {
500            action: "sending",
501            source: Box::new(source),
502        })?;
503
504        Ok(())
505    }
506
507    /// Receive the next response, using the backlog first.
508    pub async fn next_response(
509        &mut self,
510        timeout: Option<Duration>,
511    ) -> BacktestClientResult<Option<BacktestResponse>> {
512        if let Some((sequence, response)) = self.backlog.pop_front() {
513            self.last_sequence = sequence.or(self.last_sequence);
514            return Ok(Some(response));
515        }
516
517        let text = match self.next_text(timeout).await? {
518            Some(text) => text,
519            None => return Ok(None),
520        };
521
522        let (sequence, response) = match serde_json::from_str::<SequencedResponse>(&text) {
523            Ok(sequenced) => (Some(sequenced.seq_id), sequenced.response),
524            Err(_) => {
525                let response =
526                    serde_json::from_str::<BacktestResponse>(&text).map_err(|source| {
527                        BacktestClientError::DeserializeResponse {
528                            raw: text.clone(),
529                            source,
530                        }
531                    })?;
532                (None, response)
533            }
534        };
535        self.last_sequence = sequence.or(self.last_sequence);
536
537        Ok(Some(response))
538    }
539
540    /// Receive the next response and update readiness state.
541    pub async fn next_event(
542        &mut self,
543        timeout: Option<Duration>,
544    ) -> BacktestClientResult<Option<BacktestResponse>> {
545        let response = self.next_response(timeout).await?;
546        if let Some(ref response) = response {
547            self.apply_response(response);
548        }
549        Ok(response)
550    }
551
552    /// Stream responses, updating readiness state as items arrive.
553    ///
554    /// This consumes the session and yields responses until the connection ends.
555    pub fn responses(
556        self,
557        timeout: Option<Duration>,
558    ) -> impl futures::Stream<Item = BacktestClientResult<BacktestResponse>> {
559        stream::unfold(Some(self), move |state| async move {
560            let mut session = match state {
561                Some(session) => session,
562                None => return None,
563            };
564
565            match session.next_response(timeout).await {
566                Ok(Some(response)) => {
567                    session.apply_response(&response);
568                    Some((Ok(response), Some(session)))
569                }
570                Ok(None) => None,
571                Err(err) => Some((Err(err), None)),
572            }
573        })
574    }
575
576    /// Wait for the session to become ready or completed.
577    pub async fn ensure_ready(
578        &mut self,
579        timeout: Option<Duration>,
580    ) -> BacktestClientResult<ReadyOutcome> {
581        if self.ready_for_continue {
582            return Ok(ReadyOutcome::Ready);
583        }
584
585        loop {
586            let response =
587                self.next_response(timeout)
588                    .await?
589                    .ok_or_else(|| BacktestClientError::Closed {
590                        reason: "websocket ended while waiting for ReadyForContinue".to_string(),
591                    })?;
592
593            match response {
594                BacktestResponse::ReadyForContinue => {
595                    self.ready_for_continue = true;
596                    return Ok(ReadyOutcome::Ready);
597                }
598                BacktestResponse::Completed { .. } => return Ok(ReadyOutcome::Completed),
599                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
600                _ => {}
601            }
602        }
603    }
604
605    /// Wait for a specific status to be emitted.
606    pub async fn wait_for_status(
607        &mut self,
608        desired: BacktestStatus,
609        timeout: Option<Duration>,
610    ) -> BacktestClientResult<()> {
611        let desired = std::mem::discriminant(&desired);
612
613        loop {
614            let response =
615                self.next_response(timeout)
616                    .await?
617                    .ok_or_else(|| BacktestClientError::Closed {
618                        reason: "websocket ended while waiting for status".to_string(),
619                    })?;
620
621            match response {
622                BacktestResponse::Status { status }
623                    if std::mem::discriminant(&status) == desired =>
624                {
625                    return Ok(());
626                }
627                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
628                BacktestResponse::Completed {
629                    summary,
630                    agent_stats,
631                } => {
632                    return Err(BacktestClientError::UnexpectedResponse {
633                        context: "waiting for status",
634                        response: Box::new(BacktestResponse::Completed {
635                            summary,
636                            agent_stats,
637                        }),
638                    });
639                }
640                _ => {}
641            }
642        }
643    }
644
645    /// Push a response onto the end of the receive backlog.
646    pub(crate) fn push_backlog(&mut self, response: BacktestResponse) {
647        self.backlog.push_back((None, response));
648    }
649
650    /// Send a `Continue` request and reset readiness.
651    pub async fn send_continue(
652        &mut self,
653        params: ContinueParams,
654        timeout: Option<Duration>,
655    ) -> BacktestClientResult<()> {
656        self.ready_for_continue = false;
657        self.send(&BacktestRequest::Continue(params), timeout).await
658    }
659
660    /// Read and apply a single response while advancing.
661    pub async fn advance_step<F>(
662        &mut self,
663        state: &mut AdvanceState,
664        wait_for_slots: bool,
665        timeout: Option<Duration>,
666        on_event: &mut F,
667    ) -> BacktestClientResult<()>
668    where
669        F: FnMut(&BacktestResponse),
670    {
671        let Some(response) = self.next_response(timeout).await? else {
672            return Err(BacktestClientError::Closed {
673                reason: "websocket ended while awaiting continue responses".to_string(),
674            });
675        };
676
677        if self.log_raw {
678            tracing::debug!("<- {response:?}");
679        }
680
681        on_event(&response);
682
683        match response {
684            BacktestResponse::ReadyForContinue => {
685                self.ready_for_continue = true;
686                state.ready_for_continue = true;
687            }
688            BacktestResponse::SlotNotification(slot) => {
689                state.slot_notifications += 1;
690                state.last_slot = Some(slot);
691            }
692            BacktestResponse::Status { status } => {
693                state.statuses.push(status);
694            }
695            BacktestResponse::Success => {}
696            BacktestResponse::Completed {
697                summary,
698                agent_stats,
699            } => {
700                state.completed = true;
701                state.summary = summary;
702                state.agent_stats = agent_stats;
703            }
704            BacktestResponse::Error(err @ BacktestError::SimulationError { .. }) => {
705                tracing::warn!(error = %crate::error::err_chain(&err), "simulation error");
706            }
707            BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
708            BacktestResponse::SessionCreated { .. }
709            | BacktestResponse::SessionAttached { .. }
710            | BacktestResponse::SessionsCreated { .. }
711            | BacktestResponse::SessionsCreatedV2 { .. }
712            | BacktestResponse::ParallelSessionAttachedV2 { .. }
713            | BacktestResponse::SessionEventV1 { .. }
714            | BacktestResponse::SessionEventV2 { .. }
715            | BacktestResponse::Paused(_)
716            | BacktestResponse::DiscoveryBatch(_) => {
717                return Err(BacktestClientError::UnexpectedResponse {
718                    context: "continuing",
719                    response: Box::new(response),
720                });
721            }
722        }
723
724        if wait_for_slots && state.slot_notifications > state.expected_slots {
725            tracing::warn!(
726                "received {} slot notifications (expected {})",
727                state.slot_notifications,
728                state.expected_slots
729            );
730        }
731
732        Ok(())
733    }
734
735    /// Advance until the session becomes ready for another `Continue`.
736    pub async fn continue_until_ready<F>(
737        &mut self,
738        cont: Continue,
739        timeout: Option<Duration>,
740        mut on_event: F,
741    ) -> BacktestClientResult<ContinueResult>
742    where
743        F: FnMut(&BacktestResponse),
744    {
745        let expected_slots = cont.advance_count;
746        self.advance_internal(
747            cont.into_params(),
748            expected_slots,
749            false,
750            timeout,
751            &mut on_event,
752        )
753        .await
754    }
755
756    /// Advance and wait for both readiness and slot notifications.
757    pub async fn advance<F>(
758        &mut self,
759        cont: Continue,
760        timeout: Option<Duration>,
761        mut on_event: F,
762    ) -> BacktestClientResult<ContinueResult>
763    where
764        F: FnMut(&BacktestResponse),
765    {
766        let expected_slots = cont.advance_count;
767        self.advance_internal(
768            cont.into_params(),
769            expected_slots,
770            true,
771            timeout,
772            &mut on_event,
773        )
774        .await
775    }
776
777    async fn advance_internal<F>(
778        &mut self,
779        params: ContinueParams,
780        expected_slots: u64,
781        wait_for_slots: bool,
782        timeout: Option<Duration>,
783        on_event: &mut F,
784    ) -> BacktestClientResult<ContinueResult>
785    where
786        F: FnMut(&BacktestResponse),
787    {
788        self.send_continue(params, timeout).await?;
789
790        let mut state = AdvanceState::new(expected_slots);
791        while !state.is_done(wait_for_slots) {
792            self.advance_step(&mut state, wait_for_slots, timeout, on_event)
793                .await?;
794        }
795
796        Ok(ContinueResult {
797            slot_notifications: state.slot_notifications,
798            last_slot: state.last_slot,
799            statuses: state.statuses,
800            ready_for_continue: state.ready_for_continue,
801            completed: state.completed,
802        })
803    }
804
805    /// Build a program-data account modification from raw ELF bytes.
806    ///
807    /// Derives the `ProgramData` address from `program_id`, queries the session's RPC endpoint
808    /// for the current slot and rent-exempt minimum, then returns a modification map ready to
809    /// pass to [`Continue::builder().modify_accounts(...)`](crate::Continue).
810    ///
811    /// The deploy slot is set to `current_slot - 1` so the program appears deployed before the
812    /// next executed slot.
813    pub async fn modify_program(
814        &self,
815        program_id: &str,
816        elf: &[u8],
817    ) -> Result<BTreeMap<Address, AccountData>, ProgramModError> {
818        let rpc = self.rpc.as_ref().ok_or(ProgramModError::NoRpcEndpoint)?;
819        crate::injection::modify_program_via_rpc(rpc, program_id, elf).await
820    }
821
822    /// Modify accounts on the session's RPC endpoint via the custom `modifyAccounts` method.
823    ///
824    /// Returns the number of accounts modified on success.
825    pub async fn modify_accounts(
826        &self,
827        modifications: &AccountModifications,
828    ) -> BacktestClientResult<usize> {
829        let rpc_endpoint =
830            self.rpc_endpoint
831                .as_deref()
832                .ok_or_else(|| BacktestClientError::Closed {
833                    reason: "no RPC endpoint available".to_string(),
834                })?;
835
836        crate::rpc::modify_accounts(rpc_endpoint, modifications).await
837    }
838
839    /// Subscribe to program log notifications using the session's RPC endpoint.
840    ///
841    /// Equivalent to calling [`subscribe_program_logs`](crate::subscribe_program_logs) with
842    /// the endpoint from [`rpc_endpoint`](Self::rpc_endpoint), which is set after
843    /// [`BacktestClient::create_session`](crate::BacktestClient::create_session) completes.
844    pub async fn subscribe_program_logs<F, Fut>(
845        &self,
846        program_id: &str,
847        commitment: CommitmentConfig,
848        on_notification: F,
849    ) -> Result<LogSubscriptionHandle, SubscriptionError>
850    where
851        F: Fn(Response<RpcLogsResponse>) -> Fut + Send + Sync + 'static,
852        Fut: Future<Output = ()> + Send + 'static,
853    {
854        let rpc_endpoint = self
855            .rpc_endpoint
856            .as_deref()
857            .ok_or(SubscriptionError::NoRpcEndpoint)?;
858        crate::subscriptions::subscribe_program_logs(
859            rpc_endpoint,
860            program_id,
861            commitment,
862            on_notification,
863        )
864        .await
865    }
866
867    /// Subscribe to account diff notifications using the session's RPC endpoint.
868    ///
869    /// Equivalent to calling [`subscribe_account_diffs`](crate::subscribe_account_diffs) with
870    /// the endpoint from [`rpc_endpoint`](Self::rpc_endpoint), which is set after
871    /// [`BacktestClient::create_session`](crate::BacktestClient::create_session) completes.
872    pub async fn subscribe_account_diffs<F, Fut>(
873        &self,
874        account: &str,
875        on_notification: F,
876    ) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
877    where
878        F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
879        Fut: Future<Output = ()> + Send + 'static,
880    {
881        let rpc_endpoint = self
882            .rpc_endpoint
883            .as_deref()
884            .ok_or(SubscriptionError::NoRpcEndpoint)?;
885        crate::subscriptions::subscribe_account_diffs(rpc_endpoint, account, on_notification).await
886    }
887
888    /// Request server cleanup and close the underlying WebSocket.
889    ///
890    /// This is idempotent and will return `Ok(())` if the connection is already closed.
891    pub async fn close(&mut self, timeout: Option<Duration>) -> BacktestClientResult<()> {
892        self.close_with_frame(timeout, None).await
893    }
894
895    /// Close the session with a specific WebSocket close frame.
896    pub async fn close_with_frame(
897        &mut self,
898        timeout: Option<Duration>,
899        frame: Option<CloseFrame<'static>>,
900    ) -> BacktestClientResult<()> {
901        if self.ws.is_none() {
902            return Ok(());
903        }
904
905        let mut sent = false;
906        match self
907            .send(&BacktestRequest::CloseBacktestSession, timeout)
908            .await
909        {
910            Ok(()) => sent = true,
911            Err(err) if is_close_ok(&err) => {}
912            Err(err) => return Err(err),
913        }
914
915        if sent {
916            let response = match self.next_response(timeout).await {
917                Ok(Some(r)) => r,
918                Ok(None) => {
919                    self.ws.take();
920                    return Ok(());
921                }
922                Err(BacktestClientError::Closed { .. }) => {
923                    self.ws.take();
924                    return Ok(());
925                }
926                Err(BacktestClientError::WebSocket {
927                    action: "receiving",
928                    source,
929                }) if is_reset_without_close(&source) => {
930                    self.ws.take();
931                    return Ok(());
932                }
933                Err(err) => return Err(err),
934            };
935
936            match response {
937                BacktestResponse::Success | BacktestResponse::Completed { .. } => {}
938                BacktestResponse::Error(err) => return Err(BacktestClientError::Remote(err)),
939                other => {
940                    return Err(BacktestClientError::UnexpectedResponse {
941                        context: "closing session",
942                        response: Box::new(other),
943                    });
944                }
945            }
946        }
947
948        if let Some(ws) = self.ws.as_mut() {
949            let close_result = ws.close(frame).await;
950            if let Err(source) = close_result
951                && !is_ws_closed_error(&source)
952            {
953                return Err(BacktestClientError::WebSocket {
954                    action: "closing",
955                    source: Box::new(source),
956                });
957            }
958        }
959
960        // Await the close confirmation
961        match self.next_response(timeout).await {
962            Ok(_) => {}
963            Err(BacktestClientError::Closed { .. }) => {}
964            Err(BacktestClientError::WebSocket {
965                action: "receiving",
966                source,
967            }) if is_reset_without_close(&source) => {}
968            Err(err) => return Err(err),
969        }
970
971        // Give some time for the close to propagate
972        tokio::time::sleep(Duration::from_millis(100)).await;
973        self.ws.take();
974        Ok(())
975    }
976
977    /// Close the session with a close code and reason.
978    pub async fn close_with_reason(
979        &mut self,
980        timeout: Option<Duration>,
981        code: CloseCode,
982        reason: impl Into<String>,
983    ) -> BacktestClientResult<()> {
984        let frame = CloseFrame {
985            code,
986            reason: Cow::Owned(reason.into()),
987        };
988        self.close_with_frame(timeout, Some(frame)).await
989    }
990
991    async fn next_text(
992        &mut self,
993        timeout: Option<Duration>,
994    ) -> BacktestClientResult<Option<String>> {
995        loop {
996            let request_timeout = self.request_timeout;
997            let timeout = timeout.or(request_timeout);
998
999            let next_fut = self.ws_mut()?.next();
1000            let msg = match timeout {
1001                Some(duration) => tokio::time::timeout(duration, next_fut)
1002                    .await
1003                    .map_err(|_| BacktestClientError::Timeout {
1004                        action: "receiving",
1005                        duration,
1006                    })?,
1007                None => next_fut.await,
1008            };
1009
1010            let Some(msg) = msg else {
1011                return Ok(None);
1012            };
1013
1014            let msg = match msg {
1015                Ok(msg) => msg,
1016                Err(source) => {
1017                    return Err(BacktestClientError::WebSocket {
1018                        action: "receiving",
1019                        source: Box::new(source),
1020                    });
1021                }
1022            };
1023
1024            match msg {
1025                Message::Text(text) => {
1026                    if self.log_raw {
1027                        tracing::debug!("<- raw: {text}");
1028                    }
1029                    return Ok(Some(text));
1030                }
1031                Message::Binary(bin) => match String::from_utf8(bin) {
1032                    Ok(text) => {
1033                        if self.log_raw {
1034                            tracing::debug!("<- raw(bin): {text}");
1035                        }
1036                        return Ok(Some(text));
1037                    }
1038                    Err(err) => {
1039                        tracing::warn!("discarding non-utf8 binary message: {err}");
1040                        continue;
1041                    }
1042                },
1043                Message::Close(frame) => {
1044                    let reason = close_reason(frame);
1045                    return Err(BacktestClientError::Closed { reason });
1046                }
1047                Message::Ping(_) | Message::Pong(_) => continue,
1048                Message::Frame(_) => continue,
1049            }
1050        }
1051    }
1052}
1053
1054impl std::fmt::Debug for BacktestSession {
1055    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1056        f.debug_struct("BacktestSession")
1057            .field("session_id", &self.session_id)
1058            .field("rpc_endpoint", &self.rpc_endpoint)
1059            .field(
1060                "rpc",
1061                &self
1062                    .rpc
1063                    .as_ref()
1064                    .map(|_| "<RpcClient>")
1065                    .unwrap_or("<not set>"),
1066            )
1067            .field("ready_for_continue", &self.ready_for_continue)
1068            .field("request_timeout", &self.request_timeout)
1069            .finish_non_exhaustive()
1070    }
1071}
1072
1073#[derive(Debug)]
1074pub(crate) enum CreateRequestResult {
1075    Single {
1076        session_id: String,
1077        task_id: Option<String>,
1078    },
1079    Parallel {
1080        session_ids: Vec<String>,
1081        task_ids: Vec<Option<String>>,
1082    },
1083}
1084
1085/// Pad or truncate a server-supplied `task_ids` slice to match `session_ids` length.
1086/// Older servers omit `task_ids` entirely; treat that as an all-`None` slice.
1087fn align_task_ids(mut task_ids: Vec<Option<String>>, expected_len: usize) -> Vec<Option<String>> {
1088    if task_ids.len() < expected_len {
1089        task_ids.resize(expected_len, None);
1090    } else if task_ids.len() > expected_len {
1091        task_ids.truncate(expected_len);
1092    }
1093    task_ids
1094}
1095
1096impl Drop for BacktestSession {
1097    fn drop(&mut self) {
1098        let Some(ws) = self.ws.take() else {
1099            return;
1100        };
1101
1102        if let Ok(handle) = tokio::runtime::Handle::try_current() {
1103            handle.spawn(async move {
1104                let mut ws = ws;
1105                let _ = ws.close(None).await;
1106            });
1107        }
1108    }
1109}
1110
1111fn resolve_rpc_url(base: &str, endpoint: &str) -> String {
1112    if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
1113        endpoint.to_string()
1114    } else {
1115        format!("{}/{}", base, endpoint.trim_start_matches('/'))
1116    }
1117}
1118
1119fn close_reason(frame: Option<CloseFrame<'static>>) -> String {
1120    match frame {
1121        Some(frame) => format!("{:?}: {}", frame.code, frame.reason),
1122        None => "no close frame".to_string(),
1123    }
1124}
1125
1126fn is_reset_without_close(err: &WsError) -> bool {
1127    matches!(
1128        err,
1129        WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1130    )
1131}
1132
1133fn is_ws_closed_error(err: &WsError) -> bool {
1134    matches!(
1135        err,
1136        WsError::ConnectionClosed
1137            | WsError::AlreadyClosed
1138            | WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake)
1139    )
1140}
1141
1142fn is_close_ok(err: &BacktestClientError) -> bool {
1143    match err {
1144        BacktestClientError::Closed { .. } => true,
1145        BacktestClientError::WebSocket { source, .. } => is_ws_closed_error(source),
1146        _ => false,
1147    }
1148}
1149
1150#[cfg(test)]
1151mod tests {
1152    use super::*;
1153
1154    #[test]
1155    fn coverage_tracks_slot_and_completion_from_responses() {
1156        let mut coverage = SessionCoverage::default();
1157        coverage.observe_response(&BacktestResponse::SlotNotification(10));
1158        coverage.observe_response(&BacktestResponse::SlotNotification(12));
1159        coverage.observe_response(&BacktestResponse::Completed {
1160            summary: None,
1161            agent_stats: None,
1162        });
1163
1164        assert!(coverage.is_completed());
1165        assert_eq!(coverage.highest_slot_seen(), Some(12));
1166    }
1167
1168    #[test]
1169    fn coverage_validate_end_slot_checks_completion_and_range() {
1170        let mut coverage = SessionCoverage::default();
1171        assert_eq!(
1172            coverage.validate_end_slot(5),
1173            Err(CoverageError::NotCompleted)
1174        );
1175
1176        coverage.mark_completed();
1177        assert_eq!(
1178            coverage.validate_end_slot(5),
1179            Err(CoverageError::NoSlotsObserved)
1180        );
1181
1182        coverage.observe_slot(4);
1183        assert_eq!(
1184            coverage.validate_end_slot(5),
1185            Err(CoverageError::RangeNotReached {
1186                actual_end_slot: 4,
1187                expected_end_slot: 5,
1188            })
1189        );
1190
1191        coverage.observe_slot(6);
1192        assert_eq!(coverage.validate_end_slot(5), Ok(()));
1193    }
1194}