1use ratchet_rs::deflate::DeflateExtProvider;
2use ratchet_rs::{Error as RatchetError, ProtocolRegistry, WebSocketConfig};
3use tokio::net::{TcpListener, TcpStream};
4
5use nrpc::_helpers::futures::StreamExt;
6
7use crate::rpc::StaticServiceRegistry;
8
9struct MethodDescriptor<'a> {
10 service: &'a str,
11 method: &'a str,
12}
13
14pub struct WebsocketServer {
16 services: StaticServiceRegistry,
17 port: u16,
18}
19
20impl WebsocketServer {
21 pub fn new(port_usdpl: u16) -> Self {
23 Self {
24 services: StaticServiceRegistry::with_builtins(),
25 port: port_usdpl,
26 }
27 }
28
29 pub fn registry(&mut self) -> &'_ mut StaticServiceRegistry {
31 &mut self.services
32 }
33
34 pub fn register<S: nrpc::ServerService<'static> + Send + 'static>(mut self, service: S) -> Self {
36 self.services.register(service);
37 self
38 }
39
40 pub async fn run(&self) -> std::io::Result<()> {
42 #[cfg(debug_assertions)]
43 let addr = (std::net::Ipv4Addr::UNSPECIFIED, self.port);
44 #[cfg(not(debug_assertions))]
45 let addr = (std::net::Ipv4Addr::LOCALHOST, self.port);
46
47 let tcp = TcpListener::bind(addr).await?;
48
49 while let Ok((stream, _addr_do_not_use)) = tcp.accept().await {
50 tokio::spawn(error_logger("USDPL websocket server error", Self::connection_handler(self.services.clone(), stream)));
51 }
52
53 Ok(())
54 }
55
56 #[cfg(feature = "blocking")]
57 pub fn run_blocking(self) -> std::io::Result<()> {
59 let runner = tokio::runtime::Builder::new_multi_thread()
60 .enable_all()
61 .build()?;
62 runner.block_on(self.run())
63 }
64
65 async fn connection_handler(
66 mut services: StaticServiceRegistry,
67 stream: TcpStream,
68 ) -> Result<(), RatchetError> {
69 log::debug!("connection_handler invoked!");
70 let upgraded = ratchet_rs::accept_with(
71 stream,
72 WebSocketConfig::default(),
73 DeflateExtProvider::default(),
74 ProtocolRegistry::new(["usdpl-nrpc"])?,
75 )
76 .await?
77 .upgrade()
78 .await?;
79
80 let request_path = upgraded.request.uri().path();
81
82 log::debug!("accepted new connection on uri {}", request_path);
83
84 let websocket = std::sync::Arc::new(tokio::sync::Mutex::new(upgraded.websocket));
85
86 let descriptor = Self::parse_uri_path(request_path)
87 .map_err(|e| RatchetError::with_cause(ratchet_rs::ErrorKind::Protocol, e))?;
88
89 let input_stream = Box::new(nrpc::_helpers::futures::stream::StreamExt::boxed(crate::rpc::ws_stream(websocket.clone())));
90 let output_stream = services
91 .call_descriptor(
92 descriptor.service,
93 descriptor.method,
94 input_stream,
95 )
96 .await
97 .map_err(|e| {
98 RatchetError::with_cause(ratchet_rs::ErrorKind::Protocol, e.to_string())
99 })?;
100
101 output_stream.for_each(|result| async {
102 match result {
103 Ok(msg) => {
104 let mut ws_lock = websocket.lock().await;
105 if let Err(e) = ws_lock.write_binary(msg).await {
106 log::error!("websocket error while writing response on uri {}: {}", request_path, e);
107 }
108 },
109 Err(e) => {
110 log::error!("service error while writing response on uri {}: {}", request_path, e);
111 }
112 }
113 }).await;
114
115 websocket.lock().await.close(ratchet_rs::CloseReason {
116 code: ratchet_rs::CloseCode::Normal,
117 description: None,
118 }).await?;
119
120 log::debug!("ws connection {} closed", request_path);
150 Ok(())
151 }
152
153 fn parse_uri_path<'a>(path: &'a str) -> Result<MethodDescriptor<'a>, &'static str> {
154 let mut iter = path.trim_matches('/').split('/');
155 if let Some(service) = iter.next() {
156 if let Some(method) = iter.next() {
157 if iter.next().is_none() {
158 return Ok(MethodDescriptor { service, method });
159 } else {
160 Err("URL path has too many separators")
161 }
162 } else {
163 Err("URL path has no method")
164 }
165 } else {
166 Err("URL path has no service")
167 }
168 }
169}
170
171async fn error_logger<E: std::error::Error>(msg: &'static str, f: impl core::future::Future<Output=Result<(), E>>) {
172 if let Err(e) = f.await {
173 log::error!("{}: {}", msg, e);
174 }
175}