snapbox/data/
runtime.rs

1use std::collections::BTreeMap;
2
3use super::Data;
4use super::Inline;
5use super::Position;
6
7pub(crate) fn get() -> std::sync::MutexGuard<'static, Runtime> {
8    static RT: std::sync::Mutex<Runtime> = std::sync::Mutex::new(Runtime::new());
9    RT.lock().unwrap_or_else(|poisoned| poisoned.into_inner())
10}
11
12#[derive(Default)]
13pub(crate) struct Runtime {
14    per_file: Vec<SourceFileRuntime>,
15    path_count: Vec<PathRuntime>,
16}
17
18impl Runtime {
19    const fn new() -> Self {
20        Self {
21            per_file: Vec::new(),
22            path_count: Vec::new(),
23        }
24    }
25
26    pub(crate) fn count(&mut self, path_prefix: &str) -> usize {
27        if let Some(entry) = self
28            .path_count
29            .iter_mut()
30            .find(|entry| entry.is(path_prefix))
31        {
32            entry.next()
33        } else {
34            let entry = PathRuntime::new(path_prefix);
35            let next = entry.count();
36            self.path_count.push(entry);
37            next
38        }
39    }
40
41    pub(crate) fn write(&mut self, actual: &Data, inline: &Inline) -> std::io::Result<()> {
42        let actual = actual.render().expect("`actual` must be UTF-8");
43        if let Some(entry) = self
44            .per_file
45            .iter_mut()
46            .find(|f| f.path == inline.position.file)
47        {
48            entry.update(&actual, inline)?;
49        } else {
50            let mut entry = SourceFileRuntime::new(inline)?;
51            entry.update(&actual, inline)?;
52            self.per_file.push(entry);
53        }
54
55        Ok(())
56    }
57}
58
59struct SourceFileRuntime {
60    path: std::path::PathBuf,
61    original_text: String,
62    patchwork: Patchwork,
63}
64
65impl SourceFileRuntime {
66    fn new(inline: &Inline) -> std::io::Result<SourceFileRuntime> {
67        let path = inline.position.file.clone();
68        let original_text = std::fs::read_to_string(&path)?;
69        let patchwork = Patchwork::new(original_text.clone());
70        Ok(SourceFileRuntime {
71            path,
72            original_text,
73            patchwork,
74        })
75    }
76    fn update(&mut self, actual: &str, inline: &Inline) -> std::io::Result<()> {
77        let span = Span::from_pos(&inline.position, &self.original_text);
78        let patch = format_patch(actual);
79        self.patchwork.patch(span.literal_range, &patch)?;
80        std::fs::write(&inline.position.file, &self.patchwork.text)
81    }
82}
83
84#[derive(Debug)]
85struct Patchwork {
86    text: String,
87    indels: BTreeMap<OrdRange, (usize, String)>,
88}
89
90impl Patchwork {
91    fn new(text: String) -> Patchwork {
92        Patchwork {
93            text,
94            indels: BTreeMap::new(),
95        }
96    }
97    fn patch(&mut self, mut range: std::ops::Range<usize>, patch: &str) -> std::io::Result<()> {
98        let key: OrdRange = range.clone().into();
99        match self.indels.entry(key) {
100            std::collections::btree_map::Entry::Vacant(entry) => {
101                entry.insert((patch.len(), patch.to_owned()));
102            }
103            std::collections::btree_map::Entry::Occupied(entry) => {
104                if entry.get().1 == patch {
105                    return Ok(());
106                } else {
107                    return Err(std::io::Error::new(
108                        std::io::ErrorKind::Other,
109                        "cannot update as it was already modified",
110                    ));
111                }
112            }
113        }
114
115        let (delete, insert) = self
116            .indels
117            .iter()
118            .take_while(|(delete, _)| delete.start < range.start)
119            .map(|(delete, (insert, _))| (delete.end - delete.start, insert))
120            .fold((0usize, 0usize), |(x1, y1), (x2, y2)| (x1 + x2, y1 + y2));
121
122        for pos in &mut [&mut range.start, &mut range.end] {
123            **pos -= delete;
124            **pos += insert;
125        }
126
127        self.text.replace_range(range, patch);
128        Ok(())
129    }
130}
131
132#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
133struct OrdRange {
134    start: usize,
135    end: usize,
136}
137
138impl From<std::ops::Range<usize>> for OrdRange {
139    fn from(other: std::ops::Range<usize>) -> Self {
140        Self {
141            start: other.start,
142            end: other.end,
143        }
144    }
145}
146
147fn lit_kind_for_patch(patch: &str) -> StrLitKind {
148    let has_dquote = patch.chars().any(|c| c == '"');
149    if !has_dquote {
150        let has_bslash_or_newline = patch.chars().any(|c| matches!(c, '\\' | '\n'));
151        return if has_bslash_or_newline {
152            StrLitKind::Raw(1)
153        } else {
154            StrLitKind::Normal
155        };
156    }
157
158    // Find the maximum number of hashes that follow a double quote in the string.
159    // We need to use one more than that to delimit the string.
160    let leading_hashes = |s: &str| s.chars().take_while(|&c| c == '#').count();
161    let max_hashes = patch.split('"').map(leading_hashes).max().unwrap();
162    StrLitKind::Raw(max_hashes + 1)
163}
164
165fn format_patch(patch: &str) -> String {
166    let lit_kind = lit_kind_for_patch(patch);
167    let is_multiline = patch.contains('\n');
168
169    let mut buf = String::new();
170    if matches!(lit_kind, StrLitKind::Raw(_)) {
171        buf.push('[');
172    }
173    lit_kind.write_start(&mut buf).unwrap();
174    if is_multiline {
175        buf.push('\n');
176    }
177    buf.push_str(patch);
178    if is_multiline {
179        buf.push('\n');
180    }
181    lit_kind.write_end(&mut buf).unwrap();
182    if matches!(lit_kind, StrLitKind::Raw(_)) {
183        buf.push(']');
184    }
185    buf
186}
187
188#[derive(Clone, Debug)]
189struct Span {
190    /// The byte range of the argument to `expect!`, including the inner `[]` if it exists.
191    literal_range: std::ops::Range<usize>,
192}
193
194impl Span {
195    fn from_pos(pos: &Position, file: &str) -> Span {
196        let mut target_line = None;
197        let mut line_start = 0;
198        for (i, line) in crate::utils::LinesWithTerminator::new(file).enumerate() {
199            if i == pos.line as usize - 1 {
200                // `column` points to the first character of the macro invocation:
201                //
202                //    expect![[r#""#]]        expect![""]
203                //    ^       ^               ^       ^
204                //  column   offset                 offset
205                //
206                // Seek past the exclam, then skip any whitespace and
207                // the macro delimiter to get to our argument.
208                #[allow(clippy::skip_while_next)]
209                let byte_offset = line
210                    .char_indices()
211                    .skip((pos.column - 1).try_into().unwrap())
212                    .skip_while(|&(_, c)| c != '!')
213                    .skip(1) // !
214                    .skip_while(|&(_, c)| c.is_whitespace())
215                    .skip(1) // [({
216                    .skip_while(|&(_, c)| c.is_whitespace())
217                    .next()
218                    .expect("Failed to parse macro invocation")
219                    .0;
220
221                let literal_start = line_start + byte_offset;
222                target_line = Some(literal_start);
223                break;
224            }
225            line_start += line.len();
226        }
227        let literal_start = target_line.unwrap();
228
229        let lit_to_eof = &file[literal_start..];
230        let lit_to_eof_trimmed = lit_to_eof.trim_start();
231
232        let literal_start = literal_start + (lit_to_eof.len() - lit_to_eof_trimmed.len());
233
234        let literal_len =
235            locate_end(lit_to_eof_trimmed).expect("Couldn't find closing delimiter for `expect!`.");
236        let literal_range = literal_start..literal_start + literal_len;
237        Span { literal_range }
238    }
239}
240
241fn locate_end(arg_start_to_eof: &str) -> Option<usize> {
242    match arg_start_to_eof.chars().next()? {
243        c if c.is_whitespace() => panic!("skip whitespace before calling `locate_end`"),
244
245        // expect![[]]
246        '[' => {
247            let str_start_to_eof = arg_start_to_eof[1..].trim_start();
248            let str_len = find_str_lit_len(str_start_to_eof)?;
249            let str_end_to_eof = &str_start_to_eof[str_len..];
250            let closing_brace_offset = str_end_to_eof.find(']')?;
251            Some((arg_start_to_eof.len() - str_end_to_eof.len()) + closing_brace_offset + 1)
252        }
253
254        // expect![] | expect!{} | expect!()
255        ']' | '}' | ')' => Some(0),
256
257        // expect!["..."] | expect![r#"..."#]
258        _ => find_str_lit_len(arg_start_to_eof),
259    }
260}
261
262/// Parses a string literal, returning the byte index of its last character
263/// (either a quote or a hash).
264fn find_str_lit_len(str_lit_to_eof: &str) -> Option<usize> {
265    fn try_find_n_hashes(
266        s: &mut impl Iterator<Item = char>,
267        desired_hashes: usize,
268    ) -> Option<(usize, Option<char>)> {
269        let mut n = 0;
270        loop {
271            match s.next()? {
272                '#' => n += 1,
273                c => return Some((n, Some(c))),
274            }
275
276            if n == desired_hashes {
277                return Some((n, None));
278            }
279        }
280    }
281
282    let mut s = str_lit_to_eof.chars();
283    let kind = match s.next()? {
284        '"' => StrLitKind::Normal,
285        'r' => {
286            let (n, c) = try_find_n_hashes(&mut s, usize::MAX)?;
287            if c != Some('"') {
288                return None;
289            }
290            StrLitKind::Raw(n)
291        }
292        _ => return None,
293    };
294
295    let mut oldc = None;
296    loop {
297        let c = oldc.take().or_else(|| s.next())?;
298        match (c, kind) {
299            ('\\', StrLitKind::Normal) => {
300                let _escaped = s.next()?;
301            }
302            ('"', StrLitKind::Normal) => break,
303            ('"', StrLitKind::Raw(0)) => break,
304            ('"', StrLitKind::Raw(n)) => {
305                let (seen, c) = try_find_n_hashes(&mut s, n)?;
306                if seen == n {
307                    break;
308                }
309                oldc = c;
310            }
311            _ => {}
312        }
313    }
314
315    Some(str_lit_to_eof.len() - s.as_str().len())
316}
317
318#[derive(Copy, Clone)]
319enum StrLitKind {
320    Normal,
321    Raw(usize),
322}
323
324impl StrLitKind {
325    fn write_start(self, w: &mut impl std::fmt::Write) -> std::fmt::Result {
326        match self {
327            Self::Normal => write!(w, "\""),
328            Self::Raw(n) => {
329                write!(w, "r")?;
330                for _ in 0..n {
331                    write!(w, "#")?;
332                }
333                write!(w, "\"")
334            }
335        }
336    }
337
338    fn write_end(self, w: &mut impl std::fmt::Write) -> std::fmt::Result {
339        match self {
340            Self::Normal => write!(w, "\""),
341            Self::Raw(n) => {
342                write!(w, "\"")?;
343                for _ in 0..n {
344                    write!(w, "#")?;
345                }
346                Ok(())
347            }
348        }
349    }
350}
351
352#[derive(Clone)]
353struct PathRuntime {
354    path_prefix: String,
355    count: usize,
356}
357
358impl PathRuntime {
359    fn new(path_prefix: &str) -> Self {
360        Self {
361            path_prefix: path_prefix.to_owned(),
362            count: 0,
363        }
364    }
365
366    fn is(&self, path_prefix: &str) -> bool {
367        self.path_prefix == path_prefix
368    }
369
370    fn next(&mut self) -> usize {
371        self.count += 1;
372        self.count
373    }
374
375    fn count(&self) -> usize {
376        self.count
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383    use crate::assert_data_eq;
384    use crate::prelude::*;
385    use crate::str;
386
387    #[test]
388    fn test_format_patch() {
389        let patch = format_patch("hello\nworld\n");
390
391        assert_data_eq!(
392            patch,
393            str![[r##"
394[r#"
395hello
396world
397
398"#]
399"##]],
400        );
401
402        let patch = format_patch(r"hello\tworld");
403        assert_data_eq!(patch, str![[r##"[r#"hello\tworld"#]"##]].raw());
404
405        let patch = format_patch("{\"foo\": 42}");
406        assert_data_eq!(patch, str![[r##"[r#"{"foo": 42}"#]"##]]);
407    }
408
409    #[test]
410    fn test_patchwork() {
411        let mut patchwork = Patchwork::new("one two three".to_owned());
412        patchwork.patch(4..7, "zwei").unwrap();
413        patchwork.patch(0..3, "один").unwrap();
414        patchwork.patch(8..13, "3").unwrap();
415        assert_data_eq!(
416            patchwork.to_debug(),
417            str![[r#"
418Patchwork {
419    text: "один zwei 3",
420    indels: {
421        OrdRange {
422            start: 0,
423            end: 3,
424        }: (
425            8,
426            "один",
427        ),
428        OrdRange {
429            start: 4,
430            end: 7,
431        }: (
432            4,
433            "zwei",
434        ),
435        OrdRange {
436            start: 8,
437            end: 13,
438        }: (
439            1,
440            "3",
441        ),
442    },
443}
444
445"#]],
446        );
447    }
448
449    #[test]
450    fn test_patchwork_overlap_diverge() {
451        let mut patchwork = Patchwork::new("one two three".to_owned());
452        patchwork.patch(4..7, "zwei").unwrap();
453        patchwork.patch(4..7, "abcd").unwrap_err();
454        assert_data_eq!(
455            patchwork.to_debug(),
456            str![[r#"
457Patchwork {
458    text: "one zwei three",
459    indels: {
460        OrdRange {
461            start: 4,
462            end: 7,
463        }: (
464            4,
465            "zwei",
466        ),
467    },
468}
469
470"#]],
471        );
472    }
473
474    #[test]
475    fn test_patchwork_overlap_converge() {
476        let mut patchwork = Patchwork::new("one two three".to_owned());
477        patchwork.patch(4..7, "zwei").unwrap();
478        patchwork.patch(4..7, "zwei").unwrap();
479        assert_data_eq!(
480            patchwork.to_debug(),
481            str![[r#"
482Patchwork {
483    text: "one zwei three",
484    indels: {
485        OrdRange {
486            start: 4,
487            end: 7,
488        }: (
489            4,
490            "zwei",
491        ),
492    },
493}
494
495"#]],
496        );
497    }
498
499    #[test]
500    fn test_locate() {
501        macro_rules! check_locate {
502            ($( [[$s:literal]] ),* $(,)?) => {$({
503                let lit = stringify!($s);
504                let with_trailer = format!("{} \t]]\n", lit);
505                assert_eq!(locate_end(&with_trailer), Some(lit.len()));
506            })*};
507        }
508
509        // Check that we handle string literals containing "]]" correctly.
510        check_locate!(
511            [[r#"{ arr: [[1, 2], [3, 4]], other: "foo" } "#]],
512            [["]]"]],
513            [["\"]]"]],
514            [[r#""]]"#]],
515        );
516
517        // Check `str![[  ]]` as well.
518        assert_eq!(locate_end("]]"), Some(0));
519    }
520
521    #[test]
522    fn test_find_str_lit_len() {
523        macro_rules! check_str_lit_len {
524            ($( $s:literal ),* $(,)?) => {$({
525                let lit = stringify!($s);
526                assert_eq!(find_str_lit_len(lit), Some(lit.len()));
527            })*}
528        }
529
530        check_str_lit_len![
531            r##"foa\""#"##,
532            r##"
533
534                asdf][]]""""#
535            "##,
536            "",
537            "\"",
538            "\"\"",
539            "#\"#\"#",
540        ];
541    }
542}