websocket_std/sync/
client.rs

1use std::net::{TcpStream, Shutdown};
2use std::io::{Write, ErrorKind};
3use std::collections::{HashMap, VecDeque};
4use std::time::{Instant, Duration};
5use std::format;
6use core::marker::Send;
7use crate::core::net::read_into_buffer;
8use crate::result::WebSocketError;
9use crate::ws_basic::header::{OPCODE, FLAG};
10use crate::ws_basic::frame::{DataFrame, ControlFrame, Frame, FrameKind, bytes_to_frame};
11use crate::ws_basic::status_code::{WSStatus, evaulate_status_code};
12use crate::core::traits::{Serialize, Parse};
13use crate::core::binary::bytes_to_u16;
14use super::super::result::WebSocketResult;
15use crate::http::request::{Request, Method};
16use crate::http::response::Response;
17use crate::ws_basic::key::{gen_key, verify_key};
18use crate::extension::Extension;
19
20const DEFAULT_MESSAGE_SIZE: u64 = 1024;
21const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
22const SWITCHING_PROTOCOLS: u16 = 101;
23
24#[allow(non_camel_case_types)]
25#[derive(PartialEq)]
26#[repr(C)]
27enum ConnectionStatus {
28    NOT_INIT,
29    START_INIT,
30    HANDSHAKE, 
31    OPEN,
32    CLIENT_WANTS_TO_CLOSE,
33    SERVER_WANTS_TO_CLOSE,
34    CLOSE
35}
36
37#[allow(non_camel_case_types)]
38#[repr(C)]
39enum Event {
40    WEBSOCKET_DATA(Box<dyn Frame>),
41    HTTP_RESPONSE(Response),
42    HTTP_REQUEST(Request),
43    NO_DATA,
44}
45
46fn is_websocket_data(event: &Event) -> bool {
47    match event {
48        Event::WEBSOCKET_DATA(_) => true,
49        _ => false
50    }
51}
52
53#[repr(C)]
54enum EventIO {
55    INPUT,
56    OUTPUT
57}
58
59#[derive(Clone)]
60pub struct Config<'a, T: Clone> {
61    pub callback: Option<fn(&mut WSClient<'a, T>, &WSEvent, Option<T>)>,
62    pub data: Option<T>,
63    pub protocols: Option<&'a[&'a str]>,
64}
65
66#[allow(non_camel_case_types)]
67#[repr(C)]
68pub enum Reason {
69    SERVER_CLOSE(u16),
70    CLIENT_CLOSE(u16)
71}
72
73#[allow(non_camel_case_types)]
74pub enum WSEvent { 
75    ON_CONNECT(Option<String>),
76    ON_TEXT(String),
77    ON_CLOSE(Reason),
78}
79
80#[allow(dead_code)]
81#[repr(C)]
82pub struct WSClient<'a, T: Clone> {
83    host: &'a str,
84    port: u16,
85    path: &'a str,
86    connection_status: ConnectionStatus,
87    message_size: u64,
88    timeout: Duration,
89    stream: Option<TcpStream>,
90    recv_storage: Vec<u8>,                                   // Storage to keep the bytes received from the socket (bytes that didn't use to create a frame)
91    recv_data: Vec<u8>,                                      // Store the data received from the Frames until the data is completelly received
92    cb_data: Option<T>,
93    callback: Option<fn(&mut Self, &WSEvent, Option<T>)>,
94    protocol: Option<String>,
95    acceptable_protocols: Option<&'a [&'a str]>,
96    extensions: Vec<Extension>,
97    input_events: VecDeque<Event>,
98    output_events: VecDeque<Event>,
99    websocket_key: String,
100    close_iters: usize,                                      // Count the number of times send_message tries to execute after the close. If <= 1 don't raise error, otherwise raise ConnectionClose error 
101}                                                            // The close connection depends on the order of the functions event_loop and is_connected
102                        
103
104impl<'a, T> WSClient<'a, T> where T: Clone {
105    pub fn new() -> Self {
106        WSClient { 
107            host: "", 
108            port: 0, 
109            path: "", 
110            connection_status: ConnectionStatus::NOT_INIT, 
111            message_size: DEFAULT_MESSAGE_SIZE, 
112            stream: None, 
113            recv_storage: Vec::new(), 
114            recv_data: Vec::new(), 
115            timeout: DEFAULT_TIMEOUT, 
116            cb_data: None,
117            callback: None,
118            protocol: None,
119            acceptable_protocols: None,
120            extensions: Vec::new(),
121            close_iters: 0,
122            input_events: VecDeque::new(),
123            output_events: VecDeque::new(),
124            websocket_key: String::new(),
125        }
126    }
127
128    pub fn init(&mut self, host: &'a str, port: u16, path: &'a str, config: Option<Config<'a, T>>) {
129        self.host = host;
130        self.port = port;
131        self.path = path; 
132
133        if let Some(conf) = config {
134            self.cb_data = conf.data;
135            self.callback = conf.callback;
136            self.acceptable_protocols = conf.protocols;
137        }
138
139        self.connection_status = ConnectionStatus::START_INIT;
140    }
141
142    fn start_init(&mut self) -> WebSocketResult<()> {
143        let socket = TcpStream::connect(format!("{}:{}", self.host, self.port.to_string()));
144        if socket.is_err() { return Err(WebSocketError::UnreachableHost)} 
145        let sec_websocket_key = gen_key();
146        
147        let mut headers: HashMap<String, String> = HashMap::from([
148            (String::from("Upgrade"), String::from("websocket")),
149            (String::from("Connection"), String::from("Upgrade")),
150            (String::from("Sec-WebSocket-Key"), sec_websocket_key.clone()),
151            (String::from("Sec-WebSocket-Version"), String::from("13")),
152            (String::from("User-agent"), String::from("rust-websocket-std")),
153        ]);
154
155        // Add protocols to request
156        let mut protocols_value = String::new();
157        if let Some(protocols) = self.acceptable_protocols {
158            for p in protocols {
159                protocols_value.push_str(p);
160                protocols_value.push_str(", ");
161            }
162            headers.insert(String::from("Sec-WebSocket-Protocol"), (&(protocols_value)[0..protocols_value.len()-2]).to_string());
163        }
164        
165        let request = Request::new(Method::GET, self.path, "HTTP/1.1", Some(headers));
166        
167        self.output_events.push_front(Event::HTTP_REQUEST(request)); // Push front, because the client could execute send before init (store the frames to send to do it later)
168        self.websocket_key = sec_websocket_key;
169        let socket = socket.unwrap();
170        socket.set_nonblocking(true)?;
171        self.stream = Some(socket);
172        self.connection_status = ConnectionStatus::HANDSHAKE;
173            
174        Ok(())
175    }
176
177    // Returns the protocol accepted by the server
178    pub fn protocol(&self) -> Option<&str> {
179        if self.protocol.is_none() { return None };
180        return Some(self.protocol.as_ref().unwrap().as_str());
181    }
182
183    pub fn set_message_size(&mut self, size: u64) {
184        self.message_size = size;
185    }
186
187    pub fn set_timeout(&mut self, timeout: Duration) {
188        self.timeout = timeout;
189    }
190
191    pub fn send(&mut self, payload: &str) {
192        // If connection is close do nothing
193        if self.connection_status == ConnectionStatus::CLOSE { return }
194        let mut data_sent = 0;
195        let mut _i: usize = 0;
196
197        while data_sent < payload.len() {
198            _i = data_sent + self.message_size as usize; 
199            if _i >= payload.len() { _i = payload.len() };
200            let payload_chunk = payload[data_sent.._i].as_bytes();
201            let flag = if data_sent + self.message_size as usize >= payload.len() { FLAG::FIN } else { FLAG::NOFLAG };
202            let code = if data_sent == 0 { OPCODE::TEXT } else { OPCODE::CONTINUATION };
203            let frame = DataFrame::new(flag, code, payload_chunk.to_vec(), true, None);
204            self.output_events.push_back(Event::WEBSOCKET_DATA(Box::new(frame)));
205            data_sent += self.message_size as usize;
206        }
207    }
208
209    pub fn event_loop(&mut self) -> WebSocketResult<()> {
210        if self.connection_status == ConnectionStatus::NOT_INIT { return Ok(()) }
211        if self.connection_status == ConnectionStatus::START_INIT { return self.start_init()}
212        if self.connection_status == ConnectionStatus::CLOSE { return Err(WebSocketError::ConnectionClose) }
213    
214        let event = self.read_bytes_from_socket()?;
215        self.insert_input_event(event);
216        
217        let in_event = self.input_events.pop_front();     
218        // Check that the message taken from the queue is not a websocket event and the state of the websocket is different
219        // - if the state is HANDSHAKE dont pop an event if is a websocket event
220        let out_event = self.pop_output_event();
221
222        if in_event.is_some() { self.handle_event(in_event.unwrap(), EventIO::INPUT)? };
223        if out_event.is_some() { self.handle_event(out_event.unwrap(), EventIO::OUTPUT)? };
224
225        return Ok(())
226    }
227
228    fn pop_output_event(&mut self) -> Option<Event> {
229        let mut out_event = self.output_events.pop_front();
230        if out_event.is_some() &&
231        self.connection_status == ConnectionStatus::HANDSHAKE && 
232        is_websocket_data(out_event.as_ref().unwrap())
233            {
234                self.output_events.push_front(out_event.unwrap());
235                out_event = None;
236            }
237        return out_event;
238    }
239
240    fn handle_recv_bytes_frame(&mut self) -> WebSocketResult<Event> {
241        let frame = bytes_to_frame(&self.recv_storage)?;
242        if frame.is_none() { return Ok(Event::NO_DATA) };
243
244        let (frame, offset) = frame.unwrap();
245
246        let event = Event::WEBSOCKET_DATA(frame);
247        self.recv_storage.drain(0..offset);
248
249        Ok(event)
250    }
251
252    fn handle_recv_frame(&mut self, frame: Box<dyn Frame>) -> WebSocketResult<()> {
253        match frame.kind()  {
254            FrameKind::Data => { 
255                if frame.get_header().get_flag() != FLAG::FIN {
256                    self.recv_data.extend_from_slice(frame.get_data());
257                }
258
259                if self.callback.is_some() {
260                    let callback = self.callback.unwrap();
261
262                    let res = String::from_utf8(frame.get_data().to_vec());
263                    if res.is_err() { return Err(WebSocketError::DecodingFromUTF8) }
264                    
265                    let msg = res.unwrap();
266
267                    // Message received in a single frame
268                    if self.recv_data.is_empty() {
269                        callback(self, &WSEvent::ON_TEXT(msg), self.cb_data.clone());
270
271                    // Message from a multiples frames     
272                    } else {
273                        let previous_data = self.recv_data.clone();
274                        let res = String::from_utf8(previous_data);
275                        if res.is_err() { return Err(WebSocketError::DecodingFromUTF8); }
276                        
277                        let mut completed_msg = res.unwrap();
278                        completed_msg.push_str(msg.as_str());
279
280                        // Send the message to the callback function
281                        callback(self, &WSEvent::ON_TEXT(completed_msg), self.cb_data.clone());
282                        
283                        // There is 2 ways to deal with the vector data:
284                        // 1 - Remove from memory (takes more time)
285                        //         Creating a new vector produces that the old vector will be dropped (deallocating the memory)
286                        self.recv_data = Vec::new();
287
288                        // // 2 - Use the clear method (takes more memory because we never drop it)
289                        // //         The vector does not remove memory that has already been allocated.
290                        // self.recv_data.clear();
291                    }
292                }
293                return Ok(());
294            },
295            FrameKind::Control => { return self.handle_control_frame(frame.as_any().downcast_ref::<ControlFrame>().unwrap()); },
296            FrameKind::NotDefine => return Err(WebSocketError::InvalidFrame)
297        }; 
298    }
299
300    fn handle_recv_bytes_http_response(&mut self) -> WebSocketResult<Event> {
301        let response = Response::parse(&self.recv_storage);
302        if response.is_err() { return Ok(Event::NO_DATA); } // TODO: Check for timeout to raise an error
303
304        let response = response.unwrap();
305        let event = Event::HTTP_RESPONSE(response);
306        // TODO: Drain bytes not used in response (maybe two responses comes at the same time)
307        self.recv_storage.clear();
308
309        Ok(event)
310    }
311
312    fn handle_recv_http_response(&mut self, response: Response) -> WebSocketResult<()> {
313        match self.connection_status {
314            ConnectionStatus::HANDSHAKE => {
315                let sec_websocket_accept = response.header("Sec-WebSocket-Accept");
316            
317                if sec_websocket_accept.is_none() { return Err(WebSocketError::HandShake) }
318                let sec_websocket_accept = sec_websocket_accept.unwrap();
319            
320                // Verify Sec-WebSocket-Accept
321                let accepted = verify_key(&self.websocket_key, &sec_websocket_accept);
322                if !accepted {
323                    return Err(WebSocketError::HandShake);
324                }
325            
326                if response.get_status_code() == 0 || 
327                   response.get_status_code() != SWITCHING_PROTOCOLS { 
328                    return Err(WebSocketError::HandShake) 
329                }
330
331                self.protocol = response.header("Sec-WebSocket-Protocol");
332
333                let mut response_msg = None;
334                
335                if let Some(body) = response.body() {
336                   response_msg = Some(body.clone()); 
337                }
338
339                self.connection_status = ConnectionStatus::OPEN;
340
341                if let Some(callback) = self.callback { 
342                    callback(self, &WSEvent::ON_CONNECT(response_msg), self.cb_data.clone());
343                }
344            }
345            _ =>  {} // Unreachable 
346        }
347
348        Ok(())
349    }
350
351    fn handle_send_frame(&mut self, frame: Box<dyn Frame>) -> WebSocketResult<()> {
352        let sent = self.try_write(frame.serialize().as_slice())?;
353        let kind = frame.kind();
354        let mut status = None;
355
356        if frame.kind() == FrameKind::Control {
357            status = frame.as_any().downcast_ref::<ControlFrame>().unwrap().get_status_code();
358        }
359
360        if !sent { self.output_events.push_front(Event::WEBSOCKET_DATA(frame)) };
361
362        if sent && kind == FrameKind::Control && self.connection_status == ConnectionStatus::SERVER_WANTS_TO_CLOSE {
363            self.connection_status = ConnectionStatus::CLOSE;
364            self.stream.as_mut().unwrap().shutdown(Shutdown::Both)?;
365            self.stream = None;
366
367            if let Some(callback) = self.callback {
368                let reason = Reason::SERVER_CLOSE(status.unwrap_or(0));
369                callback(self, &WSEvent::ON_CLOSE(reason), self.cb_data.clone());
370            }
371        }
372
373        Ok(())
374    }
375
376    fn handle_send_http_request(&mut self, request: Request) -> WebSocketResult<()> {
377        let sent = self.try_write(request.serialize().as_slice())?;
378        if !sent { 
379            self.output_events.push_front(Event::HTTP_REQUEST(request)) 
380        }
381        Ok(())
382    }
383
384    fn handle_event(&mut self, event: Event, kind: EventIO) -> WebSocketResult<()> {
385
386        match kind {
387            EventIO::INPUT => {
388                match event {
389                    Event::WEBSOCKET_DATA(frame) => self.handle_recv_frame(frame)?,
390                    Event::HTTP_RESPONSE(response) => self.handle_recv_http_response(response)?,
391                    Event::HTTP_REQUEST(_) => {} // Unreachable
392                    Event::NO_DATA => {} // Unreachable
393                }
394            },
395
396            EventIO::OUTPUT => {
397                match event { 
398                    Event::WEBSOCKET_DATA(frame) => self.handle_send_frame(frame)?,
399                    Event::HTTP_REQUEST(request) => self.handle_send_http_request(request)?,
400                    Event::HTTP_RESPONSE(_) => {} // Unreachable
401                    Event::NO_DATA => {} // Unreachable
402                }
403            }
404        }
405
406        return Ok(());
407    }
408
409    fn read_bytes_from_socket(&mut self) -> WebSocketResult<Event> {
410        // TODO: Add timeout attribute to self in order to raise an error if any op overflow the time required to finish
411        let mut buffer = [0u8; 1024];
412        let reader = self.stream.as_mut().unwrap();
413        let bytes_readed = read_into_buffer(reader, &mut buffer)?;
414
415        if bytes_readed > 0 {
416            self.recv_storage.extend_from_slice(&buffer[0..bytes_readed]);
417        }
418
419        // Input data
420        let mut event = Event::NO_DATA;
421        if self.recv_storage.len() > 0 {
422            match self.connection_status {
423                ConnectionStatus::HANDSHAKE => event = self.handle_recv_bytes_http_response()?,
424                ConnectionStatus::OPEN | ConnectionStatus::CLIENT_WANTS_TO_CLOSE | ConnectionStatus::SERVER_WANTS_TO_CLOSE => {
425                    event = self.handle_recv_bytes_frame()?;
426                },
427
428                ConnectionStatus::CLOSE => {}, // Unreachable
429                ConnectionStatus::NOT_INIT => {}, // Unreachable
430                ConnectionStatus::START_INIT => {} // Unreachable
431            };
432        }
433        Ok(event) 
434    }
435
436    fn insert_input_event(&mut self, event: Event) {
437        match &event {
438            Event::WEBSOCKET_DATA(frame) => { 
439                if frame.kind() == FrameKind::Control {
440                    self.input_events.push_front(event);
441                } else {
442                    self.input_events.push_back(event)
443                }
444            },
445
446            Event::HTTP_RESPONSE(_) => self.input_events.push_back(event),
447            Event::HTTP_REQUEST(_) => {} // Unreachable
448            Event::NO_DATA => {}
449        }
450    }
451
452    fn try_write(&mut self, bytes: &[u8]) -> WebSocketResult<bool> {
453        let res = self.stream.as_mut().unwrap().write_all(bytes);
454        if res.is_err(){
455            let error = res.err().unwrap();
456
457            // Try to send next iteration
458            if error.kind() == ErrorKind::WouldBlock { 
459                return Ok(false);
460
461            } else {
462                return Err(WebSocketError::IOError);
463            }
464        }
465        Ok(true)
466    }
467
468    fn handle_control_frame(&mut self, frame: &ControlFrame) -> WebSocketResult<()> {
469        match frame.get_header().get_opcode() {
470            OPCODE::PING=> { 
471                let data = frame.get_data();
472                let pong_frame = ControlFrame::new(FLAG::FIN, OPCODE::PONG, None, data.to_vec(), true, None);
473                self.output_events.push_front(Event::WEBSOCKET_DATA(Box::new(pong_frame)));
474            },
475            OPCODE::PONG => { todo!("Not implemented handle PONG") },
476            OPCODE::CLOSE => {
477                let data = frame.get_data();
478                let status_code = &data[0..2];
479                let res = bytes_to_u16(status_code);
480
481                let status_code = if res.is_ok() { res.unwrap() } else { WSStatus::EXPECTED_STATUS_CODE.bits() };
482
483                match self.connection_status {
484                    // Server wants to close the connection
485                    ConnectionStatus::OPEN => {
486                        let status_code = WSStatus::from_bits(status_code);
487
488                        let reason = &data[2..data.len()];
489                        let mut status_code = if status_code.is_some() { status_code.unwrap() } else { WSStatus::PROTOCOL_ERROR };
490                        
491                        let (error, _) = evaulate_status_code(status_code);
492                        if error { status_code = WSStatus::PROTOCOL_ERROR }
493
494                        // Enqueue close frame to response to the server
495                        self.output_events.clear();
496                        self.input_events.clear();
497                        let close_frame = ControlFrame::new(FLAG::FIN, OPCODE::CLOSE, Some(status_code.bits()), reason.to_vec(), true, None);
498                        self.output_events.push_front(Event::WEBSOCKET_DATA(Box::new(close_frame)));
499
500                        self.connection_status = ConnectionStatus::SERVER_WANTS_TO_CLOSE;
501                        
502                        // TODO: Create and on close cb to handle this situation, send the status code an the reason
503                    },
504                    ConnectionStatus::CLIENT_WANTS_TO_CLOSE => {
505                        // TODO: ?
506                        // Received a response to the client close handshake
507                        // Verify the status of close handshake
508                        self.connection_status = ConnectionStatus::CLOSE;
509                        self.stream.as_mut().unwrap().shutdown(Shutdown::Both)?;
510                        
511                        if let Some(callback) = self.callback {
512                            let reason = Reason::CLIENT_CLOSE(frame.get_status_code().unwrap());
513                            callback(self, &WSEvent::ON_CLOSE(reason), self.cb_data.clone());
514                        }
515                    },
516                    ConnectionStatus::SERVER_WANTS_TO_CLOSE => {}  // Unreachable  
517                    ConnectionStatus::CLOSE => {}                  // Unreachable
518                    ConnectionStatus::HANDSHAKE => {}              // Unreachable
519                    ConnectionStatus::NOT_INIT => {}               // Unreachable
520                    ConnectionStatus::START_INIT => {}             // Unreachable
521                }
522            },
523            _ => return Err(WebSocketError::InvalidFrame)
524        }
525
526        Ok(())
527    }
528}
529
530impl<'a, T> Drop for WSClient<'a, T> where T: Clone {
531    fn drop(&mut self) {
532        if self.connection_status != ConnectionStatus::NOT_INIT &&
533            self.connection_status != ConnectionStatus::HANDSHAKE &&
534            self.connection_status != ConnectionStatus::CLOSE &&
535            self.stream.is_some() {
536
537                let msg = "Done";
538                let status_code: u16 = 1000;
539                let close_frame = ControlFrame::new(FLAG::FIN, OPCODE::CLOSE, Some(status_code), msg.as_bytes().to_vec(), true, None);
540        
541                // Add close frame at the end of the queue.
542                // Clear both queues
543                self.output_events.clear();
544                self.input_events.clear();
545                self.output_events.push_back(Event::WEBSOCKET_DATA(Box::new(close_frame)));
546                self.connection_status = ConnectionStatus::CLIENT_WANTS_TO_CLOSE;
547        
548                let timeout = Instant::now();
549        
550                // Process a response for all the events and confirm that the connection was closed.
551                while self.connection_status != ConnectionStatus::CLOSE {
552                    if timeout.elapsed().as_secs() >= self.timeout.as_secs() { break } // Close handshake timeout.
553                    let result = self.event_loop();
554                    if result.is_ok() { continue }
555                    let err = result.err().unwrap();
556
557                    // TODO: Decide what to do if an error ocurred while consuming the rest of the messages
558                    match err {
559                        _ => { break }
560                    }
561        
562                    }
563                let _ = self.stream.as_mut().unwrap().shutdown(Shutdown::Both); // Ignore result from shutdown method.
564            }
565        }
566}
567
568unsafe impl<'a, T> Send for WSClient<'a, T> where T: Clone {}