xash3d_protocol/
filter.rs

1// SPDX-License-Identifier: LGPL-3.0-only
2// SPDX-FileCopyrightText: 2023 Denis Drakhnia <numas13@gmail.com>
3
4//! Server query filter.
5//!
6//! # Supported filters:
7//!
8//! | Filter    | Type | Description                    | Examples |
9//! | --------- | ---- | ------------------------------ | -------- |
10//! | map       | str  | Map name                       | `crossfire`, `de_dust` |
11//! | gamedir   | str  | Game directory                 | `valve`, `cstrike` |
12//! | protocol  | u8   | Game directory                 | `48`, `49` |
13//! | dedicated | bool | Server running dedicated       | `0`, `1` |
14//! | lan       | bool | Server is LAN                  | `0`, `1` |
15//! | nat       | bool | Server behind NAT              | `0`, `1` |
16//! | noplayers | bool | Server is empty                | `0`, `1` |
17//! | empty     | bool | Server is not empty            | `0`, `1` |
18//! | full      | bool | Server is not full             | `0`, `1` |
19//! | password  | bool | Server is password prodected   | `0`, `1` |
20//! | secure    | bool | Server using anti-cheat        | `0`, `1` |
21//! | bots      | bool | Server has bots                | `0`, `1` |
22//!
23//! # Examples:
24//!
25//! Filter `\gamedir\valve\full\1\bots\0\password\0` will select server if:
26//!
27//! * It is Half-Life server
28//! * Is not full
29//! * Do not have bots
30//! * Is not protected by a password
31
32use std::fmt;
33use std::net::SocketAddr;
34use std::str::FromStr;
35
36use bitflags::bitflags;
37
38use crate::cursor::{Cursor, GetKeyValue, PutKeyValue};
39use crate::server::{ServerAdd, ServerFlags, ServerType};
40use crate::wrappers::Str;
41use crate::{CursorError, Error, ServerInfo};
42
43bitflags! {
44    /// Additional filter flags.
45    #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
46    pub struct FilterFlags: u16 {
47        /// Servers running dedicated
48        const DEDICATED     = 1 << 0;
49        /// Servers using anti-cheat technology (VAC, but potentially others as well)
50        const SECURE        = 1 << 1;
51        /// Servers that are not password protected
52        const PASSWORD      = 1 << 2;
53        /// Servers that are empty
54        const EMPTY         = 1 << 3;
55        /// Servers that are not full
56        const FULL          = 1 << 4;
57        /// Servers that are empty
58        const NOPLAYERS     = 1 << 5;
59        /// Servers that are behind NAT
60        const NAT           = 1 << 6;
61        /// Servers that are LAN
62        const LAN           = 1 << 7;
63        /// Servers that has bots
64        const BOTS          = 1 << 8;
65    }
66}
67
68impl<T> From<&ServerAdd<T>> for FilterFlags {
69    fn from(info: &ServerAdd<T>) -> Self {
70        let mut flags = Self::empty();
71
72        flags.set(Self::DEDICATED, info.server_type == ServerType::Dedicated);
73        flags.set(Self::SECURE, info.flags.contains(ServerFlags::SECURE));
74        flags.set(Self::PASSWORD, info.flags.contains(ServerFlags::PASSWORD));
75        flags.set(Self::EMPTY, info.players == 0);
76        flags.set(Self::FULL, info.players >= info.max);
77        flags.set(Self::NOPLAYERS, info.players == 0);
78        flags.set(Self::NAT, info.flags.contains(ServerFlags::NAT));
79        flags.set(Self::LAN, info.flags.contains(ServerFlags::LAN));
80        flags.set(Self::BOTS, info.flags.contains(ServerFlags::BOTS));
81
82        flags
83    }
84}
85
86/// Client or server version.
87#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord)]
88pub struct Version {
89    /// MAJOR version.
90    pub major: u8,
91    /// MINOR version.
92    pub minor: u8,
93    /// PATCH version.
94    pub patch: u8,
95}
96
97impl Version {
98    /// Creates a new `Version`.
99    pub const fn new(major: u8, minor: u8) -> Self {
100        Self::with_patch(major, minor, 0)
101    }
102
103    /// Creates a new `Version` with the specified `patch` version.
104    pub const fn with_patch(major: u8, minor: u8, patch: u8) -> Self {
105        Self {
106            major,
107            minor,
108            patch,
109        }
110    }
111}
112
113impl fmt::Debug for Version {
114    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
115        write!(fmt, "{}.{}", self.major, self.minor)?;
116        if self.patch != 0 {
117            write!(fmt, ".{}", self.patch)?;
118        }
119        Ok(())
120    }
121}
122
123impl fmt::Display for Version {
124    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
125        <Self as fmt::Debug>::fmt(self, fmt)
126    }
127}
128
129impl FromStr for Version {
130    type Err = CursorError;
131
132    fn from_str(s: &str) -> Result<Self, Self::Err> {
133        let (major, tail) = s.split_once('.').unwrap_or((s, "0"));
134        let (minor, patch) = tail.split_once('.').unwrap_or((tail, "0"));
135        let major = major.parse().map_err(|_| CursorError::InvalidNumber)?;
136        let minor = minor.parse().map_err(|_| CursorError::InvalidNumber)?;
137        let patch = patch.parse().map_err(|_| CursorError::InvalidNumber)?;
138        Ok(Self::with_patch(major, minor, patch))
139    }
140}
141
142impl GetKeyValue<'_> for Version {
143    fn get_key_value(cur: &mut Cursor) -> Result<Self, CursorError> {
144        cur.get_key_value().and_then(Self::from_str)
145    }
146}
147
148impl PutKeyValue for Version {
149    fn put_key_value<'a, 'b>(
150        &self,
151        cur: &'b mut crate::cursor::CursorMut<'a>,
152    ) -> Result<&'b mut crate::cursor::CursorMut<'a>, CursorError> {
153        cur.put_key_value(self.major)?
154            .put_u8(b'.')?
155            .put_key_value(self.minor)?;
156        if self.patch > 0 {
157            cur.put_u8(b'.')?.put_key_value(self.patch)?;
158        }
159        Ok(cur)
160    }
161}
162
163/// Server filter.
164#[derive(Clone, Debug, Default, PartialEq, Eq)]
165pub struct Filter<'a> {
166    /// Servers running the specified modification (ex. cstrike)
167    pub gamedir: Option<Str<&'a [u8]>>,
168    /// Servers running the specified map (ex. cs_italy)
169    pub map: Option<Str<&'a [u8]>>,
170    /// Client version.
171    pub clver: Option<Version>,
172    /// Protocol version
173    pub protocol: Option<u8>,
174    /// A number that master must sent back to game client.
175    pub key: Option<u32>,
176    /// Additional filter flags.
177    pub flags: FilterFlags,
178    /// Filter flags mask.
179    pub flags_mask: FilterFlags,
180}
181
182impl Filter<'_> {
183    /// Insert filter flag.
184    pub fn insert_flag(&mut self, flag: FilterFlags, value: bool) {
185        self.flags.set(flag, value);
186        self.flags_mask.insert(flag);
187    }
188
189    /// Test if all `other` flags are set in `flags_mask` and in `flags`.
190    pub fn contains_flags(&self, other: FilterFlags) -> Option<bool> {
191        if self.flags_mask.contains(other) {
192            Some(self.flags.contains(other))
193        } else {
194            None
195        }
196    }
197
198    /// Returns `true` if a server matches the filter.
199    pub fn matches(&self, _addr: SocketAddr, info: &ServerInfo) -> bool {
200        // TODO: match addr
201        !((info.flags & self.flags_mask) != self.flags
202            || self.gamedir.map_or(false, |s| *s != &*info.gamedir)
203            || self.map.map_or(false, |s| *s != &*info.map)
204            || self.protocol.map_or(false, |s| s != info.protocol))
205    }
206}
207
208impl<'a> TryFrom<&'a [u8]> for Filter<'a> {
209    type Error = Error;
210
211    fn try_from(src: &'a [u8]) -> Result<Self, Self::Error> {
212        trait Helper<'a> {
213            fn get<T: GetKeyValue<'a>>(&mut self, key: &'static str) -> Result<T, Error>;
214        }
215
216        impl<'a> Helper<'a> for Cursor<'a> {
217            fn get<T: GetKeyValue<'a>>(&mut self, key: &'static str) -> Result<T, Error> {
218                T::get_key_value(self).map_err(|e| Error::InvalidFilterValue(key, e))
219            }
220        }
221
222        let mut cur = Cursor::new(src);
223        let mut filter = Self::default();
224
225        loop {
226            let key = match cur.get_key_raw().map(Str) {
227                Ok(s) => s,
228                Err(CursorError::TableEnd) => break,
229                Err(e) => Err(e)?,
230            };
231
232            match *key {
233                b"dedicated" => filter.insert_flag(FilterFlags::DEDICATED, cur.get("dedicated")?),
234                b"secure" => filter.insert_flag(FilterFlags::SECURE, cur.get("secure")?),
235                b"gamedir" => filter.gamedir = Some(cur.get("gamedir")?),
236                b"map" => filter.map = Some(cur.get("map")?),
237                b"protocol" => filter.protocol = Some(cur.get("protocol")?),
238                b"empty" => filter.insert_flag(FilterFlags::EMPTY, cur.get("empty")?),
239                b"full" => filter.insert_flag(FilterFlags::FULL, cur.get("full")?),
240                b"password" => filter.insert_flag(FilterFlags::PASSWORD, cur.get("password")?),
241                b"noplayers" => filter.insert_flag(FilterFlags::NOPLAYERS, cur.get("noplayers")?),
242                b"clver" => filter.clver = Some(cur.get("clver")?),
243                b"nat" => filter.insert_flag(FilterFlags::NAT, cur.get("nat")?),
244                b"lan" => filter.insert_flag(FilterFlags::LAN, cur.get("lan")?),
245                b"bots" => filter.insert_flag(FilterFlags::BOTS, cur.get("bots")?),
246                b"key" => {
247                    filter.key = Some(
248                        cur.get_key_value::<&str>()
249                            .and_then(|s| {
250                                u32::from_str_radix(s, 16).map_err(|_| CursorError::InvalidNumber)
251                            })
252                            .map_err(|e| Error::InvalidFilterValue("key", e))?,
253                    )
254                }
255                _ => {
256                    // skip unknown fields
257                    let value = Str(cur.get_key_value_raw()?);
258                    debug!("Invalid Filter field \"{}\" = \"{}\"", key, value);
259                }
260            }
261        }
262
263        Ok(filter)
264    }
265}
266
267impl fmt::Display for &Filter<'_> {
268    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
269        macro_rules! display_flag {
270            ($n:expr, $f:expr) => {
271                if self.flags_mask.contains($f) {
272                    let flag = if self.flags.contains($f) { '1' } else { '0' };
273                    write!(fmt, "\\{}\\{}", $n, flag)?;
274                }
275            };
276        }
277
278        display_flag!("dedicated", FilterFlags::DEDICATED);
279        display_flag!("secure", FilterFlags::SECURE);
280        if let Some(s) = self.gamedir {
281            write!(fmt, "\\gamedir\\{}", s)?;
282        }
283        display_flag!("secure", FilterFlags::SECURE);
284        if let Some(s) = self.map {
285            write!(fmt, "\\map\\{}", s)?;
286        }
287        display_flag!("empty", FilterFlags::EMPTY);
288        display_flag!("full", FilterFlags::FULL);
289        display_flag!("password", FilterFlags::PASSWORD);
290        display_flag!("noplayers", FilterFlags::NOPLAYERS);
291        if let Some(v) = self.clver {
292            write!(fmt, "\\clver\\{}", v)?;
293        }
294        display_flag!("nat", FilterFlags::NAT);
295        display_flag!("lan", FilterFlags::LAN);
296        display_flag!("bots", FilterFlags::BOTS);
297        if let Some(x) = self.key {
298            write!(fmt, "\\key\\{:x}", x)?;
299        }
300        if let Some(x) = self.protocol {
301            write!(fmt, "\\protocol\\{}", x)?;
302        }
303        Ok(())
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use crate::cursor::CursorMut;
311    use crate::wrappers::Str;
312    use std::net::SocketAddr;
313
314    macro_rules! tests {
315        ($($name:ident$(($($predefined_f:ident: $predefined_v:expr),+ $(,)?))? {
316            $($src:expr => {
317                $($field:ident: $value:expr),* $(,)?
318            })+
319        })+) => {
320            $(#[test]
321            fn $name() {
322                let predefined = Filter {
323                    $($($predefined_f: $predefined_v,)+)?
324                    .. Filter::default()
325                };
326                $(assert_eq!(
327                    Filter::try_from($src as &[u8]),
328                    Ok(Filter {
329                        $($field: $value,)*
330                        ..predefined
331                    })
332                );)+
333            })+
334        };
335    }
336
337    tests! {
338        parse_gamedir {
339            b"\\gamedir\\valve" => {
340                gamedir: Some(Str(&b"valve"[..])),
341            }
342        }
343        parse_map {
344            b"\\map\\crossfire" => {
345                map: Some(Str(&b"crossfire"[..])),
346            }
347        }
348        parse_clver {
349            b"\\clver\\0.20" => {
350                clver: Some(Version::new(0, 20)),
351            }
352            b"\\clver\\0.19.3" => {
353                clver: Some(Version::with_patch(0, 19, 3)),
354            }
355        }
356        parse_protocol {
357            b"\\protocol\\48" => {
358                protocol: Some(48)
359            }
360        }
361        parse_dedicated(flags_mask: FilterFlags::DEDICATED) {
362            b"\\dedicated\\0" => {}
363            b"\\dedicated\\1" => {
364                flags: FilterFlags::DEDICATED,
365            }
366        }
367        parse_secure(flags_mask: FilterFlags::SECURE) {
368            b"\\secure\\0" => {}
369            b"\\secure\\1" => {
370                flags: FilterFlags::SECURE,
371            }
372        }
373        parse_password(flags_mask: FilterFlags::PASSWORD) {
374            b"\\password\\0" => {}
375            b"\\password\\1" => {
376                flags: FilterFlags::PASSWORD,
377            }
378        }
379        parse_empty(flags_mask: FilterFlags::EMPTY) {
380            b"\\empty\\0" => {}
381            b"\\empty\\1" => {
382                flags: FilterFlags::EMPTY,
383            }
384        }
385        parse_full(flags_mask: FilterFlags::FULL) {
386            b"\\full\\0" => {}
387            b"\\full\\1" => {
388                flags: FilterFlags::FULL,
389            }
390        }
391        parse_noplayers(flags_mask: FilterFlags::NOPLAYERS) {
392            b"\\noplayers\\0" => {}
393            b"\\noplayers\\1" => {
394                flags: FilterFlags::NOPLAYERS,
395            }
396        }
397        parse_nat(flags_mask: FilterFlags::NAT) {
398            b"\\nat\\0" => {}
399            b"\\nat\\1" => {
400                flags: FilterFlags::NAT,
401            }
402        }
403        parse_lan(flags_mask: FilterFlags::LAN) {
404            b"\\lan\\0" => {}
405            b"\\lan\\1" => {
406                flags: FilterFlags::LAN,
407            }
408        }
409        parse_bots(flags_mask: FilterFlags::BOTS) {
410            b"\\bots\\0" => {}
411            b"\\bots\\1" => {
412                flags: FilterFlags::BOTS,
413            }
414        }
415
416        parse_all {
417            b"\
418              \\bots\\1\
419              \\clver\\0.20\
420              \\dedicated\\1\
421              \\empty\\1\
422              \\full\\1\
423              \\gamedir\\valve\
424              \\lan\\1\
425              \\map\\crossfire\
426              \\nat\\1\
427              \\noplayers\\1\
428              \\password\\1\
429              \\secure\\1\
430              \\protocol\\49\
431            " => {
432                gamedir: Some(Str(&b"valve"[..])),
433                map: Some(Str(&b"crossfire"[..])),
434                protocol: Some(49),
435                clver: Some(Version::new(0, 20)),
436                flags: FilterFlags::all(),
437                flags_mask: FilterFlags::all(),
438            }
439        }
440    }
441
442    #[test]
443    fn version_to_key_value() {
444        let mut buf = [0; 64];
445        let n = CursorMut::new(&mut buf[..])
446            .put_key_value(Version::with_patch(0, 19, 3))
447            .unwrap()
448            .pos();
449        assert_eq!(&buf[..n], b"0.19.3");
450    }
451
452    macro_rules! servers {
453        ($($addr:expr => $info:expr $(=> $func:expr)?)+) => (
454            [$({
455                let addr = $addr.parse::<SocketAddr>().unwrap();
456                let mut buf = [0; 512];
457                let n = CursorMut::new(&mut buf)
458                    .put_bytes(ServerAdd::HEADER).unwrap()
459                    .put_key("challenge", 0).unwrap()
460                    .put_bytes($info).unwrap()
461                    .pos();
462                let p = ServerAdd::<Str<&[u8]>>::decode(&buf[..n]).unwrap();
463                let server = ServerInfo::new(&p);
464                $(
465                    let mut server = server;
466                    let func: fn(&mut Server) = $func;
467                    func(&mut server);
468                )?
469                (addr, server)
470            }),+]
471        );
472    }
473
474    macro_rules! matches {
475        ($servers:expr, $filter:expr$(, $expected:expr)*) => (
476            let servers = &$servers;
477            let filter = Filter::try_from($filter as &[u8]).unwrap();
478            let iter = servers
479                .iter()
480                .enumerate()
481                .filter(|(_, (addr, server))| filter.matches(*addr, &server))
482                .map(|(i, _)| i);
483            assert_eq!(iter.collect::<Vec<_>>(), [$($expected),*])
484        );
485    }
486
487    #[test]
488    fn match_dedicated() {
489        let s = servers! {
490            "0.0.0.0:0" => b""
491            "0.0.0.0:0" => b"\\type\\d"
492            "0.0.0.0:0" => b"\\type\\p"
493            "0.0.0.0:0" => b"\\type\\l"
494        };
495        matches!(s, b"", 0, 1, 2, 3);
496        matches!(s, b"\\dedicated\\0", 0, 2, 3);
497        matches!(s, b"\\dedicated\\1", 1);
498    }
499
500    #[test]
501    fn match_password() {
502        let s = servers! {
503            "0.0.0.0:0" => b""
504            "0.0.0.0:0" => b"\\password\\0"
505            "0.0.0.0:0" => b"\\password\\1"
506        };
507        matches!(s, b"", 0, 1, 2);
508        matches!(s, b"\\password\\0", 0, 1);
509        matches!(s, b"\\password\\1", 2);
510    }
511
512    #[test]
513    fn match_not_empty() {
514        let servers = servers! {
515            "0.0.0.0:0" => b"\\players\\0\\max\\8"
516            "0.0.0.0:0" => b"\\players\\4\\max\\8"
517            "0.0.0.0:0" => b"\\players\\8\\max\\8"
518        };
519        matches!(servers, b"", 0, 1, 2);
520        matches!(servers, b"\\empty\\0", 1, 2);
521        matches!(servers, b"\\empty\\1", 0);
522    }
523
524    #[test]
525    fn match_full() {
526        let servers = servers! {
527            "0.0.0.0:0" => b"\\players\\0\\max\\8"
528            "0.0.0.0:0" => b"\\players\\4\\max\\8"
529            "0.0.0.0:0" => b"\\players\\8\\max\\8"
530        };
531        matches!(servers, b"", 0, 1, 2);
532        matches!(servers, b"\\full\\0", 0, 1);
533        matches!(servers, b"\\full\\1", 2);
534    }
535
536    #[test]
537    fn match_noplayers() {
538        let servers = servers! {
539            "0.0.0.0:0" => b"\\players\\0\\max\\8"
540            "0.0.0.0:0" => b"\\players\\4\\max\\8"
541            "0.0.0.0:0" => b"\\players\\8\\max\\8"
542        };
543        matches!(servers, b"", 0, 1, 2);
544        matches!(servers, b"\\noplayers\\0", 1, 2);
545        matches!(servers, b"\\noplayers\\1", 0);
546    }
547
548    #[test]
549    fn match_nat() {
550        let servers = servers! {
551            "0.0.0.0:0" => b""
552            "0.0.0.0:0" => b"\\nat\\0"
553            "0.0.0.0:0" => b"\\nat\\1"
554        };
555        matches!(servers, b"", 0, 1, 2);
556        matches!(servers, b"\\nat\\0", 0, 1);
557        matches!(servers, b"\\nat\\1", 2);
558    }
559
560    #[test]
561    fn match_lan() {
562        let servers = servers! {
563            "0.0.0.0:0" => b""
564            "0.0.0.0:0" => b"\\lan\\0"
565            "0.0.0.0:0" => b"\\lan\\1"
566        };
567        matches!(servers, b"", 0, 1, 2);
568        matches!(servers, b"\\lan\\0", 0, 1);
569        matches!(servers, b"\\lan\\1", 2);
570    }
571
572    #[test]
573    fn match_bots() {
574        let servers = servers! {
575            "0.0.0.0:0" => b""
576            "0.0.0.0:0" => b"\\bots\\0"
577            "0.0.0.0:0" => b"\\bots\\1"
578        };
579        matches!(servers, b"", 0, 1, 2);
580        matches!(servers, b"\\bots\\0", 0, 1);
581        matches!(servers, b"\\bots\\1", 2);
582    }
583
584    #[test]
585    fn match_gamedir() {
586        let servers = servers! {
587            "0.0.0.0:0" => b"\\gamedir\\valve"
588            "0.0.0.0:0" => b"\\gamedir\\cstrike"
589            "0.0.0.0:0" => b"\\gamedir\\dod"
590            "0.0.0.0:0" => b"\\gamedir\\portal"
591            "0.0.0.0:0" => b"\\gamedir\\left4dead"
592        };
593        matches!(servers, b"", 0, 1, 2, 3, 4);
594        matches!(servers, b"\\gamedir\\valve", 0);
595        matches!(servers, b"\\gamedir\\portal", 3);
596        matches!(servers, b"\\gamedir\\left4dead", 4);
597    }
598
599    #[test]
600    fn match_map() {
601        let servers = servers! {
602            "0.0.0.0:0" => b"\\map\\crossfire"
603            "0.0.0.0:0" => b"\\map\\boot_camp"
604            "0.0.0.0:0" => b"\\map\\de_dust"
605            "0.0.0.0:0" => b"\\map\\cs_office"
606        };
607        matches!(servers, b"", 0, 1, 2, 3);
608        matches!(servers, b"\\map\\crossfire", 0);
609        matches!(servers, b"\\map\\de_dust", 2);
610        matches!(servers, b"\\map\\cs_office", 3);
611    }
612
613    #[test]
614    fn match_protocol() {
615        let s = servers! {
616            "0.0.0.0:0" => b"\\protocol\\47"
617            "0.0.0.0:0" => b"\\protocol\\48"
618            "0.0.0.0:0" => b"\\protocol\\49"
619        };
620        matches!(s, b"", 0, 1, 2);
621        matches!(s, b"\\protocol\\48", 1);
622        matches!(s, b"\\protocol\\49", 2);
623    }
624}