1#![no_std]
6
7#[cfg(test)]
8extern crate std;
9
10use core::str;
11
12pub fn with_str_bytes<R, F>(s: &mut str, f: F) -> R
39where
40 F: FnOnce(&mut [u8]) -> R,
41{
42 struct Guard<'a> {
43 bytes: &'a mut [u8],
44 panicking: bool,
45 }
46 impl Drop for Guard<'_> {
47 fn drop(&mut self) {
48 if self.panicking {
49 for byte in &mut *self.bytes {
50 *byte = 0;
51 }
52 } else if let Err(e) = str::from_utf8(self.bytes) {
53 for byte in &mut self.bytes[e.valid_up_to()..] {
54 *byte = 0;
55 }
56 panic!("`with_bytes` encountered invalid utf-8: {}", e);
57 }
58 }
59 }
60
61 let mut guard = Guard {
62 bytes: unsafe { s.as_bytes_mut() },
63 panicking: true,
64 };
65 let ret = f(&mut guard.bytes);
66 guard.panicking = false;
67 ret
68}
69
70#[cfg(test)]
71mod tests {
72 use std::boxed::Box;
73 use std::panic::{self, AssertUnwindSafe};
74 use std::string::String;
75
76 use super::with_str_bytes;
77
78 #[test]
79 fn empty() {
80 let mut data: Box<str> = Box::from("");
81 with_str_bytes(&mut data, |bytes| {
82 assert_eq!(bytes, &mut []);
83 });
84 assert_eq!(&*data, "");
85 }
86
87 #[test]
88 fn valid_utf8() {
89 let initial = "--------------------------";
90 let replaced = b"Lorem ipsum dolor sit amet";
91
92 let mut data: Box<str> = Box::from(initial);
93 with_str_bytes(&mut data, |bytes| {
94 bytes.copy_from_slice(replaced);
95 });
96 assert_eq!(data.as_bytes(), replaced);
97 }
98
99 #[test]
100 fn invalid_utf8() {
101 let mut data: Box<str> = Box::from("abc");
102
103 let msg = *panic::catch_unwind(AssertUnwindSafe(|| {
104 with_str_bytes(&mut data, |bytes| {
105 bytes[1] = 0xC0;
106 });
107 }))
108 .unwrap_err()
109 .downcast::<String>()
110 .unwrap();
111
112 assert_eq!(msg, "`with_bytes` encountered invalid utf-8: invalid utf-8 sequence of 1 bytes from index 1");
113
114 assert_eq!(&*data, "a\0\0");
115 }
116
117 #[test]
118 fn panics() {
119 let mut data: Box<str> = Box::from("abc");
120
121 let msg = *panic::catch_unwind(AssertUnwindSafe(|| {
122 with_str_bytes(&mut data, |_| panic!("Oh no"));
123 }))
124 .unwrap_err()
125 .downcast::<&'static str>()
126 .unwrap();
127
128 assert_eq!(msg, "Oh no");
129
130 assert_eq!(&*data, "\0\0\0");
131 }
132}