use crate::error;
use net2::{UdpBuilder, UdpSocketExt};
use std::collections::HashMap;
use std::convert::TryFrom;
use std::error::Error;
use std::future::Future;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::Arc;
use tokio::net::udp::{RecvHalf, SendHalf};
use tokio::net::UdpSocket;
use tokio::sync::Mutex;
use tokio::time::{delay_for, Duration};
#[cfg(not(target_os = "windows"))]
use net2::unix::UnixUdpBuilderExt;
pub const BUFF_MAX_SIZE: usize = 4096;
pub struct UdpContext<T: Send> {
pub id: usize,
recv: Arc<Mutex<RecvHalf>>,
pub send: Arc<Mutex<SendHalf>>,
pub peers: Arc<Mutex<HashMap<SocketAddr, Arc<Peer<T>>>>>,
}
pub struct UdpServer<I, R, T,S>
where
I: Fn(Arc<S>,Arc<Peer<T>>, Vec<u8>) -> R + Send + Sync + 'static,
R: Future<Output = Result<(), Box<dyn Error>>>,
T: Send + 'static,
S: Sync +Send+'static
{
inner:Arc<S>,
udp_contexts: Vec<UdpContext<T>>,
input: Option<Arc<I>>,
error_input: Option<Arc<Mutex<dyn Fn(Option<Arc<Peer<T>>>, Box<dyn Error>)->bool + Send>>>,
}
#[derive(Debug)]
pub struct TokenStore<T:Send>(pub Option<T>);
impl<T:Send> TokenStore<T>{
pub fn have(&self)->bool{
match &self.0 {
None=>false,
Some(_)=>true
}
}
pub fn get(&mut self)->Option<&mut T> {
match self.0 {
None=>None,
Some(ref mut v)=>Some(v)
}
}
pub fn set(&mut self,v:Option<T>) {
self.0 = v;
}
}
#[derive(Debug)]
pub struct UdpSend(pub Arc<Mutex<SendHalf>>,pub SocketAddr);
impl UdpSend{
pub async fn send(&self,buf: &[u8])->std::io::Result<usize> {
self.0.lock().await.send_to(buf,&self.1).await
}
}
#[derive(Debug)]
pub struct Peer<T: Send> {
pub socket_id: usize,
pub addr: SocketAddr,
pub token: Arc<Mutex<TokenStore<T>>>,
pub udp_sock: Arc<UdpSend>,
}
impl<T: Send> Peer<T> {
pub async fn send(&self, data: &[u8]) -> Result<usize, std::io::Error> {
self.udp_sock.send(data).await
}
}
impl <I,R,T> UdpServer<I,R,T,()> where
I: Fn(Arc<()>,Arc<Peer<T>>, Vec<u8>) -> R + Send + Sync + 'static,
R: Future<Output = Result<(), Box<dyn Error>>> + Send,
T: Send + 'static{
pub async fn new<A: ToSocketAddrs>(addr:A)->Result<Self, Box<dyn Error>> {
Self::new_inner(addr,Arc::new(())).await
}
}
impl<I, R, T, S> UdpServer<I, R, T, S>
where
I: Fn(Arc<S>,Arc<Peer<T>>, Vec<u8>) -> R + Send + Sync + 'static,
R: Future<Output = Result<(), Box<dyn Error>>> + Send,
T: Send + 'static,
S: Sync +Send + 'static{
#[cfg(not(target_os = "windows"))]
fn make_udp_client<A: ToSocketAddrs>(addr: &A) -> Result<std::net::UdpSocket, Box<dyn Error>> {
let res = UdpBuilder::new_v4()?
.reuse_address(true)?
.reuse_port(true)?
.bind(addr)?;
Ok(res)
}
#[cfg(target_os = "windows")]
fn make_udp_client<A: ToSocketAddrs>(addr: &A) -> Result<std::net::UdpSocket, Box<dyn Error>> {
let res = UdpBuilder::new_v4()?.reuse_address(true)?.bind(addr)?;
Ok(res)
}
fn create_udp_socket<A: ToSocketAddrs>(
addr: &A,
) -> Result<std::net::UdpSocket, Box<dyn Error>> {
let res = Self::make_udp_client(addr)?;
res.set_send_buffer_size(1784 * 10000)?;
res.set_recv_buffer_size(1784 * 10000)?;
Ok(res)
}
fn create_async_udp_socket<A: ToSocketAddrs>(addr: &A) -> Result<UdpSocket, Box<dyn Error>> {
let std_sock = Self::create_udp_socket(&addr)?;
let sock = UdpSocket::try_from(std_sock)?;
Ok(sock)
}
fn create_udp_socket_list<A: ToSocketAddrs>(
addr: &A,
listen_count: usize,
) -> Result<Vec<UdpSocket>, Box<dyn Error>> {
println!("cpus:{}", listen_count);
let mut listens = vec![];
for _ in 0..listen_count {
let sock = Self::create_async_udp_socket(addr)?;
listens.push(sock);
}
Ok(listens)
}
#[cfg(not(target_os = "windows"))]
fn get_cpu_count() -> usize {
num_cpus::get()
}
#[cfg(target_os = "windows")]
fn get_cpu_count() -> usize {
1
}
pub async fn new_inner<A: ToSocketAddrs>(addr: A, inner:Arc<S>) -> Result<Self, Box<dyn Error>> {
let udp_list = Self::create_udp_socket_list(&addr, Self::get_cpu_count())?;
let mut udp_map = vec![];
let mut id = 1;
for udp in udp_list {
let (recv, send) = udp.split();
udp_map.push(UdpContext {
id,
recv: Arc::new(Mutex::new(recv)),
send: Arc::new(Mutex::new( send)),
peers: Arc::new(Mutex::new(HashMap::new())),
});
id += 1;
}
Ok(UdpServer {
inner,
udp_contexts: udp_map,
input: None,
error_input: None,
})
}
pub fn set_input(&mut self, input: I) {
self.input = Some(Arc::new(input));
}
pub fn set_err_input<P: Fn(Option<Arc<Peer<T>>>, Box<dyn Error>)->bool + Send + 'static>(&mut self, err_input: P) {
self.error_input = Some(Arc::new(Mutex::new(err_input)));
}
pub async fn remove_peer(&self,addr:SocketAddr)->bool{
for udp_server in self.udp_contexts.iter() {
let mut peer_dict= udp_server.peers.lock().await;
if let Some(_)= peer_dict.remove(&addr){
return true;
}
}
false
}
pub async fn start(&self) -> Result<(), Box<dyn Error>> {
if let Some(input) = &self.input {
let mut tasks = vec![];
for udp_sock in self.udp_contexts.iter() {
let recv_sock = Arc::downgrade(&udp_sock.recv);
let send_sock = udp_sock.send.clone();
let input=input.clone();
let id = udp_sock.id;
let pees_ptr = udp_sock.peers.clone();
let inner= self.inner.clone();
let err_input = {
if let Some(err) = &self.error_input {
let x = err;
x.clone()
} else {
Arc::new(Mutex::new(|peer:Option<Arc<Peer<T>>>, err:Box<dyn Error>| {
match peer {
Some(peer) => {
println!("{}-{}", peer.addr, err);
}
None => {
println!("{}", err);
}
}
true
}))
}
};
let pd = tokio::spawn(async move {
let wk = recv_sock.upgrade();
if let Some(sock_mutex) = wk {
let mut buff = [0; BUFF_MAX_SIZE];
loop {
let res = {
let mut sock = sock_mutex.lock().await;
sock.recv_from(&mut buff).await
};
if let Ok((size, addr)) = res {
let peer = {
let mut lock_pees = pees_ptr.lock().await;
let res = lock_pees.entry(addr).or_insert_with(|| {
Arc::new(Peer {
socket_id: id,
addr,
token: Arc::new(Mutex::new(TokenStore(None))),
udp_sock: Arc::new(UdpSend(send_sock.clone(), addr))
})
});
res.clone()
};
let err = {
let res =
input(inner.clone(), peer.clone(), buff[0..size].to_vec()).await;
match res {
Err(er) => Some(format!("{}", er)),
Ok(()) => None,
}
};
if let Some(err_msg) = err {
let error = err_input.lock().await;
let stop = error(
Some(peer),
err_msg.into(),
);
if stop {
return;
}
}
} else if let Err(er) = res {
let error = err_input.lock().await;
let stop= error(None, error::Error::IOError(er.into()).into());
if stop{
return;
}
}
}
} else {
delay_for(Duration::from_millis(1)).await;
}
});
tasks.push(pd);
}
for task in tasks {
task.await?;
}
Ok(())
} else {
panic!("not found input")
}
}
}