smb_dtyp/binrw_util/
sized_string.rs

1#![allow(unused_assignments)]
2
3use binrw::io::Write;
4use binrw::{Endian, NamedArgs, prelude::*};
5use core::fmt::{self, Write as _};
6use std::{io::prelude::*, string::FromUtf16Error};
7
8/// Based on binrw::strings::NullWideString, but terminated by provided size rather than null char.
9#[derive(Clone, Eq, PartialEq, Default)]
10pub struct BaseSizedString<T> {
11    /// The raw wide byte string.
12    data: Vec<T>,
13}
14
15impl<T> BaseSizedString<T> {
16    const CHAR_WIDTH: u64 = std::mem::size_of::<T>() as u64;
17
18    /// Size of the string's data, in bytes.
19    ///
20    /// When using this struct, it is important to note how the size
21    /// of this string is calculated.
22    pub fn size(&self) -> u64 {
23        self.data.len() as u64 * Self::CHAR_WIDTH
24    }
25}
26
27#[derive(Debug, Clone, Copy)]
28pub enum SizedStringSize {
29    Bytes(u64),
30    Chars(u64),
31}
32
33impl SizedStringSize {
34    /// [SizedStringSize::Bytes] factory for u32 size.
35    #[inline]
36    pub fn bytes(n: u32) -> Self {
37        SizedStringSize::Bytes(n as u64)
38    }
39
40    /// [SizedStringSize::Bytes] factory for u16 size.
41    #[inline]
42    pub fn bytes16(n: u16) -> Self {
43        SizedStringSize::Bytes(n as u64)
44    }
45
46    /// [SizedStringSize::Chars] factory for u32 size.
47    #[inline]
48    pub fn chars(n: u32) -> Self {
49        SizedStringSize::Chars(n as u64)
50    }
51
52    /// [SizedStringSize::Chars] factory for u16 size.
53    #[inline]
54    pub fn chars16(n: u16) -> Self {
55        SizedStringSize::Chars(n as u64)
56    }
57
58    #[inline]
59    fn get_size_bytes<T: Sized>(&self) -> binrw::BinResult<u64> {
60        let size = match self {
61            SizedStringSize::Bytes(b) => *b,
62            SizedStringSize::Chars(c) => *c * std::mem::size_of::<T>() as u64,
63        };
64        if size % std::mem::size_of::<T>() as u64 != 0 {
65            return Err(binrw::Error::Custom {
66                pos: 0,
67                err: Box::new(format!(
68                    "SizedStringSize {:?} is not a multiple of char width {}",
69                    self,
70                    std::mem::size_of::<T>()
71                )),
72            });
73        }
74        Ok(size)
75    }
76}
77
78#[derive(NamedArgs, Debug)]
79pub struct BaseSizedStringReadArgs {
80    pub size: SizedStringSize,
81}
82
83impl<T> BinRead for BaseSizedString<T>
84where
85    T: BinRead,
86    T::Args<'static>: Default,
87{
88    type Args<'a> = BaseSizedStringReadArgs;
89
90    fn read_options<R: Read + Seek>(
91        reader: &mut R,
92        endian: Endian,
93        args: Self::Args<'_>,
94    ) -> BinResult<Self> {
95        let size_to_use = args.size.get_size_bytes::<T>()?;
96        if size_to_use == 0 {
97            return Err(binrw::Error::Custom {
98                pos: reader.stream_position()?,
99                err: Box::new(format!(
100                    "BaseSizedString<{}> had invalid read arguments {:?} - all None or zero",
101                    std::any::type_name::<T>(),
102                    args
103                )),
104            });
105        }
106
107        let size_chars = size_to_use / Self::CHAR_WIDTH;
108
109        let mut values = Vec::with_capacity(size_chars as usize);
110
111        for _ in 0..size_chars {
112            let val = <T>::read_options(reader, endian, Default::default())?;
113            values.push(val);
114        }
115        Ok(Self { data: values })
116    }
117}
118
119impl<T> BinWrite for BaseSizedString<T>
120where
121    T: BinWrite + 'static,
122    for<'a> T::Args<'a>: Clone,
123{
124    type Args<'a> = T::Args<'a>;
125
126    fn write_options<W: Write + Seek>(
127        &self,
128        writer: &mut W,
129        endian: Endian,
130        args: Self::Args<'_>,
131    ) -> BinResult<()> {
132        self.data.write_options(writer, endian, args)?;
133
134        Ok(())
135    }
136}
137
138impl<T> From<BaseSizedString<T>> for Vec<T> {
139    fn from(s: BaseSizedString<T>) -> Self {
140        s.data
141    }
142}
143
144impl<T> core::ops::Deref for BaseSizedString<T> {
145    type Target = Vec<T>;
146
147    fn deref(&self) -> &Self::Target {
148        &self.data
149    }
150}
151
152impl<T> core::ops::DerefMut for BaseSizedString<T> {
153    fn deref_mut(&mut self) -> &mut Self::Target {
154        &mut self.data
155    }
156}
157
158// TODO: Use this everywhere!
159// TODO: implement all the things beyond for it, as well.
160/// A fixed-size ANSI (single-byte) string, as opposed to [`binrw::NullString`].
161///
162/// Note: there's no support for locales in this structure.
163pub type SizedAnsiString = BaseSizedString<u8>;
164
165impl From<&str> for SizedAnsiString {
166    fn from(s: &str) -> Self {
167        assert!(s.is_ascii(), "String must be ASCII");
168        Self {
169            data: s.bytes().collect(),
170        }
171    }
172}
173
174impl FromIterator<u8> for SizedAnsiString {
175    fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
176        Self {
177            data: iter.into_iter().collect(),
178        }
179    }
180}
181
182impl TryFrom<SizedAnsiString> for String {
183    type Error = std::string::FromUtf8Error;
184
185    fn try_from(value: SizedAnsiString) -> Result<Self, Self::Error> {
186        // Every ANSI string is valid UTF-8 (ignoring page codes & locales)
187        String::from_utf8(value.data)
188    }
189}
190
191impl PartialEq<&str> for SizedAnsiString {
192    fn eq(&self, other: &&str) -> bool {
193        if !other.is_ascii() {
194            return false;
195        }
196        other.as_bytes().iter().eq(self.data.iter())
197    }
198}
199
200impl fmt::Display for SizedAnsiString {
201    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202        display_ansi(&self.data, f, core::iter::once)
203    }
204}
205
206impl fmt::Debug for SizedAnsiString {
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        write!(f, "SizedAnsiString(\"")?;
209        display_ansi(&self.data, f, char::escape_debug)?;
210        write!(f, "\")")
211    }
212}
213
214#[inline]
215fn display_ansi<Transformer: Fn(char) -> O + Clone, O: Iterator<Item = char>>(
216    input: &[u8],
217    f: &mut fmt::Formatter<'_>,
218    t: Transformer,
219) -> fmt::Result {
220    input
221        .iter()
222        .flat_map(|&b| char::from_u32(b as u32).into_iter().flat_map(t.clone()))
223        .try_for_each(|c| f.write_char(c))
224}
225
226/// A fixed-size wide (UTF-16) string, as opposed to [`binrw::NullWideString`].
227pub type SizedWideString = BaseSizedString<u16>;
228
229impl From<&str> for SizedWideString {
230    fn from(s: &str) -> Self {
231        Self {
232            data: s.encode_utf16().collect(),
233        }
234    }
235}
236
237impl FromIterator<u16> for SizedWideString {
238    fn from_iter<T: IntoIterator<Item = u16>>(iter: T) -> Self {
239        Self {
240            data: iter.into_iter().collect(),
241        }
242    }
243}
244
245impl From<String> for SizedWideString {
246    fn from(s: String) -> Self {
247        Self {
248            data: s.encode_utf16().collect(),
249        }
250    }
251}
252
253impl TryFrom<SizedWideString> for String {
254    type Error = FromUtf16Error;
255
256    fn try_from(value: SizedWideString) -> Result<Self, Self::Error> {
257        String::from_utf16(&value.data)
258    }
259}
260
261impl PartialEq<&str> for SizedWideString {
262    fn eq(&self, other: &&str) -> bool {
263        other.encode_utf16().eq(self.data.iter().copied())
264    }
265}
266
267impl fmt::Display for SizedWideString {
268    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269        display_utf16(&self.data, f, core::iter::once)
270    }
271}
272
273impl fmt::Debug for SizedWideString {
274    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
275        write!(f, "SizedWideString(\"")?;
276        display_utf16(&self.data, f, char::escape_debug)?;
277        write!(f, "\")")
278    }
279}
280
281#[inline]
282pub(crate) fn display_utf16<Transformer: Fn(char) -> O, O: Iterator<Item = char>>(
283    input: &[u16],
284    f: &mut fmt::Formatter<'_>,
285    t: Transformer,
286) -> fmt::Result {
287    char::decode_utf16(input.iter().copied())
288        .flat_map(|r| t(r.unwrap_or(char::REPLACEMENT_CHARACTER)))
289        .try_for_each(|c| f.write_char(c))
290}
291
292mod tests {
293    macro_rules! make_sized_string_tests {
294        ($name:ident, $type:ty) => {
295            #[test]
296            fn $name() {
297                use super::*;
298                let a = BaseSizedString::<$type>::from("hello");
299                assert_eq!(a, "hello");
300                assert_ne!(a, "hello world");
301                assert_ne!(a, "hel");
302                assert_ne!(a, "hello\0");
303
304                let b: BaseSizedString<$type> = a.clone();
305                assert_eq!(b, a);
306                assert_eq!(b.data, a.data);
307            }
308        };
309    }
310    make_sized_string_tests!(test_ansi_peq, u8);
311    make_sized_string_tests!(test_wide_peq, u16);
312}