1use super::{Error, vec::SecureVec};
2use core::ops::Range;
3use zeroize::Zeroize;
4
5#[derive(Clone)]
50pub struct SecureString {
51 vec: SecureVec<u8>,
52}
53
54impl SecureString {
55 pub fn new() -> Result<Self, Error> {
56 let vec = SecureVec::new()?;
57 Ok(SecureString { vec })
58 }
59
60 pub fn new_with_capacity(capacity: usize) -> Result<Self, Error> {
61 let vec = SecureVec::new_with_capacity(capacity)?;
62 Ok(SecureString { vec })
63 }
64
65 pub fn erase(&mut self) {
66 self.vec.erase();
67 }
68
69 pub fn len(&self) -> usize {
70 self.vec.len()
71 }
72
73 pub fn drain(&mut self, range: Range<usize>) {
74 let _d = self.vec.drain(range);
75 }
76
77 pub fn char_len(&self) -> usize {
78 self.str_scope(|s| s.chars().count())
79 }
80
81 pub fn push_str(&mut self, string: &str) {
83 let slice = string.as_bytes();
84 for s in slice.iter() {
85 self.vec.push(*s);
86 }
87 }
88
89 pub fn str_scope<F, R>(&self, f: F) -> R
98 where
99 F: FnOnce(&str) -> R,
100 {
101 self.vec.slice_scope(|slice| {
102 let str = core::str::from_utf8(slice).unwrap();
103 let result = f(str);
104 result
105 })
106 }
107
108 pub fn mut_scope<F, R>(&mut self, f: F) -> R
117 where
118 F: FnOnce(&mut SecureString) -> R,
119 {
120 f(self)
121 }
122
123 pub fn insert_text_at_char_idx(&mut self, char_idx: usize, text_to_insert: &str) -> usize {
124 let chars_to_insert_count = text_to_insert.chars().count();
125 if chars_to_insert_count == 0 {
126 return 0;
127 }
128
129 let bytes_to_insert = text_to_insert.as_bytes();
130 let insert_len = bytes_to_insert.len();
131
132 let byte_idx = self
134 .vec
135 .slice_scope(|current_bytes| char_to_byte_idx(current_bytes, char_idx));
136
137 self.vec.reserve(insert_len);
138
139 let old_byte_len = self.vec.len();
140
141 self.vec.unlock_memory();
143 unsafe {
144 let ptr = self.vec.as_mut_ptr();
145
146 if byte_idx < old_byte_len {
149 core::ptr::copy(
150 ptr.add(byte_idx),
151 ptr.add(byte_idx + insert_len),
152 old_byte_len - byte_idx,
153 );
154 }
155
156 core::ptr::copy_nonoverlapping(
158 bytes_to_insert.as_ptr(),
159 ptr.add(byte_idx),
160 insert_len,
161 );
162
163 self.vec.len += insert_len;
164 }
165
166 self.vec.lock_memory();
167
168 chars_to_insert_count
169 }
170
171 pub fn delete_text_char_range(&mut self, char_range: core::ops::Range<usize>) {
172 if char_range.start >= char_range.end {
173 return;
174 }
175
176 let (byte_start, byte_end) = self.str_scope(|str| {
177 let byte_start = char_to_byte_idx(str.as_bytes(), char_range.start);
178 let byte_end = char_to_byte_idx(str.as_bytes(), char_range.end);
179 (byte_start, byte_end)
180 });
181
182 let new_len = self.vec.slice_mut_scope(|current_bytes| {
183 if byte_start >= byte_end || byte_end > current_bytes.len() {
184 return 0;
185 }
186
187 let remove_len = byte_end - byte_start;
188 let old_total_len = current_bytes.len();
189
190 current_bytes.copy_within(byte_end..old_total_len, byte_start);
192
193 let new_len = old_total_len - remove_len;
194 for i in new_len..old_total_len {
196 current_bytes[i].zeroize();
197 }
198 new_len
199 });
200 self.vec.len = new_len;
201 }
202}
203
204#[cfg(feature = "std")]
205impl From<String> for SecureString {
206 fn from(s: String) -> SecureString {
207 let vec = SecureVec::from_vec(s.into_bytes()).unwrap();
208 SecureString { vec }
209 }
210}
211
212impl From<&str> for SecureString {
213 fn from(s: &str) -> SecureString {
214 let bytes = s.as_bytes();
215 let len = bytes.len();
216
217 let mut new_vec = SecureVec::new_with_capacity(len).unwrap();
218 new_vec.len = len;
219
220 new_vec.slice_mut_scope(|slice| {
221 slice[..len].copy_from_slice(bytes);
222 });
223
224 SecureString { vec: new_vec }
225 }
226}
227
228impl From<SecureVec<u8>> for SecureString {
229 fn from(vec: SecureVec<u8>) -> Self {
230 let mut new_string = SecureString::new().unwrap();
231 new_string.vec = vec;
232 new_string
233 }
234}
235
236#[cfg(feature = "serde")]
237impl serde::Serialize for SecureString {
238 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
239 where
240 S: serde::Serializer,
241 {
242 let res = self.str_scope(|str| serializer.serialize_str(str));
243 res
244 }
245}
246
247#[cfg(feature = "serde")]
248impl<'de> serde::Deserialize<'de> for SecureString {
249 fn deserialize<D>(deserializer: D) -> Result<SecureString, D::Error>
250 where
251 D: serde::Deserializer<'de>,
252 {
253 struct SecureStringVisitor;
254 impl<'de> serde::de::Visitor<'de> for SecureStringVisitor {
255 type Value = SecureString;
256 fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
257 write!(formatter, "an utf-8 encoded string")
258 }
259 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
260 where
261 E: serde::de::Error,
262 {
263 Ok(SecureString::from(v))
264 }
265 }
266 deserializer.deserialize_string(SecureStringVisitor)
267 }
268}
269
270fn char_to_byte_idx(s_bytes: &[u8], char_idx: usize) -> usize {
271 core::str::from_utf8(s_bytes)
272 .ok()
273 .and_then(|s| s.char_indices().nth(char_idx).map(|(idx, _)| idx))
274 .unwrap_or(s_bytes.len()) }
276
277#[cfg(all(test, feature = "std"))]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn test_clone() {
283 let hello_world = "Hello, world!".to_string();
284 let secure1 = SecureString::from(hello_world.clone());
285 let secure2 = secure1.clone();
286
287 secure2.str_scope(|str| {
288 assert_eq!(str, hello_world);
289 });
290 }
291
292 #[test]
293 fn test_from_secure_vec() {
294 let hello_world = "Hello, world!".to_string();
295 let vec: SecureVec<u8> = SecureVec::from_slice(hello_world.as_bytes()).unwrap();
296
297 let string = SecureString::from(hello_world);
298 let string2 = SecureString::from(vec);
299
300 string.str_scope(|str| {
301 string2.str_scope(|str2| {
302 assert_eq!(str, str2);
303 });
304 });
305 }
306
307 #[test]
308 fn test_insert_text_at_char_idx() {
309 let hello_world = "My name is ";
310 let mut secure = SecureString::from(hello_world);
311 secure.insert_text_at_char_idx(12, "Mike");
312 secure.str_scope(|str| {
313 assert_eq!(str, "My name is Mike");
314 });
315 }
316
317 #[test]
318 fn test_delete_text_char_range() {
319 let hello_world = "My name is Mike";
320 let mut secure = SecureString::from(hello_world);
321 secure.delete_text_char_range(10..17);
322 secure.str_scope(|str| {
323 assert_eq!(str, "My name is");
324 });
325 }
326
327 #[test]
328 fn test_drain() {
329 let hello_world = "Hello, world!";
330 let mut secure = SecureString::from(hello_world);
331 secure.drain(0..7);
332 secure.str_scope(|str| {
333 assert_eq!(str, "world!");
334 });
335 }
336
337 #[cfg(feature = "serde")]
338 #[test]
339 fn test_serde() {
340 let hello_world = "Hello, world!";
341 let secure = SecureString::from(hello_world);
342 let json_string = serde_json::to_string(&secure).expect("Serialization failed");
343 let json_bytes = serde_json::to_vec(&secure).expect("Serialization failed");
344 let deserialized_string: SecureString = serde_json::from_str(&json_string).expect("Deserialization failed");
345 let deserialized_bytes: SecureString = serde_json::from_slice(&json_bytes).expect("Deserialization failed");
346
347 deserialized_string.str_scope(|str| {
348 assert_eq!(str, hello_world);
349 });
350
351 deserialized_bytes.str_scope(|str| {
352 assert_eq!(str, hello_world);
353 });
354 }
355
356 #[test]
357 fn test_str_scope() {
358 let hello_word = "Hello, world!";
359 let string = SecureString::from(hello_word);
360 let _exposed_string = string.str_scope(|str| {
361 assert_eq!(str, hello_word);
362 String::from(str)
363 });
364 }
365
366 #[test]
367 fn test_push_str() {
368 let hello_world = "Hello, world!";
369
370 let mut string = SecureString::new().unwrap();
371 string.push_str(hello_world);
372 string.str_scope(|str| {
373 assert_eq!(str, hello_world);
374 });
375 }
376
377 #[test]
378 fn test_mut_scope() {
379 let hello_world = "Hello, world!";
380 let mut string = SecureString::from("Hello, ");
381 string.mut_scope(|string| {
382 string.push_str("world!");
383 });
384
385 string.str_scope(|str| {
386 assert_eq!(str, hello_world);
387 });
388 }
389}