rustenium_core/
session.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use std::time::Duration;
4use rand::Rng;
5use crate::network::NetworkRequestHandledState;
6use rustenium_bidi_commands::{Command, CommandData, ResultData, EmptyParams};
7use rustenium_bidi_commands::session::commands::{New as SessionNew, SessionNewMethod, NewParameters as SessionNewParameters, SessionCommand, SessionResult, End, SessionEndMethod};
8use rustenium_bidi_commands::session::types::CapabilitiesRequest;
9use tokio::sync::oneshot;
10use tokio::time::timeout;
11use crate::listeners::CommandResponseState;
12use crate::{
13    connection::Connection,
14    transport::{ConnectionTransport, ConnectionTransportConfig, WebsocketConnectionTransport},
15};
16use crate::error::{ResponseReceiveTimeoutError, SessionSendError};
17use crate::events::{BidiEvent, EventManagement};
18
19pub struct Session<T: ConnectionTransport> {
20    id: Option<String>,
21    connection: Connection<T>,
22    bidi_events: Arc<Mutex<Vec<BidiEvent>>>,
23    /// Tracks network requests that have been handled, keyed by request ID
24    pub handled_network_requests: Arc<Mutex<HashMap<String, NetworkRequestHandledState>>>,
25}
26
27pub enum SessionConnectionType {
28    WebSocket
29}
30
31impl<T: ConnectionTransport> Session<T> {
32    pub async fn ws_new(
33        connection_config: &ConnectionTransportConfig,
34    ) -> Session<WebsocketConnectionTransport> {
35        let connection_transport = WebsocketConnectionTransport::new(connection_config)
36            .await
37            .unwrap();
38        let connection = Connection::new(connection_transport);
39        connection.start_listeners();
40        Session {
41            id: None,
42            connection,
43            bidi_events: Arc::new(Mutex::new(Vec::new())),
44            handled_network_requests: Arc::new(Mutex::new(HashMap::new())),
45        }
46    }
47
48    pub async fn create_new_bidi_session(&mut self, connection_type: SessionConnectionType, capabilities: Option<CapabilitiesRequest>) -> () {
49        match connection_type {
50            SessionConnectionType::WebSocket => {
51                let command = SessionNew {
52                    method: SessionNewMethod::SessionNew,
53                    params: SessionNewParameters {
54                        capabilities: capabilities.unwrap_or(CapabilitiesRequest {
55                            always_match: None,
56                            first_match: None,
57                        }),
58                    }
59                };
60                let (_, event_tx) = self.event_dispatch().await;
61                self.connection.register_event_listener_channel(event_tx).await;
62                let command_result = self.send(CommandData::SessionCommand(SessionCommand::New(command.clone()))).await;
63                match command_result {
64                    Ok(command_result) => {
65                            match command_result {
66                            ResultData::SessionResult(session_result) => {
67                                match session_result {
68                                    SessionResult::NewResult(new_session_result) => {
69                                        self.id = Some(new_session_result.session_id);
70                                    }
71                                    _ => panic!("Invalid session result: {:?}", session_result)
72                                }
73                            }
74                            _ => panic!("Invalid command result: {:?}", command_result)
75                        }
76                    }
77                    Err(e) => panic!("Error creating new session: {}", e)
78                }
79            }
80        }
81    }
82
83    /// Send a command and return the receiver to wait for response
84    /// This allows the caller to release locks before waiting for the response
85    pub async fn send_and_get_receiver(&mut self, command_data: CommandData) -> oneshot::Receiver<CommandResponseState> {
86        let command_id = loop {
87            let id = rand::rng().random::<u32>() as u64;
88            if !self.connection.commands_response_subscriptions.lock().await.contains_key(&id) {
89                break id;
90            }
91        };
92
93        let command = Command {
94            id : command_id,
95            command_data,
96            extensible: HashMap::new(),
97        };
98        let (tx, rx) = oneshot::channel::<CommandResponseState>();
99        self.connection.commands_response_subscriptions.lock().await.insert(command_id, tx);
100        let raw_message = serde_json::to_string(&command).unwrap();
101        self.connection.send(raw_message).await;
102
103        rx
104    }
105
106    pub async fn send(&mut self, command_data: CommandData) -> Result<ResultData, SessionSendError>  {
107        let rx = self.send_and_get_receiver(command_data).await;
108        match timeout(Duration::from_secs(100), rx).await {
109            Ok(Ok(command_result)) => match command_result {
110                CommandResponseState::Success(response) => Ok(response.result),
111                CommandResponseState::Error(err) => Err(SessionSendError::ErrorResponse(err))
112            }
113            Ok(Err(err)) => panic!("A recv error occurred: {}", err),
114            // I might need to remove command from commands response subscriptions
115            Err(_) => Err(SessionSendError::ResponseReceiveTimeoutError(ResponseReceiveTimeoutError))
116        }
117    }
118
119    pub async fn end_session(&mut self) -> Result<ResultData, SessionSendError> {
120        let command = End {
121            method: SessionEndMethod::SessionEnd,
122            params: EmptyParams { extensible: Default::default() },
123        };
124
125        let result = self.send(CommandData::SessionCommand(SessionCommand::End(command))).await;
126
127        // Close the connection after ending the session
128        self.connection.close();
129
130        result
131    }
132}
133
134impl <T: ConnectionTransport>EventManagement for Session<T> {
135    async fn send_event(&mut self, command_data: CommandData) -> Result<ResultData, SessionSendError> {
136        self.send(command_data).await
137    }
138
139    fn get_bidi_events(&mut self) -> &mut Arc<Mutex<Vec<BidiEvent>>> {
140        &mut self.bidi_events
141    }
142
143    fn push_event(&mut self, event: BidiEvent) {
144        self.bidi_events.lock().unwrap().push(event);
145    }
146}