1#[cfg(feature = "std")]
2use std::vec::Vec;
3
4use super::{Error, vec::SecureVec};
5use core::ops::Range;
6use zeroize::Zeroize;
7
8#[derive(Clone)]
53pub struct SecureString {
54 vec: SecureVec<u8>,
55}
56
57impl SecureString {
58 pub fn new() -> Result<Self, Error> {
59 let vec = SecureVec::new()?;
60 Ok(SecureString { vec })
61 }
62
63 pub fn new_with_capacity(capacity: usize) -> Result<Self, Error> {
64 let vec = SecureVec::new_with_capacity(capacity)?;
65 Ok(SecureString { vec })
66 }
67
68 pub fn erase(&mut self) {
69 self.vec.erase();
70 }
71
72 pub fn len(&self) -> usize {
73 self.vec.len()
74 }
75
76 pub fn drain(&mut self, range: Range<usize>) {
77 let _ = self.vec.drain(range);
78 }
79
80 pub fn char_len(&self) -> usize {
81 self.str_scope(|s| s.chars().count())
82 }
83
84 pub fn push_str(&mut self, string: &str) {
86 let slice = string.as_bytes();
87 for s in slice.iter() {
88 self.vec.push(*s);
89 }
90 }
91
92 pub fn str_scope<F, R>(&self, f: F) -> R
101 where
102 F: FnOnce(&str) -> R,
103 {
104 self.vec.slice_scope(|slice| {
105 let str = core::str::from_utf8(slice).unwrap();
106 let result = f(str);
107 result
108 })
109 }
110
111 pub fn mut_scope<F, R>(&mut self, f: F) -> R
120 where
121 F: FnOnce(&mut SecureString) -> R,
122 {
123 f(self)
124 }
125
126 #[cfg(feature = "std")]
127 pub fn insert_text_at_char_idx(&mut self, char_idx: usize, text_to_insert: &str) -> usize {
128 let chars_to_insert_count = text_to_insert.chars().count();
129 if chars_to_insert_count == 0 {
130 return 0;
131 }
132
133 let bytes_to_insert = text_to_insert.as_bytes();
134 let insert_len = bytes_to_insert.len();
135
136 let byte_idx = self
138 .vec
139 .slice_scope(|current_bytes| char_to_byte_idx(current_bytes, char_idx));
140
141 let old_byte_len = self.vec.len();
142 let new_byte_len = old_byte_len + insert_len;
143
144 if new_byte_len > self.vec.capacity {
145 let mut temp_new_content = Vec::with_capacity(new_byte_len);
147 self.vec.slice_scope(|current_bytes| {
148 temp_new_content.extend_from_slice(¤t_bytes[..byte_idx]);
149 temp_new_content.extend_from_slice(bytes_to_insert);
150 if byte_idx < old_byte_len {
151 temp_new_content.extend_from_slice(¤t_bytes[byte_idx..]);
153 }
154 });
155
156 let mut new_secure_vec = SecureVec::new_with_capacity(new_byte_len).unwrap();
157 for &b in temp_new_content.iter() {
158 new_secure_vec.push(b);
159 }
160 temp_new_content.zeroize();
161
162 let mut old_vec_to_drop = core::mem::replace(&mut self.vec, new_secure_vec);
163 old_vec_to_drop.erase();
164 } else {
165 self.vec.unlock_memory();
166 unsafe {
167 let ptr = self.vec.as_mut_ptr();
168
169 if byte_idx < old_byte_len {
170 core::ptr::copy(
171 ptr.add(byte_idx),
172 ptr.add(byte_idx + insert_len),
173 old_byte_len - byte_idx,
174 );
175 }
176
177 core::ptr::copy_nonoverlapping(
178 bytes_to_insert.as_ptr(),
179 ptr.add(byte_idx),
180 insert_len,
181 );
182
183 self.vec.len = new_byte_len;
184 }
185 self.vec.lock_memory();
186 }
187
188 chars_to_insert_count
189 }
190
191 #[cfg(feature = "std")]
192 pub fn delete_text_char_range(&mut self, char_range: core::ops::Range<usize>) {
193 if char_range.start >= char_range.end {
194 return;
195 }
196
197 let (byte_start, byte_end) = self.str_scope(|str| {
198 let byte_start = char_to_byte_idx(str.as_bytes(), char_range.start);
199 let byte_end = char_to_byte_idx(str.as_bytes(), char_range.end);
200 (byte_start, byte_end)
201 });
202
203 let new_len = self.vec.slice_mut_scope(|current_bytes| {
204 if byte_start >= byte_end || byte_end > current_bytes.len() {
205 return 0;
206 }
207
208 let remove_len = byte_end - byte_start;
209 let old_total_len = current_bytes.len();
210
211 current_bytes.copy_within(byte_end..old_total_len, byte_start);
213
214 let new_len = old_total_len - remove_len;
215 for i in new_len..old_total_len {
217 current_bytes[i].zeroize();
218 }
219 new_len
220 });
221 self.vec.len = new_len;
222 }
223}
224
225
226#[cfg(feature = "std")]
227impl From<String> for SecureString {
228 fn from(s: String) -> SecureString {
229 let vec = SecureVec::from_vec(s.into_bytes()).unwrap();
230 SecureString { vec }
231 }
232}
233
234impl From<&str> for SecureString {
235 fn from(s: &str) -> SecureString {
236 let bytes = s.as_bytes();
237 let len = bytes.len();
238
239 let mut new_vec = SecureVec::new_with_capacity(len).unwrap();
240 new_vec.len = len;
241
242 new_vec.slice_mut_scope(|slice| {
243 slice[..len].copy_from_slice(bytes);
244 });
245
246 SecureString { vec: new_vec }
247 }
248}
249
250#[cfg(feature = "serde")]
251impl serde::Serialize for SecureString {
252 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
253 where
254 S: serde::Serializer,
255 {
256 let res = self.str_scope(|str| serializer.serialize_str(str));
257 res
258 }
259}
260
261#[cfg(feature = "serde")]
262impl<'de> serde::Deserialize<'de> for SecureString {
263 fn deserialize<D>(deserializer: D) -> Result<SecureString, D::Error>
264 where
265 D: serde::Deserializer<'de>,
266 {
267 struct SecureStringVisitor;
268 impl<'de> serde::de::Visitor<'de> for SecureStringVisitor {
269 type Value = SecureString;
270 fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
271 write!(formatter, "an utf-8 encoded string")
272 }
273 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
274 where
275 E: serde::de::Error,
276 {
277 Ok(SecureString::from(v))
278 }
279 }
280 deserializer.deserialize_string(SecureStringVisitor)
281 }
282}
283
284#[cfg(feature = "std")]
285fn char_to_byte_idx(s_bytes: &[u8], char_idx: usize) -> usize {
286 core::str::from_utf8(s_bytes)
287 .ok()
288 .and_then(|s| s.char_indices().nth(char_idx).map(|(idx, _)| idx))
289 .unwrap_or(s_bytes.len()) }
291
292#[cfg(all(test, feature = "std"))]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_clone() {
298 let hello_world = "Hello, world!".to_string();
299 let secure1 = SecureString::from(hello_world.clone());
300 let secure2 = secure1.clone();
301
302 secure2.str_scope(|str| {
303 assert_eq!(str, hello_world);
304 });
305 }
306
307 #[test]
308 fn test_insert_text_at_char_idx() {
309 let hello_world = "My name is ";
310 let mut secure = SecureString::from(hello_world);
311 secure.insert_text_at_char_idx(12, "Mike");
312 secure.str_scope(|str| {
313 assert_eq!(str, "My name is Mike");
314 });
315 }
316
317 #[test]
318 fn test_delete_text_char_range() {
319 let hello_world = "My name is Mike";
320 let mut secure = SecureString::from(hello_world);
321 secure.delete_text_char_range(10..17);
322 secure.str_scope(|str| {
323 assert_eq!(str, "My name is");
324 });
325 }
326
327 #[test]
328 fn test_drain() {
329 let hello_world = "Hello, world!";
330 let mut secure = SecureString::from(hello_world);
331 secure.drain(0..7);
332 secure.str_scope(|str| {
333 assert_eq!(str, "world!");
334 });
335 }
336
337 #[cfg(feature = "serde")]
338 #[test]
339 fn test_secure_string_serde() {
340 let hello_world = "Hello, world!";
341 let secure = SecureString::from(hello_world);
342 let json = serde_json::to_string(&secure).expect("Serialization failed");
343 let deserialized: SecureString = serde_json::from_str(&json).expect("Deserialization failed");
344 deserialized.str_scope(|str| {
345 assert_eq!(str, hello_world);
346 });
347 }
348
349 #[test]
350 fn test_str_scope() {
351 let hello_word = "Hello, world!";
352 let string = SecureString::from(hello_word);
353 let _exposed_string = string.str_scope(|str| {
354 assert_eq!(str, hello_word);
355 String::from(str)
356 });
357 }
358
359 #[test]
360 fn test_push_str() {
361 let hello_world = "Hello, world!";
362
363 let mut string = SecureString::new().unwrap();
364 string.push_str(hello_world);
365 string.str_scope(|str| {
366 assert_eq!(str, hello_world);
367 });
368 }
369
370 #[test]
371 fn test_mut_scope() {
372 let hello_world = "Hello, world!";
373 let mut string = SecureString::from("Hello, ");
374 string.mut_scope(|string| {
375 string.push_str("world!");
376 });
377
378 string.str_scope(|str| {
379 assert_eq!(str, hello_world);
380 });
381 }
382}