sov_universal_wallet/schema/
safe_string.rs1use std::fmt::Display;
2use std::str::FromStr;
3
4use borsh::{BorshDeserialize, BorshSerialize};
5use thiserror::Error;
6
7use crate::schema::{IndexLinking, Item, Link, Primitive, Schema, UniversalWallet};
8
9#[derive(Debug, Error, Clone, PartialEq, Eq)]
10pub enum SchemaStringError {
11 #[error("String was too long: {length}, maximum: {max}")]
12 StringTooLong { length: usize, max: usize },
13 #[error("String contained invalid character: {character}. Only printable ASCII characters are allowed.")]
14 InvalidCharacter { character: char },
15}
16
17#[derive(Default, Hash, Clone, PartialEq, Eq, PartialOrd, Ord, BorshSerialize)]
30#[cfg_attr(
31 feature = "serde",
32 derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)
33)]
34#[cfg_attr(feature = "serde", serde(try_from = "String", into = "String"))]
35pub struct SizedSafeString<const MAX_LEN: usize>(String);
36
37pub const DEFAULT_MAX_STRING_LENGTH: usize = 128;
38pub type SafeString = SizedSafeString<DEFAULT_MAX_STRING_LENGTH>;
39
40impl<const MAX_LEN: usize> SizedSafeString<MAX_LEN> {
41 pub fn as_str(&self) -> &str {
42 &self.0
43 }
44
45 pub const fn max_len(&self) -> usize {
47 MAX_LEN
48 }
49
50 pub const fn new() -> Self {
52 Self(String::new())
53 }
54
55 pub fn len(&self) -> usize {
57 self.0.len()
58 }
59
60 pub fn is_empty(&self) -> bool {
62 self.0.is_empty()
63 }
64
65 pub fn try_push(&mut self, c: char) -> Result<(), SchemaStringError> {
67 if self.len() >= MAX_LEN {
68 return Err(SchemaStringError::StringTooLong {
69 length: self.len() + 1,
70 max: MAX_LEN,
71 });
72 }
73
74 if !Self::is_valid_char(c) {
75 return Err(SchemaStringError::InvalidCharacter { character: c });
76 }
77 self.0.push(c);
78 Ok(())
79 }
80
81 pub const fn is_valid_char(c: char) -> bool {
83 c.is_ascii() && !c.is_ascii_control()
84 }
85}
86
87impl<const MAX_LEN: usize> BorshDeserialize for SizedSafeString<MAX_LEN> {
88 fn deserialize_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<Self> {
89 let len = u32::deserialize_reader(reader)? as usize;
90 if len > MAX_LEN {
91 return Err(std::io::Error::new(
92 std::io::ErrorKind::InvalidData,
93 "Unexpected length of input",
94 ));
95 }
96 let mut output = Vec::with_capacity(len);
97 for _ in 0..len {
98 output.push(u8::deserialize_reader(reader)?);
99 }
100 let string = String::from_utf8(output)
101 .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid UTF-8"))?;
102 for c in string.chars() {
103 if !Self::is_valid_char(c) {
104 return Err(std::io::Error::new(
105 std::io::ErrorKind::InvalidData,
106 "Invalid character",
107 ));
108 }
109 }
110 Ok(Self(string))
111 }
112}
113
114impl<const MAX_LEN: usize> TryFrom<String> for SizedSafeString<MAX_LEN> {
115 type Error = SchemaStringError;
116
117 fn try_from(value: String) -> Result<Self, Self::Error> {
118 if value.len() > MAX_LEN {
119 return Err(SchemaStringError::StringTooLong {
120 length: value.len(),
121 max: MAX_LEN,
122 });
123 }
124 if let Some(invalid_c) = value.chars().find(|c| !Self::is_valid_char(*c)) {
125 return Err(SchemaStringError::InvalidCharacter {
126 character: invalid_c,
127 });
128 }
129 Ok(Self(value))
130 }
131}
132
133impl<const MAX_LEN: usize> FromStr for SizedSafeString<MAX_LEN> {
134 type Err = SchemaStringError;
135 fn from_str(s: &str) -> Result<Self, Self::Err> {
136 s.try_into()
137 }
138}
139
140impl<const MAX_LEN: usize> UniversalWallet for SizedSafeString<MAX_LEN> {
141 fn scaffold() -> Item<IndexLinking> {
142 Item::Atom(Primitive::String)
143 }
144 fn get_child_links(_schema: &mut Schema) -> Vec<Link> {
145 Vec::new()
146 }
147}
148
149impl<'a, const MAX_LEN: usize> TryFrom<&'a str> for SizedSafeString<MAX_LEN> {
150 type Error = SchemaStringError;
151
152 fn try_from(value: &'a str) -> Result<Self, Self::Error> {
153 value.to_string().try_into()
154 }
155}
156
157impl<const MAX_LEN: usize> From<SizedSafeString<MAX_LEN>> for String {
158 fn from(value: SizedSafeString<MAX_LEN>) -> Self {
159 value.0
160 }
161}
162
163impl<const MAX_LEN: usize> AsRef<[u8]> for SizedSafeString<MAX_LEN> {
164 fn as_ref(&self) -> &[u8] {
165 self.0.as_ref()
166 }
167}
168
169impl<const MAX_LEN: usize> AsRef<str> for SizedSafeString<MAX_LEN> {
170 fn as_ref(&self) -> &str {
171 self.0.as_ref()
172 }
173}
174
175impl<const MAX_LEN: usize> std::fmt::Debug for SizedSafeString<MAX_LEN> {
176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177 write!(f, "{}", self.0)
178 }
179}
180
181impl<const MAX_LEN: usize> Display for SizedSafeString<MAX_LEN> {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 write!(f, "{}", self.0)
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::{SafeString, SchemaStringError, SizedSafeString};
190
191 #[test]
192 fn test_sizedsafestring_maxlen() {
193 let string_good: String = ['a'; 31].iter().collect();
194 let string_bad: String = ['a'; 32].iter().collect();
195
196 let conversion_good = <SizedSafeString<31>>::try_from(string_good);
197 assert!(conversion_good.is_ok());
198
199 let conversion_bad = <SizedSafeString<31>>::try_from(string_bad);
200 assert_eq!(
201 conversion_bad,
202 Err(SchemaStringError::StringTooLong {
203 length: 32,
204 max: 31
205 })
206 );
207 }
208
209 #[test]
210 fn test_safestring_default_len() {
211 let string_good: String = ['a'; 128].iter().collect();
212 let string_bad: String = ['a'; 129].iter().collect();
213
214 let conversion_good = SafeString::try_from(string_good);
215 assert!(conversion_good.is_ok());
216
217 let conversion_bad = SafeString::try_from(string_bad);
218 assert_eq!(
219 conversion_bad,
220 Err(SchemaStringError::StringTooLong {
221 length: 129,
222 max: 128
223 })
224 );
225 }
226
227 #[test]
228 fn test_safestring_rejects_nonascii() {
229 let string = "hello •";
230 let conversion = SafeString::try_from(string);
231 assert_eq!(
232 conversion,
233 Err(SchemaStringError::InvalidCharacter { character: '•' })
234 );
235 }
236
237 #[test]
238 fn test_safestring_rejects_control_chars() {
239 let string = "hello \n world";
240 let conversion = SafeString::try_from(string);
241 assert_eq!(
242 conversion,
243 Err(SchemaStringError::InvalidCharacter { character: '\n' })
244 );
245 }
246
247 #[test]
248 fn json_deserializing_safestring_accepts_valid() {
249 let de: SafeString = serde_json::from_str("\"Good string\"").unwrap();
250 let expected: SafeString = "Good string".try_into().unwrap();
251 assert_eq!(de, expected);
252 }
253
254 #[test]
255 fn json_deserializing_safestring_rejects_invalid() {
256 let de: Result<SafeString, _> = serde_json::from_str("\"Bad•string\"");
257 assert!(de.is_err());
258 assert_eq!(
259 de.unwrap_err().to_string(),
260 "String contained invalid character: •. Only printable ASCII characters are allowed."
261 );
262 }
263
264 #[test]
265 fn test_safe_string_borsh_invalid_char() {
266 use borsh::{to_vec, BorshDeserialize};
267 let input = String::from_utf8(vec![b'\n'; 1]).unwrap();
269 assert_eq!(None, SafeString::try_from(input.clone()).ok());
270 let encoded = to_vec(&input).unwrap();
271 let output = SafeString::try_from_slice(&encoded);
272 assert!(output.is_err());
273 }
274
275 #[test]
276 fn test_safe_string_borsh_too_long() {
277 use borsh::{to_vec, BorshDeserialize};
278 let large_input = String::from_utf8(vec![b'a'; 300]).unwrap();
280 assert_eq!(None, SafeString::try_from(large_input.clone()).ok());
281 let encoded = to_vec(&large_input).unwrap();
282 let output = SafeString::try_from_slice(&encoded);
283 assert!(output.is_err());
284 }
285
286 #[test]
287 fn test_safe_string_serde_invalid_char() {
288 let de: Result<SafeString, _> = serde_json::from_str("\"\\n\"");
289 assert!(de.is_err());
290 assert_eq!(
291 de.unwrap_err().to_string(),
292 "String contained invalid character: \n. Only printable ASCII characters are allowed."
293 );
294 }
295
296 #[test]
297 fn test_safe_string_serde_too_long() {
298 let large_input = String::from_utf8(vec![b'a'; 300]).unwrap();
299 let de: Result<SafeString, _> = serde_json::from_str(&format!("\"{large_input}\""));
300 assert!(de.is_err());
301 assert_eq!(
302 de.unwrap_err().to_string(),
303 "String was too long: 300, maximum: 128"
304 );
305 }
306}