udp_streamify/
lib.rs

1use std::{
2    collections::HashMap,
3    net::SocketAddr,
4    sync::{Arc, Weak},
5    time::{Duration, UNIX_EPOCH},
6};
7
8use tokio::{
9    net::UdpSocket,
10    sync::{
11        mpsc::{channel, Receiver, Sender, UnboundedReceiver},
12        Mutex,
13    },
14};
15use tokio::{
16    select,
17    sync::{
18        mpsc::{unbounded_channel, UnboundedSender},
19        RwLock,
20    },
21};
22
23type ClientMap = Arc<RwLock<HashMap<SocketAddr, UdpStream>>>;
24type WeakClientMap = Weak<RwLock<HashMap<SocketAddr, UdpStream>>>;
25pub struct UdpListener {
26    clients: ClientMap,
27    client_returner: UnboundedReceiver<UdpStream>,
28    no_more_clients: Receiver<String>,
29}
30#[derive(Clone)]
31pub struct UdpStream {
32    pub addr: SocketAddr,
33    to_main_sock: UnboundedSender<Vec<u8>>,
34    from_main_sock_listener: Arc<Mutex<UnboundedReceiver<Vec<u8>>>>,
35    from_main_sock_sender: UnboundedSender<Vec<u8>>,
36    last_msg_time: u64,
37    disconnect_notifier_main_tx: Sender<()>,
38    disconnect_notifier_client_notifier: Arc<Mutex<Receiver<()>>>,
39}
40impl UdpStream {
41    pub async fn read(&mut self) -> Option<Vec<u8>> {
42        let mut data = self.from_main_sock_listener.lock().await;
43        let mut disconnect = self.disconnect_notifier_client_notifier.lock().await;
44        select! {
45            d = data.recv()=>d,
46            _ = disconnect.recv() => None
47        }
48    }
49    pub async fn send(&mut self, data: &[u8]) -> Option<usize> {
50        let len = data.len();
51        self.to_main_sock.send(data.to_vec()).ok().map(|_| len)
52    }
53}
54impl UdpListener {
55    //{{{
56    pub fn new(sock: UdpSocket, buf_len: usize, client_timeout: usize) -> Self {
57        //{{{
58        let clients: ClientMap = Default::default();
59        let AcceptResult {
60            new_client,
61            no_more_accepts,
62        } = internal_accept_start(clients.clone(), sock, buf_len, client_timeout);
63        let l = UdpListener {
64            clients,
65            client_returner: new_client,
66            no_more_clients: no_more_accepts,
67        };
68        l //}}}
69    }
70    pub async fn active_connections(&self) -> usize {
71        self.clients.read().await.len()
72    }
73    pub async fn accept(&mut self) -> Result<UdpStream, String> {
74        select! {
75            client = self.client_returner.recv() =>{
76                client.ok_or(String::from("client channel has been closed"))
77            }
78            err = self.no_more_clients.recv() => Err(err.unwrap_or(String::from("an error occured but for some reason the channel too died on the way")))
79        }
80    }
81} //}}}
82struct AcceptResult {
83    new_client: UnboundedReceiver<UdpStream>,
84    no_more_accepts: Receiver<String>,
85}
86fn internal_accept_start(
87    clients: ClientMap,
88    sock: UdpSocket,
89    buf_len: usize,
90    client_timeout: usize,
91) -> AcceptResult {
92    //{{{
93    let clients = Arc::downgrade(&clients);
94    let (client_notifier_tx, client_notifier_rx) = unbounded_channel();
95    let (notx, norx) = channel(1);
96    let sock = Arc::new(sock);
97    // let c = clients.clone();
98    // tokio::spawn(async move {
99    //     let mut i = tokio::time::interval(Duration::from_secs(2));
100    //     loop {
101    //         i.tick().await;
102    //         if let Some(c) = c.upgrade() {
103    //             dbg!(c.read().await.len());
104    //         } else {
105    //             dbg!("socket dead");
106    //             break;
107    //         }
108    //     }
109    // });
110    reader(
111        sock.clone(),
112        buf_len,
113        clients.clone(),
114        client_notifier_tx,
115        notx.clone(),
116        client_timeout,
117    );
118    AcceptResult {
119        new_client: client_notifier_rx,
120        no_more_accepts: norx,
121    }
122} //}}}
123fn reader(
124    sock: Arc<UdpSocket>,
125    buf_len: usize,
126    clients: WeakClientMap,
127    client_notifier_tx: UnboundedSender<UdpStream>,
128    notx: Sender<String>,
129    client_timeout: usize,
130) {
131    //{{{
132    tokio::task::spawn(async move {
133        let mut buf = vec![0; buf_len];
134        while let Ok((n, addr)) = sock.recv_from(&mut buf).await {
135            if let Some(clients_strong) = clients.upgrade() {
136                let mut clients_map = clients_strong.write().await;
137                let main_sock = sock.clone();
138                if let Some(sock) = clients_map.get_mut(&addr) {
139                    sock.last_msg_time = UNIX_EPOCH.elapsed().unwrap().as_secs();
140                    if let Err(_) = sock.from_main_sock_sender.send(buf[..n].to_vec()) {
141                        clients_map.remove(&addr);
142                    }
143                } else {
144                    //new Sock
145                    let (from_self_tx, mut from_self_rx) = unbounded_channel();
146                    let (dis_tx, mut dis_rx) = channel(1);
147                    let (from_main_sock_tx, from_main_sock_rx) = unbounded_channel();
148                    let (dis_client_tx, dis_client_rx) = channel(1);
149                    let sock = UdpStream {
150                        addr: addr.clone(),
151                        to_main_sock: from_self_tx.clone(),
152                        last_msg_time: UNIX_EPOCH.elapsed().unwrap().as_secs(),
153                        disconnect_notifier_main_tx: dis_tx,
154                        from_main_sock_listener: Arc::new(Mutex::new(from_main_sock_rx)),
155                        from_main_sock_sender: from_main_sock_tx.clone(),
156                        disconnect_notifier_client_notifier: Arc::new(Mutex::new(dis_client_rx)),
157                    };
158                    let _ = client_notifier_tx.send(sock.clone());
159                    clients_map.insert(addr.clone(), sock);
160                    let _ = from_main_sock_tx.send(buf[..n].to_vec());
161                    drop(clients_map);
162                    let clients = clients.clone();
163                    tokio::spawn(async move {
164                        //self to remote {{{
165
166                        loop {
167                            select! {
168                                data = from_self_rx.recv() =>{
169                                    if let Some(data) = data{
170                                        if let Err(_) = main_sock.send_to(&data,addr).await{
171                                            cleanup(clients,&addr,dis_client_tx.clone()).await;
172                                            break;
173                                        };
174                                    }
175                                    if let Some(clients) = clients.upgrade(){
176                                        //increase sock time since it sock did something
177                                        if let Some(sock)=clients.write().await.get_mut(&addr){
178                                            sock.last_msg_time = UNIX_EPOCH.elapsed().unwrap().as_secs();
179                                        }
180                                    }
181
182                                }
183                                _ = dis_rx.recv() =>{
184                                    from_self_rx.close();
185                                    cleanup(clients,&addr,dis_client_tx).await;
186                                    break
187                                }
188                                _ =  tokio::time::sleep(Duration::from_secs(30)) ,if client_timeout>0 =>{
189                                    //no read for 30 seconds if no writes as well close the socket
190                                    if let Some(clients) = clients.upgrade(){
191                                        let  clients = clients.read().await;
192                                        if let Some(sock)=clients.get(&addr){
193                                            if  UNIX_EPOCH.elapsed().unwrap().as_secs() - sock.last_msg_time>=client_timeout as u64{
194                                                let _ = sock.disconnect_notifier_main_tx.send(()).await;
195                                                //dis_rx.recv above will handle the cleanup
196                                            }
197
198                                        }
199                                    }else{
200                                        break;
201                                    }
202                                }
203                            }
204                        }
205                    }); //}}}
206                }
207            } else {
208                let _ = notx.send("Socket Closed".into());
209                break;
210            }
211        }
212    });
213} //}}}
214async fn cleanup(clients: WeakClientMap, addr: &SocketAddr, notifier: Sender<()>) {
215    if let Some(clients) = clients.upgrade() {
216        if let Some(sock) = clients.write().await.remove(addr) {
217            let _ = sock.disconnect_notifier_main_tx.send(()).await;
218            let _ = notifier.send_timeout((), Duration::from_secs(5)).await;
219        }
220    }
221}