Skip to main content

tensorlogic_adapters/
merge_strategies.rs

1//! Advanced schema merging strategies for combining schemas from different sources.
2//!
3//! This module provides sophisticated strategies for merging symbol tables with
4//! conflict resolution, validation, and detailed merge reports.
5
6use crate::{DomainInfo, PredicateInfo, SymbolTable};
7use anyhow::{bail, Result};
8use std::collections::HashSet;
9
10/// Strategy for resolving conflicts during schema merging.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum MergeStrategy {
13    /// Keep the first (base) version in case of conflict
14    KeepFirst,
15    /// Keep the second (incoming) version in case of conflict
16    KeepSecond,
17    /// Fail on any conflict
18    FailOnConflict,
19    /// Union: Keep both if compatible, fail if incompatible
20    Union,
21    /// Intersection: Only keep items present in both with compatible definitions
22    Intersection,
23}
24
25/// Result of a merge operation.
26#[derive(Debug, Clone)]
27pub struct MergeResult {
28    /// The merged symbol table
29    pub merged: SymbolTable,
30    /// Report of the merge operation
31    pub report: MergeReport,
32}
33
34/// Report of a merge operation.
35#[derive(Debug, Clone)]
36pub struct MergeReport {
37    /// Domains that were added from base
38    pub base_domains: Vec<String>,
39    /// Domains that were added from incoming
40    pub incoming_domains: Vec<String>,
41    /// Domains that had conflicts
42    pub conflicting_domains: Vec<DomainConflict>,
43    /// Predicates that were added from base
44    pub base_predicates: Vec<String>,
45    /// Predicates that were added from incoming
46    pub incoming_predicates: Vec<String>,
47    /// Predicates that had conflicts
48    pub conflicting_predicates: Vec<PredicateConflict>,
49    /// Variables that were merged
50    pub merged_variables: Vec<String>,
51    /// Variables that had conflicts
52    pub conflicting_variables: Vec<VariableConflict>,
53    /// Overall merge strategy used
54    pub strategy: MergeStrategy,
55}
56
57impl MergeReport {
58    /// Create a new empty merge report.
59    pub fn new(strategy: MergeStrategy) -> Self {
60        Self {
61            base_domains: Vec::new(),
62            incoming_domains: Vec::new(),
63            conflicting_domains: Vec::new(),
64            base_predicates: Vec::new(),
65            incoming_predicates: Vec::new(),
66            conflicting_predicates: Vec::new(),
67            merged_variables: Vec::new(),
68            conflicting_variables: Vec::new(),
69            strategy,
70        }
71    }
72
73    /// Check if there were any conflicts during merging.
74    pub fn has_conflicts(&self) -> bool {
75        !self.conflicting_domains.is_empty()
76            || !self.conflicting_predicates.is_empty()
77            || !self.conflicting_variables.is_empty()
78    }
79
80    /// Get total number of conflicts.
81    pub fn conflict_count(&self) -> usize {
82        self.conflicting_domains.len()
83            + self.conflicting_predicates.len()
84            + self.conflicting_variables.len()
85    }
86
87    /// Get total number of merged items.
88    pub fn merged_count(&self) -> usize {
89        self.base_domains.len()
90            + self.incoming_domains.len()
91            + self.base_predicates.len()
92            + self.incoming_predicates.len()
93            + self.merged_variables.len()
94    }
95}
96
97/// Information about a domain conflict.
98#[derive(Debug, Clone)]
99pub struct DomainConflict {
100    /// Name of the conflicting domain
101    pub name: String,
102    /// Domain from base table
103    pub base: DomainInfo,
104    /// Domain from incoming table
105    pub incoming: DomainInfo,
106    /// How the conflict was resolved
107    pub resolution: MergeConflictResolution,
108}
109
110/// Information about a predicate conflict.
111#[derive(Debug, Clone)]
112pub struct PredicateConflict {
113    /// Name of the conflicting predicate
114    pub name: String,
115    /// Predicate from base table
116    pub base: PredicateInfo,
117    /// Predicate from incoming table
118    pub incoming: PredicateInfo,
119    /// How the conflict was resolved
120    pub resolution: MergeConflictResolution,
121}
122
123/// Information about a variable conflict.
124#[derive(Debug, Clone)]
125pub struct VariableConflict {
126    /// Name of the conflicting variable
127    pub name: String,
128    /// Domain from base table
129    pub base_domain: String,
130    /// Domain from incoming table
131    pub incoming_domain: String,
132    /// How the conflict was resolved
133    pub resolution: MergeConflictResolution,
134}
135
136/// How a merge conflict was resolved.
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum MergeConflictResolution {
139    /// Kept the base version
140    KeptBase,
141    /// Kept the incoming version
142    KeptIncoming,
143    /// Failed to resolve
144    Failed,
145    /// Merged both (only for compatible items)
146    Merged,
147}
148
149/// A schema merger with configurable strategies.
150pub struct SchemaMerger {
151    strategy: MergeStrategy,
152}
153
154impl SchemaMerger {
155    /// Create a new schema merger with the given strategy.
156    pub fn new(strategy: MergeStrategy) -> Self {
157        Self { strategy }
158    }
159
160    /// Merge two symbol tables according to the configured strategy.
161    ///
162    /// # Examples
163    ///
164    /// ```
165    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, SchemaMerger, MergeStrategy};
166    ///
167    /// let mut base = SymbolTable::new();
168    /// base.add_domain(DomainInfo::new("Person", 100)).expect("unwrap");
169    ///
170    /// let mut incoming = SymbolTable::new();
171    /// incoming.add_domain(DomainInfo::new("Organization", 50)).expect("unwrap");
172    ///
173    /// let merger = SchemaMerger::new(MergeStrategy::Union);
174    /// let result = merger.merge(&base, &incoming).expect("unwrap");
175    ///
176    /// assert_eq!(result.merged.domains.len(), 2);
177    /// ```
178    pub fn merge(&self, base: &SymbolTable, incoming: &SymbolTable) -> Result<MergeResult> {
179        let mut merged = SymbolTable::new();
180        let mut report = MergeReport::new(self.strategy);
181
182        // Merge domains
183        self.merge_domains(base, incoming, &mut merged, &mut report)?;
184
185        // Merge predicates
186        self.merge_predicates(base, incoming, &mut merged, &mut report)?;
187
188        // Merge variables
189        self.merge_variables(base, incoming, &mut merged, &mut report)?;
190
191        Ok(MergeResult { merged, report })
192    }
193
194    fn merge_domains(
195        &self,
196        base: &SymbolTable,
197        incoming: &SymbolTable,
198        merged: &mut SymbolTable,
199        report: &mut MergeReport,
200    ) -> Result<()> {
201        let base_keys: HashSet<&String> = base.domains.keys().collect();
202        let incoming_keys: HashSet<&String> = incoming.domains.keys().collect();
203
204        // Domains only in base
205        for key in base_keys.difference(&incoming_keys) {
206            let domain = base
207                .domains
208                .get(*key)
209                .expect("key from HashMap iteration is always present");
210            merged.add_domain(domain.clone())?;
211            report.base_domains.push(key.to_string());
212        }
213
214        // Domains only in incoming
215        for key in incoming_keys.difference(&base_keys) {
216            let domain = incoming
217                .domains
218                .get(*key)
219                .expect("key from HashMap iteration is always present");
220            merged.add_domain(domain.clone())?;
221            report.incoming_domains.push(key.to_string());
222        }
223
224        // Domains in both (conflicts)
225        for key in base_keys.intersection(&incoming_keys) {
226            let base_domain = base
227                .domains
228                .get(*key)
229                .expect("key from HashMap iteration is always present");
230            let incoming_domain = incoming
231                .domains
232                .get(*key)
233                .expect("key from HashMap iteration is always present");
234
235            let (domain, resolution) =
236                self.resolve_domain_conflict(base_domain, incoming_domain)?;
237
238            merged.add_domain(domain)?;
239
240            if resolution != MergeConflictResolution::Merged {
241                report.conflicting_domains.push(DomainConflict {
242                    name: key.to_string(),
243                    base: base_domain.clone(),
244                    incoming: incoming_domain.clone(),
245                    resolution,
246                });
247            }
248        }
249
250        Ok(())
251    }
252
253    fn merge_predicates(
254        &self,
255        base: &SymbolTable,
256        incoming: &SymbolTable,
257        merged: &mut SymbolTable,
258        report: &mut MergeReport,
259    ) -> Result<()> {
260        let base_keys: HashSet<&String> = base.predicates.keys().collect();
261        let incoming_keys: HashSet<&String> = incoming.predicates.keys().collect();
262
263        // Predicates only in base
264        for key in base_keys.difference(&incoming_keys) {
265            let predicate = base
266                .predicates
267                .get(*key)
268                .expect("key from HashMap iteration is always present");
269            merged.add_predicate(predicate.clone())?;
270            report.base_predicates.push(key.to_string());
271        }
272
273        // Predicates only in incoming
274        for key in incoming_keys.difference(&base_keys) {
275            let predicate = incoming
276                .predicates
277                .get(*key)
278                .expect("key from HashMap iteration is always present");
279            merged.add_predicate(predicate.clone())?;
280            report.incoming_predicates.push(key.to_string());
281        }
282
283        // Predicates in both (conflicts)
284        for key in base_keys.intersection(&incoming_keys) {
285            let base_pred = base
286                .predicates
287                .get(*key)
288                .expect("key from HashMap iteration is always present");
289            let incoming_pred = incoming
290                .predicates
291                .get(*key)
292                .expect("key from HashMap iteration is always present");
293
294            let (predicate, resolution) =
295                self.resolve_predicate_conflict(base_pred, incoming_pred)?;
296
297            merged.add_predicate(predicate)?;
298
299            if resolution != MergeConflictResolution::Merged {
300                report.conflicting_predicates.push(PredicateConflict {
301                    name: key.to_string(),
302                    base: base_pred.clone(),
303                    incoming: incoming_pred.clone(),
304                    resolution,
305                });
306            }
307        }
308
309        Ok(())
310    }
311
312    fn merge_variables(
313        &self,
314        base: &SymbolTable,
315        incoming: &SymbolTable,
316        merged: &mut SymbolTable,
317        report: &mut MergeReport,
318    ) -> Result<()> {
319        let base_keys: HashSet<&String> = base.variables.keys().collect();
320        let incoming_keys: HashSet<&String> = incoming.variables.keys().collect();
321
322        // Variables only in base
323        for key in base_keys.difference(&incoming_keys) {
324            let domain = base
325                .variables
326                .get(*key)
327                .expect("key from HashMap iteration is always present");
328            merged.bind_variable(key.to_string(), domain.clone())?;
329            report.merged_variables.push(key.to_string());
330        }
331
332        // Variables only in incoming
333        for key in incoming_keys.difference(&base_keys) {
334            let domain = incoming
335                .variables
336                .get(*key)
337                .expect("key from HashMap iteration is always present");
338            merged.bind_variable(key.to_string(), domain.clone())?;
339            report.merged_variables.push(key.to_string());
340        }
341
342        // Variables in both (conflicts)
343        for key in base_keys.intersection(&incoming_keys) {
344            let base_domain = base
345                .variables
346                .get(*key)
347                .expect("key from HashMap iteration is always present");
348            let incoming_domain = incoming
349                .variables
350                .get(*key)
351                .expect("key from HashMap iteration is always present");
352
353            let (domain, resolution) =
354                self.resolve_variable_conflict(base_domain, incoming_domain)?;
355
356            merged.bind_variable(key.to_string(), domain)?;
357
358            if resolution != MergeConflictResolution::Merged {
359                report.conflicting_variables.push(VariableConflict {
360                    name: key.to_string(),
361                    base_domain: base_domain.clone(),
362                    incoming_domain: incoming_domain.clone(),
363                    resolution,
364                });
365            }
366        }
367
368        Ok(())
369    }
370
371    fn resolve_domain_conflict(
372        &self,
373        base: &DomainInfo,
374        incoming: &DomainInfo,
375    ) -> Result<(DomainInfo, MergeConflictResolution)> {
376        match self.strategy {
377            MergeStrategy::KeepFirst => Ok((base.clone(), MergeConflictResolution::KeptBase)),
378            MergeStrategy::KeepSecond => {
379                Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
380            }
381            MergeStrategy::FailOnConflict => {
382                bail!(
383                    "Domain conflict for '{}': cardinality {} vs {}",
384                    base.name,
385                    base.cardinality,
386                    incoming.cardinality
387                )
388            }
389            MergeStrategy::Union => {
390                // For domains, union means take the larger cardinality
391                if base.cardinality >= incoming.cardinality {
392                    Ok((base.clone(), MergeConflictResolution::KeptBase))
393                } else {
394                    Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
395                }
396            }
397            MergeStrategy::Intersection => {
398                // For domains, intersection means take the smaller cardinality
399                if base.cardinality <= incoming.cardinality {
400                    Ok((base.clone(), MergeConflictResolution::KeptBase))
401                } else {
402                    Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
403                }
404            }
405        }
406    }
407
408    fn resolve_predicate_conflict(
409        &self,
410        base: &PredicateInfo,
411        incoming: &PredicateInfo,
412    ) -> Result<(PredicateInfo, MergeConflictResolution)> {
413        // Check if predicates are compatible (same signature)
414        let compatible = base.arg_domains == incoming.arg_domains;
415
416        match self.strategy {
417            MergeStrategy::KeepFirst => Ok((base.clone(), MergeConflictResolution::KeptBase)),
418            MergeStrategy::KeepSecond => {
419                Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
420            }
421            MergeStrategy::FailOnConflict => {
422                bail!(
423                    "Predicate conflict for '{}': {:?} vs {:?}",
424                    base.name,
425                    base.arg_domains,
426                    incoming.arg_domains
427                )
428            }
429            MergeStrategy::Union => {
430                if compatible {
431                    Ok((base.clone(), MergeConflictResolution::Merged))
432                } else {
433                    bail!(
434                        "Incompatible predicate signatures for '{}': {:?} vs {:?}",
435                        base.name,
436                        base.arg_domains,
437                        incoming.arg_domains
438                    )
439                }
440            }
441            MergeStrategy::Intersection => {
442                if compatible {
443                    Ok((base.clone(), MergeConflictResolution::Merged))
444                } else {
445                    bail!(
446                        "Incompatible predicate signatures for '{}': {:?} vs {:?}",
447                        base.name,
448                        base.arg_domains,
449                        incoming.arg_domains
450                    )
451                }
452            }
453        }
454    }
455
456    fn resolve_variable_conflict(
457        &self,
458        base_domain: &str,
459        incoming_domain: &str,
460    ) -> Result<(String, MergeConflictResolution)> {
461        match self.strategy {
462            MergeStrategy::KeepFirst => {
463                Ok((base_domain.to_string(), MergeConflictResolution::KeptBase))
464            }
465            MergeStrategy::KeepSecond => Ok((
466                incoming_domain.to_string(),
467                MergeConflictResolution::KeptIncoming,
468            )),
469            MergeStrategy::FailOnConflict => {
470                bail!(
471                    "Variable domain conflict: '{}' vs '{}'",
472                    base_domain,
473                    incoming_domain
474                )
475            }
476            MergeStrategy::Union | MergeStrategy::Intersection => {
477                if base_domain == incoming_domain {
478                    Ok((base_domain.to_string(), MergeConflictResolution::Merged))
479                } else {
480                    bail!(
481                        "Incompatible variable domains: '{}' vs '{}'",
482                        base_domain,
483                        incoming_domain
484                    )
485                }
486            }
487        }
488    }
489}
490
491impl Default for SchemaMerger {
492    fn default() -> Self {
493        Self::new(MergeStrategy::Union)
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500
501    fn create_base_table() -> SymbolTable {
502        let mut table = SymbolTable::new();
503        table
504            .add_domain(DomainInfo::new("Person", 100))
505            .expect("unwrap");
506        table
507            .add_predicate(PredicateInfo::new(
508                "knows",
509                vec!["Person".to_string(), "Person".to_string()],
510            ))
511            .expect("unwrap");
512        table.bind_variable("x", "Person").expect("unwrap");
513        table
514    }
515
516    fn create_incoming_table() -> SymbolTable {
517        let mut table = SymbolTable::new();
518        table
519            .add_domain(DomainInfo::new("Person", 150))
520            .expect("unwrap"); // Different cardinality
521        table
522            .add_domain(DomainInfo::new("Organization", 50))
523            .expect("unwrap");
524        table
525            .add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
526            .expect("unwrap");
527        table
528    }
529
530    #[test]
531    fn test_merge_union_no_conflicts() {
532        let base = create_base_table();
533        let incoming = create_incoming_table();
534
535        let merger = SchemaMerger::new(MergeStrategy::Union);
536        let result = merger.merge(&base, &incoming).expect("unwrap");
537
538        assert_eq!(result.merged.domains.len(), 2); // Person (larger card.) + Organization
539        assert_eq!(result.merged.predicates.len(), 2); // knows + age
540                                                       // Person domain has conflict but will be resolved by taking larger cardinality
541        assert!(result.report.has_conflicts()); // Person domain conflict
542    }
543
544    #[test]
545    fn test_merge_with_domain_conflict() {
546        let mut base = SymbolTable::new();
547        base.add_domain(DomainInfo::new("Person", 100))
548            .expect("unwrap");
549
550        let mut incoming = SymbolTable::new();
551        incoming
552            .add_domain(DomainInfo::new("Person", 200))
553            .expect("unwrap");
554
555        let merger = SchemaMerger::new(MergeStrategy::KeepFirst);
556        let result = merger.merge(&base, &incoming).expect("unwrap");
557
558        assert_eq!(result.merged.domains.len(), 1);
559        assert_eq!(
560            result
561                .merged
562                .domains
563                .get("Person")
564                .expect("unwrap")
565                .cardinality,
566            100
567        );
568        assert!(result.report.has_conflicts());
569    }
570
571    #[test]
572    fn test_merge_keep_second() {
573        let mut base = SymbolTable::new();
574        base.add_domain(DomainInfo::new("Person", 100))
575            .expect("unwrap");
576
577        let mut incoming = SymbolTable::new();
578        incoming
579            .add_domain(DomainInfo::new("Person", 200))
580            .expect("unwrap");
581
582        let merger = SchemaMerger::new(MergeStrategy::KeepSecond);
583        let result = merger.merge(&base, &incoming).expect("unwrap");
584
585        assert_eq!(
586            result
587                .merged
588                .domains
589                .get("Person")
590                .expect("unwrap")
591                .cardinality,
592            200
593        );
594    }
595
596    #[test]
597    fn test_merge_fail_on_conflict() {
598        let mut base = SymbolTable::new();
599        base.add_domain(DomainInfo::new("Person", 100))
600            .expect("unwrap");
601
602        let mut incoming = SymbolTable::new();
603        incoming
604            .add_domain(DomainInfo::new("Person", 200))
605            .expect("unwrap");
606
607        let merger = SchemaMerger::new(MergeStrategy::FailOnConflict);
608        let result = merger.merge(&base, &incoming);
609
610        assert!(result.is_err());
611    }
612
613    #[test]
614    fn test_merge_report() {
615        let base = create_base_table();
616        let incoming = create_incoming_table();
617
618        let merger = SchemaMerger::new(MergeStrategy::Union);
619        let result = merger.merge(&base, &incoming).expect("unwrap");
620
621        let report = &result.report;
622        // Base has Person (no unique domains after merge since incoming also has Person)
623        assert_eq!(report.base_domains.len(), 0);
624        // Incoming has Organization (unique)
625        assert_eq!(report.incoming_domains.len(), 1);
626        // merged_count = base_domains (0) + incoming_domains (1) + base_predicates (1) + incoming_predicates (1) + merged_variables (1)
627        assert_eq!(report.merged_count(), 4);
628        assert_eq!(report.conflict_count(), 1); // Person domain conflict
629    }
630
631    #[test]
632    fn test_predicate_conflict_compatible() {
633        let mut base = SymbolTable::new();
634        base.add_domain(DomainInfo::new("Person", 100))
635            .expect("unwrap");
636        base.add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
637            .expect("unwrap");
638
639        let mut incoming = SymbolTable::new();
640        incoming
641            .add_domain(DomainInfo::new("Person", 100))
642            .expect("unwrap");
643        incoming
644            .add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
645            .expect("unwrap");
646
647        let merger = SchemaMerger::new(MergeStrategy::Union);
648        let result = merger.merge(&base, &incoming).expect("unwrap");
649
650        assert_eq!(result.merged.predicates.len(), 1);
651        assert_eq!(result.report.conflicting_predicates.len(), 0);
652    }
653
654    #[test]
655    fn test_variable_conflict() {
656        let mut base = SymbolTable::new();
657        base.add_domain(DomainInfo::new("Person", 100))
658            .expect("unwrap");
659        base.add_domain(DomainInfo::new("Agent", 50))
660            .expect("unwrap");
661        base.bind_variable("x", "Person").expect("unwrap");
662
663        let mut incoming = SymbolTable::new();
664        incoming
665            .add_domain(DomainInfo::new("Person", 100))
666            .expect("unwrap");
667        incoming
668            .add_domain(DomainInfo::new("Agent", 50))
669            .expect("unwrap");
670        incoming.bind_variable("x", "Agent").expect("unwrap");
671
672        let merger = SchemaMerger::new(MergeStrategy::KeepFirst);
673        let result = merger.merge(&base, &incoming).expect("unwrap");
674
675        assert_eq!(result.merged.variables.get("x").expect("unwrap"), "Person");
676        assert_eq!(result.report.conflicting_variables.len(), 1);
677    }
678}