rustenium_core/
session.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use std::time::Duration;
4use rand::Rng;
5use tracing;
6use crate::network::NetworkRequestHandledState;
7use rustenium_bidi_commands::{Command, CommandData, ResultData, EmptyParams};
8use rustenium_bidi_commands::session::commands::{New as SessionNew, SessionNewMethod, NewParameters as SessionNewParameters, SessionCommand, SessionResult, End, SessionEndMethod};
9use rustenium_bidi_commands::session::types::CapabilitiesRequest;
10use tokio::sync::oneshot;
11use tokio::time::timeout;
12use crate::listeners::CommandResponseState;
13use crate::{
14    connection::Connection,
15    transport::{ConnectionTransport, ConnectionTransportConfig, WebsocketConnectionTransport},
16};
17use crate::error::{ResponseReceiveTimeoutError, SessionSendError};
18use crate::events::{BidiEvent, EventManagement};
19
20pub struct Session<T: ConnectionTransport> {
21    id: Option<String>,
22    connection: Connection<T>,
23    bidi_events: Arc<Mutex<Vec<BidiEvent>>>,
24    /// Tracks network requests that have been handled, keyed by request ID
25    pub handled_network_requests: Arc<Mutex<HashMap<String, NetworkRequestHandledState>>>,
26}
27
28pub enum SessionConnectionType {
29    WebSocket
30}
31
32impl<T: ConnectionTransport> Session<T> {
33    pub async fn ws_new(
34        connection_config: &ConnectionTransportConfig,
35    ) -> Session<WebsocketConnectionTransport> {
36        let connection_transport = WebsocketConnectionTransport::new(connection_config)
37            .await
38            .unwrap();
39        let connection = Connection::new(connection_transport);
40        connection.start_listeners();
41        Session {
42            id: None,
43            connection,
44            bidi_events: Arc::new(Mutex::new(Vec::new())),
45            handled_network_requests: Arc::new(Mutex::new(HashMap::new())),
46        }
47    }
48
49    pub async fn create_new_bidi_session(&mut self, connection_type: SessionConnectionType, capabilities: Option<CapabilitiesRequest>) -> () {
50        match connection_type {
51            SessionConnectionType::WebSocket => {
52                let command = SessionNew {
53                    method: SessionNewMethod::SessionNew,
54                    params: SessionNewParameters {
55                        capabilities: capabilities.unwrap_or(CapabilitiesRequest {
56                            always_match: None,
57                            first_match: None,
58                        }),
59                    }
60                };
61                let (_, event_tx) = self.event_dispatch().await;
62                self.connection.register_event_listener_channel(event_tx).await;
63                let command_result = self.send(CommandData::SessionCommand(SessionCommand::New(command.clone()))).await;
64                match command_result {
65                    Ok(command_result) => {
66                            match command_result {
67                            ResultData::SessionResult(session_result) => {
68                                match session_result {
69                                    SessionResult::NewResult(new_session_result) => {
70                                        self.id = Some(new_session_result.session_id);
71                                    }
72                                    _ => panic!("Invalid session result: {:?}", session_result)
73                                }
74                            }
75                            _ => panic!("Invalid command result: {:?}", command_result)
76                        }
77                    }
78                    Err(e) => panic!("Error creating new session: {}", e)
79                }
80            }
81        }
82    }
83
84    /// Send a command and return the receiver to wait for response
85    /// This allows the caller to release locks before waiting for the response
86    pub async fn send_and_get_receiver(&mut self, command_data: CommandData) -> oneshot::Receiver<CommandResponseState> {
87        let command_id = loop {
88            let id = rand::rng().random::<u32>() as u64;
89            if !self.connection.commands_response_subscriptions.lock().await.contains_key(&id) {
90                break id;
91            }
92        };
93
94        let command = Command {
95            id : command_id,
96            command_data,
97            extensible: HashMap::new(),
98        };
99        let (tx, rx) = oneshot::channel::<CommandResponseState>();
100        self.connection.commands_response_subscriptions.lock().await.insert(command_id, tx);
101        let raw_message = serde_json::to_string(&command).unwrap();
102        tracing::debug!(command_id = %command_id, raw_message = %raw_message, "Sending command");
103
104        self.connection.send(raw_message).await;
105
106        rx
107    }
108
109    pub async fn send(&mut self, command_data: CommandData) -> Result<ResultData, SessionSendError>  {
110        let rx = self.send_and_get_receiver(command_data).await;
111        match timeout(Duration::from_secs(100), rx).await {
112            Ok(Ok(command_result)) => match command_result {
113                CommandResponseState::Success(response) => Ok(response.result),
114                CommandResponseState::Error(err) => Err(SessionSendError::ErrorResponse(err))
115            }
116            Ok(Err(err)) => panic!("A recv error occurred: {}", err),
117            // I might need to remove command from commands response subscriptions
118            Err(_) => Err(SessionSendError::ResponseReceiveTimeoutError(ResponseReceiveTimeoutError))
119        }
120    }
121
122    pub async fn end_session(&mut self) -> Result<ResultData, SessionSendError> {
123        let command = End {
124            method: SessionEndMethod::SessionEnd,
125            params: EmptyParams { extensible: Default::default() },
126        };
127
128        let result = self.send(CommandData::SessionCommand(SessionCommand::End(command))).await;
129
130        // Close the connection after ending the session
131        self.connection.close();
132
133        result
134    }
135}
136
137impl <T: ConnectionTransport>EventManagement for Session<T> {
138    async fn send_event(&mut self, command_data: CommandData) -> Result<ResultData, SessionSendError> {
139        self.send(command_data).await
140    }
141
142    fn get_bidi_events(&mut self) -> &mut Arc<Mutex<Vec<BidiEvent>>> {
143        &mut self.bidi_events
144    }
145
146    fn push_event(&mut self, event: BidiEvent) {
147        self.bidi_events.lock().unwrap().push(event);
148    }
149}