1#![allow(async_fn_in_trait)]
7
8use futures_util::{SinkExt, StreamExt};
9use rustbac_datalink::{DataLink, DataLinkAddress, DataLinkError};
10use std::io;
11use std::net::{IpAddr, SocketAddr};
12use std::sync::Arc;
13use tokio::net::lookup_host;
14use tokio::sync::{broadcast, mpsc};
15use tokio_tungstenite::connect_async;
16use tokio_tungstenite::tungstenite::Message;
17
18const CHANNEL_DEPTH: usize = 128;
19const BROADCAST_DEPTH: usize = 64;
26
27#[derive(Clone)]
39pub struct BacnetScTransport {
40 endpoint: String,
41 peer_address: DataLinkAddress,
42 outbound: mpsc::Sender<Vec<u8>>,
43 inbound: Arc<broadcast::Sender<Vec<u8>>>,
45}
46
47impl std::fmt::Debug for BacnetScTransport {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 f.debug_struct("BacnetScTransport")
50 .field("endpoint", &self.endpoint)
51 .field("peer_address", &self.peer_address)
52 .finish()
53 }
54}
55
56impl BacnetScTransport {
57 pub async fn connect(endpoint: impl Into<String>) -> Result<Self, DataLinkError> {
58 let endpoint = endpoint.into();
59 let peer_address = resolve_peer_address(&endpoint).await?;
60
61 let (socket, _) = connect_async(endpoint.as_str())
62 .await
63 .map_err(|err| ws_io_error(io::ErrorKind::ConnectionRefused, err))?;
64 let (mut writer, mut reader) = socket.split();
65
66 let (outbound_tx, mut outbound_rx) = mpsc::channel::<Vec<u8>>(CHANNEL_DEPTH);
67 let (inbound_tx, _) = broadcast::channel::<Vec<u8>>(BROADCAST_DEPTH);
68 let inbound_tx = Arc::new(inbound_tx);
69 let inbound_tx_clone = inbound_tx.clone();
70
71 tokio::spawn(async move {
72 while let Some(frame) = outbound_rx.recv().await {
73 if writer.send(Message::Binary(frame)).await.is_err() {
74 return;
75 }
76 }
77 let _ = writer.close().await;
78 });
79
80 tokio::spawn(async move {
81 while let Some(next) = reader.next().await {
82 let message = match next {
83 Ok(message) => message,
84 Err(_) => break,
85 };
86
87 match message {
88 Message::Binary(payload) => {
89 let _ = inbound_tx_clone.send(payload.to_vec());
92 }
93 Message::Text(text) => {
94 log::debug!("ignoring non-binary BACnet/SC websocket frame: {text}");
95 }
96 _ => {}
97 }
98 }
99 });
100
101 Ok(Self {
102 endpoint,
103 peer_address,
104 outbound: outbound_tx,
105 inbound: inbound_tx,
106 })
107 }
108
109 pub fn endpoint(&self) -> &str {
110 &self.endpoint
111 }
112
113 pub fn peer_address(&self) -> DataLinkAddress {
114 self.peer_address
115 }
116}
117
118impl DataLink for BacnetScTransport {
119 async fn send(&self, _address: DataLinkAddress, payload: &[u8]) -> Result<(), DataLinkError> {
120 self.outbound.send(payload.to_vec()).await.map_err(|_| {
121 DataLinkError::Io(io::Error::new(
122 io::ErrorKind::BrokenPipe,
123 "BACnet/SC websocket sender task stopped",
124 ))
125 })
126 }
127
128 async fn recv(&self, buf: &mut [u8]) -> Result<(usize, DataLinkAddress), DataLinkError> {
129 let mut rx = self.inbound.subscribe();
133 loop {
134 match rx.recv().await {
135 Ok(payload) => {
136 if payload.len() > buf.len() {
137 return Err(DataLinkError::FrameTooLarge);
138 }
139 buf[..payload.len()].copy_from_slice(&payload);
140 return Ok((payload.len(), self.peer_address));
141 }
142 Err(broadcast::error::RecvError::Lagged(n)) => {
143 log::debug!("BACnet/SC recv lagged by {n} frames; skipping");
147 continue;
148 }
149 Err(broadcast::error::RecvError::Closed) => {
150 return Err(DataLinkError::Io(io::Error::new(
151 io::ErrorKind::UnexpectedEof,
152 "BACnet/SC websocket receiver task stopped",
153 )));
154 }
155 }
156 }
157 }
158}
159
160fn ws_io_error(kind: io::ErrorKind, err: impl std::fmt::Display) -> DataLinkError {
161 DataLinkError::Io(io::Error::new(
162 kind,
163 format!("BACnet/SC websocket error: {err}"),
164 ))
165}
166
167async fn resolve_peer_address(endpoint: &str) -> Result<DataLinkAddress, DataLinkError> {
168 let (scheme, remainder) = endpoint.split_once("://").ok_or_else(|| {
169 DataLinkError::Io(io::Error::new(
170 io::ErrorKind::InvalidInput,
171 format!("invalid BACnet/SC endpoint '{endpoint}'"),
172 ))
173 })?;
174 let default_port = match scheme {
175 "ws" => 80,
176 "wss" => 443,
177 _ => {
178 return Err(DataLinkError::Io(io::Error::new(
179 io::ErrorKind::InvalidInput,
180 format!("unsupported BACnet/SC endpoint scheme '{scheme}'"),
181 )))
182 }
183 };
184 let authority = remainder.split('/').next().unwrap_or_default();
185 if authority.is_empty() {
186 return Err(DataLinkError::Io(io::Error::new(
187 io::ErrorKind::InvalidInput,
188 format!("BACnet/SC endpoint '{endpoint}' is missing host"),
189 )));
190 }
191 let authority = authority.rsplit('@').next().unwrap_or(authority);
192 if authority.is_empty() {
193 return Err(DataLinkError::Io(io::Error::new(
194 io::ErrorKind::InvalidInput,
195 format!("BACnet/SC endpoint '{endpoint}' is missing host"),
196 )));
197 }
198
199 let (host, port) = if let Some(rest) = authority.strip_prefix('[') {
200 let (ipv6_host, suffix) = rest.split_once(']').ok_or_else(|| {
201 DataLinkError::Io(io::Error::new(
202 io::ErrorKind::InvalidInput,
203 format!("invalid IPv6 host in BACnet/SC endpoint '{endpoint}'"),
204 ))
205 })?;
206 let port = if suffix.is_empty() {
207 default_port
208 } else if let Some(raw_port) = suffix.strip_prefix(':') {
209 raw_port.parse::<u16>().map_err(|_| {
210 DataLinkError::Io(io::Error::new(
211 io::ErrorKind::InvalidInput,
212 format!("invalid BACnet/SC endpoint port in '{endpoint}'"),
213 ))
214 })?
215 } else {
216 return Err(DataLinkError::Io(io::Error::new(
217 io::ErrorKind::InvalidInput,
218 format!("invalid BACnet/SC endpoint authority '{authority}'"),
219 )));
220 };
221 (ipv6_host.to_string(), port)
222 } else {
223 match authority.rsplit_once(':') {
224 Some((host, raw_port)) if !host.is_empty() && !raw_port.is_empty() => {
225 let port = raw_port.parse::<u16>().map_err(|_| {
226 DataLinkError::Io(io::Error::new(
227 io::ErrorKind::InvalidInput,
228 format!("invalid BACnet/SC endpoint port in '{endpoint}'"),
229 ))
230 })?;
231 (host.to_string(), port)
232 }
233 _ => (authority.to_string(), default_port),
234 }
235 };
236
237 if let Ok(ip) = host.parse::<IpAddr>() {
238 return Ok(DataLinkAddress::Ip(SocketAddr::new(ip, port)));
239 }
240
241 let mut addrs = lookup_host((host.as_str(), port))
242 .await
243 .map_err(DataLinkError::Io)?;
244 addrs.next().map(DataLinkAddress::Ip).ok_or_else(|| {
245 DataLinkError::Io(io::Error::new(
246 io::ErrorKind::NotFound,
247 format!("unable to resolve BACnet/SC host '{host}'"),
248 ))
249 })
250}
251
252#[cfg(test)]
253mod tests {
254 use super::BacnetScTransport;
255 use futures_util::{SinkExt, StreamExt};
256 use rustbac_datalink::{DataLink, DataLinkAddress, DataLinkError};
257 use std::net::SocketAddr;
258 use tokio::net::TcpListener;
259 use tokio::time::{timeout, Duration};
260 use tokio_tungstenite::accept_async;
261 use tokio_tungstenite::tungstenite::Message;
262
263 async fn spawn_echo_server() -> (SocketAddr, tokio::task::JoinHandle<()>) {
264 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
265 let addr = listener.local_addr().unwrap();
266 let task = tokio::spawn(async move {
267 let (stream, _) = listener.accept().await.unwrap();
268 let mut ws = accept_async(stream).await.unwrap();
269 while let Some(next) = ws.next().await {
270 let msg = next.unwrap();
271 match msg {
272 Message::Binary(payload) => {
273 ws.send(Message::Binary(payload)).await.unwrap();
274 }
275 Message::Ping(payload) => {
276 ws.send(Message::Pong(payload)).await.unwrap();
277 }
278 Message::Close(frame) => {
279 let _ = ws.close(frame).await;
280 break;
281 }
282 Message::Pong(_) | Message::Text(_) => {}
283 _ => {}
284 }
285 }
286 });
287 (addr, task)
288 }
289
290 #[tokio::test]
291 async fn connect_sets_endpoint_and_peer_address() {
292 let (addr, server) = spawn_echo_server().await;
293 let endpoint = format!("ws://{addr}/hub");
294 let transport = BacnetScTransport::connect(endpoint.clone()).await.unwrap();
295 assert_eq!(transport.endpoint(), endpoint);
296 assert_eq!(transport.peer_address(), DataLinkAddress::Ip(addr));
297 drop(transport);
298 server.abort();
299 }
300
301 #[tokio::test]
302 async fn send_and_recv_binary_payload() {
303 let (addr, server) = spawn_echo_server().await;
304 let transport = BacnetScTransport::connect(format!("ws://{addr}/hub"))
305 .await
306 .unwrap();
307
308 transport
309 .send(DataLinkAddress::Ip(addr), &[1, 2, 3, 4])
310 .await
311 .unwrap();
312
313 let mut out = [0u8; 16];
314 let (n, src) = timeout(Duration::from_secs(1), transport.recv(&mut out))
315 .await
316 .unwrap()
317 .unwrap();
318 assert_eq!(n, 4);
319 assert_eq!(&out[..4], &[1, 2, 3, 4]);
320 assert_eq!(src, DataLinkAddress::Ip(addr));
321
322 drop(transport);
323 server.abort();
324 }
325
326 #[tokio::test]
327 async fn recv_reports_frame_too_large() {
328 let (addr, server) = spawn_echo_server().await;
329 let transport = BacnetScTransport::connect(format!("ws://{addr}/hub"))
330 .await
331 .unwrap();
332 transport
333 .send(DataLinkAddress::Ip(addr), &[9, 8, 7, 6])
334 .await
335 .unwrap();
336
337 let mut out = [0u8; 2];
338 let err = transport.recv(&mut out).await.unwrap_err();
339 assert!(matches!(err, DataLinkError::FrameTooLarge));
340
341 drop(transport);
342 server.abort();
343 }
344
345 #[tokio::test]
346 async fn connect_rejects_invalid_endpoint() {
347 let err = BacnetScTransport::connect("not a url").await.unwrap_err();
348 assert!(matches!(err, DataLinkError::Io(_)));
349 }
350}