use crate::error;
use net2::{UdpBuilder, UdpSocketExt};
use std::collections::HashMap;
use std::convert::TryFrom;
use std::error::Error;
use std::future::Future;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::{Arc, Weak};
use tokio::net::udp::{RecvHalf, SendHalf};
use tokio::net::UdpSocket;
use tokio::sync::Mutex;
use tokio::time::{delay_for, Duration};
use futures::executor::block_on;
#[cfg(not(target_os = "windows"))]
use net2::unix::UnixUdpBuilderExt;
pub const BUFF_MAX_SIZE: usize = 4096;
pub struct UdpContext<T: Send> {
pub id: usize,
recv: Arc<Mutex<RecvHalf>>,
pub send: Arc<Mutex<SendHalf>>,
pub peers: Arc<Mutex<HashMap<SocketAddr, Arc<Peer<T>>>>>,
}
pub struct UdpServer<I, R, T,S>
where
I: Fn(Weak<S>,Arc<Peer<T>>, Vec<u8>) -> R + Send + Sync + 'static,
R: Future<Output = Result<(), Box<dyn Error>>>,
T: Send + 'static,
S: Sync +Send+'static
{
inner:Arc<S>,
udp_contexts: Vec<UdpContext<T>>,
input: Option<Arc<I>>,
error_input: Option<Arc<Mutex<dyn Fn(Option<Arc<Peer<T>>>, Box<dyn Error>)->bool + Send>>>,
}
#[derive(Debug)]
pub struct TokenStore<T:Send>(pub Option<T>);
impl<T:Send> TokenStore<T>{
pub fn have(&self)->bool{
match &self.0 {
None=>false,
Some(_)=>true
}
}
pub fn get(&mut self)->Option<&mut T> {
match self.0 {
None=>None,
Some(ref mut v)=>Some(v)
}
}
pub fn set(&mut self,v:T) {
self.0 = Some(v);
}
}
#[derive(Debug)]
pub struct UdpSend(pub Weak<Mutex<SendHalf>>,pub SocketAddr);
impl UdpSend{
pub async fn send(&self,buf: &[u8])->std::io::Result<usize>{
if let Some(ref sock) =self.0.upgrade(){
let mut un_sock= sock.lock().await;
return un_sock.send_to(buf,&self.1).await;
}
Ok(0)
}
}
impl std::io::Write for UdpSend{
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
block_on(async move {
self.send(buf).await
})
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[derive(Debug)]
pub struct Peer<T: Send> {
pub socket_id: usize,
pub addr: SocketAddr,
pub token: Arc<Mutex<TokenStore<T>>>,
pub udp_sock: Arc<Mutex<UdpSend>>,
}
impl<T: Send> Peer<T> {
pub async fn send(&self, data: &[u8]) -> Result<usize, std::io::Error> {
let sock_have = self.udp_sock.lock().await;
sock_have.send(data).await
}
}
impl <I,R,T> UdpServer<I,R,T,()> where
I: Fn(Weak<()>,Arc<Peer<T>>, Vec<u8>) -> R + Send + Sync + 'static,
R: Future<Output = Result<(), Box<dyn Error>>> + Send,
T: Send + 'static{
pub async fn new<A: ToSocketAddrs>(addr:A)->Result<Self, Box<dyn Error>> {
Self::new_inner(addr,Arc::new(())).await
}
}
impl<I, R, T, S> UdpServer<I, R, T, S>
where
I: Fn(Weak<S>,Arc<Peer<T>>, Vec<u8>) -> R + Send + Sync + 'static,
R: Future<Output = Result<(), Box<dyn Error>>> + Send,
T: Send + 'static,
S: Sync +Send + 'static{
#[cfg(not(target_os = "windows"))]
fn make_udp_client<A: ToSocketAddrs>(addr: &A) -> Result<std::net::UdpSocket, Box<dyn Error>> {
let res = UdpBuilder::new_v4()?
.reuse_address(true)?
.reuse_port(true)?
.bind(addr)?;
Ok(res)
}
#[cfg(target_os = "windows")]
fn make_udp_client<A: ToSocketAddrs>(addr: &A) -> Result<std::net::UdpSocket, Box<dyn Error>> {
let res = UdpBuilder::new_v4()?.reuse_address(true)?.bind(addr)?;
Ok(res)
}
fn create_udp_socket<A: ToSocketAddrs>(
addr: &A,
) -> Result<std::net::UdpSocket, Box<dyn Error>> {
let res = Self::make_udp_client(addr)?;
res.set_send_buffer_size(1784 * 10000)?;
res.set_recv_buffer_size(1784 * 10000)?;
Ok(res)
}
fn create_async_udp_socket<A: ToSocketAddrs>(addr: &A) -> Result<UdpSocket, Box<dyn Error>> {
let std_sock = Self::create_udp_socket(&addr)?;
let sock = UdpSocket::try_from(std_sock)?;
Ok(sock)
}
fn create_udp_socket_list<A: ToSocketAddrs>(
addr: &A,
listen_count: usize,
) -> Result<Vec<UdpSocket>, Box<dyn Error>> {
println!("cpus:{}", listen_count);
let mut listens = vec![];
for _ in 0..listen_count {
let sock = Self::create_async_udp_socket(addr)?;
listens.push(sock);
}
Ok(listens)
}
#[cfg(not(target_os = "windows"))]
fn get_cpu_count() -> usize {
num_cpus::get()
}
#[cfg(target_os = "windows")]
fn get_cpu_count() -> usize {
1
}
pub async fn new_inner<A: ToSocketAddrs>(addr: A, inner:Arc<S>) -> Result<Self, Box<dyn Error>> {
let udp_list = Self::create_udp_socket_list(&addr, Self::get_cpu_count())?;
let mut udp_map = vec![];
let mut id = 1;
for udp in udp_list {
let (recv, send) = udp.split();
udp_map.push(UdpContext {
id,
recv: Arc::new(Mutex::new(recv)),
send: Arc::new(Mutex::new( send)),
peers: Arc::new(Mutex::new(HashMap::new())),
});
id += 1;
}
Ok(UdpServer {
inner,
udp_contexts: udp_map,
input: None,
error_input: None,
})
}
pub fn set_input(&mut self, input: I) {
self.input = Some(Arc::new(input));
}
pub fn set_err_input<P: Fn(Option<Arc<Peer<T>>>, Box<dyn Error>)->bool + Send + 'static>(&mut self, err_input: P) {
self.error_input = Some(Arc::new(Mutex::new(err_input)));
}
pub async fn start(&self) -> Result<(), Box<dyn Error>> {
if let Some(input) = &self.input {
let mut tasks = vec![];
for udp_sock in self.udp_contexts.iter() {
let recv_sock = Arc::downgrade(&udp_sock.recv);
let send_sock = udp_sock.send.clone();
let id = udp_sock.id;
let input_fn = Arc::downgrade(input);
let pees_ptr = udp_sock.peers.clone();
let inner= Arc::downgrade(&self.inner);
let err_input = {
if let Some(err) = &self.error_input {
let x = err;
x.clone()
} else {
Arc::new(Mutex::new(|peer:Option<Arc<Peer<T>>>, err:Box<dyn Error>| {
match peer {
Some(peer) => {
println!("{}-{}", peer.addr, err);
}
None => {
println!("{}", err);
}
}
true
}))
}
};
let pd = tokio::spawn(async move {
let wk = recv_sock.upgrade();
if let Some(sock_mutex) = wk {
let mut buff = [0; BUFF_MAX_SIZE];
loop {
let res = {
let mut sock = sock_mutex.lock().await;
sock.recv_from(&mut buff).await
};
if let Ok((size, addr)) = res {
let peer = {
let mut lock_pees = pees_ptr.lock().await;
let res = lock_pees.entry(addr).or_insert_with(|| {
Arc::new(Peer {
socket_id: id,
addr,
token: Arc::new(Mutex::new(TokenStore(None))),
udp_sock: Arc::new(Mutex::new( UdpSend( Arc::downgrade( &send_sock),addr)))
})
});
res.clone()
};
let input_wk = input_fn.upgrade();
if let Some(input_in) = input_wk {
let err ={
let res =
input_in(inner.clone(),peer.clone(), buff[0..size].to_vec()).await;
match res {
Err(er) => Some(format!("{}", er)),
Ok(()) => None,
}
};
if let Some(err_msg) = err {
let error = err_input.lock().await;
let stop=error(
Some(peer),
err_msg.into(),
);
if stop{
return;
}
}
} else {
let error = err_input.lock().await;
let stop=error(
None,
format!("{} input in null?",addr).into(),
);
if stop{
return;
}
}
} else if let Err(er) = res {
let error = err_input.lock().await;
let stop= error(None, error::Error::IOError(er.into()).into());
if stop{
return;
}
}
}
} else {
delay_for(Duration::from_millis(1)).await;
}
});
tasks.push(pd);
}
for task in tasks {
task.await?;
}
Ok(())
} else {
panic!("not found input")
}
}
}