rust_rocket/client.rs
1//! This module contains the main client code, including the [`RocketClient`] type.
2use crate::interpolation::*;
3use crate::track::*;
4use crate::Tracks;
5
6use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
7use std::{
8 convert::TryFrom,
9 io::{Cursor, Read, Write},
10 net::{TcpStream, ToSocketAddrs},
11};
12use thiserror::Error;
13
14#[derive(Debug, Error)]
15/// The `Error` Type. This is the main error type.
16pub enum Error {
17 #[error("Failed to establish a TCP connection with the Rocket server")]
18 /// Failure to connect to a rocket tracker. This can happen if the tracker is not running, the
19 /// address isn't correct or other network-related reasons.
20 Connect(#[source] std::io::Error),
21 #[error("Handshake with the Rocket server failed")]
22 /// Failure to transmit or receive greetings with the tracker
23 Handshake(#[source] std::io::Error),
24 #[error("The Rocket server greeting {0:?} wasn't correct")]
25 /// Handshake was performed but the the received greeting wasn't correct
26 HandshakeGreetingMismatch([u8; 12]),
27 #[error("Cannot set Rocket's TCP connection to nonblocking mode")]
28 /// Error from [`TcpStream::set_nonblocking`]
29 SetNonblocking(#[source] std::io::Error),
30 #[error("Rocket server disconnected")]
31 /// Network IO error during operation
32 IOError(#[source] std::io::Error),
33}
34
35#[derive(Debug)]
36enum ClientState {
37 New,
38 Incomplete(usize),
39 Complete,
40}
41
42#[derive(Debug, Copy, Clone)]
43/// The `Event` Type. These are the various events from the tracker.
44pub enum Event {
45 /// The tracker changes row.
46 SetRow(u32),
47 /// The tracker pauses or unpauses.
48 Pause(bool),
49 /// The tracker asks us to save our track data.
50 /// You may want to call [`RocketClient::save_tracks`] after receiving this event.
51 SaveTracks,
52}
53
54enum ReceiveResult {
55 Some(Event),
56 None,
57 Incomplete,
58}
59
60#[derive(Debug)]
61/// The `RocketClient` type. This contains the connected socket and other fields.
62pub struct RocketClient {
63 stream: TcpStream,
64 state: ClientState,
65 cmd: Vec<u8>,
66 tracks: Vec<Track>,
67}
68
69impl RocketClient {
70 /// Construct a new RocketClient.
71 ///
72 /// This constructs a new Rocket client and connects to localhost on port 1338.
73 ///
74 /// # Errors
75 ///
76 /// [`Error::Connect`] if connection cannot be established, or [`Error::Handshake`]
77 /// if the handshake fails.
78 ///
79 /// # Examples
80 ///
81 /// ```rust,no_run
82 /// # use rust_rocket::RocketClient;
83 /// let mut rocket = RocketClient::new()?;
84 /// # Ok::<(), rust_rocket::client::Error>(())
85 /// ```
86 pub fn new() -> Result<Self, Error> {
87 Self::connect(("localhost", 1338))
88 }
89
90 /// Construct a new RocketClient.
91 ///
92 /// This constructs a new Rocket client and connects to a specified host and port.
93 ///
94 /// # Errors
95 ///
96 /// [`Error::Connect`] if connection cannot be established, or [`Error::Handshake`]
97 /// if the handshake fails.
98 ///
99 /// # Examples
100 ///
101 /// ```rust,no_run
102 /// # use rust_rocket::RocketClient;
103 /// let mut rocket = RocketClient::connect(("localhost", 1338))?;
104 /// # Ok::<(), rust_rocket::client::Error>(())
105 /// ```
106 pub fn connect(addr: impl ToSocketAddrs) -> Result<Self, Error> {
107 let stream = TcpStream::connect(addr).map_err(Error::Connect)?;
108
109 let mut rocket = Self {
110 stream,
111 state: ClientState::New,
112 cmd: Vec::new(),
113 tracks: Vec::new(),
114 };
115
116 rocket.handshake()?;
117
118 rocket
119 .stream
120 .set_nonblocking(true)
121 .map_err(Error::SetNonblocking)?;
122
123 Ok(rocket)
124 }
125
126 /// Get track by name.
127 ///
128 /// If the track does not yet exist it will be created.
129 ///
130 /// # Errors
131 ///
132 /// This method can return an [`Error::IOError`] if Rocket tracker disconnects.
133 ///
134 /// # Panics
135 ///
136 /// Will panic if `name`'s length exceeds [`u32::MAX`].
137 ///
138 /// # Examples
139 ///
140 /// ```rust,no_run
141 /// # use rust_rocket::RocketClient;
142 /// # let mut rocket = RocketClient::new()?;
143 /// let track = rocket.get_track_mut("namespace:track")?;
144 /// track.get_value(3.5);
145 /// # Ok::<(), rust_rocket::client::Error>(())
146 /// ```
147 pub fn get_track_mut(&mut self, name: &str) -> Result<&mut Track, Error> {
148 if let Some((i, _)) = self
149 .tracks
150 .iter()
151 .enumerate()
152 .find(|(_, t)| t.get_name() == name)
153 {
154 Ok(&mut self.tracks[i])
155 } else {
156 // Send GET_TRACK message
157 let mut buf = vec![2];
158 buf.write_u32::<BigEndian>(u32::try_from(name.len()).expect("Track name too long"))
159 .unwrap_or_else(|_|
160 // Can writes to a vec fail? Consider changing to unreachable_unchecked in 1.0
161 unreachable!());
162 buf.extend_from_slice(name.as_bytes());
163 self.stream.write_all(&buf).map_err(Error::IOError)?;
164
165 self.tracks.push(Track::new(name));
166 Ok(self.tracks.last_mut().unwrap_or_else(||
167 // tracks cannot be empty right after pushing into it, consider changing to
168 // unreachable_unchecked in 1.0
169 unreachable!()))
170 }
171 }
172
173 /// Get track by name.
174 ///
175 /// You should use [`get_track_mut`](RocketClient::get_track_mut) to create a track.
176 pub fn get_track(&self, name: &str) -> Option<&Track> {
177 self.tracks.iter().find(|t| t.get_name() == name)
178 }
179
180 /// Get a snapshot of the tracks in the session.
181 /// The returned [`Tracks`] can be dumped to a file in any [supported format](crate#features).
182 /// The counterpart to this function is [`RocketPlayer::new`](crate::RocketPlayer::new),
183 /// which loads tracks for playback.
184 ///
185 /// # Example
186 ///
187 /// ```rust,no_run
188 /// # use rust_rocket::RocketClient;
189 /// # use std::fs::OpenOptions;
190 /// let mut rocket = RocketClient::new()?;
191 ///
192 /// // Create tracks, call poll_events, etc...
193 ///
194 /// // Open a file for writing
195 /// let mut file = OpenOptions::new()
196 /// .write(true)
197 /// .create(true)
198 /// .truncate(true)
199 /// .open("tracks.bin")
200 /// .expect("Failed to open tracks.bin for writing");
201 ///
202 /// // Save a snapshot of the client to a file for playback in release builds
203 /// let tracks = rocket.save_tracks();
204 /// # #[cfg(feature = "bincode")]
205 /// bincode::encode_into_std_write(tracks, &mut file, bincode::config::standard())
206 /// .expect("Failed to encode tracks.bin");
207 /// # Ok::<(), rust_rocket::client::Error>(())
208 /// ```
209 pub fn save_tracks(&self) -> &Tracks {
210 &self.tracks
211 }
212
213 /// Send a SetRow message.
214 ///
215 /// This changes the current row on the tracker side.
216 ///
217 /// # Errors
218 ///
219 /// This method can return an [`Error::IOError`] if Rocket tracker disconnects.
220 pub fn set_row(&mut self, row: u32) -> Result<(), Error> {
221 // Send SET_ROW message
222 let mut buf = vec![3];
223 buf.write_u32::<BigEndian>(row).unwrap_or_else(|_|
224 // Can writes to a vec fail? Consider changing to unreachable_unchecked in 1.0
225 unreachable!());
226 self.stream.write_all(&buf).map_err(Error::IOError)
227 }
228
229 /// Poll for new events from the tracker.
230 ///
231 /// This polls from events from the tracker.
232 /// You should call this fairly often your main loop.
233 /// It is recommended to keep calling this as long as your receive `Some(Event)`.
234 ///
235 /// # Errors
236 ///
237 /// This method can return an [`Error::IOError`] if Rocket tracker disconnects.
238 ///
239 /// # Examples
240 ///
241 /// ```rust,no_run
242 /// # use rust_rocket::RocketClient;
243 /// # let mut rocket = RocketClient::new()?;
244 /// while let Some(event) = rocket.poll_events()? {
245 /// match event {
246 /// // Do something with the various events.
247 /// _ => (),
248 /// }
249 /// }
250 /// # Ok::<(), rust_rocket::client::Error>(())
251 /// ```
252 pub fn poll_events(&mut self) -> Result<Option<Event>, Error> {
253 loop {
254 let result = self.poll_event()?;
255 match result {
256 ReceiveResult::None => return Ok(None),
257 ReceiveResult::Incomplete => (),
258 ReceiveResult::Some(event) => return Ok(Some(event)),
259 }
260 }
261 }
262
263 fn poll_event(&mut self) -> Result<ReceiveResult, Error> {
264 match self.state {
265 ClientState::New => {
266 let mut buf = [0; 1];
267 match self.stream.read_exact(&mut buf) {
268 Ok(()) => {
269 self.cmd.extend_from_slice(&buf);
270 match self.cmd[0] {
271 0 => self.state = ClientState::Incomplete(4 + 4 + 4 + 1), //SET_KEY
272 1 => self.state = ClientState::Incomplete(4 + 4), //DELETE_KEY
273 3 => self.state = ClientState::Incomplete(4), //SET_ROW
274 4 => self.state = ClientState::Incomplete(1), //PAUSE
275 5 => self.state = ClientState::Complete, //SAVE_TRACKS
276 _ => self.state = ClientState::Complete, // Error / Unknown
277 }
278 Ok(ReceiveResult::Incomplete)
279 }
280 Err(e) => match e.kind() {
281 std::io::ErrorKind::WouldBlock => Ok(ReceiveResult::None),
282 _ => Err(Error::IOError(e)),
283 },
284 }
285 }
286 ClientState::Incomplete(bytes) => {
287 let mut buf = vec![0; bytes];
288 match self.stream.read(&mut buf) {
289 Ok(bytes_read) => {
290 self.cmd.extend_from_slice(&buf);
291 if bytes - bytes_read > 0 {
292 self.state = ClientState::Incomplete(bytes - bytes_read);
293 } else {
294 self.state = ClientState::Complete;
295 }
296 Ok(ReceiveResult::Incomplete)
297 }
298 Err(e) => match e.kind() {
299 std::io::ErrorKind::WouldBlock => Ok(ReceiveResult::None),
300 _ => Err(Error::IOError(e)),
301 },
302 }
303 }
304 ClientState::Complete => {
305 let mut result = ReceiveResult::None;
306 {
307 // Following reads from cmd should never fail if above match arms are correct
308 let mut cursor = Cursor::new(&self.cmd);
309 let cmd = cursor.read_u8().unwrap();
310 match cmd {
311 0 => {
312 // usize::try_from(u32) will only be None if usize is smaller, and
313 // more than usize::MAX tracks are in use. That isn't possible because
314 // I'd imagine Vec::push and everything else will panic first.
315 // If you're running this on a microcontroller, I'd love to see it!
316 let track = &mut self.tracks
317 [usize::try_from(cursor.read_u32::<BigEndian>().unwrap()).unwrap()];
318 let row = cursor.read_u32::<BigEndian>().unwrap();
319 let value = cursor.read_f32::<BigEndian>().unwrap();
320 let interpolation = Interpolation::from(cursor.read_u8().unwrap());
321 let key = Key::new(row, value, interpolation);
322
323 track.set_key(key);
324 }
325 1 => {
326 let track = &mut self.tracks
327 [usize::try_from(cursor.read_u32::<BigEndian>().unwrap()).unwrap()];
328 let row = cursor.read_u32::<BigEndian>().unwrap();
329
330 track.delete_key(row);
331 }
332 3 => {
333 let row = cursor.read_u32::<BigEndian>().unwrap();
334 result = ReceiveResult::Some(Event::SetRow(row));
335 }
336 4 => {
337 let flag = cursor.read_u8().unwrap() == 1;
338 result = ReceiveResult::Some(Event::Pause(flag));
339 }
340 5 => {
341 result = ReceiveResult::Some(Event::SaveTracks);
342 }
343 _ => println!("Unknown {:?}", cmd),
344 }
345 }
346
347 self.cmd.clear();
348 self.state = ClientState::New;
349
350 Ok(result)
351 }
352 }
353 }
354
355 fn handshake(&mut self) -> Result<(), Error> {
356 let client_greeting = b"hello, synctracker!";
357 let server_greeting = b"hello, demo!";
358
359 self.stream
360 .write_all(client_greeting)
361 .map_err(Error::Handshake)?;
362
363 let mut buf = [0; 12];
364 self.stream.read_exact(&mut buf).map_err(Error::Handshake)?;
365
366 if &buf == server_greeting {
367 Ok(())
368 } else {
369 Err(Error::HandshakeGreetingMismatch(buf))
370 }
371 }
372}