roco_z21_driver/
station.rs

1//! # Z21Station
2//!
3//! The `Z21Station` module provides asynchronous communication with a Roco Fleischmann Z21
4//! digital command control (DCC) station for model railways.
5//!
6//! ## Overview
7//!
8//! This module implements a complete UDP-based API for interacting with the Z21 station,
9//! handling command transmission and event reception through an asynchronous architecture.
10//! It supports:
11//!
12//! - Automatic connection management with keep-alive functionality
13//! - Broadcast message handling for system state changes and locomotive information
14//! - DCC command transmission for controlling locomotives and accessories
15//! - XBus protocol implementation for low-level communication
16//!
17
18use crate::messages::{self, SystemState, XBusMessage};
19use crate::packet::Packet;
20use std::cell::OnceCell;
21use std::convert::TryFrom;
22use std::io;
23use std::net::SocketAddr;
24use std::sync::atomic::{AtomicBool, Ordering};
25use std::sync::Arc;
26use std::time::Duration;
27use tokio::net::UdpSocket;
28use tokio::sync::broadcast;
29use tokio::time::{self, timeout};
30
31mod loco;
32pub use loco::Loco;
33
34/// The header value for the LAN_SYSTEMSTATE_DATACHANGED event.
35const LAN_SYSTEMSTATE_DATACHANGED: u16 = 0x84;
36const LAN_SET_BROADCASTFLAGS: u16 = 0x50;
37const LAN_SYSTEMSTATE_GETDATA: u16 = 0x85;
38const X_SET_TRACK_POWER_OFF: (u8, u8) = (0x21, 0x80);
39const X_SET_TRACK_POWER_ON: (u8, u8) = (0x21, 0x81);
40const X_BC_TRACK_POWER: u8 = 0x61;
41
42/// Default timeout in milliseconds for awaiting responses.
43const DEFAULT_TIMEOUT_MS: u64 = 2000;
44
45/// Default broadcast flags for the Z21 station.(Default is ONLY LOCO_INFO & TURNOUT_INFO)
46const DEFAULT_BROADCAST_FLAGS: u32 = 0x00000001;
47
48/// Represents an asynchronous connection to a Z21 station.
49///
50/// The `Z21Station` manages a UDP socket for communication with a Z21 station. It spawns a
51/// background task to continuously listen for incoming packets and proceed these packets
52/// over an internal logic.
53pub struct Z21Station {
54    socket: Arc<UdpSocket>,
55    message_sender: broadcast::Sender<Packet>,
56    message_receiver: broadcast::Receiver<Packet>,
57    timeout: Duration,
58    keep_alive: Arc<AtomicBool>,
59    broadcast_flags: u32,
60}
61
62impl Z21Station {
63    /// Creates a new connection to a Z21 station at the specified address.
64    ///
65    /// This method establishes a UDP connection to the Z21 station, performs the initial
66    /// handshake, and starts background tasks for maintaining the connection.
67    ///
68    /// # Arguments
69    ///
70    /// * `bind_addr` - Network address of the Z21 station (typically "192.168.0.111:21105")
71    ///
72    /// # Returns
73    ///
74    /// A new `Z21Station` instance if the connection is successful.
75    ///
76    /// # Errors
77    ///
78    /// Returns an `io::Error` if:
79    /// - The UDP socket cannot be bound or connected
80    /// - The initial handshake with the Z21 station fails
81    /// - The station does not respond within the timeout period
82    ///
83    /// # Example
84    ///
85    /// ```rust
86    /// let station = Z21Station::new("192.168.0.111:21105").await?;
87    /// ```
88    pub async fn new(bind_addr: &str) -> io::Result<Self> {
89        // Bind the socket to an available local port on all interfaces.
90        let socket = UdpSocket::bind("0.0.0.0:0").await?;
91        // Enable broadcast on the socket to allow sending messages to a broadcast address.
92        socket.set_broadcast(true)?;
93        // Connect the socket to the Z21 station address.
94        socket.connect(bind_addr).await?;
95        let socket = Arc::new(socket);
96
97        // Create a broadcast channel for propagating incoming packets.
98        let (tx, rx) = broadcast::channel(100);
99        let station = Z21Station {
100            socket,
101            message_sender: tx,
102            message_receiver: rx,
103            keep_alive: Arc::new(AtomicBool::new(true)),
104            broadcast_flags: DEFAULT_BROADCAST_FLAGS,
105            timeout: Duration::from_millis(DEFAULT_TIMEOUT_MS),
106        };
107        // Start the background receiver task.
108        station.start_receiver();
109
110        // Perform the initial handshake with the Z21 station.
111        let result = station.initial_handshake().await;
112        if let Err(e) = result {
113            eprintln!(
114                "There is no connection to the Z21 station, on the specified address: {}",
115                bind_addr
116            );
117            return Err(e);
118        }
119
120        // Start the keep-alive thread.
121        station.start_keep_alive_setup_broadcast_task();
122        Ok(station)
123    }
124
125    /// Starts a background asynchronous task that continuously listens for incoming UDP packets.
126    ///
127    /// The task reads data from the socket, converts it into a [`Packet`], and then sends it through
128    /// the internal broadcast channel so that subscribers can process the packet.
129    fn start_receiver(&self) {
130        let socket = Arc::clone(&self.socket);
131        let message_sender = self.message_sender.clone();
132
133        tokio::spawn(async move {
134            let mut buf = [0u8; 1024];
135            loop {
136                match socket.recv(&mut buf).await {
137                    Ok(size) => {
138                        // Copy the received data into a vector.
139                        let data = buf[..size].to_vec();
140                        // Convert the raw data into a Packet.
141                        let packet = Packet::from(data);
142                        //println!("Received packet with header: {:?}", packet.get_header());
143                        // if packet.get_header() == 64 {
144                        //     let xbus_msg = XBusMessage::try_from(
145                        //         &packet.get_data()[0..packet.get_data_len() as usize - 4],
146                        //     );
147                        //     if let Ok(msg) = xbus_msg {
148                        //         println!(
149                        //             "Received XBus message with header: {:02x}",
150                        //             msg.get_x_header()
151                        //         );
152                        //     } else {
153                        //         eprintln!("Failed to parse XBus message");
154                        //     }
155                        // }
156
157                        // Broadcast the packet to all subscribers.
158                        if let Err(e) = message_sender.send(packet) {
159                            eprintln!("Failed to send packet via broadcast channel: {:?}", e);
160                        }
161                    }
162                    Err(e) => {
163                        eprintln!("Error receiving packet: {:?}", e);
164                        break;
165                    }
166                }
167            }
168        });
169    }
170
171    async fn initial_handshake(&self) -> io::Result<()> {
172        let packet = Packet::with_header_and_data(LAN_SYSTEMSTATE_GETDATA, &[]);
173        self.send_packet(packet).await?;
174        let _ = self
175            .receive_packet_with_header(LAN_SYSTEMSTATE_DATACHANGED)
176            .await?;
177        Ok(())
178    }
179
180    async fn send_set_broadcast_flags(socket: &Arc<UdpSocket>, flags: u32) -> io::Result<()> {
181        let flags = flags.to_le_bytes();
182        let broadcast_packet = Packet::with_header_and_data(LAN_SET_BROADCASTFLAGS, &flags);
183        let broadcast_packet: Vec<_> = broadcast_packet.into();
184        socket.send(&broadcast_packet).await?;
185        Ok(())
186    }
187
188    /// Keeps connection alive by sending a broadcast packet to the Z21 station.
189    fn start_keep_alive_setup_broadcast_task(&self) {
190        let socket = Arc::clone(&self.socket);
191        let flags = self.broadcast_flags;
192        let keep_alive = Arc::clone(&self.keep_alive);
193        tokio::spawn(async move {
194            loop {
195                let _result = Self::send_set_broadcast_flags(&socket, flags).await;
196                tokio::time::sleep(Duration::from_secs(10)).await;
197
198                if !keep_alive.load(Ordering::Relaxed) {
199                    break;
200                }
201            }
202        });
203    }
204
205    /// Sends a [`Packet`] asynchronously to the connected Z21 station.
206    ///
207    /// The packet is serialized into a byte vector and sent through the UDP socket.
208    ///
209    /// # Arguments
210    ///
211    /// * `packet` - The [`Packet`] to be transmitted.
212    ///
213    /// # Errors
214    ///
215    /// Returns an `io::Error` if the packet fails to send.
216    async fn send_packet(&self, packet: Packet) -> io::Result<()> {
217        let data: Vec<u8> = packet.into();
218        // Send the serialized packet through the connected UDP socket.
219        self.socket.send(&data).await?;
220        Ok(())
221    }
222    async fn send_packet_external(socket: &Arc<UdpSocket>, packet: Packet) -> io::Result<()> {
223        let data: Vec<u8> = packet.into();
224        // Send the serialized packet through the connected UDP socket.
225        socket.send(&data).await?;
226        Ok(())
227    }
228
229    /// Sends an XBus packet without waiting for a response
230    ///
231    /// # Arguments
232    ///
233    /// * `xbus_message` - The XBus message to send
234    ///
235    /// # Errors
236    ///
237    /// Returns an `io::Error` if the packet fails to send
238    async fn send_xbus_packet(&self, xbus_message: XBusMessage) -> io::Result<()> {
239        let data: Vec<u8> = xbus_message.into();
240        let packet = Packet::with_header_and_data(messages::XBUS_HEADER, &data);
241        self.send_packet(packet).await
242    }
243
244    /// Sends an XBus command and waits for the expected response
245    ///
246    /// # Arguments
247    ///
248    /// * `xbus_message` - The XBus message to send
249    /// * `expected_response_xbus_header` - Optional expected response header. If None, uses the sent message header
250    ///
251    /// # Errors
252    ///
253    /// Returns an `io::Error` if:
254    /// - The packet fails to send
255    /// - No response is received within the timeout period
256    /// - The response has an invalid format
257    async fn send_xbus_command(
258        &self,
259        xbus_message: XBusMessage,
260        expected_response_xbus_header: Option<u8>,
261    ) -> io::Result<XBusMessage> {
262        let x_header = xbus_message.get_x_header();
263        self.send_xbus_packet(xbus_message).await?;
264
265        let expected_header = expected_response_xbus_header.unwrap_or(x_header);
266        let xbus_return = self.receive_xbus_packet(expected_header).await?;
267        Ok(xbus_return)
268    }
269
270    /// Asynchronously waits for a packet with the specified header.
271    ///
272    /// This function listens on the internal broadcast channel and filters incoming packets,
273    /// returning the first packet that matches the given header value.
274    ///
275    /// # Arguments
276    ///
277    /// * `header` - The header value to filter for.
278    ///
279    /// # Errors
280    ///
281    /// Returns an `io::Error` if the broadcast channel is closed or an error occurs while receiving.
282    async fn receive_packet_with_header(&self, header: u16) -> io::Result<Packet> {
283        let mut msg_rcv = self.message_receiver.resubscribe();
284        match timeout(self.timeout, async {
285            loop {
286                match msg_rcv.recv().await {
287                    Ok(packet) => {
288                        if packet.get_header() == header {
289                            return Ok(packet);
290                        }
291                    }
292                    Err(_) => {
293                        return Err(io::Error::new(io::ErrorKind::Other, "Channel closed"));
294                    }
295                }
296            }
297        })
298        .await
299        {
300            Ok(result) => result,
301            Err(_) => Err(io::Error::new(
302                io::ErrorKind::TimedOut,
303                format!("Timeout waiting for packet with header 0x{:04x}", header),
304            )),
305        }
306    }
307
308    async fn receive_xbus_packet(&self, expected_xbus_header: u8) -> io::Result<XBusMessage> {
309        let mut msg_rcv = self.message_receiver.resubscribe();
310        match timeout(self.timeout, async {
311            loop {
312                match msg_rcv.recv().await {
313                    Ok(packet) => {
314                        if packet.get_header() == messages::XBUS_HEADER {
315                            let end_payload = packet.get_data_len() as isize - 4;
316                            if end_payload <= 0 {
317                                continue;
318                            }
319                            let end_payload = end_payload as usize;
320                            let payload = &packet.get_data()[0..end_payload];
321                            let xbus_msg = XBusMessage::try_from(payload);
322                            if let Ok(msg) = xbus_msg {
323                                if msg.get_x_header() == expected_xbus_header {
324                                    return Ok(msg);
325                                }
326                            }
327                        }
328                    }
329                    Err(_) => {
330                        return Err(io::Error::new(io::ErrorKind::Other, "Channel closed"));
331                    }
332                }
333            }
334        })
335        .await
336        {
337            Ok(result) => result,
338            Err(_) => Err(io::Error::new(
339                io::ErrorKind::TimedOut,
340                format!(
341                    "Timeout waiting for XBus message with header 0x{:02x}",
342                    expected_xbus_header
343                ),
344            )),
345        }
346    }
347
348    /// Receives a single packet from the internal broadcast channel.
349    ///
350    /// This method awaits the next available packet regardless of its header.
351    ///
352    /// # Errors
353    ///
354    /// Returns an `io::Error` if the broadcast channel is closed.
355    async fn receive_packet(&self) -> io::Result<Packet> {
356        let mut msg_rcv = self.message_receiver.resubscribe();
357        match timeout(self.timeout, async {
358            match msg_rcv.recv().await {
359                Ok(packet) => Ok(packet),
360                Err(_) => Err(io::Error::new(io::ErrorKind::Other, "Channel closed")),
361            }
362        })
363        .await
364        {
365            Ok(result) => result,
366            Err(_) => Err(io::Error::new(
367                io::ErrorKind::TimedOut,
368                "Timeout waiting for packet",
369            )),
370        }
371    }
372
373    /// Turns off the track voltage.
374    ///
375    /// This is equivalent to pressing the STOP button on the Z21 station or the MultiMaus
376    /// controller. It cuts power to all tracks, stopping all locomotives immediately.
377    ///
378    /// # Returns
379    ///
380    /// `Ok(())` if the command was successfully sent and acknowledged.
381    ///
382    /// # Errors
383    ///
384    /// Returns an `io::Error` if the command fails to send or no acknowledgment is received.
385    ///
386    /// # Example
387    ///
388    /// ```rust
389    /// // Emergency stop all locomotives by cutting track power
390    /// station.voltage_off().await?;
391    /// ```
392    pub async fn voltage_off(&self) -> io::Result<()> {
393        self.send_xbus_command(
394            XBusMessage::new_single(X_SET_TRACK_POWER_OFF.0, X_SET_TRACK_POWER_OFF.1),
395            Some(X_BC_TRACK_POWER),
396        )
397        .await?;
398        Ok(())
399    }
400
401    /// Turns on the track voltage.
402    ///
403    /// This restores power to the tracks after an emergency stop or when the system
404    /// is first started. It also disables programming mode if it was active.
405    ///
406    /// # Returns
407    ///
408    /// `Ok(())` if the command was successfully sent and acknowledged.
409    ///
410    /// # Errors
411    ///
412    /// Returns an `io::Error` if the command fails to send or no acknowledgment is received.
413    ///
414    /// # Example
415    ///
416    /// ```rust
417    /// // Restore power to the tracks
418    /// station.voltage_on().await?;
419    /// ```
420    pub async fn voltage_on(&self) -> io::Result<()> {
421        self.send_xbus_command(
422            XBusMessage::new_single(X_SET_TRACK_POWER_ON.0, X_SET_TRACK_POWER_ON.1),
423            Some(X_BC_TRACK_POWER),
424        )
425        .await?;
426        Ok(())
427    }
428
429    /// Retrieves the serial number from the Z21 station.
430    ///
431    /// # Returns
432    ///
433    /// The Z21 station's serial number as a 32-bit unsigned integer.
434    ///
435    /// # Errors
436    ///
437    /// Returns an `io::Error` if:
438    /// - Sending the request fails
439    /// - The response times out
440    /// - The response data is invalid (e.g., too short)
441    ///
442    /// # Example
443    ///
444    /// ```rust
445    /// let serial = station.get_serial_number().await?;
446    /// println!("Z21 station serial number: {}", serial);
447    /// ```
448    pub async fn get_serial_number(&self) -> io::Result<u32> {
449        let packet = Packet::with_header_and_data(0x10, &[]);
450        self.send_packet(packet).await?;
451        let response = self.receive_packet_with_header(0x10).await?;
452        let data = response.get_data();
453        if data.len() < 4 {
454            return Err(io::Error::new(
455                io::ErrorKind::InvalidData,
456                "Response data too short",
457            ));
458        }
459        Ok(u32::from_le_bytes([data[0], data[1], data[2], data[3]]))
460    }
461
462    /// Subscribes to system state updates from the Z21 station.
463    ///
464    /// This method sets up a polling mechanism to regularly request system state updates
465    /// and calls the provided callback function whenever new state information is received.
466    ///
467    /// # Arguments
468    ///
469    /// * `freq_in_sec` - Polling frequency in Hz (updates per second)
470    /// * `subscriber` - Callback function that receives `SystemState` updates
471    ///
472    /// # Example
473    ///
474    /// ```rust
475    /// station.subscribe_system_state(1.0, Box::new(|state| {
476    ///     println!("Main track voltage: {:.2}V", state.main_track_voltage);
477    ///     println!("Temperature: {}°C", state.temperature);
478    ///     println!("Current: {}mA", state.current);
479    /// }));
480    /// ```
481
482    pub fn subscribe_system_state(
483        &self,
484        freq_in_sec: f64,
485        subscriber: Box<dyn Fn(SystemState) + Send + Sync>,
486    ) {
487        let mut receiver = self.message_receiver.resubscribe();
488        let socket = Arc::clone(&self.socket);
489        let keep_alive = Arc::clone(&self.keep_alive);
490        let packet = Packet::with_header_and_data(LAN_SYSTEMSTATE_GETDATA, &[]);
491        tokio::spawn(async move {
492            loop {
493                let result = Self::send_packet_external(&socket, packet.clone()).await;
494                if result.is_err() {
495                    break;
496                }
497
498                time::sleep(Duration::from_millis((1000. / freq_in_sec) as u64)).await;
499
500                if !keep_alive.load(Ordering::Relaxed) {
501                    break;
502                }
503            }
504        });
505        tokio::spawn(async move {
506            while let Ok(packet) = receiver.recv().await {
507                if packet.get_header() == LAN_SYSTEMSTATE_DATACHANGED {
508                    let state = SystemState::try_from(&packet.get_data()[..]);
509                    if let Ok(state) = state {
510                        subscriber(state);
511                    }
512                }
513            }
514        });
515    }
516
517    /// Logs out from the Z21 station.
518    ///
519    /// This method should be called at the end of a session to gracefully terminate
520    /// the connection with the Z21 station.
521    ///
522    /// # Returns
523    ///
524    /// `Ok(())` if the logout command was successfully sent.
525    ///
526    /// # Errors
527    ///
528    /// Returns an `io::Error` if the logout command fails to send.
529    ///
530    /// # Example
531    ///
532    /// ```rust
533    /// // Clean up and disconnect from the Z21 station
534    /// station.logout().await?;
535    /// ```
536    pub async fn logout(&self) -> io::Result<()> {
537        let packet = Packet::with_header_and_data(0x30, &[]);
538        self.send_packet(packet).await
539    }
540}
541
542impl Drop for Z21Station {
543    fn drop(&mut self) {
544        self.keep_alive.store(false, Ordering::Relaxed);
545    }
546}