rust_rocket/
client.rs

1//! This module contains the main client code, including the [`RocketClient`] type.
2use crate::interpolation::*;
3use crate::track::*;
4
5use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
6use std::{
7    convert::TryFrom,
8    io::{Cursor, Read, Write},
9    net::{TcpStream, ToSocketAddrs},
10};
11use thiserror::Error;
12
13#[derive(Debug, Error)]
14/// The `Error` Type. This is the main error type.
15pub enum Error {
16    #[error("Failed to establish a TCP connection with the Rocket server")]
17    /// Failure to connect to a rocket tracker. This can happen if the tracker is not running, the
18    /// address isn't correct or other network-related reasons.
19    Connect(#[source] std::io::Error),
20    #[error("Handshake with the Rocket server failed")]
21    /// Failure to transmit or receive greetings with the tracker
22    Handshake(#[source] std::io::Error),
23    #[error("The Rocket server greeting {0:?} wasn't correct")]
24    /// Handshake was performed but the the received greeting wasn't correct
25    HandshakeGreetingMismatch([u8; 12]),
26    #[error("Cannot set Rocket's TCP connection to nonblocking mode")]
27    /// Error from [`TcpStream::set_nonblocking`]
28    SetNonblocking(#[source] std::io::Error),
29    #[error("Rocket server disconnected")]
30    /// Network IO error during operation
31    IOError(#[source] std::io::Error),
32}
33
34#[derive(Debug)]
35enum ClientState {
36    New,
37    Incomplete(usize),
38    Complete,
39}
40
41#[derive(Debug, Copy, Clone)]
42/// The `Event` Type. These are the various events from the tracker.
43pub enum Event {
44    /// The tracker changes row.
45    SetRow(u32),
46    /// The tracker pauses or unpauses.
47    Pause(bool),
48    /// The tracker asks us to save our track data.
49    /// You may want to call [`RocketClient::save_tracks`] after receiving this event.
50    SaveTracks,
51}
52
53enum ReceiveResult {
54    Some(Event),
55    None,
56    Incomplete,
57}
58
59#[derive(Debug)]
60/// The `RocketClient` type. This contains the connected socket and other fields.
61pub struct RocketClient {
62    stream: TcpStream,
63    state: ClientState,
64    cmd: Vec<u8>,
65    tracks: Vec<Track>,
66}
67
68impl RocketClient {
69    /// Construct a new RocketClient.
70    ///
71    /// This constructs a new Rocket client and connects to localhost on port 1338.
72    ///
73    /// # Errors
74    ///
75    /// [`Error::Connect`] if connection cannot be established, or [`Error::Handshake`]
76    /// if the handshake fails.
77    ///
78    /// # Examples
79    ///
80    /// ```rust,no_run
81    /// # use rust_rocket::RocketClient;
82    /// let mut rocket = RocketClient::new().unwrap();
83    /// ```
84    pub fn new() -> Result<Self, Error> {
85        Self::connect(("localhost", 1338))
86    }
87
88    /// Construct a new RocketClient.
89    ///
90    /// This constructs a new Rocket client and connects to a specified host and port.
91    ///
92    /// # Errors
93    ///
94    /// [`Error::Connect`] if connection cannot be established, or [`Error::Handshake`]
95    /// if the handshake fails.
96    ///
97    /// # Examples
98    ///
99    /// ```rust,no_run
100    /// # use rust_rocket::RocketClient;
101    /// let mut rocket = RocketClient::connect(("localhost", 1338)).unwrap();
102    /// ```
103    pub fn connect(addr: impl ToSocketAddrs) -> Result<Self, Error> {
104        let stream = TcpStream::connect(addr).map_err(Error::Connect)?;
105
106        let mut rocket = Self {
107            stream,
108            state: ClientState::New,
109            cmd: Vec::new(),
110            tracks: Vec::new(),
111        };
112
113        rocket.handshake()?;
114
115        rocket
116            .stream
117            .set_nonblocking(true)
118            .map_err(Error::SetNonblocking)?;
119
120        Ok(rocket)
121    }
122
123    /// Get track by name.
124    ///
125    /// If the track does not yet exist it will be created.
126    ///
127    /// # Errors
128    ///
129    /// This method can return an [`Error::IOError`] if Rocket tracker disconnects.
130    ///
131    /// # Panics
132    ///
133    /// Will panic if `name`'s length exceeds [`u32::MAX`].
134    ///
135    /// # Examples
136    ///
137    /// ```rust,no_run
138    /// # use rust_rocket::RocketClient;
139    /// # let mut rocket = RocketClient::new().unwrap();
140    /// let track = rocket.get_track_mut("namespace:track").unwrap();
141    /// track.get_value(3.5);
142    /// ```
143    pub fn get_track_mut(&mut self, name: &str) -> Result<&mut Track, Error> {
144        if let Some((i, _)) = self
145            .tracks
146            .iter()
147            .enumerate()
148            .find(|(_, t)| t.get_name() == name)
149        {
150            Ok(&mut self.tracks[i])
151        } else {
152            // Send GET_TRACK message
153            let mut buf = vec![2];
154            buf.write_u32::<BigEndian>(u32::try_from(name.len()).expect("Track name too long"))
155                .unwrap_or_else(|_|
156                // Can writes to a vec fail? Consider changing to unreachable_unchecked in 1.0
157                unreachable!());
158            buf.extend_from_slice(name.as_bytes());
159            self.stream.write_all(&buf).map_err(Error::IOError)?;
160
161            self.tracks.push(Track::new(name));
162            Ok(self.tracks.last_mut().unwrap_or_else(||
163                // tracks cannot be empty right after pushing into it, consider changing to
164                // unreachable_unchecked in 1.0
165                unreachable!()))
166        }
167    }
168
169    /// Get track by name.
170    ///
171    /// You should use [`get_track_mut`](RocketClient::get_track_mut) to create a track.
172    pub fn get_track(&self, name: &str) -> Option<&Track> {
173        self.tracks.iter().find(|t| t.get_name() == name)
174    }
175
176    /// Create a clone of the tracks in the session which can then be serialized to a file in any
177    /// format with a serde implementation.
178    /// Tracks can be turned into a [`RocketPlayer`](crate::RocketPlayer::new) for playback.
179    pub fn save_tracks(&self) -> Vec<Track> {
180        self.tracks.clone()
181    }
182
183    /// Send a SetRow message.
184    ///
185    /// This changes the current row on the tracker side.
186    ///
187    /// # Errors
188    ///
189    /// This method can return an [`Error::IOError`] if Rocket tracker disconnects.
190    pub fn set_row(&mut self, row: u32) -> Result<(), Error> {
191        // Send SET_ROW message
192        let mut buf = vec![3];
193        buf.write_u32::<BigEndian>(row).unwrap_or_else(|_|
194                // Can writes to a vec fail? Consider changing to unreachable_unchecked in 1.0
195                unreachable!());
196        self.stream.write_all(&buf).map_err(Error::IOError)
197    }
198
199    /// Poll for new events from the tracker.
200    ///
201    /// This polls from events from the tracker.
202    /// You should call this fairly often your main loop.
203    /// It is recommended to keep calling this as long as your receive `Some(Event)`.
204    ///
205    /// # Errors
206    ///
207    /// This method can return an [`Error::IOError`] if Rocket tracker disconnects.
208    ///
209    /// # Examples
210    ///
211    /// ```rust,no_run
212    /// # use rust_rocket::RocketClient;
213    /// # let mut rocket = RocketClient::new().unwrap();
214    /// while let Some(event) = rocket.poll_events().unwrap() {
215    ///     match event {
216    ///         // Do something with the various events.
217    ///         _ => (),
218    ///     }
219    /// }
220    /// ```
221    pub fn poll_events(&mut self) -> Result<Option<Event>, Error> {
222        loop {
223            let result = self.poll_event()?;
224            match result {
225                ReceiveResult::None => return Ok(None),
226                ReceiveResult::Incomplete => (),
227                ReceiveResult::Some(event) => return Ok(Some(event)),
228            }
229        }
230    }
231
232    fn poll_event(&mut self) -> Result<ReceiveResult, Error> {
233        match self.state {
234            ClientState::New => {
235                let mut buf = [0; 1];
236                match self.stream.read_exact(&mut buf) {
237                    Ok(()) => {
238                        self.cmd.extend_from_slice(&buf);
239                        match self.cmd[0] {
240                            0 => self.state = ClientState::Incomplete(4 + 4 + 4 + 1), //SET_KEY
241                            1 => self.state = ClientState::Incomplete(4 + 4),         //DELETE_KEY
242                            3 => self.state = ClientState::Incomplete(4),             //SET_ROW
243                            4 => self.state = ClientState::Incomplete(1),             //PAUSE
244                            5 => self.state = ClientState::Complete,                  //SAVE_TRACKS
245                            _ => self.state = ClientState::Complete, // Error / Unknown
246                        }
247                        Ok(ReceiveResult::Incomplete)
248                    }
249                    Err(e) => match e.kind() {
250                        std::io::ErrorKind::WouldBlock => Ok(ReceiveResult::None),
251                        _ => Err(Error::IOError(e)),
252                    },
253                }
254            }
255            ClientState::Incomplete(bytes) => {
256                let mut buf = vec![0; bytes];
257                match self.stream.read(&mut buf) {
258                    Ok(bytes_read) => {
259                        self.cmd.extend_from_slice(&buf);
260                        if bytes - bytes_read > 0 {
261                            self.state = ClientState::Incomplete(bytes - bytes_read);
262                        } else {
263                            self.state = ClientState::Complete;
264                        }
265                        Ok(ReceiveResult::Incomplete)
266                    }
267                    Err(e) => match e.kind() {
268                        std::io::ErrorKind::WouldBlock => Ok(ReceiveResult::None),
269                        _ => Err(Error::IOError(e)),
270                    },
271                }
272            }
273            ClientState::Complete => {
274                let mut result = ReceiveResult::None;
275                {
276                    // Following reads from cmd should never fail if above match arms are correct
277                    let mut cursor = Cursor::new(&self.cmd);
278                    let cmd = cursor.read_u8().unwrap();
279                    match cmd {
280                        0 => {
281                            // usize::try_from(u32) will only be None if usize is smaller, and
282                            // more than usize::MAX tracks are in use. That isn't possible because
283                            // I'd imagine Vec::push and everything else will panic first.
284                            // If you're running this on a microcontroller, I'd love to see it!
285                            let track = &mut self.tracks
286                                [usize::try_from(cursor.read_u32::<BigEndian>().unwrap()).unwrap()];
287                            let row = cursor.read_u32::<BigEndian>().unwrap();
288                            let value = cursor.read_f32::<BigEndian>().unwrap();
289                            let interpolation = Interpolation::from(cursor.read_u8().unwrap());
290                            let key = Key::new(row, value, interpolation);
291
292                            track.set_key(key);
293                        }
294                        1 => {
295                            let track = &mut self.tracks
296                                [usize::try_from(cursor.read_u32::<BigEndian>().unwrap()).unwrap()];
297                            let row = cursor.read_u32::<BigEndian>().unwrap();
298
299                            track.delete_key(row);
300                        }
301                        3 => {
302                            let row = cursor.read_u32::<BigEndian>().unwrap();
303                            result = ReceiveResult::Some(Event::SetRow(row));
304                        }
305                        4 => {
306                            let flag = cursor.read_u8().unwrap() == 1;
307                            result = ReceiveResult::Some(Event::Pause(flag));
308                        }
309                        5 => {
310                            result = ReceiveResult::Some(Event::SaveTracks);
311                        }
312                        _ => println!("Unknown {:?}", cmd),
313                    }
314                }
315
316                self.cmd.clear();
317                self.state = ClientState::New;
318
319                Ok(result)
320            }
321        }
322    }
323
324    fn handshake(&mut self) -> Result<(), Error> {
325        let client_greeting = b"hello, synctracker!";
326        let server_greeting = b"hello, demo!";
327
328        self.stream
329            .write_all(client_greeting)
330            .map_err(Error::Handshake)?;
331
332        let mut buf = [0; 12];
333        self.stream.read_exact(&mut buf).map_err(Error::Handshake)?;
334
335        if &buf == server_greeting {
336            Ok(())
337        } else {
338            Err(Error::HandshakeGreetingMismatch(buf))
339        }
340    }
341}