rustenium_core/
session.rs1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use std::time::Duration;
4use rand::Rng;
5use tracing;
6use crate::network::NetworkRequestHandledState;
7use rustenium_bidi_commands::{Command, CommandData, ResultData, EmptyParams};
8use rustenium_bidi_commands::session::commands::{New as SessionNew, SessionNewMethod, NewParameters as SessionNewParameters, SessionCommand, SessionResult, End, SessionEndMethod};
9use rustenium_bidi_commands::session::types::CapabilitiesRequest;
10use tokio::sync::oneshot;
11use tokio::time::timeout;
12use crate::listeners::CommandResponseState;
13use crate::{
14 connection::Connection,
15 transport::{ConnectionTransport, ConnectionTransportConfig, WebsocketConnectionTransport},
16};
17use crate::error::{ResponseReceiveTimeoutError, SessionSendError};
18use crate::events::{BidiEvent, EventManagement};
19
20pub struct Session<T: ConnectionTransport> {
21 id: Option<String>,
22 connection: Connection<T>,
23 bidi_events: Arc<Mutex<Vec<BidiEvent>>>,
24 pub handled_network_requests: Arc<Mutex<HashMap<String, NetworkRequestHandledState>>>,
26}
27
28pub enum SessionConnectionType {
29 WebSocket
30}
31
32impl<T: ConnectionTransport> Session<T> {
33 pub async fn ws_new(
34 connection_config: &ConnectionTransportConfig,
35 ) -> Session<WebsocketConnectionTransport> {
36 let connection_transport = WebsocketConnectionTransport::new(connection_config)
37 .await
38 .unwrap();
39 let connection = Connection::new(connection_transport);
40 connection.start_listeners();
41 Session {
42 id: None,
43 connection,
44 bidi_events: Arc::new(Mutex::new(Vec::new())),
45 handled_network_requests: Arc::new(Mutex::new(HashMap::new())),
46 }
47 }
48
49 pub async fn create_new_bidi_session(&mut self, connection_type: SessionConnectionType, capabilities: Option<CapabilitiesRequest>) -> () {
50 match connection_type {
51 SessionConnectionType::WebSocket => {
52 let command = SessionNew {
53 method: SessionNewMethod::SessionNew,
54 params: SessionNewParameters {
55 capabilities: capabilities.unwrap_or(CapabilitiesRequest {
56 always_match: None,
57 first_match: None,
58 }),
59 }
60 };
61 let (_, event_tx) = self.event_dispatch().await;
62 self.connection.register_event_listener_channel(event_tx).await;
63 let command_result = self.send(CommandData::SessionCommand(SessionCommand::New(command.clone()))).await;
64 match command_result {
65 Ok(command_result) => {
66 match command_result {
67 ResultData::SessionResult(session_result) => {
68 match session_result {
69 SessionResult::NewResult(new_session_result) => {
70 self.id = Some(new_session_result.session_id);
71 }
72 _ => panic!("Invalid session result: {:?}", session_result)
73 }
74 }
75 _ => panic!("Invalid command result: {:?}", command_result)
76 }
77 }
78 Err(e) => panic!("Error creating new session: {}", e)
79 }
80 }
81 }
82 }
83
84 pub async fn send_and_get_receiver(&mut self, command_data: CommandData) -> oneshot::Receiver<CommandResponseState> {
87 let command_id = loop {
88 let id = rand::rng().random::<u32>() as u64;
89 if !self.connection.commands_response_subscriptions.lock().await.contains_key(&id) {
90 break id;
91 }
92 };
93
94 let command = Command {
95 id : command_id,
96 command_data,
97 extensible: HashMap::new(),
98 };
99 let (tx, rx) = oneshot::channel::<CommandResponseState>();
100 self.connection.commands_response_subscriptions.lock().await.insert(command_id, tx);
101 let raw_message = serde_json::to_string(&command).unwrap();
102 tracing::debug!(command_id = %command_id, raw_message = %raw_message, "Sending command");
103
104 self.connection.send(raw_message).await;
105
106 rx
107 }
108
109 pub async fn send(&mut self, command_data: CommandData) -> Result<ResultData, SessionSendError> {
110 let rx = self.send_and_get_receiver(command_data).await;
111 match timeout(Duration::from_secs(100), rx).await {
112 Ok(Ok(command_result)) => match command_result {
113 CommandResponseState::Success(response) => Ok(response.result),
114 CommandResponseState::Error(err) => Err(SessionSendError::ErrorResponse(err))
115 }
116 Ok(Err(err)) => panic!("A recv error occurred: {}", err),
117 Err(_) => Err(SessionSendError::ResponseReceiveTimeoutError(ResponseReceiveTimeoutError))
119 }
120 }
121
122 pub async fn end_session(&mut self) -> Result<ResultData, SessionSendError> {
123 let command = End {
124 method: SessionEndMethod::SessionEnd,
125 params: EmptyParams { extensible: Default::default() },
126 };
127
128 let result = self.send(CommandData::SessionCommand(SessionCommand::End(command))).await;
129
130 self.connection.close();
132
133 result
134 }
135}
136
137impl <T: ConnectionTransport>EventManagement for Session<T> {
138 async fn send_event(&mut self, command_data: CommandData) -> Result<ResultData, SessionSendError> {
139 self.send(command_data).await
140 }
141
142 fn get_bidi_events(&mut self) -> &mut Arc<Mutex<Vec<BidiEvent>>> {
143 &mut self.bidi_events
144 }
145
146 fn push_event(&mut self, event: BidiEvent) {
147 self.bidi_events.lock().unwrap().push(event);
148 }
149}