rust_rocket/client.rs
1//! This module contains the main client code, including the [`RocketClient`] type.
2use crate::interpolation::*;
3use crate::track::*;
4
5use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
6use std::{
7 convert::TryFrom,
8 io::{Cursor, Read, Write},
9 net::{TcpStream, ToSocketAddrs},
10};
11use thiserror::Error;
12
13#[derive(Debug, Error)]
14/// The `Error` Type. This is the main error type.
15pub enum Error {
16 #[error("Failed to establish a TCP connection with the Rocket server")]
17 /// Failure to connect to a rocket tracker. This can happen if the tracker is not running, the
18 /// address isn't correct or other network-related reasons.
19 Connect(#[source] std::io::Error),
20 #[error("Handshake with the Rocket server failed")]
21 /// Failure to transmit or receive greetings with the tracker
22 Handshake(#[source] std::io::Error),
23 #[error("The Rocket server greeting {0:?} wasn't correct")]
24 /// Handshake was performed but the the received greeting wasn't correct
25 HandshakeGreetingMismatch([u8; 12]),
26 #[error("Cannot set Rocket's TCP connection to nonblocking mode")]
27 /// Error from [`TcpStream::set_nonblocking`]
28 SetNonblocking(#[source] std::io::Error),
29 #[error("Rocket server disconnected")]
30 /// Network IO error during operation
31 IOError(#[source] std::io::Error),
32}
33
34#[derive(Debug)]
35enum ClientState {
36 New,
37 Incomplete(usize),
38 Complete,
39}
40
41#[derive(Debug, Copy, Clone)]
42/// The `Event` Type. These are the various events from the tracker.
43pub enum Event {
44 /// The tracker changes row.
45 SetRow(u32),
46 /// The tracker pauses or unpauses.
47 Pause(bool),
48 /// The tracker asks us to save our track data.
49 /// You may want to call [`RocketClient::save_tracks`] after receiving this event.
50 SaveTracks,
51}
52
53enum ReceiveResult {
54 Some(Event),
55 None,
56 Incomplete,
57}
58
59#[derive(Debug)]
60/// The `RocketClient` type. This contains the connected socket and other fields.
61pub struct RocketClient {
62 stream: TcpStream,
63 state: ClientState,
64 cmd: Vec<u8>,
65 tracks: Vec<Track>,
66}
67
68impl RocketClient {
69 /// Construct a new RocketClient.
70 ///
71 /// This constructs a new Rocket client and connects to localhost on port 1338.
72 ///
73 /// # Errors
74 ///
75 /// [`Error::Connect`] if connection cannot be established, or [`Error::Handshake`]
76 /// if the handshake fails.
77 ///
78 /// # Examples
79 ///
80 /// ```rust,no_run
81 /// # use rust_rocket::RocketClient;
82 /// let mut rocket = RocketClient::new().unwrap();
83 /// ```
84 pub fn new() -> Result<Self, Error> {
85 Self::connect(("localhost", 1338))
86 }
87
88 /// Construct a new RocketClient.
89 ///
90 /// This constructs a new Rocket client and connects to a specified host and port.
91 ///
92 /// # Errors
93 ///
94 /// [`Error::Connect`] if connection cannot be established, or [`Error::Handshake`]
95 /// if the handshake fails.
96 ///
97 /// # Examples
98 ///
99 /// ```rust,no_run
100 /// # use rust_rocket::RocketClient;
101 /// let mut rocket = RocketClient::connect(("localhost", 1338)).unwrap();
102 /// ```
103 pub fn connect(addr: impl ToSocketAddrs) -> Result<Self, Error> {
104 let stream = TcpStream::connect(addr).map_err(Error::Connect)?;
105
106 let mut rocket = Self {
107 stream,
108 state: ClientState::New,
109 cmd: Vec::new(),
110 tracks: Vec::new(),
111 };
112
113 rocket.handshake()?;
114
115 rocket
116 .stream
117 .set_nonblocking(true)
118 .map_err(Error::SetNonblocking)?;
119
120 Ok(rocket)
121 }
122
123 /// Get track by name.
124 ///
125 /// If the track does not yet exist it will be created.
126 ///
127 /// # Errors
128 ///
129 /// This method can return an [`Error::IOError`] if Rocket tracker disconnects.
130 ///
131 /// # Panics
132 ///
133 /// Will panic if `name`'s length exceeds [`u32::MAX`].
134 ///
135 /// # Examples
136 ///
137 /// ```rust,no_run
138 /// # use rust_rocket::RocketClient;
139 /// # let mut rocket = RocketClient::new().unwrap();
140 /// let track = rocket.get_track_mut("namespace:track").unwrap();
141 /// track.get_value(3.5);
142 /// ```
143 pub fn get_track_mut(&mut self, name: &str) -> Result<&mut Track, Error> {
144 if let Some((i, _)) = self
145 .tracks
146 .iter()
147 .enumerate()
148 .find(|(_, t)| t.get_name() == name)
149 {
150 Ok(&mut self.tracks[i])
151 } else {
152 // Send GET_TRACK message
153 let mut buf = vec![2];
154 buf.write_u32::<BigEndian>(u32::try_from(name.len()).expect("Track name too long"))
155 .unwrap_or_else(|_|
156 // Can writes to a vec fail? Consider changing to unreachable_unchecked in 1.0
157 unreachable!());
158 buf.extend_from_slice(name.as_bytes());
159 self.stream.write_all(&buf).map_err(Error::IOError)?;
160
161 self.tracks.push(Track::new(name));
162 Ok(self.tracks.last_mut().unwrap_or_else(||
163 // tracks cannot be empty right after pushing into it, consider changing to
164 // unreachable_unchecked in 1.0
165 unreachable!()))
166 }
167 }
168
169 /// Get track by name.
170 ///
171 /// You should use [`get_track_mut`](RocketClient::get_track_mut) to create a track.
172 pub fn get_track(&self, name: &str) -> Option<&Track> {
173 self.tracks.iter().find(|t| t.get_name() == name)
174 }
175
176 /// Create a clone of the tracks in the session which can then be serialized to a file in any
177 /// format with a serde implementation.
178 /// Tracks can be turned into a [`RocketPlayer`](crate::RocketPlayer::new) for playback.
179 pub fn save_tracks(&self) -> Vec<Track> {
180 self.tracks.clone()
181 }
182
183 /// Send a SetRow message.
184 ///
185 /// This changes the current row on the tracker side.
186 ///
187 /// # Errors
188 ///
189 /// This method can return an [`Error::IOError`] if Rocket tracker disconnects.
190 pub fn set_row(&mut self, row: u32) -> Result<(), Error> {
191 // Send SET_ROW message
192 let mut buf = vec![3];
193 buf.write_u32::<BigEndian>(row).unwrap_or_else(|_|
194 // Can writes to a vec fail? Consider changing to unreachable_unchecked in 1.0
195 unreachable!());
196 self.stream.write_all(&buf).map_err(Error::IOError)
197 }
198
199 /// Poll for new events from the tracker.
200 ///
201 /// This polls from events from the tracker.
202 /// You should call this fairly often your main loop.
203 /// It is recommended to keep calling this as long as your receive `Some(Event)`.
204 ///
205 /// # Errors
206 ///
207 /// This method can return an [`Error::IOError`] if Rocket tracker disconnects.
208 ///
209 /// # Examples
210 ///
211 /// ```rust,no_run
212 /// # use rust_rocket::RocketClient;
213 /// # let mut rocket = RocketClient::new().unwrap();
214 /// while let Some(event) = rocket.poll_events().unwrap() {
215 /// match event {
216 /// // Do something with the various events.
217 /// _ => (),
218 /// }
219 /// }
220 /// ```
221 pub fn poll_events(&mut self) -> Result<Option<Event>, Error> {
222 loop {
223 let result = self.poll_event()?;
224 match result {
225 ReceiveResult::None => return Ok(None),
226 ReceiveResult::Incomplete => (),
227 ReceiveResult::Some(event) => return Ok(Some(event)),
228 }
229 }
230 }
231
232 fn poll_event(&mut self) -> Result<ReceiveResult, Error> {
233 match self.state {
234 ClientState::New => {
235 let mut buf = [0; 1];
236 match self.stream.read_exact(&mut buf) {
237 Ok(()) => {
238 self.cmd.extend_from_slice(&buf);
239 match self.cmd[0] {
240 0 => self.state = ClientState::Incomplete(4 + 4 + 4 + 1), //SET_KEY
241 1 => self.state = ClientState::Incomplete(4 + 4), //DELETE_KEY
242 3 => self.state = ClientState::Incomplete(4), //SET_ROW
243 4 => self.state = ClientState::Incomplete(1), //PAUSE
244 5 => self.state = ClientState::Complete, //SAVE_TRACKS
245 _ => self.state = ClientState::Complete, // Error / Unknown
246 }
247 Ok(ReceiveResult::Incomplete)
248 }
249 Err(e) => match e.kind() {
250 std::io::ErrorKind::WouldBlock => Ok(ReceiveResult::None),
251 _ => Err(Error::IOError(e)),
252 },
253 }
254 }
255 ClientState::Incomplete(bytes) => {
256 let mut buf = vec![0; bytes];
257 match self.stream.read(&mut buf) {
258 Ok(bytes_read) => {
259 self.cmd.extend_from_slice(&buf);
260 if bytes - bytes_read > 0 {
261 self.state = ClientState::Incomplete(bytes - bytes_read);
262 } else {
263 self.state = ClientState::Complete;
264 }
265 Ok(ReceiveResult::Incomplete)
266 }
267 Err(e) => match e.kind() {
268 std::io::ErrorKind::WouldBlock => Ok(ReceiveResult::None),
269 _ => Err(Error::IOError(e)),
270 },
271 }
272 }
273 ClientState::Complete => {
274 let mut result = ReceiveResult::None;
275 {
276 // Following reads from cmd should never fail if above match arms are correct
277 let mut cursor = Cursor::new(&self.cmd);
278 let cmd = cursor.read_u8().unwrap();
279 match cmd {
280 0 => {
281 // usize::try_from(u32) will only be None if usize is smaller, and
282 // more than usize::MAX tracks are in use. That isn't possible because
283 // I'd imagine Vec::push and everything else will panic first.
284 // If you're running this on a microcontroller, I'd love to see it!
285 let track = &mut self.tracks
286 [usize::try_from(cursor.read_u32::<BigEndian>().unwrap()).unwrap()];
287 let row = cursor.read_u32::<BigEndian>().unwrap();
288 let value = cursor.read_f32::<BigEndian>().unwrap();
289 let interpolation = Interpolation::from(cursor.read_u8().unwrap());
290 let key = Key::new(row, value, interpolation);
291
292 track.set_key(key);
293 }
294 1 => {
295 let track = &mut self.tracks
296 [usize::try_from(cursor.read_u32::<BigEndian>().unwrap()).unwrap()];
297 let row = cursor.read_u32::<BigEndian>().unwrap();
298
299 track.delete_key(row);
300 }
301 3 => {
302 let row = cursor.read_u32::<BigEndian>().unwrap();
303 result = ReceiveResult::Some(Event::SetRow(row));
304 }
305 4 => {
306 let flag = cursor.read_u8().unwrap() == 1;
307 result = ReceiveResult::Some(Event::Pause(flag));
308 }
309 5 => {
310 result = ReceiveResult::Some(Event::SaveTracks);
311 }
312 _ => println!("Unknown {:?}", cmd),
313 }
314 }
315
316 self.cmd.clear();
317 self.state = ClientState::New;
318
319 Ok(result)
320 }
321 }
322 }
323
324 fn handshake(&mut self) -> Result<(), Error> {
325 let client_greeting = b"hello, synctracker!";
326 let server_greeting = b"hello, demo!";
327
328 self.stream
329 .write_all(client_greeting)
330 .map_err(Error::Handshake)?;
331
332 let mut buf = [0; 12];
333 self.stream.read_exact(&mut buf).map_err(Error::Handshake)?;
334
335 if &buf == server_greeting {
336 Ok(())
337 } else {
338 Err(Error::HandshakeGreetingMismatch(buf))
339 }
340 }
341}