1use crate::error::{CdpSessionSendError, ResponseReceiveTimeoutError, SessionSendError};
2use crate::events::{BidiEvent, BidiEventManagement, CdpEvent, CdpEventManagement};
3use crate::listeners::{
4 CdpCommandResponseState,
5 CommandResponseState,
6};
7use crate::network::NetworkRequestHandledState;
8use crate::{
9 connection::{BidiConnection, CdpConnection},
10 transport::{ConnectionTransport, ConnectionTransportConfig, WebsocketConnectionTransport},
11};
12use rand::Rng;
13use rustenium_bidi_definitions::Command;
14use rustenium_bidi_definitions::base::{CommandMessage, CommandResponse};
15use rustenium_bidi_definitions::session::command_builders::{EndBuilder, NewBuilder};
16use rustenium_bidi_definitions::session::results::NewResult;
17use rustenium_bidi_definitions::session::types::CapabilitiesRequest;
18use rustenium_cdp_definitions::Command as CdpCommand;
19use rustenium_cdp_definitions::base as cdp_base;
20use serde_json;
21use std::collections::HashMap;
22use std::sync::{Arc, Mutex};
23use std::time::Duration;
24use tokio::sync::oneshot;
25use tokio::time::timeout;
26use tracing;
27
28pub struct BidiSession<T: ConnectionTransport> {
29 id: String,
30 connection: BidiConnection<T>,
31 events: Arc<Mutex<Vec<BidiEvent>>>,
32 pub handled_network_requests: Arc<Mutex<HashMap<String, NetworkRequestHandledState>>>,
34}
35
36impl BidiSession<WebsocketConnectionTransport> {
37 pub async fn new(
38 connection_config: &ConnectionTransportConfig,
39 capabilities: CapabilitiesRequest,
40 ) -> Self {
41 let transport = WebsocketConnectionTransport::new(connection_config).await.unwrap();
42 tracing::info!("Connected to WebSocket at {}", connection_config.full_endpoint());
43 let connection = BidiConnection::new(transport);
44 connection.start_listeners();
45
46 let mut session = Self {
47 id: String::new(),
48 connection,
49 events: Arc::new(Mutex::new(Vec::new())),
50 handled_network_requests: Arc::new(Mutex::new(HashMap::new())),
51 };
52
53 let (_, event_tx) = session.event_dispatch().await;
54 session.connection.register_event_listener_channel(event_tx).await;
55
56 let command = NewBuilder::default()
57 .capabilities(capabilities)
58 .build()
59 .unwrap();
60 let command_result = session.send(command).await;
61 match command_result {
62 Ok(command_result) => {
63 let result: NewResult = command_result.result.clone().try_into().expect(
64 format!("Invalid command result: {:?}", command_result).as_str(),
65 );
66 session.id = result.session_id;
67 }
68 Err(e) => panic!("Error creating new session: {}", e),
69 }
70
71 session
72 }
73}
74
75impl<T: ConnectionTransport> BidiSession<T> {
76 pub async fn send_and_get_receiver(
79 &mut self,
80 command: impl Into<Command>,
81 ) -> oneshot::Receiver<CommandResponseState> {
82 let command_id = loop {
83 let id = rand::rng().random::<u32>() as u64;
84 if !self
85 .connection
86 .commands_response_subscriptions
87 .lock()
88 .await
89 .contains_key(&id)
90 {
91 break id;
92 }
93 };
94
95 let command = CommandMessage {
96 id: command_id,
97 command_data: command.into(),
98 extensible: HashMap::new(),
99 };
100 let (tx, rx) = oneshot::channel::<CommandResponseState>();
101 self.connection
102 .commands_response_subscriptions
103 .lock()
104 .await
105 .insert(command_id, tx);
106 let raw_message = serde_json::to_string(&command).unwrap();
107 tracing::debug!(command_id = %command_id, raw_message = %raw_message, "Sending command");
108
109 self.connection.send(raw_message).await;
110
111 rx
112 }
113
114 pub async fn send(
115 &mut self,
116 command: impl Into<Command>,
117 ) -> Result<CommandResponse, SessionSendError> {
118 let rx = self.send_and_get_receiver(command).await;
119 let response = timeout(Duration::from_secs(5), rx).await;
120 match response {
121 Ok(Ok(command_result)) => match command_result {
122 CommandResponseState::Success(response) => {
123 tracing::debug!(id = response.id, raw_message = %response.result, "Command response success");
124 Ok(response)
125 }
126 CommandResponseState::Error(err) => {
127 tracing::debug!(id = err.id, stacktrace = err.stacktrace, code = %err.error, "Command response failed");
128 Err(SessionSendError::ErrorResponse(err))
129 }
130 },
131 Ok(Err(err)) => panic!("A recv error occurred: {}", err),
132 Err(_) => Err(SessionSendError::ResponseReceiveTimeoutError(
133 ResponseReceiveTimeoutError,
134 )),
135 }
136 }
137
138 pub async fn end_session(&mut self) -> Result<CommandResponse, SessionSendError> {
139 let result = self.send(EndBuilder::default().build()).await;
140 self.connection.close().await;
141 result
142 }
143}
144
145impl<T: ConnectionTransport> BidiEventManagement for BidiSession<T> {
146 async fn send_event(
147 &mut self,
148 command: impl Into<Command>,
149 ) -> Result<CommandResponse, SessionSendError> {
150 self.send(command).await
151 }
152
153 fn get_events(&mut self) -> &mut Arc<Mutex<Vec<BidiEvent>>> {
154 &mut self.events
155 }
156
157 fn push_event(&mut self, event: BidiEvent) {
158 self.events.lock().unwrap().push(event);
159 }
160}
161
162pub struct CdpSession<T: ConnectionTransport> {
165 connection: CdpConnection<T>,
166 events: Arc<Mutex<Vec<CdpEvent>>>,
167 pub session_id: Option<String>,
168}
169
170impl<T: ConnectionTransport> CdpSession<T> {
171 pub async fn ws_new(
172 config: &ConnectionTransportConfig,
173 ) -> CdpSession<WebsocketConnectionTransport> {
174 let transport = WebsocketConnectionTransport::new(config).await.unwrap();
175 tracing::info!("Successfully connected to Browser CDP");
176 let connection = CdpConnection::new(transport);
177 connection.start_listeners();
178 let events = Arc::new(Mutex::new(Vec::new()));
179
180 let mut session = CdpSession {
181 connection,
182 events,
183 session_id: None,
184 };
185
186 let (_, dispatch_tx) = session.event_dispatch().await;
187 session
188 .connection
189 .register_event_listener_channel(dispatch_tx)
190 .await;
191
192 session
193 }
194
195 pub async fn register_event_listener(
196 &mut self,
197 tx: tokio::sync::mpsc::UnboundedSender<cdp_base::EventResponse>,
198 ) {
199 self.connection.register_event_listener_channel(tx).await;
200 }
201
202 pub async fn send_and_get_receiver(
203 &mut self,
204 command: impl Into<CdpCommand>,
205 ) -> oneshot::Receiver<CdpCommandResponseState> {
206 let command_id = loop {
207 let id = rand::rng().random::<u16>();
208 if !self.connection.commands_response_subscriptions.lock().await.contains_key(&id) {
209 break id;
210 }
211 };
212
213 let command: CdpCommand = command.into();
214 let msg = cdp_base::CommandMessage {
215 id: command_id,
216 command_data: command.into(),
217 };
218
219 let (tx, rx) = oneshot::channel::<CdpCommandResponseState>();
220 self.connection.commands_response_subscriptions.lock().await.insert(command_id, tx);
221
222 let raw = serde_json::to_string(&msg).unwrap();
223 tracing::debug!(command_id = %command_id, raw_message = %raw, "Sending CDP command");
224 self.connection.send(raw).await;
225
226 rx
227 }
228
229 pub async fn send(
230 &mut self,
231 command: impl Into<CdpCommand>,
232 ) -> Result<cdp_base::CommandResponse, CdpSessionSendError> {
233 let rx = self.send_and_get_receiver(command).await;
234 match timeout(Duration::from_secs(20), rx).await {
235 Ok(Ok(state)) => match state {
236 CdpCommandResponseState::Success(response) => {
237 tracing::debug!(id = response.id, raw_message = %response.result, "CDP command response success");
238 Ok(response)
239 }
240 CdpCommandResponseState::Error(err) => {
241 tracing::debug!(id = ?err.id, error = %err.error, "CDP command response failed");
242 Err(CdpSessionSendError::ErrorResponse(err))
243 }
244 },
245 Ok(Err(e)) => panic!("CDP recv error: {}", e),
246 Err(_) => Err(CdpSessionSendError::ResponseReceiveTimeoutError(
247 ResponseReceiveTimeoutError,
248 )),
249 }
250 }
251
252 pub async fn close(&self) {
253 self.connection.close().await;
254 }
255}
256
257impl<T: ConnectionTransport> CdpEventManagement for CdpSession<T> {
258 fn get_events(&mut self) -> &mut Arc<Mutex<Vec<CdpEvent>>> {
259 &mut self.events
260 }
261
262 fn push_event(&mut self, event: CdpEvent) {
263 self.events.lock().unwrap().push(event);
264 }
265}