rust_rocket/
client.rs

1//! Main client code, including the [`RocketClient`] type.
2//!
3//! # Usage
4//!
5//! The usual workflow with the low level client API can be described in a few steps:
6//!
7//! 0. Install a rocket tracker ([original Qt editor](https://github.com/rocket/rocket)
8//!    or [emoon's OpenGL-based editor](https://github.com/emoon/rocket))
9//! 1. Connect the [`RocketClient`] to the running tracker by calling [`RocketClient::new`]
10//! 2. Create tracks with [`get_track_mut`](RocketClient::get_track_mut)
11//! 3. In your main loop, poll for updates from the Rocket tracker by calling [`poll_events`](RocketClient::poll_events).
12//! 4. Keep the tracker in sync by calling [`set_row`](RocketClient::set_row) (see tips below)
13//! 5. Get values from the tracks with [`Track::get_value`]
14//!
15//! See the linked documentation items and the examples-directory for more examples.
16//!
17//! # Tips
18//!
19//! The library is agnostic to your source of time. In a typical production, some kind of music player library
20//! determines the time for everything else, including the rocket tracks.
21//! It's recommended that you treat every 8th row as a beat of music instead of real time in seconds.
22//!
23//! ```rust,no_run
24//! # use std::time::Duration;
25//! # use rust_rocket::client::{RocketClient, Event, Error};
26//! struct MusicPlayer; // Your music player, not included in this crate
27//! # impl MusicPlayer {
28//! #     fn new() -> Self { Self }
29//! #     fn get_time(&self) -> Duration { Duration::ZERO }
30//! #     fn seek(&self, _to: Duration) {}
31//! #     fn pause(&self, _state: bool) {}
32//! # }
33//!
34//! const ROWS_PER_BEAT: f32 = 8.;
35//! const BEATS_PER_MIN: f32 = 123.; // This depends on your choice of music track
36//! const SECS_PER_MIN: f32  = 60.;
37//!
38//! fn time_to_row(time: Duration) -> f32 {
39//!     let secs = time.as_secs_f32();
40//!     let beats = secs * BEATS_PER_MIN / SECS_PER_MIN;
41//!     beats * ROWS_PER_BEAT
42//! }
43//!
44//! fn row_to_time(row: u32) -> Duration {
45//!     let beats = row as f32 / ROWS_PER_BEAT;
46//!     let secs = beats / (BEATS_PER_MIN / SECS_PER_MIN);
47//!     Duration::from_secs_f32(secs)
48//! }
49//!
50//! fn get(rocket: &mut RocketClient, track: &str, row: f32) -> f32 {
51//!     let track = rocket.get_track_mut(track).unwrap();
52//!     track.get_value(row)
53//! }
54//!
55//! fn main() -> Result<(), Error> {
56//!     let mut music = MusicPlayer::new(/* ... */);
57//!     let mut rocket = RocketClient::new()?;
58//!
59//!     // Create window, render resources etc...
60//!
61//!     loop {
62//!         // Get current frame's time
63//!         let time = music.get_time();
64//!         let row = time_to_row(time);
65//!
66//!         // Keep the rocket tracker in sync.
67//!         // When using the low level API, it's recommended to combine consecutive seek events to a single seek.
68//!         // This ensures the smoothest scrolling in editor.
69//!         let mut seek = None;
70//!         while let Some(event) = rocket.poll_events()? {
71//!             match event {
72//!                 Event::SetRow(to) => seek = Some(to),
73//!                 Event::Pause(state) => music.pause(state),
74//!                 Event::SaveTracks => {/* Call save_tracks and write to a file */}
75//!             }
76//!         }
77//!         // When using the low level API, it's recommended to call set_time only when the not seeking.
78//!         if let Some(seek) = seek {
79//!             music.seek(row_to_time(seek));
80//!             continue;
81//!         }
82//!         rocket.set_row(row as u32)?;
83//!
84//!         // Render frame and read values with Track's get_value function
85//!         let _ = get(&mut rocket, "track0", row);
86//!     }
87//! }
88//! ```
89use crate::interpolation::*;
90use crate::track::*;
91use crate::Tracks;
92
93use byteorder::ByteOrder;
94use byteorder::{BigEndian, ReadBytesExt};
95use std::hint::unreachable_unchecked;
96use std::{
97    convert::TryFrom,
98    io::{self, Cursor, Read, Write},
99    net::{TcpStream, ToSocketAddrs},
100};
101use thiserror::Error;
102
103// Rocket protocol commands
104const CLIENT_GREETING: &[u8] = b"hello, synctracker!";
105const SERVER_GREETING: &[u8] = b"hello, demo!";
106
107const SET_KEY: u8 = 0;
108const DELETE_KEY: u8 = 1;
109const GET_TRACK: u8 = 2;
110const SET_ROW: u8 = 3;
111const PAUSE: u8 = 4;
112const SAVE_TRACKS: u8 = 5;
113
114const SET_KEY_LEN: usize = 4 + 4 + 4 + 1;
115const DELETE_KEY_LEN: usize = 4 + 4;
116const GET_TRACK_LEN: usize = 4; // Does not account for name length
117const SET_ROW_LEN: usize = 4;
118const PAUSE_LEN: usize = 1;
119
120const MAX_COMMAND_LEN: usize = SET_KEY_LEN;
121
122/// The `Error` Type. This is the main error type.
123#[derive(Debug, Error)]
124pub enum Error {
125    /// Failure to connect to a rocket tracker. This can happen if the tracker is not running, the
126    /// address isn't correct or other network-related reasons.
127    #[error("Failed to establish a TCP connection with the Rocket tracker")]
128    Connect(#[source] std::io::Error),
129    /// Failure to transmit or receive greetings with the tracker
130    #[error("Handshake with the Rocket tracker failed")]
131    Handshake(#[source] std::io::Error),
132    /// Handshake was performed but the the received greeting wasn't correct
133    #[error("The Rocket tracker greeting {0:?} wasn't correct")]
134    HandshakeGreetingMismatch([u8; SERVER_GREETING.len()]),
135    /// Error from [`TcpStream::set_nonblocking`]
136    #[error("Cannot set Rocket's TCP connection to nonblocking mode")]
137    SetNonblocking(#[source] std::io::Error),
138    /// Network IO error during operation
139    #[error("Rocket tracker disconnected")]
140    IOError(#[source] std::io::Error),
141}
142
143#[derive(Debug)]
144enum ClientState {
145    New,
146    Incomplete(usize),
147    Complete,
148}
149
150/// The `Event` Type. These are the various events from the tracker.
151#[derive(Debug, Copy, Clone)]
152pub enum Event {
153    /// The tracker changes row.
154    SetRow(u32),
155    /// The tracker pauses or unpauses.
156    Pause(bool),
157    /// The tracker asks us to save our track data.
158    /// You may want to call [`RocketClient::save_tracks`] after receiving this event.
159    SaveTracks,
160}
161
162#[derive(Debug)]
163enum ReceiveResult {
164    Some(Event),
165    None,
166    Incomplete,
167}
168
169/// The `RocketClient` type. This contains the connected socket and other fields.
170#[derive(Debug)]
171pub struct RocketClient {
172    stream: TcpStream,
173    state: ClientState,
174    cmd: Vec<u8>,
175    tracks: Vec<Track>,
176}
177
178impl RocketClient {
179    /// Construct a new RocketClient.
180    ///
181    /// This constructs a new Rocket client and connects to localhost on port 1338.
182    ///
183    /// # Errors
184    ///
185    /// [`Error::Connect`] if connection cannot be established, or [`Error::Handshake`]
186    /// if the handshake fails.
187    ///
188    /// # Examples
189    ///
190    /// ```rust,no_run
191    /// # use rust_rocket::RocketClient;
192    /// let mut rocket = RocketClient::new()?;
193    /// # Ok::<(), rust_rocket::client::Error>(())
194    /// ```
195    pub fn new() -> Result<Self, Error> {
196        Self::connect(("localhost", 1338))
197    }
198
199    /// Construct a new RocketClient.
200    ///
201    /// This constructs a new Rocket client and connects to a specified host and port.
202    ///
203    /// # Errors
204    ///
205    /// [`Error::Connect`] if connection cannot be established, or [`Error::Handshake`]
206    /// if the handshake fails.
207    ///
208    /// # Examples
209    ///
210    /// ```rust,no_run
211    /// # use rust_rocket::RocketClient;
212    /// let mut rocket = RocketClient::connect(("localhost", 1338))?;
213    /// # Ok::<(), rust_rocket::client::Error>(())
214    /// ```
215    pub fn connect(addr: impl ToSocketAddrs) -> Result<Self, Error> {
216        let stream = TcpStream::connect(addr).map_err(Error::Connect)?;
217
218        let mut rocket = Self {
219            stream,
220            state: ClientState::New,
221            cmd: Vec::new(),
222            tracks: Vec::new(),
223        };
224
225        rocket.handshake()?;
226
227        rocket
228            .stream
229            .set_nonblocking(true)
230            .map_err(Error::SetNonblocking)?;
231
232        Ok(rocket)
233    }
234
235    /// Get track by name.
236    ///
237    /// If the track does not yet exist it will be created.
238    ///
239    /// # Errors
240    ///
241    /// This method can return an [`Error::IOError`] if Rocket tracker disconnects.
242    ///
243    /// # Panics
244    ///
245    /// Will panic if `name`'s length exceeds [`u32::MAX`].
246    ///
247    /// # Examples
248    ///
249    /// ```rust,no_run
250    /// # use rust_rocket::RocketClient;
251    /// # let mut rocket = RocketClient::new()?;
252    /// let track = rocket.get_track_mut("namespace:track")?;
253    /// track.get_value(3.5);
254    /// # Ok::<(), rust_rocket::client::Error>(())
255    /// ```
256    pub fn get_track_mut(&mut self, name: &str) -> Result<&mut Track, Error> {
257        if let Some((i, _)) = self
258            .tracks
259            .iter()
260            .enumerate()
261            .find(|(_, t)| t.get_name() == name)
262        {
263            Ok(&mut self.tracks[i])
264        } else {
265            // Send GET_TRACK message
266            let mut buf = [GET_TRACK; 1 + GET_TRACK_LEN];
267            let name_len = u32::try_from(name.len()).expect("Track name too long");
268            BigEndian::write_u32(&mut buf[1..][..GET_TRACK_LEN], name_len);
269            self.stream.write_all(&buf).map_err(Error::IOError)?;
270            self.stream
271                .write_all(name.as_bytes())
272                .map_err(Error::IOError)?;
273
274            self.tracks.push(Track::new(name));
275            // SAFETY: tracks cannot be empty, because it was pushed to on the previous line
276            let track = self
277                .tracks
278                .last_mut()
279                .unwrap_or_else(|| unsafe { unreachable_unchecked() });
280            Ok(track)
281        }
282    }
283
284    /// Get track by name.
285    ///
286    /// You should use [`get_track_mut`](RocketClient::get_track_mut) to create a track.
287    pub fn get_track(&self, name: &str) -> Option<&Track> {
288        self.tracks.iter().find(|t| t.get_name() == name)
289    }
290
291    /// Get a snapshot of the tracks in the session.
292    ///
293    /// The returned [`Tracks`] can be dumped to a file in any [supported format](crate#features).
294    /// The counterpart to this function is [`RocketPlayer::new`](crate::RocketPlayer::new),
295    /// which loads tracks for playback.
296    ///
297    /// # Example
298    ///
299    /// ```rust,no_run
300    /// # use rust_rocket::RocketClient;
301    /// # use std::fs::OpenOptions;
302    /// let mut rocket = RocketClient::new()?;
303    ///
304    /// // Create tracks, call poll_events, etc...
305    ///
306    /// // Open a file for writing
307    /// let mut file = OpenOptions::new()
308    ///     .write(true)
309    ///     .create(true)
310    ///     .truncate(true)
311    ///     .open("tracks.bin")
312    ///     .expect("Failed to open tracks.bin for writing");
313    ///
314    /// // Save a snapshot of the client to a file for playback in release builds
315    /// let tracks = rocket.save_tracks();
316    /// # #[cfg(feature = "bincode")]
317    /// bincode::encode_into_std_write(tracks, &mut file, bincode::config::standard())
318    ///     .expect("Failed to encode tracks.bin");
319    /// # Ok::<(), rust_rocket::client::Error>(())
320    /// ```
321    pub fn save_tracks(&self) -> &Tracks {
322        &self.tracks
323    }
324
325    /// Send a SetRow message.
326    ///
327    /// This changes the current row on the tracker side.
328    ///
329    /// # Errors
330    ///
331    /// This method can return an [`Error::IOError`] if Rocket tracker disconnects.
332    pub fn set_row(&mut self, row: u32) -> Result<(), Error> {
333        // Send SET_ROW message
334        let mut buf = [SET_ROW; 1 + SET_ROW_LEN];
335        BigEndian::write_u32(&mut buf[1..][..SET_ROW_LEN], row);
336        self.stream.write_all(&buf).map_err(Error::IOError)
337    }
338
339    /// Poll for new events from the tracker.
340    ///
341    /// This polls from events from the tracker.
342    /// You should call this fairly often your main loop.
343    /// It is recommended to keep calling this as long as your receive `Some(Event)`.
344    ///
345    /// # Errors
346    ///
347    /// This method can return an [`Error::IOError`] if the rocket tracker disconnects.
348    ///
349    /// # Examples
350    ///
351    /// ```rust,no_run
352    /// # use rust_rocket::RocketClient;
353    /// # let mut rocket = RocketClient::new()?;
354    /// while let Some(event) = rocket.poll_events()? {
355    ///     match event {
356    ///         // Do something with the various events.
357    ///         _ => (),
358    ///     }
359    /// }
360    /// # Ok::<(), rust_rocket::client::Error>(())
361    /// ```
362    pub fn poll_events(&mut self) -> Result<Option<Event>, Error> {
363        loop {
364            match self.poll_event()? {
365                ReceiveResult::None => return Ok(None),
366                ReceiveResult::Incomplete => { /* Keep reading */ }
367                ReceiveResult::Some(event) => return Ok(Some(event)),
368            }
369        }
370    }
371
372    fn poll_event(&mut self) -> Result<ReceiveResult, Error> {
373        match self.state {
374            ClientState::New => self.poll_event_new(),
375            ClientState::Incomplete(bytes) => self.poll_event_incomplete(bytes),
376            ClientState::Complete => Ok(self.process_event().unwrap_or_else(|_| unreachable!())),
377        }
378    }
379
380    fn poll_event_new(&mut self) -> Result<ReceiveResult, Error> {
381        let mut buf = [0; 1];
382        match self.stream.read_exact(&mut buf) {
383            Ok(()) => {
384                self.cmd.extend_from_slice(&buf);
385                match self.cmd[0] {
386                    SET_KEY => self.state = ClientState::Incomplete(SET_KEY_LEN),
387                    DELETE_KEY => self.state = ClientState::Incomplete(DELETE_KEY_LEN),
388                    SET_ROW => self.state = ClientState::Incomplete(SET_ROW_LEN),
389                    PAUSE => self.state = ClientState::Incomplete(PAUSE_LEN),
390                    SAVE_TRACKS => self.state = ClientState::Complete,
391                    _ => self.state = ClientState::Complete, // Error / Unknown
392                }
393                Ok(ReceiveResult::Incomplete)
394            }
395            Err(e) => match e.kind() {
396                std::io::ErrorKind::WouldBlock => Ok(ReceiveResult::None),
397                _ => Err(Error::IOError(e)),
398            },
399        }
400    }
401
402    fn poll_event_incomplete(&mut self, bytes: usize) -> Result<ReceiveResult, Error> {
403        let mut buf = [0; MAX_COMMAND_LEN];
404        match self.stream.read(&mut buf[..bytes]) {
405            Ok(bytes_read) => {
406                self.cmd.extend_from_slice(&buf[..bytes_read]);
407                if bytes - bytes_read > 0 {
408                    self.state = ClientState::Incomplete(bytes - bytes_read);
409                } else {
410                    self.state = ClientState::Complete;
411                }
412                Ok(ReceiveResult::Incomplete)
413            }
414            Err(e) => match e.kind() {
415                std::io::ErrorKind::WouldBlock => Ok(ReceiveResult::None),
416                _ => Err(Error::IOError(e)),
417            },
418        }
419    }
420
421    // This function should never fail if [`poll_event_new`] and [`poll_event_incomplete`] are correct
422    fn process_event(&mut self) -> Result<ReceiveResult, io::Error> {
423        let mut result = ReceiveResult::None;
424
425        let mut cursor = Cursor::new(&self.cmd);
426        let cmd = cursor.read_u8()?;
427        match cmd {
428            SET_KEY => {
429                // usize::try_from(u32) will only be None if usize is smaller, and
430                // more than usize::MAX tracks are in use. That isn't possible because
431                // I'd imagine Vec::push and everything else will panic first.
432                // If you're running this on a microcontroller, I'd love to see it!
433                let index = usize::try_from(cursor.read_u32::<BigEndian>()?).unwrap();
434                let track = &mut self.tracks[index];
435                let row = cursor.read_u32::<BigEndian>()?;
436                let value = cursor.read_f32::<BigEndian>()?;
437                let interpolation = Interpolation::from(cursor.read_u8()?);
438                let key = Key::new(row, value, interpolation);
439
440                track.set_key(key);
441            }
442            DELETE_KEY => {
443                let index = usize::try_from(cursor.read_u32::<BigEndian>()?).unwrap();
444                let track = &mut self.tracks[index];
445                let row = cursor.read_u32::<BigEndian>()?;
446
447                track.delete_key(row);
448            }
449            SET_ROW => {
450                let row = cursor.read_u32::<BigEndian>()?;
451                result = ReceiveResult::Some(Event::SetRow(row));
452            }
453            PAUSE => {
454                let flag = cursor.read_u8()? == 1;
455                result = ReceiveResult::Some(Event::Pause(flag));
456            }
457            SAVE_TRACKS => {
458                result = ReceiveResult::Some(Event::SaveTracks);
459            }
460            _ => eprintln!("rocket: Unknown command: {:?}", cmd),
461        }
462
463        self.cmd.clear();
464        self.state = ClientState::New;
465
466        Ok(result)
467    }
468
469    fn handshake(&mut self) -> Result<(), Error> {
470        self.stream
471            .write_all(CLIENT_GREETING)
472            .map_err(Error::Handshake)?;
473
474        let mut buf = [0; SERVER_GREETING.len()];
475        self.stream.read_exact(&mut buf).map_err(Error::Handshake)?;
476
477        if buf == SERVER_GREETING {
478            Ok(())
479        } else {
480            Err(Error::HandshakeGreetingMismatch(buf))
481        }
482    }
483}