rust_rocket/
client.rs

1//! This module contains the main client code, including the [`RocketClient`] type.
2use crate::interpolation::*;
3use crate::track::*;
4use crate::Tracks;
5
6use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
7use std::{
8    convert::TryFrom,
9    io::{Cursor, Read, Write},
10    net::{TcpStream, ToSocketAddrs},
11};
12use thiserror::Error;
13
14#[derive(Debug, Error)]
15/// The `Error` Type. This is the main error type.
16pub enum Error {
17    #[error("Failed to establish a TCP connection with the Rocket server")]
18    /// Failure to connect to a rocket tracker. This can happen if the tracker is not running, the
19    /// address isn't correct or other network-related reasons.
20    Connect(#[source] std::io::Error),
21    #[error("Handshake with the Rocket server failed")]
22    /// Failure to transmit or receive greetings with the tracker
23    Handshake(#[source] std::io::Error),
24    #[error("The Rocket server greeting {0:?} wasn't correct")]
25    /// Handshake was performed but the the received greeting wasn't correct
26    HandshakeGreetingMismatch([u8; 12]),
27    #[error("Cannot set Rocket's TCP connection to nonblocking mode")]
28    /// Error from [`TcpStream::set_nonblocking`]
29    SetNonblocking(#[source] std::io::Error),
30    #[error("Rocket server disconnected")]
31    /// Network IO error during operation
32    IOError(#[source] std::io::Error),
33}
34
35#[derive(Debug)]
36enum ClientState {
37    New,
38    Incomplete(usize),
39    Complete,
40}
41
42#[derive(Debug, Copy, Clone)]
43/// The `Event` Type. These are the various events from the tracker.
44pub enum Event {
45    /// The tracker changes row.
46    SetRow(u32),
47    /// The tracker pauses or unpauses.
48    Pause(bool),
49    /// The tracker asks us to save our track data.
50    /// You may want to call [`RocketClient::save_tracks`] after receiving this event.
51    SaveTracks,
52}
53
54enum ReceiveResult {
55    Some(Event),
56    None,
57    Incomplete,
58}
59
60#[derive(Debug)]
61/// The `RocketClient` type. This contains the connected socket and other fields.
62pub struct RocketClient {
63    stream: TcpStream,
64    state: ClientState,
65    cmd: Vec<u8>,
66    tracks: Vec<Track>,
67}
68
69impl RocketClient {
70    /// Construct a new RocketClient.
71    ///
72    /// This constructs a new Rocket client and connects to localhost on port 1338.
73    ///
74    /// # Errors
75    ///
76    /// [`Error::Connect`] if connection cannot be established, or [`Error::Handshake`]
77    /// if the handshake fails.
78    ///
79    /// # Examples
80    ///
81    /// ```rust,no_run
82    /// # use rust_rocket::RocketClient;
83    /// let mut rocket = RocketClient::new()?;
84    /// # Ok::<(), rust_rocket::client::Error>(())
85    /// ```
86    pub fn new() -> Result<Self, Error> {
87        Self::connect(("localhost", 1338))
88    }
89
90    /// Construct a new RocketClient.
91    ///
92    /// This constructs a new Rocket client and connects to a specified host and port.
93    ///
94    /// # Errors
95    ///
96    /// [`Error::Connect`] if connection cannot be established, or [`Error::Handshake`]
97    /// if the handshake fails.
98    ///
99    /// # Examples
100    ///
101    /// ```rust,no_run
102    /// # use rust_rocket::RocketClient;
103    /// let mut rocket = RocketClient::connect(("localhost", 1338))?;
104    /// # Ok::<(), rust_rocket::client::Error>(())
105    /// ```
106    pub fn connect(addr: impl ToSocketAddrs) -> Result<Self, Error> {
107        let stream = TcpStream::connect(addr).map_err(Error::Connect)?;
108
109        let mut rocket = Self {
110            stream,
111            state: ClientState::New,
112            cmd: Vec::new(),
113            tracks: Vec::new(),
114        };
115
116        rocket.handshake()?;
117
118        rocket
119            .stream
120            .set_nonblocking(true)
121            .map_err(Error::SetNonblocking)?;
122
123        Ok(rocket)
124    }
125
126    /// Get track by name.
127    ///
128    /// If the track does not yet exist it will be created.
129    ///
130    /// # Errors
131    ///
132    /// This method can return an [`Error::IOError`] if Rocket tracker disconnects.
133    ///
134    /// # Panics
135    ///
136    /// Will panic if `name`'s length exceeds [`u32::MAX`].
137    ///
138    /// # Examples
139    ///
140    /// ```rust,no_run
141    /// # use rust_rocket::RocketClient;
142    /// # let mut rocket = RocketClient::new()?;
143    /// let track = rocket.get_track_mut("namespace:track")?;
144    /// track.get_value(3.5);
145    /// # Ok::<(), rust_rocket::client::Error>(())
146    /// ```
147    pub fn get_track_mut(&mut self, name: &str) -> Result<&mut Track, Error> {
148        if let Some((i, _)) = self
149            .tracks
150            .iter()
151            .enumerate()
152            .find(|(_, t)| t.get_name() == name)
153        {
154            Ok(&mut self.tracks[i])
155        } else {
156            // Send GET_TRACK message
157            let mut buf = vec![2];
158            buf.write_u32::<BigEndian>(u32::try_from(name.len()).expect("Track name too long"))
159                .unwrap_or_else(|_|
160                // Can writes to a vec fail? Consider changing to unreachable_unchecked in 1.0
161                unreachable!());
162            buf.extend_from_slice(name.as_bytes());
163            self.stream.write_all(&buf).map_err(Error::IOError)?;
164
165            self.tracks.push(Track::new(name));
166            Ok(self.tracks.last_mut().unwrap_or_else(||
167                // tracks cannot be empty right after pushing into it, consider changing to
168                // unreachable_unchecked in 1.0
169                unreachable!()))
170        }
171    }
172
173    /// Get track by name.
174    ///
175    /// You should use [`get_track_mut`](RocketClient::get_track_mut) to create a track.
176    pub fn get_track(&self, name: &str) -> Option<&Track> {
177        self.tracks.iter().find(|t| t.get_name() == name)
178    }
179
180    /// Get a snapshot of the tracks in the session.
181    /// The returned [`Tracks`] can be dumped to a file in any [supported format](crate#features).
182    /// The counterpart to this function is [`RocketPlayer::new`](crate::RocketPlayer::new),
183    /// which loads tracks for playback.
184    ///
185    /// # Example
186    ///
187    /// ```rust,no_run
188    /// # use rust_rocket::RocketClient;
189    /// # use std::fs::OpenOptions;
190    /// let mut rocket = RocketClient::new()?;
191    ///
192    /// // Create tracks, call poll_events, etc...
193    ///
194    /// // Open a file for writing
195    /// let mut file = OpenOptions::new()
196    ///     .write(true)
197    ///     .create(true)
198    ///     .truncate(true)
199    ///     .open("tracks.bin")
200    ///     .expect("Failed to open tracks.bin for writing");
201    ///
202    /// // Save a snapshot of the client to a file for playback in release builds
203    /// let tracks = rocket.save_tracks();
204    /// # #[cfg(feature = "bincode")]
205    /// bincode::encode_into_std_write(tracks, &mut file, bincode::config::standard())
206    ///     .expect("Failed to encode tracks.bin");
207    /// # Ok::<(), rust_rocket::client::Error>(())
208    /// ```
209    pub fn save_tracks(&self) -> &Tracks {
210        &self.tracks
211    }
212
213    /// Send a SetRow message.
214    ///
215    /// This changes the current row on the tracker side.
216    ///
217    /// # Errors
218    ///
219    /// This method can return an [`Error::IOError`] if Rocket tracker disconnects.
220    pub fn set_row(&mut self, row: u32) -> Result<(), Error> {
221        // Send SET_ROW message
222        let mut buf = vec![3];
223        buf.write_u32::<BigEndian>(row).unwrap_or_else(|_|
224                // Can writes to a vec fail? Consider changing to unreachable_unchecked in 1.0
225                unreachable!());
226        self.stream.write_all(&buf).map_err(Error::IOError)
227    }
228
229    /// Poll for new events from the tracker.
230    ///
231    /// This polls from events from the tracker.
232    /// You should call this fairly often your main loop.
233    /// It is recommended to keep calling this as long as your receive `Some(Event)`.
234    ///
235    /// # Errors
236    ///
237    /// This method can return an [`Error::IOError`] if Rocket tracker disconnects.
238    ///
239    /// # Examples
240    ///
241    /// ```rust,no_run
242    /// # use rust_rocket::RocketClient;
243    /// # let mut rocket = RocketClient::new()?;
244    /// while let Some(event) = rocket.poll_events()? {
245    ///     match event {
246    ///         // Do something with the various events.
247    ///         _ => (),
248    ///     }
249    /// }
250    /// # Ok::<(), rust_rocket::client::Error>(())
251    /// ```
252    pub fn poll_events(&mut self) -> Result<Option<Event>, Error> {
253        loop {
254            let result = self.poll_event()?;
255            match result {
256                ReceiveResult::None => return Ok(None),
257                ReceiveResult::Incomplete => (),
258                ReceiveResult::Some(event) => return Ok(Some(event)),
259            }
260        }
261    }
262
263    fn poll_event(&mut self) -> Result<ReceiveResult, Error> {
264        match self.state {
265            ClientState::New => {
266                let mut buf = [0; 1];
267                match self.stream.read_exact(&mut buf) {
268                    Ok(()) => {
269                        self.cmd.extend_from_slice(&buf);
270                        match self.cmd[0] {
271                            0 => self.state = ClientState::Incomplete(4 + 4 + 4 + 1), //SET_KEY
272                            1 => self.state = ClientState::Incomplete(4 + 4),         //DELETE_KEY
273                            3 => self.state = ClientState::Incomplete(4),             //SET_ROW
274                            4 => self.state = ClientState::Incomplete(1),             //PAUSE
275                            5 => self.state = ClientState::Complete,                  //SAVE_TRACKS
276                            _ => self.state = ClientState::Complete, // Error / Unknown
277                        }
278                        Ok(ReceiveResult::Incomplete)
279                    }
280                    Err(e) => match e.kind() {
281                        std::io::ErrorKind::WouldBlock => Ok(ReceiveResult::None),
282                        _ => Err(Error::IOError(e)),
283                    },
284                }
285            }
286            ClientState::Incomplete(bytes) => {
287                let mut buf = vec![0; bytes];
288                match self.stream.read(&mut buf) {
289                    Ok(bytes_read) => {
290                        self.cmd.extend_from_slice(&buf);
291                        if bytes - bytes_read > 0 {
292                            self.state = ClientState::Incomplete(bytes - bytes_read);
293                        } else {
294                            self.state = ClientState::Complete;
295                        }
296                        Ok(ReceiveResult::Incomplete)
297                    }
298                    Err(e) => match e.kind() {
299                        std::io::ErrorKind::WouldBlock => Ok(ReceiveResult::None),
300                        _ => Err(Error::IOError(e)),
301                    },
302                }
303            }
304            ClientState::Complete => {
305                let mut result = ReceiveResult::None;
306                {
307                    // Following reads from cmd should never fail if above match arms are correct
308                    let mut cursor = Cursor::new(&self.cmd);
309                    let cmd = cursor.read_u8().unwrap();
310                    match cmd {
311                        0 => {
312                            // usize::try_from(u32) will only be None if usize is smaller, and
313                            // more than usize::MAX tracks are in use. That isn't possible because
314                            // I'd imagine Vec::push and everything else will panic first.
315                            // If you're running this on a microcontroller, I'd love to see it!
316                            let track = &mut self.tracks
317                                [usize::try_from(cursor.read_u32::<BigEndian>().unwrap()).unwrap()];
318                            let row = cursor.read_u32::<BigEndian>().unwrap();
319                            let value = cursor.read_f32::<BigEndian>().unwrap();
320                            let interpolation = Interpolation::from(cursor.read_u8().unwrap());
321                            let key = Key::new(row, value, interpolation);
322
323                            track.set_key(key);
324                        }
325                        1 => {
326                            let track = &mut self.tracks
327                                [usize::try_from(cursor.read_u32::<BigEndian>().unwrap()).unwrap()];
328                            let row = cursor.read_u32::<BigEndian>().unwrap();
329
330                            track.delete_key(row);
331                        }
332                        3 => {
333                            let row = cursor.read_u32::<BigEndian>().unwrap();
334                            result = ReceiveResult::Some(Event::SetRow(row));
335                        }
336                        4 => {
337                            let flag = cursor.read_u8().unwrap() == 1;
338                            result = ReceiveResult::Some(Event::Pause(flag));
339                        }
340                        5 => {
341                            result = ReceiveResult::Some(Event::SaveTracks);
342                        }
343                        _ => println!("Unknown {:?}", cmd),
344                    }
345                }
346
347                self.cmd.clear();
348                self.state = ClientState::New;
349
350                Ok(result)
351            }
352        }
353    }
354
355    fn handshake(&mut self) -> Result<(), Error> {
356        let client_greeting = b"hello, synctracker!";
357        let server_greeting = b"hello, demo!";
358
359        self.stream
360            .write_all(client_greeting)
361            .map_err(Error::Handshake)?;
362
363        let mut buf = [0; 12];
364        self.stream.read_exact(&mut buf).map_err(Error::Handshake)?;
365
366        if &buf == server_greeting {
367            Ok(())
368        } else {
369            Err(Error::HandshakeGreetingMismatch(buf))
370        }
371    }
372}