1use std::collections::BTreeMap;
2
3use super::Data;
4use super::Inline;
5use super::Position;
6
7pub(crate) fn get() -> std::sync::MutexGuard<'static, Runtime> {
8 static RT: std::sync::Mutex<Runtime> = std::sync::Mutex::new(Runtime::new());
9 RT.lock().unwrap_or_else(|poisoned| poisoned.into_inner())
10}
11
12#[derive(Default)]
13pub(crate) struct Runtime {
14 per_file: Vec<SourceFileRuntime>,
15 path_count: Vec<PathRuntime>,
16}
17
18impl Runtime {
19 const fn new() -> Self {
20 Self {
21 per_file: Vec::new(),
22 path_count: Vec::new(),
23 }
24 }
25
26 pub(crate) fn count(&mut self, path_prefix: &str) -> usize {
27 if let Some(entry) = self
28 .path_count
29 .iter_mut()
30 .find(|entry| entry.is(path_prefix))
31 {
32 entry.next()
33 } else {
34 let entry = PathRuntime::new(path_prefix);
35 let next = entry.count();
36 self.path_count.push(entry);
37 next
38 }
39 }
40
41 pub(crate) fn write(&mut self, actual: &Data, inline: &Inline) -> std::io::Result<()> {
42 let actual = actual.render().expect("`actual` must be UTF-8");
43 if let Some(entry) = self
44 .per_file
45 .iter_mut()
46 .find(|f| f.path == inline.position.file)
47 {
48 entry.update(&actual, inline)?;
49 } else {
50 let mut entry = SourceFileRuntime::new(inline)?;
51 entry.update(&actual, inline)?;
52 self.per_file.push(entry);
53 }
54
55 Ok(())
56 }
57}
58
59struct SourceFileRuntime {
60 path: std::path::PathBuf,
61 original_text: String,
62 patchwork: Patchwork,
63}
64
65impl SourceFileRuntime {
66 fn new(inline: &Inline) -> std::io::Result<SourceFileRuntime> {
67 let path = inline.position.file.clone();
68 let original_text = std::fs::read_to_string(&path)?;
69 let patchwork = Patchwork::new(original_text.clone());
70 Ok(SourceFileRuntime {
71 path,
72 original_text,
73 patchwork,
74 })
75 }
76 fn update(&mut self, actual: &str, inline: &Inline) -> std::io::Result<()> {
77 let span = Span::from_pos(&inline.position, &self.original_text);
78 let patch = format_patch(actual);
79 self.patchwork.patch(span.literal_range, &patch)?;
80 std::fs::write(&inline.position.file, &self.patchwork.text)
81 }
82}
83
84#[derive(Debug)]
85struct Patchwork {
86 text: String,
87 indels: BTreeMap<OrdRange, (usize, String)>,
88}
89
90impl Patchwork {
91 fn new(text: String) -> Patchwork {
92 Patchwork {
93 text,
94 indels: BTreeMap::new(),
95 }
96 }
97 fn patch(&mut self, mut range: std::ops::Range<usize>, patch: &str) -> std::io::Result<()> {
98 let key: OrdRange = range.clone().into();
99 match self.indels.entry(key) {
100 std::collections::btree_map::Entry::Vacant(entry) => {
101 entry.insert((patch.len(), patch.to_owned()));
102 }
103 std::collections::btree_map::Entry::Occupied(entry) => {
104 if entry.get().1 == patch {
105 return Ok(());
106 } else {
107 return Err(std::io::Error::new(
108 std::io::ErrorKind::Other,
109 "cannot update as it was already modified",
110 ));
111 }
112 }
113 }
114
115 let (delete, insert) = self
116 .indels
117 .iter()
118 .take_while(|(delete, _)| delete.start < range.start)
119 .map(|(delete, (insert, _))| (delete.end - delete.start, insert))
120 .fold((0usize, 0usize), |(x1, y1), (x2, y2)| (x1 + x2, y1 + y2));
121
122 for pos in &mut [&mut range.start, &mut range.end] {
123 **pos -= delete;
124 **pos += insert;
125 }
126
127 self.text.replace_range(range, patch);
128 Ok(())
129 }
130}
131
132#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
133struct OrdRange {
134 start: usize,
135 end: usize,
136}
137
138impl From<std::ops::Range<usize>> for OrdRange {
139 fn from(other: std::ops::Range<usize>) -> Self {
140 Self {
141 start: other.start,
142 end: other.end,
143 }
144 }
145}
146
147fn lit_kind_for_patch(patch: &str) -> StrLitKind {
148 let has_dquote = patch.chars().any(|c| c == '"');
149 if !has_dquote {
150 let has_bslash_or_newline = patch.chars().any(|c| matches!(c, '\\' | '\n'));
151 return if has_bslash_or_newline {
152 StrLitKind::Raw(1)
153 } else {
154 StrLitKind::Normal
155 };
156 }
157
158 let leading_hashes = |s: &str| s.chars().take_while(|&c| c == '#').count();
161 let max_hashes = patch.split('"').map(leading_hashes).max().unwrap();
162 StrLitKind::Raw(max_hashes + 1)
163}
164
165fn format_patch(patch: &str) -> String {
166 let lit_kind = lit_kind_for_patch(patch);
167 let is_multiline = patch.contains('\n');
168
169 let mut buf = String::new();
170 if matches!(lit_kind, StrLitKind::Raw(_)) {
171 buf.push('[');
172 }
173 lit_kind.write_start(&mut buf).unwrap();
174 if is_multiline {
175 buf.push('\n');
176 }
177 buf.push_str(patch);
178 if is_multiline {
179 buf.push('\n');
180 }
181 lit_kind.write_end(&mut buf).unwrap();
182 if matches!(lit_kind, StrLitKind::Raw(_)) {
183 buf.push(']');
184 }
185 buf
186}
187
188#[derive(Clone, Debug)]
189struct Span {
190 literal_range: std::ops::Range<usize>,
192}
193
194impl Span {
195 fn from_pos(pos: &Position, file: &str) -> Span {
196 let mut target_line = None;
197 let mut line_start = 0;
198 for (i, line) in crate::utils::LinesWithTerminator::new(file).enumerate() {
199 if i == pos.line as usize - 1 {
200 #[allow(clippy::skip_while_next)]
209 let byte_offset = line
210 .char_indices()
211 .skip((pos.column - 1).try_into().unwrap())
212 .skip_while(|&(_, c)| c != '!')
213 .skip(1) .skip_while(|&(_, c)| c.is_whitespace())
215 .skip(1) .skip_while(|&(_, c)| c.is_whitespace())
217 .next()
218 .expect("Failed to parse macro invocation")
219 .0;
220
221 let literal_start = line_start + byte_offset;
222 target_line = Some(literal_start);
223 break;
224 }
225 line_start += line.len();
226 }
227 let literal_start = target_line.unwrap();
228
229 let lit_to_eof = &file[literal_start..];
230 let lit_to_eof_trimmed = lit_to_eof.trim_start();
231
232 let literal_start = literal_start + (lit_to_eof.len() - lit_to_eof_trimmed.len());
233
234 let literal_len =
235 locate_end(lit_to_eof_trimmed).expect("Couldn't find closing delimiter for `expect!`.");
236 let literal_range = literal_start..literal_start + literal_len;
237 Span { literal_range }
238 }
239}
240
241fn locate_end(arg_start_to_eof: &str) -> Option<usize> {
242 match arg_start_to_eof.chars().next()? {
243 c if c.is_whitespace() => panic!("skip whitespace before calling `locate_end`"),
244
245 '[' => {
247 let str_start_to_eof = arg_start_to_eof[1..].trim_start();
248 let str_len = find_str_lit_len(str_start_to_eof)?;
249 let str_end_to_eof = &str_start_to_eof[str_len..];
250 let closing_brace_offset = str_end_to_eof.find(']')?;
251 Some((arg_start_to_eof.len() - str_end_to_eof.len()) + closing_brace_offset + 1)
252 }
253
254 ']' | '}' | ')' => Some(0),
256
257 _ => find_str_lit_len(arg_start_to_eof),
259 }
260}
261
262fn find_str_lit_len(str_lit_to_eof: &str) -> Option<usize> {
265 fn try_find_n_hashes(
266 s: &mut impl Iterator<Item = char>,
267 desired_hashes: usize,
268 ) -> Option<(usize, Option<char>)> {
269 let mut n = 0;
270 loop {
271 match s.next()? {
272 '#' => n += 1,
273 c => return Some((n, Some(c))),
274 }
275
276 if n == desired_hashes {
277 return Some((n, None));
278 }
279 }
280 }
281
282 let mut s = str_lit_to_eof.chars();
283 let kind = match s.next()? {
284 '"' => StrLitKind::Normal,
285 'r' => {
286 let (n, c) = try_find_n_hashes(&mut s, usize::MAX)?;
287 if c != Some('"') {
288 return None;
289 }
290 StrLitKind::Raw(n)
291 }
292 _ => return None,
293 };
294
295 let mut oldc = None;
296 loop {
297 let c = oldc.take().or_else(|| s.next())?;
298 match (c, kind) {
299 ('\\', StrLitKind::Normal) => {
300 let _escaped = s.next()?;
301 }
302 ('"', StrLitKind::Normal) => break,
303 ('"', StrLitKind::Raw(0)) => break,
304 ('"', StrLitKind::Raw(n)) => {
305 let (seen, c) = try_find_n_hashes(&mut s, n)?;
306 if seen == n {
307 break;
308 }
309 oldc = c;
310 }
311 _ => {}
312 }
313 }
314
315 Some(str_lit_to_eof.len() - s.as_str().len())
316}
317
318#[derive(Copy, Clone)]
319enum StrLitKind {
320 Normal,
321 Raw(usize),
322}
323
324impl StrLitKind {
325 fn write_start(self, w: &mut impl std::fmt::Write) -> std::fmt::Result {
326 match self {
327 Self::Normal => write!(w, "\""),
328 Self::Raw(n) => {
329 write!(w, "r")?;
330 for _ in 0..n {
331 write!(w, "#")?;
332 }
333 write!(w, "\"")
334 }
335 }
336 }
337
338 fn write_end(self, w: &mut impl std::fmt::Write) -> std::fmt::Result {
339 match self {
340 Self::Normal => write!(w, "\""),
341 Self::Raw(n) => {
342 write!(w, "\"")?;
343 for _ in 0..n {
344 write!(w, "#")?;
345 }
346 Ok(())
347 }
348 }
349 }
350}
351
352#[derive(Clone)]
353struct PathRuntime {
354 path_prefix: String,
355 count: usize,
356}
357
358impl PathRuntime {
359 fn new(path_prefix: &str) -> Self {
360 Self {
361 path_prefix: path_prefix.to_owned(),
362 count: 0,
363 }
364 }
365
366 fn is(&self, path_prefix: &str) -> bool {
367 self.path_prefix == path_prefix
368 }
369
370 fn next(&mut self) -> usize {
371 self.count += 1;
372 self.count
373 }
374
375 fn count(&self) -> usize {
376 self.count
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383 use crate::assert_data_eq;
384 use crate::prelude::*;
385 use crate::str;
386
387 #[test]
388 fn test_format_patch() {
389 let patch = format_patch("hello\nworld\n");
390
391 assert_data_eq!(
392 patch,
393 str![[r##"
394[r#"
395hello
396world
397
398"#]
399"##]],
400 );
401
402 let patch = format_patch(r"hello\tworld");
403 assert_data_eq!(patch, str![[r##"[r#"hello\tworld"#]"##]].raw());
404
405 let patch = format_patch("{\"foo\": 42}");
406 assert_data_eq!(patch, str![[r##"[r#"{"foo": 42}"#]"##]]);
407 }
408
409 #[test]
410 fn test_patchwork() {
411 let mut patchwork = Patchwork::new("one two three".to_owned());
412 patchwork.patch(4..7, "zwei").unwrap();
413 patchwork.patch(0..3, "один").unwrap();
414 patchwork.patch(8..13, "3").unwrap();
415 assert_data_eq!(
416 patchwork.to_debug(),
417 str![[r#"
418Patchwork {
419 text: "один zwei 3",
420 indels: {
421 OrdRange {
422 start: 0,
423 end: 3,
424 }: (
425 8,
426 "один",
427 ),
428 OrdRange {
429 start: 4,
430 end: 7,
431 }: (
432 4,
433 "zwei",
434 ),
435 OrdRange {
436 start: 8,
437 end: 13,
438 }: (
439 1,
440 "3",
441 ),
442 },
443}
444
445"#]],
446 );
447 }
448
449 #[test]
450 fn test_patchwork_overlap_diverge() {
451 let mut patchwork = Patchwork::new("one two three".to_owned());
452 patchwork.patch(4..7, "zwei").unwrap();
453 patchwork.patch(4..7, "abcd").unwrap_err();
454 assert_data_eq!(
455 patchwork.to_debug(),
456 str![[r#"
457Patchwork {
458 text: "one zwei three",
459 indels: {
460 OrdRange {
461 start: 4,
462 end: 7,
463 }: (
464 4,
465 "zwei",
466 ),
467 },
468}
469
470"#]],
471 );
472 }
473
474 #[test]
475 fn test_patchwork_overlap_converge() {
476 let mut patchwork = Patchwork::new("one two three".to_owned());
477 patchwork.patch(4..7, "zwei").unwrap();
478 patchwork.patch(4..7, "zwei").unwrap();
479 assert_data_eq!(
480 patchwork.to_debug(),
481 str![[r#"
482Patchwork {
483 text: "one zwei three",
484 indels: {
485 OrdRange {
486 start: 4,
487 end: 7,
488 }: (
489 4,
490 "zwei",
491 ),
492 },
493}
494
495"#]],
496 );
497 }
498
499 #[test]
500 fn test_locate() {
501 macro_rules! check_locate {
502 ($( [[$s:literal]] ),* $(,)?) => {$({
503 let lit = stringify!($s);
504 let with_trailer = format!("{} \t]]\n", lit);
505 assert_eq!(locate_end(&with_trailer), Some(lit.len()));
506 })*};
507 }
508
509 check_locate!(
511 [[r#"{ arr: [[1, 2], [3, 4]], other: "foo" } "#]],
512 [["]]"]],
513 [["\"]]"]],
514 [[r#""]]"#]],
515 );
516
517 assert_eq!(locate_end("]]"), Some(0));
519 }
520
521 #[test]
522 fn test_find_str_lit_len() {
523 macro_rules! check_str_lit_len {
524 ($( $s:literal ),* $(,)?) => {$({
525 let lit = stringify!($s);
526 assert_eq!(find_str_lit_len(lit), Some(lit.len()));
527 })*}
528 }
529
530 check_str_lit_len![
531 r##"foa\""#"##,
532 r##"
533
534 asdf][]]""""#
535 "##,
536 "",
537 "\"",
538 "\"\"",
539 "#\"#\"#",
540 ];
541 }
542}