1use crate::error::Result;
2use bytes::Bytes;
3use std::collections::{HashMap, HashSet};
4use std::sync::atomic::{AtomicU32, Ordering};
5use tokio::net::TcpStream;
6use tokio::sync::{mpsc, oneshot, Mutex};
7use tokio_rustls::client::TlsStream as ClientTlsStream;
8use tokio_rustls::server::TlsStream as ServerTlsStream;
9use tokio_tungstenite::WebSocketStream;
10
11pub type ServerWsStream = WebSocketStream<ServerTlsStream<TcpStream>>;
12pub type ClientWsStream = WebSocketStream<ClientTlsStream<TcpStream>>;
13
14pub async fn accept_ws(tls_stream: ServerTlsStream<TcpStream>) -> Result<ServerWsStream> {
16 let ws = tokio_tungstenite::accept_async(tls_stream).await?;
17 Ok(ws)
18}
19
20pub async fn connect_ws(
22 tls_stream: ClientTlsStream<TcpStream>,
23 url: &str,
24) -> Result<ClientWsStream> {
25 let (ws, _response) = tokio_tungstenite::client_async(url, tls_stream).await?;
26 Ok(ws)
27}
28
29struct ChannelState {
31 channels: HashMap<u32, mpsc::Sender<Bytes>>,
32 ready_signals: HashMap<u32, oneshot::Sender<()>>,
33 tunnel_channels: HashMap<u32, HashSet<u32>>,
34}
35
36pub struct ChannelMap {
41 state: Mutex<ChannelState>,
42 next_id: AtomicU32,
43}
44
45impl ChannelMap {
46 pub fn new(start_id: u32) -> Self {
48 Self {
49 state: Mutex::new(ChannelState {
50 channels: HashMap::new(),
51 ready_signals: HashMap::new(),
52 tunnel_channels: HashMap::new(),
53 }),
54 next_id: AtomicU32::new(start_id),
55 }
56 }
57
58 pub fn alloc_id(&self) -> u32 {
60 self.next_id.fetch_add(2, Ordering::Relaxed)
61 }
62
63 pub async fn has(&self, channel_id: u32) -> bool {
65 self.state.lock().await.channels.contains_key(&channel_id)
66 }
67
68 pub async fn insert(&self, channel_id: u32, sender: mpsc::Sender<Bytes>) {
70 self.state.lock().await.channels.insert(channel_id, sender);
71 }
72
73 pub async fn insert_with_tunnel(
75 &self,
76 channel_id: u32,
77 tunnel_id: u32,
78 sender: mpsc::Sender<Bytes>,
79 ) {
80 let mut s = self.state.lock().await;
81 s.channels.insert(channel_id, sender);
82 s.tunnel_channels
83 .entry(tunnel_id)
84 .or_default()
85 .insert(channel_id);
86 }
87
88 pub async fn send(&self, channel_id: u32, data: Bytes) -> bool {
92 let tx = {
93 let s = self.state.lock().await;
94 s.channels.get(&channel_id).cloned()
95 };
96 if let Some(tx) = tx {
97 match tx.try_send(data) {
98 Ok(()) => true,
99 Err(mpsc::error::TrySendError::Full(_)) => {
100 self.remove(channel_id).await;
103 false
104 }
105 Err(mpsc::error::TrySendError::Closed(_)) => false,
106 }
107 } else {
108 false
109 }
110 }
111
112 pub async fn remove(&self, channel_id: u32) {
114 let mut s = self.state.lock().await;
115 s.channels.remove(&channel_id);
116 s.ready_signals.remove(&channel_id);
117 for set in s.tunnel_channels.values_mut() {
118 set.remove(&channel_id);
119 }
120 }
121
122 pub async fn close_all(&self) {
124 let mut s = self.state.lock().await;
125 s.channels.clear();
126 s.ready_signals.clear();
127 s.tunnel_channels.clear();
128 }
129
130 pub async fn close_tunnel(&self, tunnel_id: u32) -> Vec<u32> {
132 let mut s = self.state.lock().await;
133 let channel_ids: Vec<u32> = s
134 .tunnel_channels
135 .remove(&tunnel_id)
136 .unwrap_or_default()
137 .into_iter()
138 .collect();
139 for &id in &channel_ids {
140 s.channels.remove(&id);
141 s.ready_signals.remove(&id);
142 }
143 channel_ids
144 }
145
146 pub async fn wait_ready(&self, channel_id: u32) -> oneshot::Receiver<()> {
148 let (tx, rx) = oneshot::channel();
149 self.state
150 .lock()
151 .await
152 .ready_signals
153 .insert(channel_id, tx);
154 rx
155 }
156
157 pub async fn signal_ready(&self, channel_id: u32) -> bool {
159 if let Some(tx) = self.state.lock().await.ready_signals.remove(&channel_id) {
160 tx.send(()).is_ok()
161 } else {
162 false
163 }
164 }
165}