Skip to main content

rustenium_core/
session.rs

1use crate::error::{ResponseReceiveTimeoutError, SessionSendError};
2use crate::events::{BidiEvent, BidiEventManagement};
3use crate::listeners::CommandResponseState;
4use crate::network::NetworkRequestHandledState;
5use crate::{
6    connection::Connection,
7    transport::{ConnectionTransport, ConnectionTransportConfig, WebsocketConnectionTransport},
8};
9use rand::Rng;
10use rustenium_bidi_definitions::Command;
11use rustenium_bidi_definitions::base::{CommandMessage, CommandResponse};
12use rustenium_bidi_definitions::session::command_builders::{EndBuilder, NewBuilder};
13use rustenium_bidi_definitions::session::results::NewResult;
14use rustenium_bidi_definitions::session::types::CapabilitiesRequest;
15use serde_json;
16use std::collections::HashMap;
17use std::sync::{Arc, Mutex};
18use std::time::Duration;
19use tokio::sync::oneshot;
20use tokio::time::timeout;
21use tracing;
22
23pub struct BidiSession<T: ConnectionTransport> {
24    id: Option<String>,
25    connection: Connection<T>,
26    events: Arc<Mutex<Vec<BidiEvent>>>,
27    /// Tracks network requests that have been handled, keyed by request ID
28    pub handled_network_requests: Arc<Mutex<HashMap<String, NetworkRequestHandledState>>>,
29}
30
31pub enum SessionConnectionType {
32    WebSocket,
33}
34
35impl<T: ConnectionTransport> BidiSession<T> {
36    pub async fn ws_new(
37        connection_config: &ConnectionTransportConfig,
38    ) -> BidiSession<WebsocketConnectionTransport> {
39        let connection_transport = WebsocketConnectionTransport::new(connection_config)
40            .await
41            .unwrap();
42        let connection = Connection::new(connection_transport);
43        connection.start_listeners();
44        BidiSession {
45            id: None,
46            connection,
47            events: Arc::new(Mutex::new(Vec::new())),
48            handled_network_requests: Arc::new(Mutex::new(HashMap::new())),
49        }
50    }
51
52    pub async fn create_new_bidi_session(
53        &mut self,
54        connection_type: SessionConnectionType,
55        capabilities: CapabilitiesRequest,
56    ) -> () {
57        match connection_type {
58            SessionConnectionType::WebSocket => {
59                let command = NewBuilder::default()
60                    .capabilities(capabilities)
61                    .build()
62                    .unwrap();
63                let (_, event_tx) = self.event_dispatch().await;
64                self.connection
65                    .register_event_listener_channel(event_tx)
66                    .await;
67                let command_result = self.send(command).await;
68                match command_result {
69                    Ok(command_result) => {
70                        let command_result: NewResult =
71                            command_result.result.clone().try_into().expect(
72                                format!("Invalid command result: {:?}", command_result).as_str(),
73                            );
74                        self.id = Option::from(command_result.session_id);
75                    }
76                    Err(e) => panic!("Error creating new session: {}", e),
77                }
78            }
79        }
80    }
81
82    /// Send a command and return the receiver to wait for response
83    /// This allows the caller to release locks before waiting for the response
84    pub async fn send_and_get_receiver(
85        &mut self,
86        command: impl Into<Command>,
87    ) -> oneshot::Receiver<CommandResponseState> {
88        let command_id = loop {
89            let id = rand::rng().random::<u32>() as u64;
90            if !self
91                .connection
92                .commands_response_subscriptions
93                .lock()
94                .await
95                .contains_key(&id)
96            {
97                break id;
98            }
99        };
100
101        let command = CommandMessage {
102            id: command_id,
103            command_data: command.into(),
104            extensible: HashMap::new(),
105        };
106        let (tx, rx) = oneshot::channel::<CommandResponseState>();
107        self.connection
108            .commands_response_subscriptions
109            .lock()
110            .await
111            .insert(command_id, tx);
112        let raw_message = serde_json::to_string(&command).unwrap();
113        tracing::debug!(command_id = %command_id, raw_message = %raw_message, "Sending command");
114
115        self.connection.send(raw_message).await;
116
117        rx
118    }
119
120    pub async fn send(
121        &mut self,
122        command: impl Into<Command>,
123    ) -> Result<CommandResponse, SessionSendError> {
124        let rx = self.send_and_get_receiver(command).await;
125        match timeout(Duration::from_secs(100), rx).await {
126            Ok(Ok(command_result)) => match command_result {
127                CommandResponseState::Success(response) => {
128                    tracing::debug!(id = response.id, raw_message = %response.result, "Command response success");
129                    Ok(response)
130                }
131                CommandResponseState::Error(err) => {
132                    tracing::debug!(id = err.id, stacktrace = err.stacktrace, code = %err.error, "Command response failed");
133                    Err(SessionSendError::ErrorResponse(err))
134                }
135            },
136            Ok(Err(err)) => panic!("A recv error occurred: {}", err),
137            // I might need to remove command from commands response subscriptions
138            Err(_) => Err(SessionSendError::ResponseReceiveTimeoutError(
139                ResponseReceiveTimeoutError,
140            )),
141        }
142    }
143
144    pub async fn end_session(&mut self) -> Result<CommandResponse, SessionSendError> {
145        let result = self.send(EndBuilder::default().build()).await;
146
147        // Close the connection after ending the session
148        self.connection.close();
149
150        result
151    }
152}
153
154impl<T: ConnectionTransport> BidiEventManagement for BidiSession<T> {
155    async fn send_event(
156        &mut self,
157        command: impl Into<Command>,
158    ) -> Result<CommandResponse, SessionSendError> {
159        self.send(command).await
160    }
161
162    fn get_events(&mut self) -> &mut Arc<Mutex<Vec<BidiEvent>>> {
163        &mut self.events
164    }
165
166    fn push_event(&mut self, event: BidiEvent) {
167        self.events.lock().unwrap().push(event);
168    }
169}