1#![allow(async_fn_in_trait)]
7
8use futures_util::{SinkExt, StreamExt};
9use rustbac_datalink::{DataLink, DataLinkAddress, DataLinkError};
10use std::io;
11use std::net::{IpAddr, SocketAddr};
12use std::sync::Arc;
13use tokio::net::lookup_host;
14use tokio::sync::{mpsc, Mutex};
15use tokio_tungstenite::connect_async;
16use tokio_tungstenite::tungstenite::Message;
17
18const CHANNEL_DEPTH: usize = 128;
19
20#[derive(Debug, Clone)]
23pub struct BacnetScTransport {
24 endpoint: String,
25 peer_address: DataLinkAddress,
26 outbound: mpsc::Sender<Vec<u8>>,
27 inbound: Arc<Mutex<mpsc::Receiver<Vec<u8>>>>,
28}
29
30impl BacnetScTransport {
31 pub async fn connect(endpoint: impl Into<String>) -> Result<Self, DataLinkError> {
32 let endpoint = endpoint.into();
33 let peer_address = resolve_peer_address(&endpoint).await?;
34
35 let (socket, _) = connect_async(endpoint.as_str())
36 .await
37 .map_err(|err| ws_io_error(io::ErrorKind::ConnectionRefused, err))?;
38 let (mut writer, mut reader) = socket.split();
39
40 let (outbound_tx, mut outbound_rx) = mpsc::channel::<Vec<u8>>(CHANNEL_DEPTH);
41 let (inbound_tx, inbound_rx) = mpsc::channel::<Vec<u8>>(CHANNEL_DEPTH);
42
43 tokio::spawn(async move {
44 while let Some(frame) = outbound_rx.recv().await {
45 if writer.send(Message::Binary(frame)).await.is_err() {
46 return;
47 }
48 }
49 let _ = writer.close().await;
50 });
51
52 tokio::spawn(async move {
53 while let Some(next) = reader.next().await {
54 let message = match next {
55 Ok(message) => message,
56 Err(_) => break,
57 };
58
59 match message {
60 Message::Binary(payload) => {
61 if inbound_tx.send(payload.to_vec()).await.is_err() {
62 break;
63 }
64 }
65 Message::Text(text) => {
66 log::debug!("ignoring non-binary BACnet/SC websocket frame: {text}");
67 }
68 _ => {}
69 }
70 }
71 });
72
73 Ok(Self {
74 endpoint,
75 peer_address,
76 outbound: outbound_tx,
77 inbound: Arc::new(Mutex::new(inbound_rx)),
78 })
79 }
80
81 pub fn endpoint(&self) -> &str {
82 &self.endpoint
83 }
84
85 pub fn peer_address(&self) -> DataLinkAddress {
86 self.peer_address
87 }
88}
89
90impl DataLink for BacnetScTransport {
91 async fn send(&self, _address: DataLinkAddress, payload: &[u8]) -> Result<(), DataLinkError> {
92 self.outbound.send(payload.to_vec()).await.map_err(|_| {
93 DataLinkError::Io(io::Error::new(
94 io::ErrorKind::BrokenPipe,
95 "BACnet/SC websocket sender task stopped",
96 ))
97 })
98 }
99
100 async fn recv(&self, buf: &mut [u8]) -> Result<(usize, DataLinkAddress), DataLinkError> {
101 let mut inbound = self.inbound.lock().await;
102 let payload = inbound.recv().await.ok_or_else(|| {
103 DataLinkError::Io(io::Error::new(
104 io::ErrorKind::UnexpectedEof,
105 "BACnet/SC websocket receiver task stopped",
106 ))
107 })?;
108 if payload.len() > buf.len() {
109 return Err(DataLinkError::FrameTooLarge);
110 }
111 buf[..payload.len()].copy_from_slice(&payload);
112 Ok((payload.len(), self.peer_address))
113 }
114}
115
116fn ws_io_error(kind: io::ErrorKind, err: impl std::fmt::Display) -> DataLinkError {
117 DataLinkError::Io(io::Error::new(
118 kind,
119 format!("BACnet/SC websocket error: {err}"),
120 ))
121}
122
123async fn resolve_peer_address(endpoint: &str) -> Result<DataLinkAddress, DataLinkError> {
124 let (scheme, remainder) = endpoint.split_once("://").ok_or_else(|| {
125 DataLinkError::Io(io::Error::new(
126 io::ErrorKind::InvalidInput,
127 format!("invalid BACnet/SC endpoint '{endpoint}'"),
128 ))
129 })?;
130 let default_port = match scheme {
131 "ws" => 80,
132 "wss" => 443,
133 _ => {
134 return Err(DataLinkError::Io(io::Error::new(
135 io::ErrorKind::InvalidInput,
136 format!("unsupported BACnet/SC endpoint scheme '{scheme}'"),
137 )))
138 }
139 };
140 let authority = remainder.split('/').next().unwrap_or_default();
141 if authority.is_empty() {
142 return Err(DataLinkError::Io(io::Error::new(
143 io::ErrorKind::InvalidInput,
144 format!("BACnet/SC endpoint '{endpoint}' is missing host"),
145 )));
146 }
147 let authority = authority.rsplit('@').next().unwrap_or(authority);
148 if authority.is_empty() {
149 return Err(DataLinkError::Io(io::Error::new(
150 io::ErrorKind::InvalidInput,
151 format!("BACnet/SC endpoint '{endpoint}' is missing host"),
152 )));
153 }
154
155 let (host, port) = if let Some(rest) = authority.strip_prefix('[') {
156 let (ipv6_host, suffix) = rest.split_once(']').ok_or_else(|| {
157 DataLinkError::Io(io::Error::new(
158 io::ErrorKind::InvalidInput,
159 format!("invalid IPv6 host in BACnet/SC endpoint '{endpoint}'"),
160 ))
161 })?;
162 let port = if suffix.is_empty() {
163 default_port
164 } else if let Some(raw_port) = suffix.strip_prefix(':') {
165 raw_port.parse::<u16>().map_err(|_| {
166 DataLinkError::Io(io::Error::new(
167 io::ErrorKind::InvalidInput,
168 format!("invalid BACnet/SC endpoint port in '{endpoint}'"),
169 ))
170 })?
171 } else {
172 return Err(DataLinkError::Io(io::Error::new(
173 io::ErrorKind::InvalidInput,
174 format!("invalid BACnet/SC endpoint authority '{authority}'"),
175 )));
176 };
177 (ipv6_host.to_string(), port)
178 } else {
179 match authority.rsplit_once(':') {
180 Some((host, raw_port)) if !host.is_empty() && !raw_port.is_empty() => {
181 let port = raw_port.parse::<u16>().map_err(|_| {
182 DataLinkError::Io(io::Error::new(
183 io::ErrorKind::InvalidInput,
184 format!("invalid BACnet/SC endpoint port in '{endpoint}'"),
185 ))
186 })?;
187 (host.to_string(), port)
188 }
189 _ => (authority.to_string(), default_port),
190 }
191 };
192
193 if let Ok(ip) = host.parse::<IpAddr>() {
194 return Ok(DataLinkAddress::Ip(SocketAddr::new(ip, port)));
195 }
196
197 let mut addrs = lookup_host((host.as_str(), port))
198 .await
199 .map_err(DataLinkError::Io)?;
200 addrs.next().map(DataLinkAddress::Ip).ok_or_else(|| {
201 DataLinkError::Io(io::Error::new(
202 io::ErrorKind::NotFound,
203 format!("unable to resolve BACnet/SC host '{host}'"),
204 ))
205 })
206}
207
208#[cfg(test)]
209mod tests {
210 use super::BacnetScTransport;
211 use futures_util::{SinkExt, StreamExt};
212 use rustbac_datalink::{DataLink, DataLinkAddress, DataLinkError};
213 use std::net::SocketAddr;
214 use tokio::net::TcpListener;
215 use tokio::time::{timeout, Duration};
216 use tokio_tungstenite::accept_async;
217 use tokio_tungstenite::tungstenite::Message;
218
219 async fn spawn_echo_server() -> (SocketAddr, tokio::task::JoinHandle<()>) {
220 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
221 let addr = listener.local_addr().unwrap();
222 let task = tokio::spawn(async move {
223 let (stream, _) = listener.accept().await.unwrap();
224 let mut ws = accept_async(stream).await.unwrap();
225 while let Some(next) = ws.next().await {
226 let msg = next.unwrap();
227 match msg {
228 Message::Binary(payload) => {
229 ws.send(Message::Binary(payload)).await.unwrap();
230 }
231 Message::Ping(payload) => {
232 ws.send(Message::Pong(payload)).await.unwrap();
233 }
234 Message::Close(frame) => {
235 let _ = ws.close(frame).await;
236 break;
237 }
238 Message::Pong(_) | Message::Text(_) => {}
239 _ => {}
240 }
241 }
242 });
243 (addr, task)
244 }
245
246 #[tokio::test]
247 async fn connect_sets_endpoint_and_peer_address() {
248 let (addr, server) = spawn_echo_server().await;
249 let endpoint = format!("ws://{addr}/hub");
250 let transport = BacnetScTransport::connect(endpoint.clone()).await.unwrap();
251 assert_eq!(transport.endpoint(), endpoint);
252 assert_eq!(transport.peer_address(), DataLinkAddress::Ip(addr));
253 drop(transport);
254 server.abort();
255 }
256
257 #[tokio::test]
258 async fn send_and_recv_binary_payload() {
259 let (addr, server) = spawn_echo_server().await;
260 let transport = BacnetScTransport::connect(format!("ws://{addr}/hub"))
261 .await
262 .unwrap();
263
264 transport
265 .send(DataLinkAddress::Ip(addr), &[1, 2, 3, 4])
266 .await
267 .unwrap();
268
269 let mut out = [0u8; 16];
270 let (n, src) = timeout(Duration::from_secs(1), transport.recv(&mut out))
271 .await
272 .unwrap()
273 .unwrap();
274 assert_eq!(n, 4);
275 assert_eq!(&out[..4], &[1, 2, 3, 4]);
276 assert_eq!(src, DataLinkAddress::Ip(addr));
277
278 drop(transport);
279 server.abort();
280 }
281
282 #[tokio::test]
283 async fn recv_reports_frame_too_large() {
284 let (addr, server) = spawn_echo_server().await;
285 let transport = BacnetScTransport::connect(format!("ws://{addr}/hub"))
286 .await
287 .unwrap();
288 transport
289 .send(DataLinkAddress::Ip(addr), &[9, 8, 7, 6])
290 .await
291 .unwrap();
292
293 let mut out = [0u8; 2];
294 let err = transport.recv(&mut out).await.unwrap_err();
295 assert!(matches!(err, DataLinkError::FrameTooLarge));
296
297 drop(transport);
298 server.abort();
299 }
300
301 #[tokio::test]
302 async fn connect_rejects_invalid_endpoint() {
303 let err = BacnetScTransport::connect("not a url").await.unwrap_err();
304 assert!(matches!(err, DataLinkError::Io(_)));
305 }
306}