1use super::{Error, vec::SecureVec};
2use core::ops::Range;
3use zeroize::Zeroize;
4
5#[derive(Clone)]
48pub struct SecureString {
49 vec: SecureVec<u8>,
50}
51
52impl SecureString {
53 pub fn new() -> Result<Self, Error> {
54 let vec = SecureVec::new()?;
55 Ok(SecureString { vec })
56 }
57
58 pub fn new_with_capacity(capacity: usize) -> Result<Self, Error> {
59 let vec = SecureVec::new_with_capacity(capacity)?;
60 Ok(SecureString { vec })
61 }
62
63 pub fn erase(&mut self) {
64 self.vec.erase();
65 }
66
67 pub fn len(&self) -> usize {
71 self.vec.len()
72 }
73
74 pub fn is_empty(&self) -> bool {
75 self.vec.is_empty()
76 }
77
78 pub fn drain(&mut self, range: Range<usize>) {
79 let _d = self.vec.drain(range);
80 }
81
82 pub fn char_len(&self) -> usize {
83 self.unlock_str(|s| s.chars().count())
84 }
85
86 pub fn push_str(&mut self, string: &str) {
88 let slice = string.as_bytes();
89 for s in slice.iter() {
90 self.vec.push(*s);
91 }
92 }
93
94 pub fn unlock_str<F, R>(&self, f: F) -> R
96 where
97 F: FnOnce(&str) -> R,
98 {
99 self.vec.unlock_slice(|slice| {
100 let str = core::str::from_utf8(slice).unwrap();
101 f(str)
102 })
103 }
104
105 pub fn unlock_mut<F, R>(&mut self, f: F) -> R
107 where
108 F: FnOnce(&mut SecureString) -> R,
109 {
110 f(self)
111 }
112
113 pub fn insert_text_at_char_idx(&mut self, char_idx: usize, text_to_insert: &str) -> usize {
131 let chars_to_insert_count = text_to_insert.chars().count();
132 if chars_to_insert_count == 0 {
133 return 0;
134 }
135
136 let bytes_to_insert = text_to_insert.as_bytes();
137 let insert_len = bytes_to_insert.len();
138
139 let byte_idx = self
141 .vec
142 .unlock_slice(|current_bytes| char_to_byte_idx(current_bytes, char_idx));
143
144 self.vec.reserve(insert_len);
145
146 let old_byte_len = self.vec.len();
147
148 self.vec.unlock_memory();
150 unsafe {
151 let ptr = self.vec.as_mut_ptr();
152
153 if byte_idx < old_byte_len {
156 core::ptr::copy(
157 ptr.add(byte_idx),
158 ptr.add(byte_idx + insert_len),
159 old_byte_len - byte_idx,
160 );
161 }
162
163 core::ptr::copy_nonoverlapping(
165 bytes_to_insert.as_ptr(),
166 ptr.add(byte_idx),
167 insert_len,
168 );
169
170 self.vec.len += insert_len;
171 }
172
173 self.vec.lock_memory();
174
175 chars_to_insert_count
176 }
177
178 pub fn delete_text_char_range(&mut self, char_range: core::ops::Range<usize>) {
192 if char_range.start >= char_range.end {
193 return;
194 }
195
196 let (byte_start, byte_end) = self.unlock_str(|str| {
197 let byte_start = char_to_byte_idx(str.as_bytes(), char_range.start);
198 let byte_end = char_to_byte_idx(str.as_bytes(), char_range.end);
199 (byte_start, byte_end)
200 });
201
202 let new_len = self.vec.unlock_slice_mut(|current_bytes| {
203 if byte_start >= byte_end || byte_end > current_bytes.len() {
204 return 0;
205 }
206
207 let remove_len = byte_end - byte_start;
208 let old_total_len = current_bytes.len();
209
210 current_bytes.copy_within(byte_end..old_total_len, byte_start);
212
213 let new_len = old_total_len - remove_len;
214 for i in new_len..old_total_len {
216 current_bytes[i].zeroize();
217 }
218 new_len
219 });
220 self.vec.len = new_len;
221 }
222}
223
224#[cfg(feature = "std")]
225impl From<String> for SecureString {
226 fn from(s: String) -> SecureString {
230 let vec = SecureVec::from_vec(s.into_bytes()).unwrap();
231 SecureString { vec }
232 }
233}
234
235impl From<&str> for SecureString {
236 fn from(s: &str) -> SecureString {
240 let bytes = s.as_bytes();
241 let len = bytes.len();
242
243 let mut new_vec = SecureVec::new_with_capacity(len).unwrap();
244 new_vec.len = len;
245
246 new_vec.unlock_slice_mut(|slice| {
247 slice[..len].copy_from_slice(bytes);
248 });
249
250 SecureString { vec: new_vec }
251 }
252}
253
254impl From<SecureVec<u8>> for SecureString {
255 fn from(vec: SecureVec<u8>) -> Self {
256 let mut new_string = SecureString::new().unwrap();
257 new_string.vec = vec;
258 new_string
259 }
260}
261
262#[cfg(feature = "serde")]
263impl serde::Serialize for SecureString {
264 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
265 where
266 S: serde::Serializer,
267 {
268 let res = self.unlock_str(|str| serializer.serialize_str(str));
269 res
270 }
271}
272
273#[cfg(feature = "serde")]
274impl<'de> serde::Deserialize<'de> for SecureString {
275 fn deserialize<D>(deserializer: D) -> Result<SecureString, D::Error>
276 where
277 D: serde::Deserializer<'de>,
278 {
279 struct SecureStringVisitor;
280 impl<'de> serde::de::Visitor<'de> for SecureStringVisitor {
281 type Value = SecureString;
282 fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
283 write!(formatter, "an utf-8 encoded string")
284 }
285 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
286 where
287 E: serde::de::Error,
288 {
289 Ok(SecureString::from(v))
290 }
291 }
292 deserializer.deserialize_string(SecureStringVisitor)
293 }
294}
295
296fn char_to_byte_idx(s_bytes: &[u8], char_idx: usize) -> usize {
297 core::str::from_utf8(s_bytes)
298 .ok()
299 .and_then(|s| s.char_indices().nth(char_idx).map(|(idx, _)| idx))
300 .unwrap_or(s_bytes.len()) }
302
303#[cfg(all(test, feature = "std"))]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_creation() {
309 let hello_world = "Hello, world!";
310 let secure = SecureString::from(hello_world);
311
312 secure.unlock_str(|str| {
313 assert_eq!(str, hello_world);
314 });
315 }
316
317 #[test]
318 fn test_from_string() {
319 let hello_world = String::from("Hello, world!");
320 let string = SecureString::from(hello_world);
321
322 string.unlock_str(|str| {
323 assert_eq!(str, "Hello, world!");
324 });
325 }
326
327 #[test]
328 fn test_from_secure_vec() {
329 let hello_world = "Hello, world!".to_string();
330 let vec: SecureVec<u8> = SecureVec::from_slice(hello_world.as_bytes()).unwrap();
331
332 let string = SecureString::from(hello_world);
333 let string2 = SecureString::from(vec);
334
335 string.unlock_str(|str| {
336 string2.unlock_str(|str2| {
337 assert_eq!(str, str2);
338 });
339 });
340 }
341
342 #[test]
343 fn test_clone() {
344 let hello_world = "Hello, world!".to_string();
345 let secure1 = SecureString::from(hello_world.clone());
346 let secure2 = secure1.clone();
347
348 secure2.unlock_str(|str| {
349 assert_eq!(str, hello_world);
350 });
351 }
352
353 #[test]
354 fn test_insert_text_at_char_idx() {
355 let hello_world = "My name is ";
356 let mut secure = SecureString::from(hello_world);
357 secure.insert_text_at_char_idx(12, "Mike");
358 secure.unlock_str(|str| {
359 assert_eq!(str, "My name is Mike");
360 });
361 }
362
363 #[test]
364 fn test_delete_text_char_range() {
365 let hello_world = "My name is Mike";
366 let mut secure = SecureString::from(hello_world);
367 secure.delete_text_char_range(10..17);
368 secure.unlock_str(|str| {
369 assert_eq!(str, "My name is");
370 });
371 }
372
373 #[test]
374 fn test_drain() {
375 let hello_world = "Hello, world!";
376 let mut secure = SecureString::from(hello_world);
377 secure.drain(0..7);
378 secure.unlock_str(|str| {
379 assert_eq!(str, "world!");
380 });
381 }
382
383 #[cfg(feature = "serde")]
384 #[test]
385 fn test_serde() {
386 let hello_world = "Hello, world!";
387 let secure = SecureString::from(hello_world);
388
389 let json_string = serde_json::to_string(&secure).expect("Serialization failed");
390 let json_bytes = serde_json::to_vec(&secure).expect("Serialization failed");
391
392 let deserialized_string: SecureString =
393 serde_json::from_str(&json_string).expect("Deserialization failed");
394
395 let deserialized_bytes: SecureString =
396 serde_json::from_slice(&json_bytes).expect("Deserialization failed");
397
398 deserialized_string.unlock_str(|str| {
399 assert_eq!(str, hello_world);
400 });
401
402 deserialized_bytes.unlock_str(|str| {
403 assert_eq!(str, hello_world);
404 });
405 }
406
407 #[test]
408 fn test_unlock_str() {
409 let hello_word = "Hello, world!";
410 let string = SecureString::from(hello_word);
411 let _exposed_string = string.unlock_str(|str| {
412 assert_eq!(str, hello_word);
413 String::from(str)
414 });
415 }
416
417 #[test]
418 fn test_push_str() {
419 let hello_world = "Hello, world!";
420
421 let mut string = SecureString::new().unwrap();
422 string.push_str(hello_world);
423 string.unlock_str(|str| {
424 assert_eq!(str, hello_world);
425 });
426 }
427
428 #[test]
429 fn test_unlock_mut() {
430 let hello_world = "Hello, world!";
431 let mut string = SecureString::from("Hello, ");
432 string.unlock_mut(|string| {
433 string.push_str("world!");
434 });
435
436 string.unlock_str(|str| {
437 assert_eq!(str, hello_world);
438 });
439 }
440}