wsclient/
ws_client.rs

1use super::Handler;
2// use futures_util::FutureExt;
3use futures_util::{future, pin_mut, StreamExt};
4use std::sync::{Arc, Mutex};
5use tokio::io::AsyncReadExt;
6use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
7use tracing::{error, info, warn};
8
9#[derive(Clone, Default)]
10struct ChatHandler {}
11impl Handler for ChatHandler {
12    fn process(&self, data: String) {
13        println!("{:?}", data);
14    }
15}
16
17#[derive(Clone, Debug)]
18pub struct WsClient {
19    pub tx: Arc<Mutex<Option<futures_channel::mpsc::UnboundedSender<Message>>>>,
20    pub runtime: Arc<Mutex<tokio::runtime::Runtime>>,
21}
22
23impl WsClient {
24    pub fn new() -> Self {
25        Self {
26            tx: Arc::new(Mutex::new(None)),
27            runtime: Arc::new(Mutex::new(
28                tokio::runtime::Builder::new_multi_thread()
29                    .worker_threads(4)
30                    .thread_name("wsclient")
31                    .enable_all()
32                    .build()
33                    .unwrap(),
34            )),
35        }
36    }
37
38    pub fn send(&self, data: Vec<u8>) {
39        if let Some(tx) = self.tx.lock().unwrap().clone() {
40            // tx.unbounded_send(Message::Text(data)).unwrap();
41            let _ = tx.unbounded_send(Message::Binary(data));
42        } else {
43            warn!("tx is none, not valid");
44        }
45    }
46
47    pub fn send_ignore_error(&self, data: String) {
48        if let Some(tx) = self.tx.lock().unwrap().clone() {
49            let _ = tx.unbounded_send(Message::Text(data));
50        } else {
51            warn!("tx is none, not valid");
52        }
53    }
54
55    pub async fn start(&mut self, url: String, handler: Box<dyn Handler>) -> &mut Self {
56        let url = url::Url::parse(&url).unwrap();
57
58        // info!("start: {}", url);
59        let (stdin_tx, stdin_rx) = futures_channel::mpsc::unbounded();
60        self.tx = Arc::new(Mutex::new(Some(stdin_tx.clone())));
61
62        let func = |url: url::Url,
63                    stdin_rx: futures_channel::mpsc::UnboundedReceiver<Message>,
64                    handler: Box<dyn Handler>| async move {
65            // let (ws_stream, _resp) = connect_async(url).await.expect("Failed to connect");
66            let tmp_conn = connect_async(url).await;
67            if tmp_conn.is_err() {
68                println!("connect with error");
69                for (key, value) in std::env::vars() {
70                    info!("{key}: {value}");
71                }
72                error!("error: {}", tmp_conn.as_ref().err().unwrap());
73                handler.process("connect_error".to_string());
74            }
75            let (ws_stream, _resp) = tmp_conn.unwrap();
76
77            info!("websocket handshake has been successfully completed");
78            let (write, read) = ws_stream.split();
79            let stdin_to_ws = stdin_rx.map(Ok).forward(write);
80            let ws_to_stdout = {
81                read.for_each(|message| async {
82                    match message {
83                        Ok(message) => {
84                            let data = message.into_data();
85                            let data_string = String::from_utf8_lossy(&data).to_string();
86                            if !data_string.eq("ping") {
87                                handler.process(data_string);
88                            }
89                        }
90                        Err(e) => {
91                            println!("ws read error: {}", e);
92                            handler.process("connect_error".to_string());
93                            // std::process::exit(1);
94                        }
95                    }
96                })
97            };
98            pin_mut!(stdin_to_ws, ws_to_stdout);
99            future::select(stdin_to_ws, ws_to_stdout).await;
100        };
101
102        // info!("tokio::spawn: {}", url);
103        // tokio::spawn(func(url, stdin_rx, handler));
104        self.runtime
105            .lock()
106            .unwrap()
107            .spawn(func(url, stdin_rx, handler));
108        // info!("tokio::spawn finish");
109        self
110    }
111
112    pub async fn chat(&mut self, url: String) {
113        self.start(url, Box::new(ChatHandler::default())).await;
114
115        let mut stdin = tokio::io::stdin();
116        loop {
117            let mut buf = vec![0; 65536];
118            let n = match stdin.read(&mut buf).await {
119                Err(_) | Ok(0) => break,
120                Ok(n) => n,
121            };
122            buf.truncate(n);
123
124            if let Some(tx) = self.tx.lock().unwrap().clone() {
125                tx.unbounded_send(Message::binary(buf)).unwrap();
126            }
127        }
128    }
129}