secure_types/
string.rs

1use super::{Error, vec::SecureVec};
2use core::ops::Range;
3use zeroize::Zeroize;
4
5/// A securely allocated, growable UTF-8 string, just like `std::string::String`.
6///
7/// It is a wrapper around [SecureVec<u8>] and inherits all of its security guarantees.
8///
9/// Access to the string contents is provided through scoped methods like `unlock_str`,
10/// which ensure the memory is only unlocked for the briefest possible time.
11/// 
12/// # Notes
13/// 
14/// If you return a new allocated `String` from one of the unlock methods you are responsible for zeroizing the memory.
15///
16/// # Example
17///
18/// ```
19/// use secure_types::{SecureString, Zeroize};
20///
21/// // Create a SecureString
22/// let mut secret = SecureString::from("my_super_secret");
23///
24/// // The memory is locked here
25///
26/// // Safely append more data.
27/// secret.push_str("_password");
28///
29/// // The memory is locked here.
30///
31/// // Use a scope to safely access the content as a &str.
32/// secret.unlock_str(|exposed_str| {
33///     assert_eq!(exposed_str, "my_super_secret_password");
34/// });
35/// 
36/// // Not recommended but if you allocate a new String make sure to zeroize it
37/// let mut exposed = secret.unlock_str(|exposed_str| {
38///     String::from(exposed_str)
39/// });
40/// 
41/// // Do what you need to to do with the new string
42/// // When you are done with it, zeroize it
43/// exposed.zeroize();
44///
45/// // When `secret` is dropped, its data zeroized.
46/// ```
47#[derive(Clone)]
48pub struct SecureString {
49   vec: SecureVec<u8>,
50}
51
52impl SecureString {
53   pub fn new() -> Result<Self, Error> {
54      let vec = SecureVec::new()?;
55      Ok(SecureString { vec })
56   }
57
58   pub fn new_with_capacity(capacity: usize) -> Result<Self, Error> {
59      let vec = SecureVec::new_with_capacity(capacity)?;
60      Ok(SecureString { vec })
61   }
62
63   pub fn erase(&mut self) {
64      self.vec.erase();
65   }
66
67   /// Returns the length of the inner `SecureVec`
68   ///
69   /// If you want the character length use [`char_len`](Self::char_len)
70   pub fn len(&self) -> usize {
71      self.vec.len()
72   }
73
74   pub fn is_empty(&self) -> bool {
75      self.vec.is_empty()
76   }
77
78   pub fn drain(&mut self, range: Range<usize>) {
79      let _d = self.vec.drain(range);
80   }
81
82   pub fn char_len(&self) -> usize {
83      self.unlock_str(|s| s.chars().count())
84   }
85
86   /// Push a `&str` into the `SecureString`
87   pub fn push_str(&mut self, string: &str) {
88      let slice = string.as_bytes();
89      for s in slice.iter() {
90         self.vec.push(*s);
91      }
92   }
93
94   /// Immutable access as `&str`
95   pub fn unlock_str<F, R>(&self, f: F) -> R
96   where
97      F: FnOnce(&str) -> R,
98   {
99      self.vec.unlock_slice(|slice| {
100         let str = core::str::from_utf8(slice).unwrap();
101         f(str)
102      })
103   }
104
105   /// Mutable access to the `SecureString`
106   pub fn unlock_mut<F, R>(&mut self, f: F) -> R
107   where
108      F: FnOnce(&mut SecureString) -> R,
109   {
110      f(self)
111   }
112
113   /// Inserts text at the given character index
114   /// 
115   /// # Returns
116   /// 
117   /// The number of characters inserted
118   ///
119   /// # Example
120   ///
121   /// ```
122   /// use secure_types::SecureString;
123   ///
124   /// let mut string = SecureString::from("GreekFeta");
125   /// string.insert_text_at_char_idx(9, "Cheese");
126   /// string.unlock_str(|str| {
127   ///     assert_eq!(str, "GreekFetaCheese");
128   /// });
129   /// ```
130   pub fn insert_text_at_char_idx(&mut self, char_idx: usize, text_to_insert: &str) -> usize {
131      let chars_to_insert_count = text_to_insert.chars().count();
132      if chars_to_insert_count == 0 {
133         return 0;
134      }
135
136      let bytes_to_insert = text_to_insert.as_bytes();
137      let insert_len = bytes_to_insert.len();
138
139      // Get the byte index corresponding to the character index
140      let byte_idx = self
141         .vec
142         .unlock_slice(|current_bytes| char_to_byte_idx(current_bytes, char_idx));
143
144      self.vec.reserve(insert_len);
145
146      let old_byte_len = self.vec.len();
147
148      // Perform the insertion in-place
149      self.vec.unlock_memory();
150      unsafe {
151         let ptr = self.vec.as_mut_ptr();
152
153         // Shift the "tail" of the string (from the insertion point to the end)
154         // to the right to make a gap for the new content.
155         if byte_idx < old_byte_len {
156            core::ptr::copy(
157               ptr.add(byte_idx),
158               ptr.add(byte_idx + insert_len),
159               old_byte_len - byte_idx,
160            );
161         }
162
163         // Copy the new text into the newly created gap.
164         core::ptr::copy_nonoverlapping(
165            bytes_to_insert.as_ptr(),
166            ptr.add(byte_idx),
167            insert_len,
168         );
169
170         self.vec.len += insert_len;
171      }
172
173      self.vec.lock_memory();
174
175      chars_to_insert_count
176   }
177
178   /// Deletes the text in the given character range
179   ///
180   /// # Example
181   ///
182   /// ```
183   /// use secure_types::SecureString;
184   ///
185   /// let mut string = SecureString::from("GreekFetaCheese");
186   /// string.delete_text_char_range(9..15);
187   /// string.unlock_str(|str| {
188   ///     assert_eq!(str, "GreekFeta");
189   /// });
190   /// ```
191   pub fn delete_text_char_range(&mut self, char_range: core::ops::Range<usize>) {
192      if char_range.start >= char_range.end {
193         return;
194      }
195
196      let (byte_start, byte_end) = self.unlock_str(|str| {
197         let byte_start = char_to_byte_idx(str.as_bytes(), char_range.start);
198         let byte_end = char_to_byte_idx(str.as_bytes(), char_range.end);
199         (byte_start, byte_end)
200      });
201
202      let new_len = self.vec.unlock_slice_mut(|current_bytes| {
203         if byte_start >= byte_end || byte_end > current_bytes.len() {
204            return 0;
205         }
206
207         let remove_len = byte_end - byte_start;
208         let old_total_len = current_bytes.len();
209
210         // Shift elements left
211         current_bytes.copy_within(byte_end..old_total_len, byte_start);
212
213         let new_len = old_total_len - remove_len;
214         // Zeroize the tail end that is now unused
215         for i in new_len..old_total_len {
216            current_bytes[i].zeroize();
217         }
218         new_len
219      });
220      self.vec.len = new_len;
221   }
222}
223
224#[cfg(feature = "std")]
225impl From<String> for SecureString {
226   /// Creates a new `SecureString` from a `String`.
227   ///
228   /// The `String` is zeroized afterwards.
229   fn from(s: String) -> SecureString {
230      let vec = SecureVec::from_vec(s.into_bytes()).unwrap();
231      SecureString { vec }
232   }
233}
234
235impl From<&str> for SecureString {
236   /// Creates a new `SecureString` from a `&str`.
237   ///
238   /// The `&str` is not zeroized, you are responsible for zeroizing it.
239   fn from(s: &str) -> SecureString {
240      let bytes = s.as_bytes();
241      let len = bytes.len();
242
243      let mut new_vec = SecureVec::new_with_capacity(len).unwrap();
244      new_vec.len = len;
245
246      new_vec.unlock_slice_mut(|slice| {
247         slice[..len].copy_from_slice(bytes);
248      });
249
250      SecureString { vec: new_vec }
251   }
252}
253
254impl From<SecureVec<u8>> for SecureString {
255   fn from(vec: SecureVec<u8>) -> Self {
256      let mut new_string = SecureString::new().unwrap();
257      new_string.vec = vec;
258      new_string
259   }
260}
261
262#[cfg(feature = "serde")]
263impl serde::Serialize for SecureString {
264   fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
265   where
266      S: serde::Serializer,
267   {
268      let res = self.unlock_str(|str| serializer.serialize_str(str));
269      res
270   }
271}
272
273#[cfg(feature = "serde")]
274impl<'de> serde::Deserialize<'de> for SecureString {
275   fn deserialize<D>(deserializer: D) -> Result<SecureString, D::Error>
276   where
277      D: serde::Deserializer<'de>,
278   {
279      struct SecureStringVisitor;
280      impl<'de> serde::de::Visitor<'de> for SecureStringVisitor {
281         type Value = SecureString;
282         fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
283            write!(formatter, "an utf-8 encoded string")
284         }
285         fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
286         where
287            E: serde::de::Error,
288         {
289            Ok(SecureString::from(v))
290         }
291      }
292      deserializer.deserialize_string(SecureStringVisitor)
293   }
294}
295
296fn char_to_byte_idx(s_bytes: &[u8], char_idx: usize) -> usize {
297   core::str::from_utf8(s_bytes)
298      .ok()
299      .and_then(|s| s.char_indices().nth(char_idx).map(|(idx, _)| idx))
300      .unwrap_or(s_bytes.len()) // Fallback to end if char_idx is out of bounds
301}
302
303#[cfg(all(test, feature = "std"))]
304mod tests {
305   use super::*;
306
307   #[test]
308   fn test_creation() {
309      let hello_world = "Hello, world!";
310      let secure = SecureString::from(hello_world);
311
312      secure.unlock_str(|str| {
313         assert_eq!(str, hello_world);
314      });
315   }
316
317   #[test]
318   fn test_from_string() {
319      let hello_world = String::from("Hello, world!");
320      let string = SecureString::from(hello_world);
321
322      string.unlock_str(|str| {
323         assert_eq!(str, "Hello, world!");
324      });
325   }
326
327   #[test]
328   fn test_from_secure_vec() {
329      let hello_world = "Hello, world!".to_string();
330      let vec: SecureVec<u8> = SecureVec::from_slice(hello_world.as_bytes()).unwrap();
331
332      let string = SecureString::from(hello_world);
333      let string2 = SecureString::from(vec);
334
335      string.unlock_str(|str| {
336         string2.unlock_str(|str2| {
337            assert_eq!(str, str2);
338         });
339      });
340   }
341
342   #[test]
343   fn test_clone() {
344      let hello_world = "Hello, world!".to_string();
345      let secure1 = SecureString::from(hello_world.clone());
346      let secure2 = secure1.clone();
347
348      secure2.unlock_str(|str| {
349         assert_eq!(str, hello_world);
350      });
351   }
352
353   #[test]
354   fn test_insert_text_at_char_idx() {
355      let hello_world = "My name is ";
356      let mut secure = SecureString::from(hello_world);
357      secure.insert_text_at_char_idx(12, "Mike");
358      secure.unlock_str(|str| {
359         assert_eq!(str, "My name is Mike");
360      });
361   }
362
363   #[test]
364   fn test_delete_text_char_range() {
365      let hello_world = "My name is Mike";
366      let mut secure = SecureString::from(hello_world);
367      secure.delete_text_char_range(10..17);
368      secure.unlock_str(|str| {
369         assert_eq!(str, "My name is");
370      });
371   }
372
373   #[test]
374   fn test_drain() {
375      let hello_world = "Hello, world!";
376      let mut secure = SecureString::from(hello_world);
377      secure.drain(0..7);
378      secure.unlock_str(|str| {
379         assert_eq!(str, "world!");
380      });
381   }
382
383   #[cfg(feature = "serde")]
384   #[test]
385   fn test_serde() {
386      let hello_world = "Hello, world!";
387      let secure = SecureString::from(hello_world);
388
389      let json_string = serde_json::to_string(&secure).expect("Serialization failed");
390      let json_bytes = serde_json::to_vec(&secure).expect("Serialization failed");
391
392      let deserialized_string: SecureString =
393         serde_json::from_str(&json_string).expect("Deserialization failed");
394
395      let deserialized_bytes: SecureString =
396         serde_json::from_slice(&json_bytes).expect("Deserialization failed");
397
398      deserialized_string.unlock_str(|str| {
399         assert_eq!(str, hello_world);
400      });
401
402      deserialized_bytes.unlock_str(|str| {
403         assert_eq!(str, hello_world);
404      });
405   }
406
407   #[test]
408   fn test_unlock_str() {
409      let hello_word = "Hello, world!";
410      let string = SecureString::from(hello_word);
411      let _exposed_string = string.unlock_str(|str| {
412         assert_eq!(str, hello_word);
413         String::from(str)
414      });
415   }
416
417   #[test]
418   fn test_push_str() {
419      let hello_world = "Hello, world!";
420
421      let mut string = SecureString::new().unwrap();
422      string.push_str(hello_world);
423      string.unlock_str(|str| {
424         assert_eq!(str, hello_world);
425      });
426   }
427
428   #[test]
429   fn test_unlock_mut() {
430      let hello_world = "Hello, world!";
431      let mut string = SecureString::from("Hello, ");
432      string.unlock_mut(|string| {
433         string.push_str("world!");
434      });
435
436      string.unlock_str(|str| {
437         assert_eq!(str, hello_world);
438      });
439   }
440}