1use std::{
2 collections::HashMap,
3 net::SocketAddr,
4 sync::{Arc, Weak},
5 time::{Duration, UNIX_EPOCH},
6};
7
8use tokio::{
9 net::UdpSocket,
10 sync::{
11 mpsc::{channel, Receiver, Sender, UnboundedReceiver},
12 Mutex,
13 },
14};
15use tokio::{
16 select,
17 sync::{
18 mpsc::{unbounded_channel, UnboundedSender},
19 RwLock,
20 },
21};
22
23type ClientMap = Arc<RwLock<HashMap<SocketAddr, UdpStream>>>;
24type WeakClientMap = Weak<RwLock<HashMap<SocketAddr, UdpStream>>>;
25pub struct UdpListener {
26 clients: ClientMap,
27 client_returner: UnboundedReceiver<UdpStream>,
28 no_more_clients: Receiver<String>,
29}
30#[derive(Clone)]
31pub struct UdpStream {
32 pub addr: SocketAddr,
33 to_main_sock: UnboundedSender<Vec<u8>>,
34 from_main_sock_listener: Arc<Mutex<UnboundedReceiver<Vec<u8>>>>,
35 from_main_sock_sender: UnboundedSender<Vec<u8>>,
36 last_msg_time: u64,
37 disconnect_notifier_main_tx: Sender<()>,
38 disconnect_notifier_client_notifier: Arc<Mutex<Receiver<()>>>,
39}
40impl UdpStream {
41 pub async fn read(&mut self) -> Option<Vec<u8>> {
42 let mut data = self.from_main_sock_listener.lock().await;
43 let mut disconnect = self.disconnect_notifier_client_notifier.lock().await;
44 select! {
45 d = data.recv()=>d,
46 _ = disconnect.recv() => None
47 }
48 }
49 pub async fn send(&mut self, data: &[u8]) -> Option<usize> {
50 let len = data.len();
51 self.to_main_sock.send(data.to_vec()).ok().map(|_| len)
52 }
53}
54impl UdpListener {
55 pub fn new(sock: UdpSocket, buf_len: usize, client_timeout: usize) -> Self {
57 let clients: ClientMap = Default::default();
59 let AcceptResult {
60 new_client,
61 no_more_accepts,
62 } = internal_accept_start(clients.clone(), sock, buf_len, client_timeout);
63 let l = UdpListener {
64 clients,
65 client_returner: new_client,
66 no_more_clients: no_more_accepts,
67 };
68 l }
70 pub async fn active_connections(&self) -> usize {
71 self.clients.read().await.len()
72 }
73 pub async fn accept(&mut self) -> Result<UdpStream, String> {
74 select! {
75 client = self.client_returner.recv() =>{
76 client.ok_or(String::from("client channel has been closed"))
77 }
78 err = self.no_more_clients.recv() => Err(err.unwrap_or(String::from("an error occured but for some reason the channel too died on the way")))
79 }
80 }
81} struct AcceptResult {
83 new_client: UnboundedReceiver<UdpStream>,
84 no_more_accepts: Receiver<String>,
85}
86fn internal_accept_start(
87 clients: ClientMap,
88 sock: UdpSocket,
89 buf_len: usize,
90 client_timeout: usize,
91) -> AcceptResult {
92 let clients = Arc::downgrade(&clients);
94 let (client_notifier_tx, client_notifier_rx) = unbounded_channel();
95 let (notx, norx) = channel(1);
96 let sock = Arc::new(sock);
97 reader(
111 sock.clone(),
112 buf_len,
113 clients.clone(),
114 client_notifier_tx,
115 notx.clone(),
116 client_timeout,
117 );
118 AcceptResult {
119 new_client: client_notifier_rx,
120 no_more_accepts: norx,
121 }
122} fn reader(
124 sock: Arc<UdpSocket>,
125 buf_len: usize,
126 clients: WeakClientMap,
127 client_notifier_tx: UnboundedSender<UdpStream>,
128 notx: Sender<String>,
129 client_timeout: usize,
130) {
131 tokio::task::spawn(async move {
133 let mut buf = vec![0; buf_len];
134 while let Ok((n, addr)) = sock.recv_from(&mut buf).await {
135 if let Some(clients_strong) = clients.upgrade() {
136 let mut clients_map = clients_strong.write().await;
137 let main_sock = sock.clone();
138 if let Some(sock) = clients_map.get_mut(&addr) {
139 sock.last_msg_time = UNIX_EPOCH.elapsed().unwrap().as_secs();
140 if let Err(_) = sock.from_main_sock_sender.send(buf[..n].to_vec()) {
141 clients_map.remove(&addr);
142 }
143 } else {
144 let (from_self_tx, mut from_self_rx) = unbounded_channel();
146 let (dis_tx, mut dis_rx) = channel(1);
147 let (from_main_sock_tx, from_main_sock_rx) = unbounded_channel();
148 let (dis_client_tx, dis_client_rx) = channel(1);
149 let sock = UdpStream {
150 addr: addr.clone(),
151 to_main_sock: from_self_tx.clone(),
152 last_msg_time: UNIX_EPOCH.elapsed().unwrap().as_secs(),
153 disconnect_notifier_main_tx: dis_tx,
154 from_main_sock_listener: Arc::new(Mutex::new(from_main_sock_rx)),
155 from_main_sock_sender: from_main_sock_tx.clone(),
156 disconnect_notifier_client_notifier: Arc::new(Mutex::new(dis_client_rx)),
157 };
158 let _ = client_notifier_tx.send(sock.clone());
159 clients_map.insert(addr.clone(), sock);
160 let _ = from_main_sock_tx.send(buf[..n].to_vec());
161 drop(clients_map);
162 let clients = clients.clone();
163 tokio::spawn(async move {
164 loop {
167 select! {
168 data = from_self_rx.recv() =>{
169 if let Some(data) = data{
170 if let Err(_) = main_sock.send_to(&data,addr).await{
171 cleanup(clients,&addr,dis_client_tx.clone()).await;
172 break;
173 };
174 }
175 if let Some(clients) = clients.upgrade(){
176 if let Some(sock)=clients.write().await.get_mut(&addr){
178 sock.last_msg_time = UNIX_EPOCH.elapsed().unwrap().as_secs();
179 }
180 }
181
182 }
183 _ = dis_rx.recv() =>{
184 from_self_rx.close();
185 cleanup(clients,&addr,dis_client_tx).await;
186 break
187 }
188 _ = tokio::time::sleep(Duration::from_secs(30)) ,if client_timeout>0 =>{
189 if let Some(clients) = clients.upgrade(){
191 let clients = clients.read().await;
192 if let Some(sock)=clients.get(&addr){
193 if UNIX_EPOCH.elapsed().unwrap().as_secs() - sock.last_msg_time>=client_timeout as u64{
194 let _ = sock.disconnect_notifier_main_tx.send(()).await;
195 }
197
198 }
199 }else{
200 break;
201 }
202 }
203 }
204 }
205 }); }
207 } else {
208 let _ = notx.send("Socket Closed".into());
209 break;
210 }
211 }
212 });
213} async fn cleanup(clients: WeakClientMap, addr: &SocketAddr, notifier: Sender<()>) {
215 if let Some(clients) = clients.upgrade() {
216 if let Some(sock) = clients.write().await.remove(addr) {
217 let _ = sock.disconnect_notifier_main_tx.send(()).await;
218 let _ = notifier.send_timeout((), Duration::from_secs(5)).await;
219 }
220 }
221}