Skip to main content

semdiff_core/
lib.rs

1use rayon::Scope;
2use std::cmp::Ordering;
3use std::error::Error;
4use std::mem;
5use std::ops::{Deref, DerefMut};
6use std::sync::Mutex;
7use thiserror::Error;
8
9pub mod fs;
10
11#[cfg(test)]
12mod tests;
13
14#[derive(Debug)]
15pub enum TraversalNode<Node, Leaf> {
16    Node(Node),
17    Leaf(Leaf),
18}
19
20impl<Node, Leaf> PartialEq for TraversalNode<Node, Leaf>
21where
22    Node: NodeTraverse,
23    Leaf: LeafTraverse,
24{
25    fn eq(&self, other: &Self) -> bool {
26        if mem::discriminant(self) != mem::discriminant(other) {
27            return false;
28        }
29        match (self, other) {
30            (TraversalNode::Node(a), TraversalNode::Node(b)) => a.name() == b.name(),
31            (TraversalNode::Leaf(a), TraversalNode::Leaf(b)) => a.name() == b.name(),
32            _ => unreachable!(),
33        }
34    }
35}
36
37impl<Node, Leaf> Eq for TraversalNode<Node, Leaf>
38where
39    Node: NodeTraverse,
40    Leaf: LeafTraverse,
41{
42}
43
44impl<Node, Leaf> PartialOrd for TraversalNode<Node, Leaf>
45where
46    Node: NodeTraverse,
47    Leaf: LeafTraverse,
48{
49    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50        Some(self.cmp(other))
51    }
52}
53
54impl<Node, Leaf> Ord for TraversalNode<Node, Leaf>
55where
56    Node: NodeTraverse,
57    Leaf: LeafTraverse,
58{
59    fn cmp(&self, other: &Self) -> Ordering {
60        match (self, other) {
61            (TraversalNode::Node(_), TraversalNode::Leaf(_)) => Ordering::Less,
62            (TraversalNode::Leaf(_), TraversalNode::Node(_)) => Ordering::Greater,
63            (TraversalNode::Node(a), TraversalNode::Node(b)) => a.name().cmp(b.name()),
64            (TraversalNode::Leaf(a), TraversalNode::Leaf(b)) => a.name().cmp(b.name()),
65        }
66    }
67}
68
69pub trait LeafTraverse {
70    fn name(&self) -> &str;
71}
72
73pub trait NodeTraverse: Sized {
74    type Leaf: LeafTraverse + Clone;
75    type TraverseError: Error + Send + 'static;
76    fn name(&self) -> &str;
77    #[allow(clippy::type_complexity)]
78    fn children(
79        &mut self,
80    ) -> Result<impl Iterator<Item = Result<TraversalNode<Self, Self::Leaf>, Self::TraverseError>>, Self::TraverseError>;
81}
82
83pub trait Diff {
84    fn equal(&self) -> bool;
85}
86
87#[derive(Debug)]
88pub enum MayUnsupported<T> {
89    Ok(T),
90    Unsupported,
91}
92
93pub trait DiffCalculator<T> {
94    type Error: Error + Send + 'static;
95    type Diff: Diff + Send;
96    fn diff(&self, name: &str, expected: T, actual: T) -> Result<MayUnsupported<Self::Diff>, Self::Error>;
97}
98
99pub trait DetailReporter<Diff, T, Reporter> {
100    type Error: Error + Send + 'static;
101    fn report_unchanged(&self, name: &str, diff: &Diff, reporter: &Reporter)
102    -> Result<MayUnsupported<()>, Self::Error>;
103    fn report_modified(&self, name: &str, diff: &Diff, reporter: &Reporter) -> Result<MayUnsupported<()>, Self::Error>;
104    fn report_added(&self, name: &str, data: &T, reporter: &Reporter) -> Result<MayUnsupported<()>, Self::Error>;
105    fn report_deleted(&self, name: &str, data: &T, reporter: &Reporter) -> Result<MayUnsupported<()>, Self::Error>;
106}
107
108#[derive(Debug, Error)]
109pub enum EitherError<T1, T2> {
110    #[error("{0}")]
111    Left(#[source] T1),
112    #[error("{0}")]
113    Right(#[source] T2),
114}
115
116impl<R, Diff, T, R1, R2> DetailReporter<Diff, T, (R1, R2)> for R
117where
118    R: DetailReporter<Diff, T, R1>,
119    R: DetailReporter<Diff, T, R2>,
120{
121    type Error = EitherError<<R as DetailReporter<Diff, T, R1>>::Error, <R as DetailReporter<Diff, T, R2>>::Error>;
122
123    fn report_unchanged(
124        &self,
125        name: &str,
126        diff: &Diff,
127        (reporter1, reporter2): &(R1, R2),
128    ) -> Result<MayUnsupported<()>, Self::Error> {
129        match <R as DetailReporter<Diff, T, R1>>::report_unchanged(self, name, diff, reporter1) {
130            Ok(MayUnsupported::Unsupported) => return Ok(MayUnsupported::Unsupported),
131            Ok(MayUnsupported::Ok(())) => {}
132            Err(e) => return Err(EitherError::Left(e)),
133        }
134        <R as DetailReporter<Diff, T, R2>>::report_unchanged(self, name, diff, reporter2).map_err(EitherError::Right)
135    }
136
137    fn report_modified(
138        &self,
139        name: &str,
140        diff: &Diff,
141        (reporter1, reporter2): &(R1, R2),
142    ) -> Result<MayUnsupported<()>, Self::Error> {
143        match <R as DetailReporter<Diff, T, R1>>::report_modified(self, name, diff, reporter1) {
144            Ok(MayUnsupported::Unsupported) => return Ok(MayUnsupported::Unsupported),
145            Ok(MayUnsupported::Ok(())) => {}
146            Err(e) => return Err(EitherError::Left(e)),
147        }
148        <R as DetailReporter<Diff, T, R2>>::report_modified(self, name, diff, reporter2).map_err(EitherError::Right)
149    }
150
151    fn report_added(
152        &self,
153        name: &str,
154        data: &T,
155        (reporter1, reporter2): &(R1, R2),
156    ) -> Result<MayUnsupported<()>, Self::Error> {
157        match <R as DetailReporter<Diff, T, R1>>::report_added(self, name, data, reporter1) {
158            Ok(MayUnsupported::Unsupported) => return Ok(MayUnsupported::Unsupported),
159            Ok(MayUnsupported::Ok(())) => {}
160            Err(e) => return Err(EitherError::Left(e)),
161        }
162        <R as DetailReporter<Diff, T, R2>>::report_added(self, name, data, reporter2).map_err(EitherError::Right)
163    }
164
165    fn report_deleted(
166        &self,
167        name: &str,
168        data: &T,
169        (reporter1, reporter2): &(R1, R2),
170    ) -> Result<MayUnsupported<()>, Self::Error> {
171        match <R as DetailReporter<Diff, T, R1>>::report_deleted(self, name, data, reporter1) {
172            Ok(MayUnsupported::Unsupported) => return Ok(MayUnsupported::Unsupported),
173            Ok(MayUnsupported::Ok(())) => {}
174            Err(e) => return Err(EitherError::Left(e)),
175        }
176        <R as DetailReporter<Diff, T, R2>>::report_deleted(self, name, data, reporter2).map_err(EitherError::Right)
177    }
178}
179
180#[doc(hidden)]
181mod __sealed {
182    pub trait Sealed {}
183}
184
185pub trait DiffReport<T, Reporter>: __sealed::Sealed + Sync {
186    fn diff(
187        &self,
188        name: &str,
189        expected: T,
190        actual: T,
191        reporter: &Reporter,
192    ) -> Result<MayUnsupported<()>, Box<dyn Error + Send>>;
193    fn added(&self, name: &str, data: T, reporter: &Reporter) -> Result<MayUnsupported<()>, Box<dyn Error + Send>>;
194    fn deleted(&self, name: &str, data: T, reporter: &Reporter) -> Result<MayUnsupported<()>, Box<dyn Error + Send>>;
195}
196
197#[derive(Debug)]
198pub struct DiffAndReport<DiffCalculator, DetailReporter> {
199    diff: DiffCalculator,
200    report: DetailReporter,
201}
202
203impl<DiffCalculator, DetailReporter> DiffAndReport<DiffCalculator, DetailReporter> {
204    pub fn new(diff: DiffCalculator, report: DetailReporter) -> Self {
205        Self { diff, report }
206    }
207}
208
209impl<DiffCalculator, DetailReporter> __sealed::Sealed for DiffAndReport<DiffCalculator, DetailReporter> {}
210
211impl<D, R, T, Reporter> DiffReport<T, Reporter> for DiffAndReport<D, R>
212where
213    D: DiffCalculator<T> + Sync,
214    R: DetailReporter<D::Diff, T, Reporter> + Sync,
215    T: Send,
216    Reporter: Sync,
217{
218    fn diff(
219        &self,
220        name: &str,
221        expected: T,
222        actual: T,
223        reporter: &Reporter,
224    ) -> Result<MayUnsupported<()>, Box<dyn Error + Send>> {
225        let diff = self
226            .diff
227            .diff(name, expected, actual)
228            .map_err(|e| Box::new(e) as Box<dyn Error + Send>)?;
229        let MayUnsupported::Ok(diff) = diff else {
230            return Ok(MayUnsupported::Unsupported);
231        };
232        if diff.equal() {
233            self.report
234                .report_unchanged(name, &diff, reporter)
235                .map_err(|e| Box::new(e) as Box<dyn Error + Send>)
236        } else {
237            self.report
238                .report_modified(name, &diff, reporter)
239                .map_err(|e| Box::new(e) as Box<dyn Error + Send>)
240        }
241    }
242
243    fn added(&self, name: &str, data: T, reporter: &Reporter) -> Result<MayUnsupported<()>, Box<dyn Error + Send>> {
244        self.report
245            .report_added(name, &data, reporter)
246            .map_err(|e| Box::new(e) as Box<dyn Error + Send>)
247    }
248
249    fn deleted(&self, name: &str, data: T, reporter: &Reporter) -> Result<MayUnsupported<()>, Box<dyn Error + Send>> {
250        self.report
251            .report_deleted(name, &data, reporter)
252            .map_err(|e| Box::new(e) as Box<dyn Error + Send>)
253    }
254}
255
256pub trait Reporter {
257    type Error: Error + Send + 'static;
258    fn start(&mut self) -> Result<(), Self::Error>;
259    fn finish(self) -> Result<(), Self::Error>;
260}
261
262impl<R1, R2> Reporter for (R1, R2)
263where
264    R1: Reporter,
265    R2: Reporter,
266{
267    type Error = EitherError<R1::Error, R2::Error>;
268
269    fn start(&mut self) -> Result<(), Self::Error> {
270        self.0.start().map_err(EitherError::Left)?;
271        self.1.start().map_err(EitherError::Right)?;
272        Ok(())
273    }
274
275    fn finish(self) -> Result<(), Self::Error> {
276        let result1 = self.0.finish();
277        let result2 = self.1.finish();
278        result1.map_err(EitherError::Left)?;
279        result2.map_err(EitherError::Right)?;
280        Ok(())
281    }
282}
283
284#[derive(Debug, Error)]
285pub enum CalcDiffError<TraverseError, ReporterError> {
286    #[error("{0}")]
287    TraverseError(#[source] TraverseError),
288    #[error("{0}")]
289    ReporterError(#[source] ReporterError),
290    #[error("{0}")]
291    DiffError(#[source] Box<dyn Error + Send>),
292    #[error("No diff report matched")]
293    NoDiffReportMatched,
294}
295
296pub fn calc_diff<N, R>(
297    expected: N,
298    actual: N,
299    diff: &[Box<dyn DiffReport<N::Leaf, R>>],
300    mut reporter: R,
301) -> Result<(), CalcDiffError<N::TraverseError, R::Error>>
302where
303    N: NodeTraverse + Send,
304    N::Leaf: Send,
305    R: Reporter + Sync,
306{
307    reporter.start().map_err(CalcDiffError::ReporterError)?;
308    let errors = Mutex::new(None);
309    rayon::scope(|scope| {
310        if let Err(error) = calc_diff_inner::<N, R, R::Error>(
311            &mut String::new(),
312            Some(expected),
313            Some(actual),
314            diff,
315            &reporter,
316            scope,
317            &errors,
318        ) {
319            record_error(&errors, error);
320        }
321    });
322    if let Some(error) = errors.lock().unwrap().take() {
323        return Err(error);
324    }
325    reporter.finish().map_err(CalcDiffError::ReporterError)?;
326    Ok(())
327}
328
329fn calc_diff_inner<'scope, N, R, RE>(
330    name: &mut String,
331    expected: Option<N>,
332    actual: Option<N>,
333    diff: &'scope [Box<dyn DiffReport<N::Leaf, R>>],
334    reporter: &'scope R,
335    scope: &Scope<'scope>,
336    errors: &'scope Mutex<Option<CalcDiffError<N::TraverseError, RE>>>,
337) -> Result<(), CalcDiffError<N::TraverseError, RE>>
338where
339    N: NodeTraverse,
340    N::Leaf: Send,
341    R: Reporter + Sync,
342    RE: Send + 'scope,
343{
344    match (expected, actual) {
345        (Some(mut expected), Some(mut actual)) => {
346            let mut expected = expected
347                .children()
348                .map_err(CalcDiffError::TraverseError)?
349                .collect::<Result<Vec<_>, _>>()
350                .map_err(CalcDiffError::TraverseError)?;
351            let mut actual = actual
352                .children()
353                .map_err(CalcDiffError::TraverseError)?
354                .collect::<Result<Vec<_>, _>>()
355                .map_err(CalcDiffError::TraverseError)?;
356            expected.sort_unstable();
357            actual.sort_unstable();
358            let mut expected_iter = expected.into_iter().peekable();
359            let mut actual_iter = actual.into_iter().peekable();
360
361            loop {
362                let pair = match (expected_iter.peek(), actual_iter.peek()) {
363                    (Some(expected), Some(actual)) => match expected.cmp(actual) {
364                        Ordering::Less => (expected_iter.next(), None),
365                        Ordering::Equal => (expected_iter.next(), actual_iter.next()),
366                        Ordering::Greater => (None, actual_iter.next()),
367                    },
368                    (Some(_), None) => (expected_iter.next(), None),
369                    (None, Some(_)) => (None, actual_iter.next()),
370                    (None, None) => (None, None),
371                };
372                match pair {
373                    (None, None) => break,
374                    (Some(expected), Some(actual)) => match (expected, actual) {
375                        (TraversalNode::Node(expected), TraversalNode::Node(actual)) => {
376                            let mut name = AppendedName::new(name, expected.name());
377                            calc_diff_inner(&mut name, Some(expected), Some(actual), diff, reporter, scope, errors)?;
378                        }
379                        (TraversalNode::Leaf(expected), TraversalNode::Leaf(actual)) => {
380                            let name = AppendedName::new(name, expected.name());
381                            let name = name.clone();
382                            spawn_task(scope, errors, move || {
383                                run_diff::<N, R, RE>(diff, reporter, &name, &expected, &actual)
384                            });
385                        }
386                        _ => unreachable!(),
387                    },
388                    (Some(expected), None) => match expected {
389                        TraversalNode::Node(node) => {
390                            let mut name = AppendedName::new(name, node.name());
391                            calc_diff_inner(&mut name, Some(node), None, diff, reporter, scope, errors)?;
392                        }
393                        TraversalNode::Leaf(leaf) => {
394                            let name = AppendedName::new(name, leaf.name());
395                            let name = name.clone();
396                            spawn_task(scope, errors, move || {
397                                run_deleted::<N, R, RE>(diff, reporter, &name, &leaf)
398                            });
399                        }
400                    },
401                    (None, Some(actual)) => match actual {
402                        TraversalNode::Node(node) => {
403                            let mut name = AppendedName::new(name, node.name());
404                            calc_diff_inner(&mut name, None, Some(node), diff, reporter, scope, errors)?;
405                        }
406                        TraversalNode::Leaf(leaf) => {
407                            let name = AppendedName::new(name, leaf.name());
408                            let name = name.clone();
409                            spawn_task(scope, errors, move || {
410                                run_added::<N, R, RE>(diff, reporter, &name, &leaf)
411                            });
412                        }
413                    },
414                }
415            }
416        }
417        (Some(mut expected), None) => {
418            for result in expected.children().map_err(CalcDiffError::TraverseError)? {
419                let node = result.map_err(CalcDiffError::TraverseError)?;
420                match node {
421                    TraversalNode::Node(node) => {
422                        let mut name = AppendedName::new(name, node.name());
423                        calc_diff_inner(&mut name, Some(node), None, diff, reporter, scope, errors)?;
424                    }
425                    TraversalNode::Leaf(leaf) => {
426                        let name = AppendedName::new(name, leaf.name());
427                        let name = name.clone();
428                        spawn_task(scope, errors, move || {
429                            run_deleted::<N, R, RE>(diff, reporter, &name, &leaf)
430                        });
431                    }
432                }
433            }
434        }
435        (None, Some(mut actual)) => {
436            for result in actual.children().map_err(CalcDiffError::TraverseError)? {
437                let node = result.map_err(CalcDiffError::TraverseError)?;
438                match node {
439                    TraversalNode::Node(node) => {
440                        let mut name = AppendedName::new(name, node.name());
441                        calc_diff_inner(&mut name, Some(node), None, diff, reporter, scope, errors)?;
442                    }
443                    TraversalNode::Leaf(leaf) => {
444                        let name = AppendedName::new(name, leaf.name());
445                        let name = name.clone();
446                        spawn_task(scope, errors, move || {
447                            run_added::<N, R, RE>(diff, reporter, &name, &leaf)
448                        });
449                    }
450                }
451            }
452        }
453        (None, None) => {}
454    }
455    Ok(())
456}
457
458fn record_error<TE, RE>(errors: &Mutex<Option<CalcDiffError<TE, RE>>>, error: CalcDiffError<TE, RE>) {
459    let mut guard = errors.lock().unwrap();
460    if guard.is_none() {
461        *guard = Some(error);
462    }
463}
464
465fn spawn_task<'scope, TE, RE>(
466    scope: &Scope<'scope>,
467    errors: &'scope Mutex<Option<CalcDiffError<TE, RE>>>,
468    task: impl FnOnce() -> Result<(), CalcDiffError<TE, RE>> + Send + 'scope,
469) where
470    TE: Send + 'scope,
471    RE: Send + 'scope,
472{
473    scope.spawn(move |_| {
474        if let Err(error) = task() {
475            record_error(errors, error);
476        }
477    });
478}
479
480struct AppendedName<'a> {
481    original_len: usize,
482    name: &'a mut String,
483}
484
485impl AppendedName<'_> {
486    fn new<'a>(name: &'a mut String, segment: &str) -> AppendedName<'a> {
487        let original_len = name.len();
488        if !name.is_empty() {
489            name.push('/');
490        }
491        name.push_str(segment);
492        AppendedName { original_len, name }
493    }
494}
495
496impl Deref for AppendedName<'_> {
497    type Target = String;
498
499    fn deref(&self) -> &Self::Target {
500        self.name
501    }
502}
503
504impl DerefMut for AppendedName<'_> {
505    fn deref_mut(&mut self) -> &mut Self::Target {
506        self.name
507    }
508}
509
510impl Drop for AppendedName<'_> {
511    fn drop(&mut self) {
512        self.name.truncate(self.original_len);
513    }
514}
515
516fn run_diff<N, R, RE>(
517    diff: &[Box<dyn DiffReport<N::Leaf, R>>],
518    reporter: &R,
519    name: &str,
520    expected: &N::Leaf,
521    actual: &N::Leaf,
522) -> Result<(), CalcDiffError<N::TraverseError, RE>>
523where
524    N: NodeTraverse,
525    N::Leaf: Clone,
526    R: Reporter + Sync,
527{
528    for diff in diff {
529        if let MayUnsupported::Ok(()) = diff
530            .diff(name, expected.clone(), actual.clone(), reporter)
531            .map_err(CalcDiffError::DiffError)?
532        {
533            return Ok(());
534        }
535    }
536    Err(CalcDiffError::<N::TraverseError, RE>::NoDiffReportMatched)
537}
538
539fn run_added<N, R, RE>(
540    diff: &[Box<dyn DiffReport<N::Leaf, R>>],
541    reporter: &R,
542    name: &str,
543    actual: &N::Leaf,
544) -> Result<(), CalcDiffError<N::TraverseError, RE>>
545where
546    N: NodeTraverse,
547    N::Leaf: Clone,
548    R: Reporter + Sync,
549{
550    for diff in diff {
551        if let MayUnsupported::Ok(()) = diff
552            .added(name, actual.clone(), reporter)
553            .map_err(CalcDiffError::DiffError)?
554        {
555            return Ok(());
556        }
557    }
558    Err(CalcDiffError::<N::TraverseError, RE>::NoDiffReportMatched)
559}
560
561fn run_deleted<N, R, RE>(
562    diff: &[Box<dyn DiffReport<N::Leaf, R>>],
563    reporter: &R,
564    name: &str,
565    expected: &N::Leaf,
566) -> Result<(), CalcDiffError<N::TraverseError, RE>>
567where
568    N: NodeTraverse,
569    N::Leaf: Clone,
570    R: Reporter + Sync,
571{
572    for diff in diff {
573        if let MayUnsupported::Ok(()) = diff
574            .deleted(name, expected.clone(), reporter)
575            .map_err(CalcDiffError::DiffError)?
576        {
577            return Ok(());
578        }
579    }
580    Err(CalcDiffError::<N::TraverseError, RE>::NoDiffReportMatched)
581}