1#[cfg(feature = "std")]
2use std::vec::Vec;
3
4use super::{Error, vec::SecureVec};
5use core::ops::Range;
6use zeroize::Zeroize;
7
8#[derive(Clone)]
53pub struct SecureString {
54 vec: SecureVec<u8>,
55}
56
57impl SecureString {
58 pub fn new() -> Result<Self, Error> {
59 let vec = SecureVec::new()?;
60 Ok(SecureString { vec })
61 }
62
63 pub fn new_with_capacity(capacity: usize) -> Result<Self, Error> {
64 let vec = SecureVec::with_capacity(capacity)?;
65 Ok(SecureString { vec })
66 }
67
68 pub fn erase(&mut self) {
69 self.vec.erase();
70 }
71
72 pub fn len(&self) -> usize {
73 self.vec.len()
74 }
75
76 pub fn drain(&mut self, range: Range<usize>) {
77 let _ = self.vec.drain(range);
78 }
79
80 pub fn char_len(&self) -> usize {
81 self.str_scope(|s| s.chars().count())
82 }
83
84 pub fn push_str(&mut self, string: &str) {
86 let slice = string.as_bytes();
87 for s in slice.iter() {
88 self.vec.push(*s);
89 }
90 }
91
92 pub fn str_scope<F, R>(&self, f: F) -> R
101 where
102 F: FnOnce(&str) -> R,
103 {
104 self.vec.slice_scope(|slice| {
105 let str = core::str::from_utf8(slice).unwrap();
106 let result = f(str);
107 result
108 })
109 }
110
111 pub fn mut_scope<F, R>(&mut self, f: F) -> R
120 where
121 F: FnOnce(&mut SecureString) -> R,
122 {
123 f(self)
124 }
125
126 #[cfg(feature = "std")]
127 pub fn insert_text_at_char_idx(&mut self, char_idx: usize, text_to_insert: &str) -> usize {
128 let chars_to_insert_count = text_to_insert.chars().count();
129 if chars_to_insert_count == 0 {
130 return 0;
131 }
132
133 let bytes_to_insert = text_to_insert.as_bytes();
134 let insert_len = bytes_to_insert.len();
135
136 let byte_idx = self
138 .vec
139 .slice_scope(|current_bytes| char_to_byte_idx(current_bytes, char_idx));
140
141 let old_byte_len = self.vec.len();
142 let new_byte_len = old_byte_len + insert_len;
143
144 if new_byte_len > self.vec.capacity {
145 let mut temp_new_content = Vec::with_capacity(new_byte_len);
147 self.vec.slice_scope(|current_bytes| {
148 temp_new_content.extend_from_slice(¤t_bytes[..byte_idx]);
149 temp_new_content.extend_from_slice(bytes_to_insert);
150 if byte_idx < old_byte_len {
151 temp_new_content.extend_from_slice(¤t_bytes[byte_idx..]);
153 }
154 });
155
156 let mut new_secure_vec = SecureVec::with_capacity(new_byte_len).unwrap();
157 for &b in temp_new_content.iter() {
158 new_secure_vec.push(b);
159 }
160 temp_new_content.zeroize();
161
162 let mut old_vec_to_drop = core::mem::replace(&mut self.vec, new_secure_vec);
163 old_vec_to_drop.erase();
164 } else {
165 self.vec.unlock_memory();
166 unsafe {
167 let ptr = self.vec.as_mut_ptr();
168
169 if byte_idx < old_byte_len {
170 core::ptr::copy(
171 ptr.add(byte_idx),
172 ptr.add(byte_idx + insert_len),
173 old_byte_len - byte_idx,
174 );
175 }
176
177 core::ptr::copy_nonoverlapping(
178 bytes_to_insert.as_ptr(),
179 ptr.add(byte_idx),
180 insert_len,
181 );
182
183 self.vec.len = new_byte_len;
184 }
185 self.vec.lock_memory();
186 }
187
188 chars_to_insert_count
189 }
190
191 #[cfg(feature = "std")]
192 pub fn delete_text_char_range(&mut self, char_range: core::ops::Range<usize>) {
193 if char_range.start >= char_range.end {
194 return;
195 }
196
197 let (byte_start, byte_end) = self.str_scope(|str| {
198 let byte_start = char_to_byte_idx(str.as_bytes(), char_range.start);
199 let byte_end = char_to_byte_idx(str.as_bytes(), char_range.end);
200 (byte_start, byte_end)
201 });
202
203 let new_len = self.vec.slice_mut_scope(|current_bytes| {
204 if byte_start >= byte_end || byte_end > current_bytes.len() {
205 return 0;
206 }
207
208 let remove_len = byte_end - byte_start;
209 let old_total_len = current_bytes.len();
210
211 current_bytes.copy_within(byte_end..old_total_len, byte_start);
213
214 let new_len = old_total_len - remove_len;
215 for i in new_len..old_total_len {
217 current_bytes[i].zeroize();
218 }
219 new_len
220 });
221 self.vec.len = new_len;
222 }
223}
224
225impl From<&str> for SecureString {
226 fn from(s: &str) -> SecureString {
227 let bytes = s.as_bytes();
228 let len = bytes.len();
229
230 let mut new_vec = SecureVec::with_capacity(len).unwrap();
231 new_vec.len = len;
232
233 new_vec.slice_mut_scope(|slice| {
234 slice[..len].copy_from_slice(bytes);
235 });
236
237 SecureString { vec: new_vec }
238 }
239}
240
241#[cfg(feature = "serde")]
242impl serde::Serialize for SecureString {
243 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
244 where
245 S: serde::Serializer,
246 {
247 let res = self.str_scope(|str| serializer.serialize_str(str));
248 res
249 }
250}
251
252#[cfg(feature = "serde")]
253impl<'de> serde::Deserialize<'de> for SecureString {
254 fn deserialize<D>(deserializer: D) -> Result<SecureString, D::Error>
255 where
256 D: serde::Deserializer<'de>,
257 {
258 struct SecureStringVisitor;
259 impl<'de> serde::de::Visitor<'de> for SecureStringVisitor {
260 type Value = SecureString;
261 fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
262 write!(formatter, "an utf-8 encoded string")
263 }
264 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
265 where
266 E: serde::de::Error,
267 {
268 Ok(SecureString::from(v))
269 }
270 }
271 deserializer.deserialize_string(SecureStringVisitor)
272 }
273}
274
275#[cfg(feature = "std")]
276fn char_to_byte_idx(s_bytes: &[u8], char_idx: usize) -> usize {
277 core::str::from_utf8(s_bytes)
278 .ok()
279 .and_then(|s| s.char_indices().nth(char_idx).map(|(idx, _)| idx))
280 .unwrap_or(s_bytes.len()) }
282
283#[cfg(all(test, feature = "std"))]
284mod tests {
285 use super::*;
286
287 #[test]
288 fn test_creation() {
289 let hello_world = "Hello, world!";
290 let hello_world2 = String::from(hello_world);
291
292 let _ = SecureString::from(hello_world);
293 let _ = SecureString::from(hello_world2.as_str());
294 }
295
296 #[test]
297 fn test_clone() {
298 let hello_world = "Hello, world!";
299 let secure1 = SecureString::from(hello_world);
300 let secure2 = secure1.clone();
301
302 secure2.str_scope(|str| {
303 assert_eq!(str, hello_world);
304 });
305 }
306
307 #[test]
308 fn test_insert_text_at_char_idx() {
309 let hello_world = "My name is ";
310 let mut secure = SecureString::from(hello_world);
311 secure.insert_text_at_char_idx(12, "Mike");
312 secure.str_scope(|str| {
313 assert_eq!(str, "My name is Mike");
314 });
315 }
316
317 #[test]
318 fn test_delete_text_char_range() {
319 let hello_world = "My name is Mike";
320 let mut secure = SecureString::from(hello_world);
321 secure.delete_text_char_range(10..17);
322 secure.str_scope(|str| {
323 assert_eq!(str, "My name is");
324 });
325 }
326
327 #[test]
328 fn test_drain() {
329 let hello_world = "Hello, world!";
330 let mut secure = SecureString::from(hello_world);
331 secure.drain(0..7);
332 secure.str_scope(|str| {
333 assert_eq!(str, "world!");
334 });
335 }
336
337 #[cfg(feature = "serde")]
338 #[test]
339 fn test_secure_string_serde() {
340 let hello_world = "Hello, world!";
341 let secure = SecureString::from(hello_world);
342 let json = serde_json::to_string(&secure).expect("Serialization failed");
343 let deserialized: SecureString = serde_json::from_str(&json).expect("Deserialization failed");
344 deserialized.str_scope(|str| {
345 assert_eq!(str, hello_world);
346 });
347 }
348
349 #[test]
350 fn test_str_scope() {
351 let hello_word = "Hello, world!";
352 let string = SecureString::from(hello_word);
353 let _exposed_string = string.str_scope(|str| {
354 assert_eq!(str, hello_word);
355 String::from(str)
356 });
357 }
358
359 #[test]
360 fn test_push_str() {
361 let hello_world = "Hello, world!";
362
363 let mut string = SecureString::new().unwrap();
364 string.push_str(hello_world);
365 string.str_scope(|str| {
366 assert_eq!(str, hello_world);
367 });
368 }
369
370 #[test]
371 fn test_mut_scope() {
372 let hello_world = "Hello, world!";
373 let mut string = SecureString::from("Hello, ");
374 string.mut_scope(|string| {
375 string.push_str("world!");
376 });
377
378 string.str_scope(|str| {
379 assert_eq!(str, hello_world);
380 });
381 }
382}