Skip to main content

rustenium_core/
session.rs

1use crate::error::{CdpSessionSendError, ResponseReceiveTimeoutError, SessionSendError};
2use crate::events::{BidiEvent, BidiEventManagement, CdpEvent, CdpEventManagement};
3use crate::listeners::{
4    CdpCommandResponseState,
5    CommandResponseState,
6};
7use crate::network::NetworkRequestHandledState;
8use crate::{
9    connection::{BidiConnection, CdpConnection},
10    transport::{ConnectionTransport, ConnectionTransportConfig, WebsocketConnectionTransport},
11};
12use rand::Rng;
13use rustenium_bidi_definitions::Command;
14use rustenium_bidi_definitions::base::{CommandMessage, CommandResponse};
15use rustenium_bidi_definitions::session::command_builders::{EndBuilder, NewBuilder};
16use rustenium_bidi_definitions::session::results::NewResult;
17use rustenium_bidi_definitions::session::types::CapabilitiesRequest;
18use rustenium_cdp_definitions::Command as CdpCommand;
19use rustenium_cdp_definitions::base as cdp_base;
20use serde_json;
21use std::collections::HashMap;
22use std::sync::{Arc, Mutex};
23use std::time::Duration;
24use tokio::sync::oneshot;
25use tokio::time::timeout;
26use tracing;
27
28pub struct BidiSession<T: ConnectionTransport> {
29    id: String,
30    connection: BidiConnection<T>,
31    events: Arc<Mutex<Vec<BidiEvent>>>,
32    /// Tracks network requests that have been handled, keyed by request ID
33    pub handled_network_requests: Arc<Mutex<HashMap<String, NetworkRequestHandledState>>>,
34}
35
36impl BidiSession<WebsocketConnectionTransport> {
37    pub async fn new(
38        connection_config: &ConnectionTransportConfig,
39        capabilities: CapabilitiesRequest,
40    ) -> Self {
41        let transport = WebsocketConnectionTransport::new(connection_config).await.unwrap();
42        tracing::info!("Connected to WebSocket at {}", connection_config.full_endpoint());
43        let connection = BidiConnection::new(transport);
44        connection.start_listeners();
45
46        let mut session = Self {
47            id: String::new(),
48            connection,
49            events: Arc::new(Mutex::new(Vec::new())),
50            handled_network_requests: Arc::new(Mutex::new(HashMap::new())),
51        };
52
53        let (_, event_tx) = session.event_dispatch().await;
54        session.connection.register_event_listener_channel(event_tx).await;
55
56        let command = NewBuilder::default()
57            .capabilities(capabilities)
58            .build()
59            .unwrap();
60        let command_result = session.send(command).await;
61        match command_result {
62            Ok(command_result) => {
63                let result: NewResult = command_result.result.clone().try_into().expect(
64                    format!("Invalid command result: {:?}", command_result).as_str(),
65                );
66                session.id = result.session_id;
67            }
68            Err(e) => panic!("Error creating new session: {}", e),
69        }
70
71        session
72    }
73}
74
75impl<T: ConnectionTransport> BidiSession<T> {
76    /// Send a command and return the receiver to wait for response.
77    /// This allows the caller to release locks before waiting for the response.
78    pub async fn send_and_get_receiver(
79        &mut self,
80        command: impl Into<Command>,
81    ) -> oneshot::Receiver<CommandResponseState> {
82        let command_id = loop {
83            let id = rand::rng().random::<u32>() as u64;
84            if !self
85                .connection
86                .commands_response_subscriptions
87                .lock()
88                .await
89                .contains_key(&id)
90            {
91                break id;
92            }
93        };
94
95        let command = CommandMessage {
96            id: command_id,
97            command_data: command.into(),
98            extensible: HashMap::new(),
99        };
100        let (tx, rx) = oneshot::channel::<CommandResponseState>();
101        self.connection
102            .commands_response_subscriptions
103            .lock()
104            .await
105            .insert(command_id, tx);
106        let raw_message = serde_json::to_string(&command).unwrap();
107        tracing::debug!(command_id = %command_id, raw_message = %raw_message, "Sending command");
108
109        self.connection.send(raw_message).await;
110
111        rx
112    }
113
114    pub async fn send(
115        &mut self,
116        command: impl Into<Command>,
117    ) -> Result<CommandResponse, SessionSendError> {
118        let rx = self.send_and_get_receiver(command).await;
119        let response = timeout(Duration::from_secs(5), rx).await;
120        match response {
121            Ok(Ok(command_result)) => match command_result {
122                CommandResponseState::Success(response) => {
123                    tracing::debug!(id = response.id, raw_message = %response.result, "Command response success");
124                    Ok(response)
125                }
126                CommandResponseState::Error(err) => {
127                    tracing::debug!(id = err.id, stacktrace = err.stacktrace, code = %err.error, "Command response failed");
128                    Err(SessionSendError::ErrorResponse(err))
129                }
130            },
131            Ok(Err(err)) => panic!("A recv error occurred: {}", err),
132            Err(_) => Err(SessionSendError::ResponseReceiveTimeoutError(
133                ResponseReceiveTimeoutError,
134            )),
135        }
136    }
137
138    pub async fn end_session(&mut self) -> Result<CommandResponse, SessionSendError> {
139        let result = self.send(EndBuilder::default().build()).await;
140        self.connection.close().await;
141        result
142    }
143}
144
145impl<T: ConnectionTransport> BidiEventManagement for BidiSession<T> {
146    async fn send_event(
147        &mut self,
148        command: impl Into<Command>,
149    ) -> Result<CommandResponse, SessionSendError> {
150        self.send(command).await
151    }
152
153    fn get_events(&mut self) -> &mut Arc<Mutex<Vec<BidiEvent>>> {
154        &mut self.events
155    }
156
157    fn push_event(&mut self, event: BidiEvent) {
158        self.events.lock().unwrap().push(event);
159    }
160}
161
162// ── CDP Session ──────────────────────────────────────────────────────────────
163
164pub struct CdpSession<T: ConnectionTransport> {
165    connection: CdpConnection<T>,
166    events: Arc<Mutex<Vec<CdpEvent>>>,
167    pub session_id: Option<String>,
168}
169
170impl<T: ConnectionTransport> CdpSession<T> {
171    pub async fn ws_new(
172        config: &ConnectionTransportConfig,
173    ) -> CdpSession<WebsocketConnectionTransport> {
174        let transport = WebsocketConnectionTransport::new(config).await.unwrap();
175        tracing::info!("Successfully connected to Browser CDP");
176        let connection = CdpConnection::new(transport);
177        connection.start_listeners();
178        let events = Arc::new(Mutex::new(Vec::new()));
179
180        let mut session = CdpSession {
181            connection,
182            events,
183            session_id: None,
184        };
185
186        let (_, dispatch_tx) = session.event_dispatch().await;
187        session
188            .connection
189            .register_event_listener_channel(dispatch_tx)
190            .await;
191
192        session
193    }
194
195    pub async fn register_event_listener(
196        &mut self,
197        tx: tokio::sync::mpsc::UnboundedSender<cdp_base::EventResponse>,
198    ) {
199        self.connection.register_event_listener_channel(tx).await;
200    }
201
202    pub async fn send_and_get_receiver(
203        &mut self,
204        command: impl Into<CdpCommand>,
205    ) -> oneshot::Receiver<CdpCommandResponseState> {
206        let command_id = loop {
207            let id = rand::rng().random::<u16>();
208            if !self.connection.commands_response_subscriptions.lock().await.contains_key(&id) {
209                break id;
210            }
211        };
212
213        let command: CdpCommand = command.into();
214        let msg = cdp_base::CommandMessage {
215            id: command_id,
216            command_data: command.into(),
217        };
218
219        let (tx, rx) = oneshot::channel::<CdpCommandResponseState>();
220        self.connection.commands_response_subscriptions.lock().await.insert(command_id, tx);
221
222        let raw = serde_json::to_string(&msg).unwrap();
223        tracing::debug!(command_id = %command_id, raw_message = %raw, "Sending CDP command");
224        self.connection.send(raw).await;
225
226        rx
227    }
228
229    pub async fn send(
230        &mut self,
231        command: impl Into<CdpCommand>,
232    ) -> Result<cdp_base::CommandResponse, CdpSessionSendError> {
233        let rx = self.send_and_get_receiver(command).await;
234        match timeout(Duration::from_secs(20), rx).await {
235            Ok(Ok(state)) => match state {
236                CdpCommandResponseState::Success(response) => {
237                    tracing::debug!(id = response.id, raw_message = %response.result, "CDP command response success");
238                    Ok(response)
239                }
240                CdpCommandResponseState::Error(err) => {
241                    tracing::debug!(id = ?err.id, error = %err.error, "CDP command response failed");
242                    Err(CdpSessionSendError::ErrorResponse(err))
243                }
244            },
245            Ok(Err(e)) => panic!("CDP recv error: {}", e),
246            Err(_) => Err(CdpSessionSendError::ResponseReceiveTimeoutError(
247                ResponseReceiveTimeoutError,
248            )),
249        }
250    }
251
252    pub async fn close(&self) {
253        self.connection.close().await;
254    }
255}
256
257impl<T: ConnectionTransport> CdpEventManagement for CdpSession<T> {
258    fn get_events(&mut self) -> &mut Arc<Mutex<Vec<CdpEvent>>> {
259        &mut self.events
260    }
261
262    fn push_event(&mut self, event: CdpEvent) {
263        self.events.lock().unwrap().push(event);
264    }
265}