1use std::net::{TcpStream, Shutdown};
2use std::io::{Write, ErrorKind};
3use std::collections::{HashMap, VecDeque};
4use std::time::{Instant, Duration};
5use std::format;
6use core::marker::Send;
7use crate::core::net::read_into_buffer;
8use crate::result::WebSocketError;
9use crate::ws_basic::header::{OPCODE, FLAG};
10use crate::ws_basic::frame::{DataFrame, ControlFrame, Frame, FrameKind, bytes_to_frame};
11use crate::ws_basic::status_code::{WSStatus, evaulate_status_code};
12use crate::core::traits::{Serialize, Parse};
13use crate::core::binary::bytes_to_u16;
14use super::super::result::WebSocketResult;
15use crate::http::request::{Request, Method};
16use crate::http::response::Response;
17use crate::ws_basic::key::{gen_key, verify_key};
18use crate::extension::Extension;
19
20const DEFAULT_MESSAGE_SIZE: u64 = 1024;
21const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
22const SWITCHING_PROTOCOLS: u16 = 101;
23
24#[allow(non_camel_case_types)]
25#[derive(PartialEq)]
26#[repr(C)]
27enum ConnectionStatus {
28 NOT_INIT,
29 START_INIT,
30 HANDSHAKE,
31 OPEN,
32 CLIENT_WANTS_TO_CLOSE,
33 SERVER_WANTS_TO_CLOSE,
34 CLOSE
35}
36
37#[allow(non_camel_case_types)]
38#[repr(C)]
39enum Event {
40 WEBSOCKET_DATA(Box<dyn Frame>),
41 HTTP_RESPONSE(Response),
42 HTTP_REQUEST(Request),
43 NO_DATA,
44}
45
46fn is_websocket_data(event: &Event) -> bool {
47 match event {
48 Event::WEBSOCKET_DATA(_) => true,
49 _ => false
50 }
51}
52
53#[repr(C)]
54enum EventIO {
55 INPUT,
56 OUTPUT
57}
58
59#[derive(Clone)]
60pub struct Config<'a, T: Clone> {
61 pub callback: Option<fn(&mut WSClient<'a, T>, &WSEvent, Option<T>)>,
62 pub data: Option<T>,
63 pub protocols: Option<&'a[&'a str]>,
64}
65
66#[allow(non_camel_case_types)]
67#[repr(C)]
68pub enum Reason {
69 SERVER_CLOSE(u16),
70 CLIENT_CLOSE(u16)
71}
72
73#[allow(non_camel_case_types)]
74pub enum WSEvent {
75 ON_CONNECT(Option<String>),
76 ON_TEXT(String),
77 ON_CLOSE(Reason),
78}
79
80#[allow(dead_code)]
81#[repr(C)]
82pub struct WSClient<'a, T: Clone> {
83 host: &'a str,
84 port: u16,
85 path: &'a str,
86 connection_status: ConnectionStatus,
87 message_size: u64,
88 timeout: Duration,
89 stream: Option<TcpStream>,
90 recv_storage: Vec<u8>, recv_data: Vec<u8>, cb_data: Option<T>,
93 callback: Option<fn(&mut Self, &WSEvent, Option<T>)>,
94 protocol: Option<String>,
95 acceptable_protocols: Option<&'a [&'a str]>,
96 extensions: Vec<Extension>,
97 input_events: VecDeque<Event>,
98 output_events: VecDeque<Event>,
99 websocket_key: String,
100 close_iters: usize, } impl<'a, T> WSClient<'a, T> where T: Clone {
105 pub fn new() -> Self {
106 WSClient {
107 host: "",
108 port: 0,
109 path: "",
110 connection_status: ConnectionStatus::NOT_INIT,
111 message_size: DEFAULT_MESSAGE_SIZE,
112 stream: None,
113 recv_storage: Vec::new(),
114 recv_data: Vec::new(),
115 timeout: DEFAULT_TIMEOUT,
116 cb_data: None,
117 callback: None,
118 protocol: None,
119 acceptable_protocols: None,
120 extensions: Vec::new(),
121 close_iters: 0,
122 input_events: VecDeque::new(),
123 output_events: VecDeque::new(),
124 websocket_key: String::new(),
125 }
126 }
127
128 pub fn init(&mut self, host: &'a str, port: u16, path: &'a str, config: Option<Config<'a, T>>) {
129 self.host = host;
130 self.port = port;
131 self.path = path;
132
133 if let Some(conf) = config {
134 self.cb_data = conf.data;
135 self.callback = conf.callback;
136 self.acceptable_protocols = conf.protocols;
137 }
138
139 self.connection_status = ConnectionStatus::START_INIT;
140 }
141
142 fn start_init(&mut self) -> WebSocketResult<()> {
143 let socket = TcpStream::connect(format!("{}:{}", self.host, self.port.to_string()));
144 if socket.is_err() { return Err(WebSocketError::UnreachableHost)}
145 let sec_websocket_key = gen_key();
146
147 let mut headers: HashMap<String, String> = HashMap::from([
148 (String::from("Upgrade"), String::from("websocket")),
149 (String::from("Connection"), String::from("Upgrade")),
150 (String::from("Sec-WebSocket-Key"), sec_websocket_key.clone()),
151 (String::from("Sec-WebSocket-Version"), String::from("13")),
152 (String::from("User-agent"), String::from("rust-websocket-std")),
153 ]);
154
155 let mut protocols_value = String::new();
157 if let Some(protocols) = self.acceptable_protocols {
158 for p in protocols {
159 protocols_value.push_str(p);
160 protocols_value.push_str(", ");
161 }
162 headers.insert(String::from("Sec-WebSocket-Protocol"), (&(protocols_value)[0..protocols_value.len()-2]).to_string());
163 }
164
165 let request = Request::new(Method::GET, self.path, "HTTP/1.1", Some(headers));
166
167 self.output_events.push_front(Event::HTTP_REQUEST(request)); self.websocket_key = sec_websocket_key;
169 let socket = socket.unwrap();
170 socket.set_nonblocking(true)?;
171 self.stream = Some(socket);
172 self.connection_status = ConnectionStatus::HANDSHAKE;
173
174 Ok(())
175 }
176
177 pub fn protocol(&self) -> Option<&str> {
179 if self.protocol.is_none() { return None };
180 return Some(self.protocol.as_ref().unwrap().as_str());
181 }
182
183 pub fn set_message_size(&mut self, size: u64) {
184 self.message_size = size;
185 }
186
187 pub fn set_timeout(&mut self, timeout: Duration) {
188 self.timeout = timeout;
189 }
190
191 pub fn send(&mut self, payload: &str) {
192 if self.connection_status == ConnectionStatus::CLOSE { return }
194 let mut data_sent = 0;
195 let mut _i: usize = 0;
196
197 while data_sent < payload.len() {
198 _i = data_sent + self.message_size as usize;
199 if _i >= payload.len() { _i = payload.len() };
200 let payload_chunk = payload[data_sent.._i].as_bytes();
201 let flag = if data_sent + self.message_size as usize >= payload.len() { FLAG::FIN } else { FLAG::NOFLAG };
202 let code = if data_sent == 0 { OPCODE::TEXT } else { OPCODE::CONTINUATION };
203 let frame = DataFrame::new(flag, code, payload_chunk.to_vec(), true, None);
204 self.output_events.push_back(Event::WEBSOCKET_DATA(Box::new(frame)));
205 data_sent += self.message_size as usize;
206 }
207 }
208
209 pub fn event_loop(&mut self) -> WebSocketResult<()> {
210 if self.connection_status == ConnectionStatus::NOT_INIT { return Ok(()) }
211 if self.connection_status == ConnectionStatus::START_INIT { return self.start_init()}
212 if self.connection_status == ConnectionStatus::CLOSE { return Err(WebSocketError::ConnectionClose) }
213
214 let event = self.read_bytes_from_socket()?;
215 self.insert_input_event(event);
216
217 let in_event = self.input_events.pop_front();
218 let out_event = self.pop_output_event();
221
222 if in_event.is_some() { self.handle_event(in_event.unwrap(), EventIO::INPUT)? };
223 if out_event.is_some() { self.handle_event(out_event.unwrap(), EventIO::OUTPUT)? };
224
225 return Ok(())
226 }
227
228 fn pop_output_event(&mut self) -> Option<Event> {
229 let mut out_event = self.output_events.pop_front();
230 if out_event.is_some() &&
231 self.connection_status == ConnectionStatus::HANDSHAKE &&
232 is_websocket_data(out_event.as_ref().unwrap())
233 {
234 self.output_events.push_front(out_event.unwrap());
235 out_event = None;
236 }
237 return out_event;
238 }
239
240 fn handle_recv_bytes_frame(&mut self) -> WebSocketResult<Event> {
241 let frame = bytes_to_frame(&self.recv_storage)?;
242 if frame.is_none() { return Ok(Event::NO_DATA) };
243
244 let (frame, offset) = frame.unwrap();
245
246 let event = Event::WEBSOCKET_DATA(frame);
247 self.recv_storage.drain(0..offset);
248
249 Ok(event)
250 }
251
252 fn handle_recv_frame(&mut self, frame: Box<dyn Frame>) -> WebSocketResult<()> {
253 match frame.kind() {
254 FrameKind::Data => {
255 if frame.get_header().get_flag() != FLAG::FIN {
256 self.recv_data.extend_from_slice(frame.get_data());
257 }
258
259 if self.callback.is_some() {
260 let callback = self.callback.unwrap();
261
262 let res = String::from_utf8(frame.get_data().to_vec());
263 if res.is_err() { return Err(WebSocketError::DecodingFromUTF8) }
264
265 let msg = res.unwrap();
266
267 if self.recv_data.is_empty() {
269 callback(self, &WSEvent::ON_TEXT(msg), self.cb_data.clone());
270
271 } else {
273 let previous_data = self.recv_data.clone();
274 let res = String::from_utf8(previous_data);
275 if res.is_err() { return Err(WebSocketError::DecodingFromUTF8); }
276
277 let mut completed_msg = res.unwrap();
278 completed_msg.push_str(msg.as_str());
279
280 callback(self, &WSEvent::ON_TEXT(completed_msg), self.cb_data.clone());
282
283 self.recv_data = Vec::new();
287
288 }
292 }
293 return Ok(());
294 },
295 FrameKind::Control => { return self.handle_control_frame(frame.as_any().downcast_ref::<ControlFrame>().unwrap()); },
296 FrameKind::NotDefine => return Err(WebSocketError::InvalidFrame)
297 };
298 }
299
300 fn handle_recv_bytes_http_response(&mut self) -> WebSocketResult<Event> {
301 let response = Response::parse(&self.recv_storage);
302 if response.is_err() { return Ok(Event::NO_DATA); } let response = response.unwrap();
305 let event = Event::HTTP_RESPONSE(response);
306 self.recv_storage.clear();
308
309 Ok(event)
310 }
311
312 fn handle_recv_http_response(&mut self, response: Response) -> WebSocketResult<()> {
313 match self.connection_status {
314 ConnectionStatus::HANDSHAKE => {
315 let sec_websocket_accept = response.header("Sec-WebSocket-Accept");
316
317 if sec_websocket_accept.is_none() { return Err(WebSocketError::HandShake) }
318 let sec_websocket_accept = sec_websocket_accept.unwrap();
319
320 let accepted = verify_key(&self.websocket_key, &sec_websocket_accept);
322 if !accepted {
323 return Err(WebSocketError::HandShake);
324 }
325
326 if response.get_status_code() == 0 ||
327 response.get_status_code() != SWITCHING_PROTOCOLS {
328 return Err(WebSocketError::HandShake)
329 }
330
331 self.protocol = response.header("Sec-WebSocket-Protocol");
332
333 let mut response_msg = None;
334
335 if let Some(body) = response.body() {
336 response_msg = Some(body.clone());
337 }
338
339 self.connection_status = ConnectionStatus::OPEN;
340
341 if let Some(callback) = self.callback {
342 callback(self, &WSEvent::ON_CONNECT(response_msg), self.cb_data.clone());
343 }
344 }
345 _ => {} }
347
348 Ok(())
349 }
350
351 fn handle_send_frame(&mut self, frame: Box<dyn Frame>) -> WebSocketResult<()> {
352 let sent = self.try_write(frame.serialize().as_slice())?;
353 let kind = frame.kind();
354 let mut status = None;
355
356 if frame.kind() == FrameKind::Control {
357 status = frame.as_any().downcast_ref::<ControlFrame>().unwrap().get_status_code();
358 }
359
360 if !sent { self.output_events.push_front(Event::WEBSOCKET_DATA(frame)) };
361
362 if sent && kind == FrameKind::Control && self.connection_status == ConnectionStatus::SERVER_WANTS_TO_CLOSE {
363 self.connection_status = ConnectionStatus::CLOSE;
364 self.stream.as_mut().unwrap().shutdown(Shutdown::Both)?;
365 self.stream = None;
366
367 if let Some(callback) = self.callback {
368 let reason = Reason::SERVER_CLOSE(status.unwrap_or(0));
369 callback(self, &WSEvent::ON_CLOSE(reason), self.cb_data.clone());
370 }
371 }
372
373 Ok(())
374 }
375
376 fn handle_send_http_request(&mut self, request: Request) -> WebSocketResult<()> {
377 let sent = self.try_write(request.serialize().as_slice())?;
378 if !sent {
379 self.output_events.push_front(Event::HTTP_REQUEST(request))
380 }
381 Ok(())
382 }
383
384 fn handle_event(&mut self, event: Event, kind: EventIO) -> WebSocketResult<()> {
385
386 match kind {
387 EventIO::INPUT => {
388 match event {
389 Event::WEBSOCKET_DATA(frame) => self.handle_recv_frame(frame)?,
390 Event::HTTP_RESPONSE(response) => self.handle_recv_http_response(response)?,
391 Event::HTTP_REQUEST(_) => {} Event::NO_DATA => {} }
394 },
395
396 EventIO::OUTPUT => {
397 match event {
398 Event::WEBSOCKET_DATA(frame) => self.handle_send_frame(frame)?,
399 Event::HTTP_REQUEST(request) => self.handle_send_http_request(request)?,
400 Event::HTTP_RESPONSE(_) => {} Event::NO_DATA => {} }
403 }
404 }
405
406 return Ok(());
407 }
408
409 fn read_bytes_from_socket(&mut self) -> WebSocketResult<Event> {
410 let mut buffer = [0u8; 1024];
412 let reader = self.stream.as_mut().unwrap();
413 let bytes_readed = read_into_buffer(reader, &mut buffer)?;
414
415 if bytes_readed > 0 {
416 self.recv_storage.extend_from_slice(&buffer[0..bytes_readed]);
417 }
418
419 let mut event = Event::NO_DATA;
421 if self.recv_storage.len() > 0 {
422 match self.connection_status {
423 ConnectionStatus::HANDSHAKE => event = self.handle_recv_bytes_http_response()?,
424 ConnectionStatus::OPEN | ConnectionStatus::CLIENT_WANTS_TO_CLOSE | ConnectionStatus::SERVER_WANTS_TO_CLOSE => {
425 event = self.handle_recv_bytes_frame()?;
426 },
427
428 ConnectionStatus::CLOSE => {}, ConnectionStatus::NOT_INIT => {}, ConnectionStatus::START_INIT => {} };
432 }
433 Ok(event)
434 }
435
436 fn insert_input_event(&mut self, event: Event) {
437 match &event {
438 Event::WEBSOCKET_DATA(frame) => {
439 if frame.kind() == FrameKind::Control {
440 self.input_events.push_front(event);
441 } else {
442 self.input_events.push_back(event)
443 }
444 },
445
446 Event::HTTP_RESPONSE(_) => self.input_events.push_back(event),
447 Event::HTTP_REQUEST(_) => {} Event::NO_DATA => {}
449 }
450 }
451
452 fn try_write(&mut self, bytes: &[u8]) -> WebSocketResult<bool> {
453 let res = self.stream.as_mut().unwrap().write_all(bytes);
454 if res.is_err(){
455 let error = res.err().unwrap();
456
457 if error.kind() == ErrorKind::WouldBlock {
459 return Ok(false);
460
461 } else {
462 return Err(WebSocketError::IOError);
463 }
464 }
465 Ok(true)
466 }
467
468 fn handle_control_frame(&mut self, frame: &ControlFrame) -> WebSocketResult<()> {
469 match frame.get_header().get_opcode() {
470 OPCODE::PING=> {
471 let data = frame.get_data();
472 let pong_frame = ControlFrame::new(FLAG::FIN, OPCODE::PONG, None, data.to_vec(), true, None);
473 self.output_events.push_front(Event::WEBSOCKET_DATA(Box::new(pong_frame)));
474 },
475 OPCODE::PONG => { todo!("Not implemented handle PONG") },
476 OPCODE::CLOSE => {
477 let data = frame.get_data();
478 let status_code = &data[0..2];
479 let res = bytes_to_u16(status_code);
480
481 let status_code = if res.is_ok() { res.unwrap() } else { WSStatus::EXPECTED_STATUS_CODE.bits() };
482
483 match self.connection_status {
484 ConnectionStatus::OPEN => {
486 let status_code = WSStatus::from_bits(status_code);
487
488 let reason = &data[2..data.len()];
489 let mut status_code = if status_code.is_some() { status_code.unwrap() } else { WSStatus::PROTOCOL_ERROR };
490
491 let (error, _) = evaulate_status_code(status_code);
492 if error { status_code = WSStatus::PROTOCOL_ERROR }
493
494 self.output_events.clear();
496 self.input_events.clear();
497 let close_frame = ControlFrame::new(FLAG::FIN, OPCODE::CLOSE, Some(status_code.bits()), reason.to_vec(), true, None);
498 self.output_events.push_front(Event::WEBSOCKET_DATA(Box::new(close_frame)));
499
500 self.connection_status = ConnectionStatus::SERVER_WANTS_TO_CLOSE;
501
502 },
504 ConnectionStatus::CLIENT_WANTS_TO_CLOSE => {
505 self.connection_status = ConnectionStatus::CLOSE;
509 self.stream.as_mut().unwrap().shutdown(Shutdown::Both)?;
510
511 if let Some(callback) = self.callback {
512 let reason = Reason::CLIENT_CLOSE(frame.get_status_code().unwrap());
513 callback(self, &WSEvent::ON_CLOSE(reason), self.cb_data.clone());
514 }
515 },
516 ConnectionStatus::SERVER_WANTS_TO_CLOSE => {} ConnectionStatus::CLOSE => {} ConnectionStatus::HANDSHAKE => {} ConnectionStatus::NOT_INIT => {} ConnectionStatus::START_INIT => {} }
522 },
523 _ => return Err(WebSocketError::InvalidFrame)
524 }
525
526 Ok(())
527 }
528}
529
530impl<'a, T> Drop for WSClient<'a, T> where T: Clone {
531 fn drop(&mut self) {
532 if self.connection_status != ConnectionStatus::NOT_INIT &&
533 self.connection_status != ConnectionStatus::HANDSHAKE &&
534 self.connection_status != ConnectionStatus::CLOSE &&
535 self.stream.is_some() {
536
537 let msg = "Done";
538 let status_code: u16 = 1000;
539 let close_frame = ControlFrame::new(FLAG::FIN, OPCODE::CLOSE, Some(status_code), msg.as_bytes().to_vec(), true, None);
540
541 self.output_events.clear();
544 self.input_events.clear();
545 self.output_events.push_back(Event::WEBSOCKET_DATA(Box::new(close_frame)));
546 self.connection_status = ConnectionStatus::CLIENT_WANTS_TO_CLOSE;
547
548 let timeout = Instant::now();
549
550 while self.connection_status != ConnectionStatus::CLOSE {
552 if timeout.elapsed().as_secs() >= self.timeout.as_secs() { break } let result = self.event_loop();
554 if result.is_ok() { continue }
555 let err = result.err().unwrap();
556
557 match err {
559 _ => { break }
560 }
561
562 }
563 let _ = self.stream.as_mut().unwrap().shutdown(Shutdown::Both); }
565 }
566}
567
568unsafe impl<'a, T> Send for WSClient<'a, T> where T: Clone {}