Skip to main content

sov_universal_wallet/schema/
safe_string.rs

1use std::fmt::Display;
2use std::str::FromStr;
3
4use borsh::{BorshDeserialize, BorshSerialize};
5use thiserror::Error;
6
7use crate::schema::{IndexLinking, Item, Link, Primitive, Schema, UniversalWallet};
8
9#[derive(Debug, Error, Clone, PartialEq, Eq)]
10pub enum SchemaStringError {
11    #[error("String was too long: {length}, maximum: {max}")]
12    StringTooLong { length: usize, max: usize },
13    #[error("String contained invalid character: {character}. Only printable ASCII characters are allowed.")]
14    InvalidCharacter { character: char },
15}
16
17/// A String wrapper which enforces certain constraints to ensure it is safely displayable as part
18/// of a transaction without confusing the user. Only printable ASCII is allowed, and the length is
19/// limited.
20///
21/// `UniversalWallet` implementation is forbidden on `std::String` by default, to avoid the possibility
22/// of untrusted input supplying highly confusing text that tricks users into misunderstanding the
23/// transaction they are signing. `SafeString` enforces some constraints to mitigate this risk. If
24/// you need to encode a large data blob such as a hex string, use a `Vec<u8>` with the
25/// `[sov_wallet(display = "hex")]` attribute (or any of the other display styles). Avoid raw
26/// `String`s if possible.
27/// If an actual `String` is absolutely necessary, then a newtype wrapper can be used, on which
28/// `UniversalWallet` is derived manually.
29#[derive(Default, Hash, Clone, PartialEq, Eq, PartialOrd, Ord, BorshSerialize)]
30#[cfg_attr(
31    feature = "serde",
32    derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)
33)]
34#[cfg_attr(feature = "serde", serde(try_from = "String", into = "String"))]
35pub struct SizedSafeString<const MAX_LEN: usize>(String);
36
37pub const DEFAULT_MAX_STRING_LENGTH: usize = 128;
38pub type SafeString = SizedSafeString<DEFAULT_MAX_STRING_LENGTH>;
39
40impl<const MAX_LEN: usize> SizedSafeString<MAX_LEN> {
41    pub fn as_str(&self) -> &str {
42        &self.0
43    }
44
45    /// A convenience method to get the maximum length of SafeString instance
46    pub const fn max_len(&self) -> usize {
47        MAX_LEN
48    }
49
50    /// Return an empty SafeString. This method does not allocate
51    pub const fn new() -> Self {
52        Self(String::new())
53    }
54
55    /// Returns the length (*not* capacity or max_length) of the string in bytes
56    pub fn len(&self) -> usize {
57        self.0.len()
58    }
59
60    /// Returns true if the string is empty
61    pub fn is_empty(&self) -> bool {
62        self.0.is_empty()
63    }
64
65    /// Appends the given `char`` to the end of this `SizedSafeString` if possible.
66    pub fn try_push(&mut self, c: char) -> Result<(), SchemaStringError> {
67        if self.len() >= MAX_LEN {
68            return Err(SchemaStringError::StringTooLong {
69                length: self.len() + 1,
70                max: MAX_LEN,
71            });
72        }
73
74        if !Self::is_valid_char(c) {
75            return Err(SchemaStringError::InvalidCharacter { character: c });
76        }
77        self.0.push(c);
78        Ok(())
79    }
80
81    /// Returns true if the character is a valid member of `SizedSafeString`
82    pub const fn is_valid_char(c: char) -> bool {
83        c.is_ascii() && !c.is_ascii_control()
84    }
85}
86
87impl<const MAX_LEN: usize> BorshDeserialize for SizedSafeString<MAX_LEN> {
88    fn deserialize_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<Self> {
89        let len = u32::deserialize_reader(reader)? as usize;
90        if len > MAX_LEN {
91            return Err(std::io::Error::new(
92                std::io::ErrorKind::InvalidData,
93                "Unexpected length of input",
94            ));
95        }
96        let mut output = Vec::with_capacity(len);
97        for _ in 0..len {
98            output.push(u8::deserialize_reader(reader)?);
99        }
100        let string = String::from_utf8(output)
101            .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid UTF-8"))?;
102        for c in string.chars() {
103            if !Self::is_valid_char(c) {
104                return Err(std::io::Error::new(
105                    std::io::ErrorKind::InvalidData,
106                    "Invalid character",
107                ));
108            }
109        }
110        Ok(Self(string))
111    }
112}
113
114impl<const MAX_LEN: usize> TryFrom<String> for SizedSafeString<MAX_LEN> {
115    type Error = SchemaStringError;
116
117    fn try_from(value: String) -> Result<Self, Self::Error> {
118        if value.len() > MAX_LEN {
119            return Err(SchemaStringError::StringTooLong {
120                length: value.len(),
121                max: MAX_LEN,
122            });
123        }
124        if let Some(invalid_c) = value.chars().find(|c| !Self::is_valid_char(*c)) {
125            return Err(SchemaStringError::InvalidCharacter {
126                character: invalid_c,
127            });
128        }
129        Ok(Self(value))
130    }
131}
132
133impl<const MAX_LEN: usize> FromStr for SizedSafeString<MAX_LEN> {
134    type Err = SchemaStringError;
135    fn from_str(s: &str) -> Result<Self, Self::Err> {
136        s.try_into()
137    }
138}
139
140impl<const MAX_LEN: usize> UniversalWallet for SizedSafeString<MAX_LEN> {
141    fn scaffold() -> Item<IndexLinking> {
142        Item::Atom(Primitive::String)
143    }
144    fn get_child_links(_schema: &mut Schema) -> Vec<Link> {
145        Vec::new()
146    }
147}
148
149impl<'a, const MAX_LEN: usize> TryFrom<&'a str> for SizedSafeString<MAX_LEN> {
150    type Error = SchemaStringError;
151
152    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
153        value.to_string().try_into()
154    }
155}
156
157impl<const MAX_LEN: usize> From<SizedSafeString<MAX_LEN>> for String {
158    fn from(value: SizedSafeString<MAX_LEN>) -> Self {
159        value.0
160    }
161}
162
163impl<const MAX_LEN: usize> AsRef<[u8]> for SizedSafeString<MAX_LEN> {
164    fn as_ref(&self) -> &[u8] {
165        self.0.as_ref()
166    }
167}
168
169impl<const MAX_LEN: usize> AsRef<str> for SizedSafeString<MAX_LEN> {
170    fn as_ref(&self) -> &str {
171        self.0.as_ref()
172    }
173}
174
175impl<const MAX_LEN: usize> std::fmt::Debug for SizedSafeString<MAX_LEN> {
176    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177        write!(f, "{}", self.0)
178    }
179}
180
181impl<const MAX_LEN: usize> Display for SizedSafeString<MAX_LEN> {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        write!(f, "{}", self.0)
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::{SafeString, SchemaStringError, SizedSafeString};
190
191    #[test]
192    fn test_sizedsafestring_maxlen() {
193        let string_good: String = ['a'; 31].iter().collect();
194        let string_bad: String = ['a'; 32].iter().collect();
195
196        let conversion_good = <SizedSafeString<31>>::try_from(string_good);
197        assert!(conversion_good.is_ok());
198
199        let conversion_bad = <SizedSafeString<31>>::try_from(string_bad);
200        assert_eq!(
201            conversion_bad,
202            Err(SchemaStringError::StringTooLong {
203                length: 32,
204                max: 31
205            })
206        );
207    }
208
209    #[test]
210    fn test_safestring_default_len() {
211        let string_good: String = ['a'; 128].iter().collect();
212        let string_bad: String = ['a'; 129].iter().collect();
213
214        let conversion_good = SafeString::try_from(string_good);
215        assert!(conversion_good.is_ok());
216
217        let conversion_bad = SafeString::try_from(string_bad);
218        assert_eq!(
219            conversion_bad,
220            Err(SchemaStringError::StringTooLong {
221                length: 129,
222                max: 128
223            })
224        );
225    }
226
227    #[test]
228    fn test_safestring_rejects_nonascii() {
229        let string = "hello •";
230        let conversion = SafeString::try_from(string);
231        assert_eq!(
232            conversion,
233            Err(SchemaStringError::InvalidCharacter { character: '•' })
234        );
235    }
236
237    #[test]
238    fn test_safestring_rejects_control_chars() {
239        let string = "hello \n world";
240        let conversion = SafeString::try_from(string);
241        assert_eq!(
242            conversion,
243            Err(SchemaStringError::InvalidCharacter { character: '\n' })
244        );
245    }
246
247    #[test]
248    fn json_deserializing_safestring_accepts_valid() {
249        let de: SafeString = serde_json::from_str("\"Good string\"").unwrap();
250        let expected: SafeString = "Good string".try_into().unwrap();
251        assert_eq!(de, expected);
252    }
253
254    #[test]
255    fn json_deserializing_safestring_rejects_invalid() {
256        let de: Result<SafeString, _> = serde_json::from_str("\"Bad•string\"");
257        assert!(de.is_err());
258        assert_eq!(
259            de.unwrap_err().to_string(),
260            "String contained invalid character: •. Only printable ASCII characters are allowed."
261        );
262    }
263
264    #[test]
265    fn test_safe_string_borsh_invalid_char() {
266        use borsh::{to_vec, BorshDeserialize};
267        // the SafeString does not accept ascii control chars and is limited to 128 chars
268        let input = String::from_utf8(vec![b'\n'; 1]).unwrap();
269        assert_eq!(None, SafeString::try_from(input.clone()).ok());
270        let encoded = to_vec(&input).unwrap();
271        let output = SafeString::try_from_slice(&encoded);
272        assert!(output.is_err());
273    }
274
275    #[test]
276    fn test_safe_string_borsh_too_long() {
277        use borsh::{to_vec, BorshDeserialize};
278        // the SafeString does not accept ascii control chars and is limited to 128 chars
279        let large_input = String::from_utf8(vec![b'a'; 300]).unwrap();
280        assert_eq!(None, SafeString::try_from(large_input.clone()).ok());
281        let encoded = to_vec(&large_input).unwrap();
282        let output = SafeString::try_from_slice(&encoded);
283        assert!(output.is_err());
284    }
285
286    #[test]
287    fn test_safe_string_serde_invalid_char() {
288        let de: Result<SafeString, _> = serde_json::from_str("\"\\n\"");
289        assert!(de.is_err());
290        assert_eq!(
291            de.unwrap_err().to_string(),
292            "String contained invalid character: \n. Only printable ASCII characters are allowed."
293        );
294    }
295
296    #[test]
297    fn test_safe_string_serde_too_long() {
298        let large_input = String::from_utf8(vec![b'a'; 300]).unwrap();
299        let de: Result<SafeString, _> = serde_json::from_str(&format!("\"{large_input}\""));
300        assert!(de.is_err());
301        assert_eq!(
302            de.unwrap_err().to_string(),
303            "String was too long: 300, maximum: 128"
304        );
305    }
306}