xash3d_protocol/
cursor.rs

1// SPDX-License-Identifier: LGPL-3.0-only
2// SPDX-FileCopyrightText: 2023 Denis Drakhnia <numas13@gmail.com>
3
4use std::io::{self, Write as _};
5use std::{fmt, mem, str};
6
7use thiserror::Error;
8
9use super::color;
10use super::wrappers::Str;
11
12/// The error type for `Cursor` and `CursorMut`.
13#[derive(Error, Debug, PartialEq, Eq)]
14pub enum Error {
15    /// Invalid number.
16    #[error("Invalid number")]
17    InvalidNumber,
18    /// Invalid string.
19    #[error("Invalid string")]
20    InvalidString,
21    /// Invalid boolean.
22    #[error("Invalid boolean")]
23    InvalidBool,
24    /// Invalid table entry.
25    #[error("Invalid table key")]
26    InvalidTableKey,
27    /// Invalid table entry.
28    #[error("Invalid table entry")]
29    InvalidTableValue,
30    /// Table end found.
31    #[error("Table end")]
32    TableEnd,
33    /// Expected data not found.
34    #[error("Expected data not found")]
35    Expect,
36    /// An unexpected data found.
37    #[error("Unexpected data")]
38    ExpectEmpty,
39    /// Buffer size is no enougth to decode or encode a packet.
40    #[error("Unexpected end of buffer")]
41    UnexpectedEnd,
42}
43
44pub trait GetKeyValue<'a>: Sized {
45    fn get_key_value(cur: &mut Cursor<'a>) -> Result<Self, Error>;
46}
47
48impl<'a> GetKeyValue<'a> for &'a [u8] {
49    fn get_key_value(cur: &mut Cursor<'a>) -> Result<Self, Error> {
50        cur.get_key_value_raw()
51    }
52}
53
54impl<'a> GetKeyValue<'a> for Str<&'a [u8]> {
55    fn get_key_value(cur: &mut Cursor<'a>) -> Result<Self, Error> {
56        cur.get_key_value_raw().map(Str)
57    }
58}
59
60impl<'a> GetKeyValue<'a> for &'a str {
61    fn get_key_value(cur: &mut Cursor<'a>) -> Result<Self, Error> {
62        let raw = cur.get_key_value_raw()?;
63        str::from_utf8(raw).map_err(|_| Error::InvalidString)
64    }
65}
66
67impl<'a> GetKeyValue<'a> for Box<str> {
68    fn get_key_value(cur: &mut Cursor<'a>) -> Result<Self, Error> {
69        let raw = cur.get_key_value_raw()?;
70        str::from_utf8(raw)
71            .map(|s| s.to_owned().into_boxed_str())
72            .map_err(|_| Error::InvalidString)
73    }
74}
75
76impl<'a> GetKeyValue<'a> for String {
77    fn get_key_value(cur: &mut Cursor<'a>) -> Result<Self, Error> {
78        let raw = cur.get_key_value_raw()?;
79        str::from_utf8(raw)
80            .map(|s| s.to_owned())
81            .map_err(|_| Error::InvalidString)
82    }
83}
84
85impl<'a> GetKeyValue<'a> for bool {
86    fn get_key_value(cur: &mut Cursor<'a>) -> Result<Self, Error> {
87        match cur.get_key_value_raw()? {
88            b"0" => Ok(false),
89            b"1" => Ok(true),
90            _ => Err(Error::InvalidBool),
91        }
92    }
93}
94
95macro_rules! impl_get_value {
96    ($($t:ty),+ $(,)?) => {
97        $(impl<'a> GetKeyValue<'a> for $t {
98            fn get_key_value(cur: &mut Cursor<'a>) -> Result<Self, Error> {
99                let s = cur.get_key_value::<&str>()?;
100                // HACK: special case for one asshole
101                let (_, s) = color::trim_start_color(s);
102                s.parse().map_err(|_| Error::InvalidNumber)
103            }
104        })+
105    };
106}
107
108impl_get_value! {
109    u8,
110    u16,
111    u32,
112    u64,
113
114    i8,
115    i16,
116    i32,
117    i64,
118}
119
120// TODO: impl GetKeyValue for f32 and f64
121
122#[derive(Copy, Clone)]
123pub struct Cursor<'a> {
124    buffer: &'a [u8],
125}
126
127macro_rules! impl_get {
128    ($($n:ident: $t:ty = $f:ident),+ $(,)?) => (
129        $(#[inline]
130        pub fn $n(&mut self) -> Result<$t, Error> {
131            const N: usize = mem::size_of::<$t>();
132            self.get_array::<N>().map(<$t>::$f)
133        })+
134    );
135}
136
137impl<'a> Cursor<'a> {
138    pub fn new(buffer: &'a [u8]) -> Self {
139        Self { buffer }
140    }
141
142    pub fn end(self) -> &'a [u8] {
143        self.buffer
144    }
145
146    pub fn as_slice(&'a self) -> &'a [u8] {
147        self.buffer
148    }
149
150    #[inline(always)]
151    pub fn remaining(&self) -> usize {
152        self.buffer.len()
153    }
154
155    #[inline(always)]
156    pub fn has_remaining(&self) -> bool {
157        self.remaining() != 0
158    }
159
160    pub fn get_bytes(&mut self, count: usize) -> Result<&'a [u8], Error> {
161        if count <= self.remaining() {
162            let (head, tail) = self.buffer.split_at(count);
163            self.buffer = tail;
164            Ok(head)
165        } else {
166            Err(Error::UnexpectedEnd)
167        }
168    }
169
170    pub fn advance(&mut self, count: usize) -> Result<(), Error> {
171        self.get_bytes(count).map(|_| ())
172    }
173
174    pub fn get_array<const N: usize>(&mut self) -> Result<[u8; N], Error> {
175        self.get_bytes(N).map(|s| {
176            let mut array = [0; N];
177            array.copy_from_slice(s);
178            array
179        })
180    }
181
182    pub fn get_str(&mut self, n: usize) -> Result<&'a str, Error> {
183        let mut cur = *self;
184        let s = cur
185            .get_bytes(n)
186            .and_then(|s| str::from_utf8(s).map_err(|_| Error::InvalidString))?;
187        *self = cur;
188        Ok(s)
189    }
190
191    pub fn get_cstr(&mut self) -> Result<Str<&'a [u8]>, Error> {
192        let pos = self
193            .buffer
194            .iter()
195            .position(|&c| c == b'\0')
196            .ok_or(Error::UnexpectedEnd)?;
197        let (head, tail) = self.buffer.split_at(pos);
198        self.buffer = &tail[1..];
199        Ok(Str(&head[..pos]))
200    }
201
202    pub fn get_cstr_as_str(&mut self) -> Result<&'a str, Error> {
203        str::from_utf8(&self.get_cstr()?).map_err(|_| Error::InvalidString)
204    }
205
206    #[inline(always)]
207    pub fn get_u8(&mut self) -> Result<u8, Error> {
208        self.get_array::<1>().map(|s| s[0])
209    }
210
211    #[inline(always)]
212    pub fn get_i8(&mut self) -> Result<i8, Error> {
213        self.get_array::<1>().map(|s| s[0] as i8)
214    }
215
216    impl_get! {
217        get_u16_le: u16 = from_le_bytes,
218        get_u32_le: u32 = from_le_bytes,
219        get_u64_le: u64 = from_le_bytes,
220        get_i16_le: i16 = from_le_bytes,
221        get_i32_le: i32 = from_le_bytes,
222        get_i64_le: i64 = from_le_bytes,
223        get_f32_le: f32 = from_le_bytes,
224        get_f64_le: f64 = from_le_bytes,
225
226        get_u16_be: u16 = from_be_bytes,
227        get_u32_be: u32 = from_be_bytes,
228        get_u64_be: u64 = from_be_bytes,
229        get_i16_be: i16 = from_be_bytes,
230        get_i32_be: i32 = from_be_bytes,
231        get_i64_be: i64 = from_be_bytes,
232        get_f32_be: f32 = from_be_bytes,
233        get_f64_be: f64 = from_be_bytes,
234
235        get_u16_ne: u16 = from_ne_bytes,
236        get_u32_ne: u32 = from_ne_bytes,
237        get_u64_ne: u64 = from_ne_bytes,
238        get_i16_ne: i16 = from_ne_bytes,
239        get_i32_ne: i32 = from_ne_bytes,
240        get_i64_ne: i64 = from_ne_bytes,
241        get_f32_ne: f32 = from_ne_bytes,
242        get_f64_ne: f64 = from_ne_bytes,
243    }
244
245    pub fn expect(&mut self, s: &[u8]) -> Result<(), Error> {
246        if self.buffer.starts_with(s) {
247            self.advance(s.len())?;
248            Ok(())
249        } else {
250            Err(Error::Expect)
251        }
252    }
253
254    pub fn expect_empty(&self) -> Result<(), Error> {
255        if self.has_remaining() {
256            Err(Error::ExpectEmpty)
257        } else {
258            Ok(())
259        }
260    }
261
262    pub fn take_while<F>(&mut self, mut cond: F) -> Result<&'a [u8], Error>
263    where
264        F: FnMut(u8) -> bool,
265    {
266        self.buffer
267            .iter()
268            .position(|&i| !cond(i))
269            .ok_or(Error::UnexpectedEnd)
270            .and_then(|n| self.get_bytes(n))
271    }
272
273    pub fn take_while_or_all<F>(&mut self, cond: F) -> &'a [u8]
274    where
275        F: FnMut(u8) -> bool,
276    {
277        self.take_while(cond).unwrap_or_else(|_| {
278            let (head, tail) = self.buffer.split_at(self.buffer.len());
279            self.buffer = tail;
280            head
281        })
282    }
283
284    pub fn get_key_value_raw(&mut self) -> Result<&'a [u8], Error> {
285        let mut cur = *self;
286        match cur.get_u8()? {
287            b'\\' => {
288                let value = cur.take_while_or_all(|c| c != b'\\' && c != b'\n');
289                *self = cur;
290                Ok(value)
291            }
292            _ => Err(Error::InvalidTableValue),
293        }
294    }
295
296    pub fn get_key_value<T: GetKeyValue<'a>>(&mut self) -> Result<T, Error> {
297        T::get_key_value(self)
298    }
299
300    pub fn skip_key_value<T: GetKeyValue<'a>>(&mut self) -> Result<(), Error> {
301        T::get_key_value(self).map(|_| ())
302    }
303
304    pub fn get_key_raw(&mut self) -> Result<&'a [u8], Error> {
305        let mut cur = *self;
306        match cur.get_u8() {
307            Ok(b'\\') => {
308                let value = cur.take_while(|c| c != b'\\' && c != b'\n')?;
309                *self = cur;
310                Ok(value)
311            }
312            Ok(b'\n') | Err(Error::UnexpectedEnd) => Err(Error::TableEnd),
313            _ => Err(Error::InvalidTableKey),
314        }
315    }
316
317    pub fn get_key<T: GetKeyValue<'a>>(&mut self) -> Result<(&'a [u8], T), Error> {
318        Ok((self.get_key_raw()?, self.get_key_value()?))
319    }
320}
321
322pub trait PutKeyValue {
323    fn put_key_value<'a, 'b>(
324        &self,
325        cur: &'b mut CursorMut<'a>,
326    ) -> Result<&'b mut CursorMut<'a>, Error>;
327}
328
329impl<T> PutKeyValue for &T
330where
331    T: PutKeyValue,
332{
333    fn put_key_value<'a, 'b>(
334        &self,
335        cur: &'b mut CursorMut<'a>,
336    ) -> Result<&'b mut CursorMut<'a>, Error> {
337        (*self).put_key_value(cur)
338    }
339}
340
341impl PutKeyValue for &str {
342    fn put_key_value<'a, 'b>(
343        &self,
344        cur: &'b mut CursorMut<'a>,
345    ) -> Result<&'b mut CursorMut<'a>, Error> {
346        cur.put_str(self)
347    }
348}
349
350impl PutKeyValue for bool {
351    fn put_key_value<'a, 'b>(
352        &self,
353        cur: &'b mut CursorMut<'a>,
354    ) -> Result<&'b mut CursorMut<'a>, Error> {
355        cur.put_u8(if *self { b'1' } else { b'0' })
356    }
357}
358
359macro_rules! impl_put_key_value {
360    ($($t:ty),+ $(,)?) => {
361        $(impl PutKeyValue for $t {
362            fn put_key_value<'a, 'b>(&self, cur: &'b mut CursorMut<'a>) -> Result<&'b mut CursorMut<'a>, Error> {
363                cur.put_as_str(self)
364            }
365        })+
366    };
367}
368
369impl_put_key_value! {
370    u8,
371    u16,
372    u32,
373    u64,
374
375    i8,
376    i16,
377    i32,
378    i64,
379
380    f32,
381    f64,
382}
383
384pub struct CursorMut<'a> {
385    buffer: &'a mut [u8],
386    pos: usize,
387}
388
389macro_rules! impl_put {
390    ($($n:ident: $t:ty = $f:ident),+ $(,)?) => (
391        $(#[inline]
392        pub fn $n(&mut self, n: $t) -> Result<&mut Self, Error> {
393            self.put_array(&n.$f())
394        })+
395    );
396}
397
398impl<'a> CursorMut<'a> {
399    pub fn new(buffer: &'a mut [u8]) -> Self {
400        Self { buffer, pos: 0 }
401    }
402
403    pub fn pos(&mut self) -> usize {
404        self.pos
405    }
406
407    #[inline(always)]
408    pub fn remaining(&self) -> usize {
409        self.buffer.len() - self.pos
410    }
411
412    pub fn advance<F>(&mut self, count: usize, mut f: F) -> Result<&mut Self, Error>
413    where
414        F: FnMut(&mut [u8]),
415    {
416        if count <= self.remaining() {
417            f(&mut self.buffer[self.pos..self.pos + count]);
418            self.pos += count;
419            Ok(self)
420        } else {
421            Err(Error::UnexpectedEnd)
422        }
423    }
424
425    pub fn put_bytes(&mut self, s: &[u8]) -> Result<&mut Self, Error> {
426        self.advance(s.len(), |i| {
427            i.copy_from_slice(s);
428        })
429    }
430
431    pub fn put_array<const N: usize>(&mut self, s: &[u8; N]) -> Result<&mut Self, Error> {
432        self.advance(N, |i| {
433            i.copy_from_slice(s);
434        })
435    }
436
437    pub fn put_str(&mut self, s: &str) -> Result<&mut Self, Error> {
438        self.put_bytes(s.as_bytes())
439    }
440
441    pub fn put_cstr(&mut self, s: &str) -> Result<&mut Self, Error> {
442        self.put_str(s)?.put_u8(0)
443    }
444
445    #[inline(always)]
446    pub fn put_u8(&mut self, n: u8) -> Result<&mut Self, Error> {
447        self.put_array(&[n])
448    }
449
450    #[inline(always)]
451    pub fn put_i8(&mut self, n: i8) -> Result<&mut Self, Error> {
452        self.put_u8(n as u8)
453    }
454
455    impl_put! {
456        put_u16_le: u16 = to_le_bytes,
457        put_u32_le: u32 = to_le_bytes,
458        put_u64_le: u64 = to_le_bytes,
459        put_i16_le: i16 = to_le_bytes,
460        put_i32_le: i32 = to_le_bytes,
461        put_i64_le: i64 = to_le_bytes,
462        put_f32_le: f32 = to_le_bytes,
463        put_f64_le: f64 = to_le_bytes,
464
465        put_u16_be: u16 = to_be_bytes,
466        put_u32_be: u32 = to_be_bytes,
467        put_u64_be: u64 = to_be_bytes,
468        put_i16_be: i16 = to_be_bytes,
469        put_i32_be: i32 = to_be_bytes,
470        put_i64_be: i64 = to_be_bytes,
471        put_f32_be: f32 = to_be_bytes,
472        put_f64_be: f64 = to_be_bytes,
473
474        put_u16_ne: u16 = to_ne_bytes,
475        put_u32_ne: u32 = to_ne_bytes,
476        put_u64_ne: u64 = to_ne_bytes,
477        put_i16_ne: i16 = to_ne_bytes,
478        put_i32_ne: i32 = to_ne_bytes,
479        put_i64_ne: i64 = to_ne_bytes,
480        put_f32_ne: f32 = to_ne_bytes,
481        put_f64_ne: f64 = to_ne_bytes,
482    }
483
484    pub fn put_as_str<T: fmt::Display>(&mut self, value: T) -> Result<&mut Self, Error> {
485        let mut cur = io::Cursor::new(&mut self.buffer[self.pos..]);
486        write!(&mut cur, "{}", value).map_err(|_| Error::UnexpectedEnd)?;
487        self.pos += cur.position() as usize;
488        Ok(self)
489    }
490
491    pub fn put_key_value<T: PutKeyValue>(&mut self, value: T) -> Result<&mut Self, Error> {
492        value.put_key_value(self)
493    }
494
495    pub fn put_key_raw(&mut self, key: &str, value: &[u8]) -> Result<&mut Self, Error> {
496        self.put_u8(b'\\')?
497            .put_str(key)?
498            .put_u8(b'\\')?
499            .put_bytes(value)
500    }
501
502    pub fn put_key<T: PutKeyValue>(&mut self, key: &str, value: T) -> Result<&mut Self, Error> {
503        self.put_u8(b'\\')?
504            .put_str(key)?
505            .put_u8(b'\\')?
506            .put_key_value(value)
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    #[test]
515    fn cursor() -> Result<(), Error> {
516        let mut buf = [0; 64];
517        let n = CursorMut::new(&mut buf)
518            .put_bytes(b"12345678")?
519            .put_array(b"4321")?
520            .put_str("abc")?
521            .put_cstr("def")?
522            .put_u8(0x7f)?
523            .put_i8(-128)?
524            .put_u32_le(0x44332211)?
525            .pos();
526        let s = &buf[..n];
527
528        let mut cur = Cursor::new(s);
529        assert_eq!(cur.get_bytes(8), Ok(&b"12345678"[..]));
530        assert_eq!(cur.get_array::<4>(), Ok(*b"4321"));
531        assert_eq!(cur.get_str(3), Ok("abc"));
532        assert_eq!(cur.get_cstr(), Ok(Str(&b"def"[..])));
533        assert_eq!(cur.get_u8(), Ok(0x7f));
534        assert_eq!(cur.get_i8(), Ok(-128));
535        assert_eq!(cur.get_u32_le(), Ok(0x44332211));
536        assert_eq!(cur.get_u8(), Err(Error::UnexpectedEnd));
537
538        Ok(())
539    }
540
541    #[test]
542    fn key() -> Result<(), Error> {
543        let mut buf = [0; 512];
544        let n = CursorMut::new(&mut buf)
545            .put_key("p", 49)?
546            .put_key("map", "crossfire")?
547            .put_key("dm", true)?
548            .put_key("team", false)?
549            .put_key("coop", false)?
550            .put_key("numcl", 4)?
551            .put_key("maxcl", 32)?
552            .put_key("gamedir", "valve")?
553            .put_key("password", false)?
554            .put_key("host", "test")?
555            .pos();
556        let s = &buf[..n];
557
558        let mut cur = Cursor::new(s);
559        assert_eq!(cur.get_key(), Ok((&b"p"[..], 49_u8)));
560        assert_eq!(cur.get_key(), Ok((&b"map"[..], "crossfire")));
561        assert_eq!(cur.get_key(), Ok((&b"dm"[..], true)));
562        assert_eq!(cur.get_key(), Ok((&b"team"[..], false)));
563        assert_eq!(cur.get_key(), Ok((&b"coop"[..], false)));
564        assert_eq!(cur.get_key(), Ok((&b"numcl"[..], 4_u8)));
565        assert_eq!(cur.get_key(), Ok((&b"maxcl"[..], 32_u8)));
566        assert_eq!(cur.get_key(), Ok((&b"gamedir"[..], "valve")));
567        assert_eq!(cur.get_key(), Ok((&b"password"[..], false)));
568        assert_eq!(cur.get_key(), Ok((&b"host"[..], "test")));
569        assert_eq!(cur.get_key::<&[u8]>(), Err(Error::TableEnd));
570
571        Ok(())
572    }
573}