rustenium_core/
session.rs1use crate::error::{ResponseReceiveTimeoutError, SessionSendError};
2use crate::events::{BidiEvent, BidiEventManagement};
3use crate::listeners::CommandResponseState;
4use crate::network::NetworkRequestHandledState;
5use crate::{
6 connection::Connection,
7 transport::{ConnectionTransport, ConnectionTransportConfig, WebsocketConnectionTransport},
8};
9use rand::Rng;
10use rustenium_bidi_definitions::Command;
11use rustenium_bidi_definitions::base::{CommandMessage, CommandResponse};
12use rustenium_bidi_definitions::session::command_builders::{EndBuilder, NewBuilder};
13use rustenium_bidi_definitions::session::results::NewResult;
14use rustenium_bidi_definitions::session::types::CapabilitiesRequest;
15use serde_json;
16use std::collections::HashMap;
17use std::sync::{Arc, Mutex};
18use std::time::Duration;
19use tokio::sync::oneshot;
20use tokio::time::timeout;
21use tracing;
22
23pub struct BidiSession<T: ConnectionTransport> {
24 id: Option<String>,
25 connection: Connection<T>,
26 events: Arc<Mutex<Vec<BidiEvent>>>,
27 pub handled_network_requests: Arc<Mutex<HashMap<String, NetworkRequestHandledState>>>,
29}
30
31pub enum SessionConnectionType {
32 WebSocket,
33}
34
35impl<T: ConnectionTransport> BidiSession<T> {
36 pub async fn ws_new(
37 connection_config: &ConnectionTransportConfig,
38 ) -> BidiSession<WebsocketConnectionTransport> {
39 let connection_transport = WebsocketConnectionTransport::new(connection_config)
40 .await
41 .unwrap();
42 let connection = Connection::new(connection_transport);
43 connection.start_listeners();
44 BidiSession {
45 id: None,
46 connection,
47 events: Arc::new(Mutex::new(Vec::new())),
48 handled_network_requests: Arc::new(Mutex::new(HashMap::new())),
49 }
50 }
51
52 pub async fn create_new_bidi_session(
53 &mut self,
54 connection_type: SessionConnectionType,
55 capabilities: CapabilitiesRequest,
56 ) -> () {
57 match connection_type {
58 SessionConnectionType::WebSocket => {
59 let command = NewBuilder::default()
60 .capabilities(capabilities)
61 .build()
62 .unwrap();
63 let (_, event_tx) = self.event_dispatch().await;
64 self.connection
65 .register_event_listener_channel(event_tx)
66 .await;
67 let command_result = self.send(command).await;
68 match command_result {
69 Ok(command_result) => {
70 let command_result: NewResult =
71 command_result.result.clone().try_into().expect(
72 format!("Invalid command result: {:?}", command_result).as_str(),
73 );
74 self.id = Option::from(command_result.session_id);
75 }
76 Err(e) => panic!("Error creating new session: {}", e),
77 }
78 }
79 }
80 }
81
82 pub async fn send_and_get_receiver(
85 &mut self,
86 command: impl Into<Command>,
87 ) -> oneshot::Receiver<CommandResponseState> {
88 let command_id = loop {
89 let id = rand::rng().random::<u32>() as u64;
90 if !self
91 .connection
92 .commands_response_subscriptions
93 .lock()
94 .await
95 .contains_key(&id)
96 {
97 break id;
98 }
99 };
100
101 let command = CommandMessage {
102 id: command_id,
103 command_data: command.into(),
104 extensible: HashMap::new(),
105 };
106 let (tx, rx) = oneshot::channel::<CommandResponseState>();
107 self.connection
108 .commands_response_subscriptions
109 .lock()
110 .await
111 .insert(command_id, tx);
112 let raw_message = serde_json::to_string(&command).unwrap();
113 tracing::debug!(command_id = %command_id, raw_message = %raw_message, "Sending command");
114
115 self.connection.send(raw_message).await;
116
117 rx
118 }
119
120 pub async fn send(
121 &mut self,
122 command: impl Into<Command>,
123 ) -> Result<CommandResponse, SessionSendError> {
124 let rx = self.send_and_get_receiver(command).await;
125 match timeout(Duration::from_secs(100), rx).await {
126 Ok(Ok(command_result)) => match command_result {
127 CommandResponseState::Success(response) => {
128 tracing::debug!(id = response.id, raw_message = %response.result, "Command response success");
129 Ok(response)
130 }
131 CommandResponseState::Error(err) => {
132 tracing::debug!(id = err.id, stacktrace = err.stacktrace, code = %err.error, "Command response failed");
133 Err(SessionSendError::ErrorResponse(err))
134 }
135 },
136 Ok(Err(err)) => panic!("A recv error occurred: {}", err),
137 Err(_) => Err(SessionSendError::ResponseReceiveTimeoutError(
139 ResponseReceiveTimeoutError,
140 )),
141 }
142 }
143
144 pub async fn end_session(&mut self) -> Result<CommandResponse, SessionSendError> {
145 let result = self.send(EndBuilder::default().build()).await;
146
147 self.connection.close();
149
150 result
151 }
152}
153
154impl<T: ConnectionTransport> BidiEventManagement for BidiSession<T> {
155 async fn send_event(
156 &mut self,
157 command: impl Into<Command>,
158 ) -> Result<CommandResponse, SessionSendError> {
159 self.send(command).await
160 }
161
162 fn get_events(&mut self) -> &mut Arc<Mutex<Vec<BidiEvent>>> {
163 &mut self.events
164 }
165
166 fn push_event(&mut self, event: BidiEvent) {
167 self.events.lock().unwrap().push(event);
168 }
169}