Skip to main content

tensorlogic_adapters/
merge_strategies.rs

1//! Advanced schema merging strategies for combining schemas from different sources.
2//!
3//! This module provides sophisticated strategies for merging symbol tables with
4//! conflict resolution, validation, and detailed merge reports.
5
6use crate::{DomainInfo, PredicateInfo, SymbolTable};
7use anyhow::{bail, Result};
8use std::collections::HashSet;
9
10/// Strategy for resolving conflicts during schema merging.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum MergeStrategy {
13    /// Keep the first (base) version in case of conflict
14    KeepFirst,
15    /// Keep the second (incoming) version in case of conflict
16    KeepSecond,
17    /// Fail on any conflict
18    FailOnConflict,
19    /// Union: Keep both if compatible, fail if incompatible
20    Union,
21    /// Intersection: Only keep items present in both with compatible definitions
22    Intersection,
23}
24
25/// Result of a merge operation.
26#[derive(Debug, Clone)]
27pub struct MergeResult {
28    /// The merged symbol table
29    pub merged: SymbolTable,
30    /// Report of the merge operation
31    pub report: MergeReport,
32}
33
34/// Report of a merge operation.
35#[derive(Debug, Clone)]
36pub struct MergeReport {
37    /// Domains that were added from base
38    pub base_domains: Vec<String>,
39    /// Domains that were added from incoming
40    pub incoming_domains: Vec<String>,
41    /// Domains that had conflicts
42    pub conflicting_domains: Vec<DomainConflict>,
43    /// Predicates that were added from base
44    pub base_predicates: Vec<String>,
45    /// Predicates that were added from incoming
46    pub incoming_predicates: Vec<String>,
47    /// Predicates that had conflicts
48    pub conflicting_predicates: Vec<PredicateConflict>,
49    /// Variables that were merged
50    pub merged_variables: Vec<String>,
51    /// Variables that had conflicts
52    pub conflicting_variables: Vec<VariableConflict>,
53    /// Overall merge strategy used
54    pub strategy: MergeStrategy,
55}
56
57impl MergeReport {
58    /// Create a new empty merge report.
59    pub fn new(strategy: MergeStrategy) -> Self {
60        Self {
61            base_domains: Vec::new(),
62            incoming_domains: Vec::new(),
63            conflicting_domains: Vec::new(),
64            base_predicates: Vec::new(),
65            incoming_predicates: Vec::new(),
66            conflicting_predicates: Vec::new(),
67            merged_variables: Vec::new(),
68            conflicting_variables: Vec::new(),
69            strategy,
70        }
71    }
72
73    /// Check if there were any conflicts during merging.
74    pub fn has_conflicts(&self) -> bool {
75        !self.conflicting_domains.is_empty()
76            || !self.conflicting_predicates.is_empty()
77            || !self.conflicting_variables.is_empty()
78    }
79
80    /// Get total number of conflicts.
81    pub fn conflict_count(&self) -> usize {
82        self.conflicting_domains.len()
83            + self.conflicting_predicates.len()
84            + self.conflicting_variables.len()
85    }
86
87    /// Get total number of merged items.
88    pub fn merged_count(&self) -> usize {
89        self.base_domains.len()
90            + self.incoming_domains.len()
91            + self.base_predicates.len()
92            + self.incoming_predicates.len()
93            + self.merged_variables.len()
94    }
95}
96
97/// Information about a domain conflict.
98#[derive(Debug, Clone)]
99pub struct DomainConflict {
100    /// Name of the conflicting domain
101    pub name: String,
102    /// Domain from base table
103    pub base: DomainInfo,
104    /// Domain from incoming table
105    pub incoming: DomainInfo,
106    /// How the conflict was resolved
107    pub resolution: MergeConflictResolution,
108}
109
110/// Information about a predicate conflict.
111#[derive(Debug, Clone)]
112pub struct PredicateConflict {
113    /// Name of the conflicting predicate
114    pub name: String,
115    /// Predicate from base table
116    pub base: PredicateInfo,
117    /// Predicate from incoming table
118    pub incoming: PredicateInfo,
119    /// How the conflict was resolved
120    pub resolution: MergeConflictResolution,
121}
122
123/// Information about a variable conflict.
124#[derive(Debug, Clone)]
125pub struct VariableConflict {
126    /// Name of the conflicting variable
127    pub name: String,
128    /// Domain from base table
129    pub base_domain: String,
130    /// Domain from incoming table
131    pub incoming_domain: String,
132    /// How the conflict was resolved
133    pub resolution: MergeConflictResolution,
134}
135
136/// How a merge conflict was resolved.
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum MergeConflictResolution {
139    /// Kept the base version
140    KeptBase,
141    /// Kept the incoming version
142    KeptIncoming,
143    /// Failed to resolve
144    Failed,
145    /// Merged both (only for compatible items)
146    Merged,
147}
148
149/// A schema merger with configurable strategies.
150pub struct SchemaMerger {
151    strategy: MergeStrategy,
152}
153
154impl SchemaMerger {
155    /// Create a new schema merger with the given strategy.
156    pub fn new(strategy: MergeStrategy) -> Self {
157        Self { strategy }
158    }
159
160    /// Merge two symbol tables according to the configured strategy.
161    ///
162    /// # Examples
163    ///
164    /// ```
165    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, SchemaMerger, MergeStrategy};
166    ///
167    /// let mut base = SymbolTable::new();
168    /// base.add_domain(DomainInfo::new("Person", 100)).unwrap();
169    ///
170    /// let mut incoming = SymbolTable::new();
171    /// incoming.add_domain(DomainInfo::new("Organization", 50)).unwrap();
172    ///
173    /// let merger = SchemaMerger::new(MergeStrategy::Union);
174    /// let result = merger.merge(&base, &incoming).unwrap();
175    ///
176    /// assert_eq!(result.merged.domains.len(), 2);
177    /// ```
178    pub fn merge(&self, base: &SymbolTable, incoming: &SymbolTable) -> Result<MergeResult> {
179        let mut merged = SymbolTable::new();
180        let mut report = MergeReport::new(self.strategy);
181
182        // Merge domains
183        self.merge_domains(base, incoming, &mut merged, &mut report)?;
184
185        // Merge predicates
186        self.merge_predicates(base, incoming, &mut merged, &mut report)?;
187
188        // Merge variables
189        self.merge_variables(base, incoming, &mut merged, &mut report)?;
190
191        Ok(MergeResult { merged, report })
192    }
193
194    fn merge_domains(
195        &self,
196        base: &SymbolTable,
197        incoming: &SymbolTable,
198        merged: &mut SymbolTable,
199        report: &mut MergeReport,
200    ) -> Result<()> {
201        let base_keys: HashSet<&String> = base.domains.keys().collect();
202        let incoming_keys: HashSet<&String> = incoming.domains.keys().collect();
203
204        // Domains only in base
205        for key in base_keys.difference(&incoming_keys) {
206            let domain = base.domains.get(*key).unwrap();
207            merged.add_domain(domain.clone())?;
208            report.base_domains.push(key.to_string());
209        }
210
211        // Domains only in incoming
212        for key in incoming_keys.difference(&base_keys) {
213            let domain = incoming.domains.get(*key).unwrap();
214            merged.add_domain(domain.clone())?;
215            report.incoming_domains.push(key.to_string());
216        }
217
218        // Domains in both (conflicts)
219        for key in base_keys.intersection(&incoming_keys) {
220            let base_domain = base.domains.get(*key).unwrap();
221            let incoming_domain = incoming.domains.get(*key).unwrap();
222
223            let (domain, resolution) =
224                self.resolve_domain_conflict(base_domain, incoming_domain)?;
225
226            merged.add_domain(domain)?;
227
228            if resolution != MergeConflictResolution::Merged {
229                report.conflicting_domains.push(DomainConflict {
230                    name: key.to_string(),
231                    base: base_domain.clone(),
232                    incoming: incoming_domain.clone(),
233                    resolution,
234                });
235            }
236        }
237
238        Ok(())
239    }
240
241    fn merge_predicates(
242        &self,
243        base: &SymbolTable,
244        incoming: &SymbolTable,
245        merged: &mut SymbolTable,
246        report: &mut MergeReport,
247    ) -> Result<()> {
248        let base_keys: HashSet<&String> = base.predicates.keys().collect();
249        let incoming_keys: HashSet<&String> = incoming.predicates.keys().collect();
250
251        // Predicates only in base
252        for key in base_keys.difference(&incoming_keys) {
253            let predicate = base.predicates.get(*key).unwrap();
254            merged.add_predicate(predicate.clone())?;
255            report.base_predicates.push(key.to_string());
256        }
257
258        // Predicates only in incoming
259        for key in incoming_keys.difference(&base_keys) {
260            let predicate = incoming.predicates.get(*key).unwrap();
261            merged.add_predicate(predicate.clone())?;
262            report.incoming_predicates.push(key.to_string());
263        }
264
265        // Predicates in both (conflicts)
266        for key in base_keys.intersection(&incoming_keys) {
267            let base_pred = base.predicates.get(*key).unwrap();
268            let incoming_pred = incoming.predicates.get(*key).unwrap();
269
270            let (predicate, resolution) =
271                self.resolve_predicate_conflict(base_pred, incoming_pred)?;
272
273            merged.add_predicate(predicate)?;
274
275            if resolution != MergeConflictResolution::Merged {
276                report.conflicting_predicates.push(PredicateConflict {
277                    name: key.to_string(),
278                    base: base_pred.clone(),
279                    incoming: incoming_pred.clone(),
280                    resolution,
281                });
282            }
283        }
284
285        Ok(())
286    }
287
288    fn merge_variables(
289        &self,
290        base: &SymbolTable,
291        incoming: &SymbolTable,
292        merged: &mut SymbolTable,
293        report: &mut MergeReport,
294    ) -> Result<()> {
295        let base_keys: HashSet<&String> = base.variables.keys().collect();
296        let incoming_keys: HashSet<&String> = incoming.variables.keys().collect();
297
298        // Variables only in base
299        for key in base_keys.difference(&incoming_keys) {
300            let domain = base.variables.get(*key).unwrap();
301            merged.bind_variable(key.to_string(), domain.clone())?;
302            report.merged_variables.push(key.to_string());
303        }
304
305        // Variables only in incoming
306        for key in incoming_keys.difference(&base_keys) {
307            let domain = incoming.variables.get(*key).unwrap();
308            merged.bind_variable(key.to_string(), domain.clone())?;
309            report.merged_variables.push(key.to_string());
310        }
311
312        // Variables in both (conflicts)
313        for key in base_keys.intersection(&incoming_keys) {
314            let base_domain = base.variables.get(*key).unwrap();
315            let incoming_domain = incoming.variables.get(*key).unwrap();
316
317            let (domain, resolution) =
318                self.resolve_variable_conflict(base_domain, incoming_domain)?;
319
320            merged.bind_variable(key.to_string(), domain)?;
321
322            if resolution != MergeConflictResolution::Merged {
323                report.conflicting_variables.push(VariableConflict {
324                    name: key.to_string(),
325                    base_domain: base_domain.clone(),
326                    incoming_domain: incoming_domain.clone(),
327                    resolution,
328                });
329            }
330        }
331
332        Ok(())
333    }
334
335    fn resolve_domain_conflict(
336        &self,
337        base: &DomainInfo,
338        incoming: &DomainInfo,
339    ) -> Result<(DomainInfo, MergeConflictResolution)> {
340        match self.strategy {
341            MergeStrategy::KeepFirst => Ok((base.clone(), MergeConflictResolution::KeptBase)),
342            MergeStrategy::KeepSecond => {
343                Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
344            }
345            MergeStrategy::FailOnConflict => {
346                bail!(
347                    "Domain conflict for '{}': cardinality {} vs {}",
348                    base.name,
349                    base.cardinality,
350                    incoming.cardinality
351                )
352            }
353            MergeStrategy::Union => {
354                // For domains, union means take the larger cardinality
355                if base.cardinality >= incoming.cardinality {
356                    Ok((base.clone(), MergeConflictResolution::KeptBase))
357                } else {
358                    Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
359                }
360            }
361            MergeStrategy::Intersection => {
362                // For domains, intersection means take the smaller cardinality
363                if base.cardinality <= incoming.cardinality {
364                    Ok((base.clone(), MergeConflictResolution::KeptBase))
365                } else {
366                    Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
367                }
368            }
369        }
370    }
371
372    fn resolve_predicate_conflict(
373        &self,
374        base: &PredicateInfo,
375        incoming: &PredicateInfo,
376    ) -> Result<(PredicateInfo, MergeConflictResolution)> {
377        // Check if predicates are compatible (same signature)
378        let compatible = base.arg_domains == incoming.arg_domains;
379
380        match self.strategy {
381            MergeStrategy::KeepFirst => Ok((base.clone(), MergeConflictResolution::KeptBase)),
382            MergeStrategy::KeepSecond => {
383                Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
384            }
385            MergeStrategy::FailOnConflict => {
386                bail!(
387                    "Predicate conflict for '{}': {:?} vs {:?}",
388                    base.name,
389                    base.arg_domains,
390                    incoming.arg_domains
391                )
392            }
393            MergeStrategy::Union => {
394                if compatible {
395                    Ok((base.clone(), MergeConflictResolution::Merged))
396                } else {
397                    bail!(
398                        "Incompatible predicate signatures for '{}': {:?} vs {:?}",
399                        base.name,
400                        base.arg_domains,
401                        incoming.arg_domains
402                    )
403                }
404            }
405            MergeStrategy::Intersection => {
406                if compatible {
407                    Ok((base.clone(), MergeConflictResolution::Merged))
408                } else {
409                    bail!(
410                        "Incompatible predicate signatures for '{}': {:?} vs {:?}",
411                        base.name,
412                        base.arg_domains,
413                        incoming.arg_domains
414                    )
415                }
416            }
417        }
418    }
419
420    fn resolve_variable_conflict(
421        &self,
422        base_domain: &str,
423        incoming_domain: &str,
424    ) -> Result<(String, MergeConflictResolution)> {
425        match self.strategy {
426            MergeStrategy::KeepFirst => {
427                Ok((base_domain.to_string(), MergeConflictResolution::KeptBase))
428            }
429            MergeStrategy::KeepSecond => Ok((
430                incoming_domain.to_string(),
431                MergeConflictResolution::KeptIncoming,
432            )),
433            MergeStrategy::FailOnConflict => {
434                bail!(
435                    "Variable domain conflict: '{}' vs '{}'",
436                    base_domain,
437                    incoming_domain
438                )
439            }
440            MergeStrategy::Union | MergeStrategy::Intersection => {
441                if base_domain == incoming_domain {
442                    Ok((base_domain.to_string(), MergeConflictResolution::Merged))
443                } else {
444                    bail!(
445                        "Incompatible variable domains: '{}' vs '{}'",
446                        base_domain,
447                        incoming_domain
448                    )
449                }
450            }
451        }
452    }
453}
454
455impl Default for SchemaMerger {
456    fn default() -> Self {
457        Self::new(MergeStrategy::Union)
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    fn create_base_table() -> SymbolTable {
466        let mut table = SymbolTable::new();
467        table.add_domain(DomainInfo::new("Person", 100)).unwrap();
468        table
469            .add_predicate(PredicateInfo::new(
470                "knows",
471                vec!["Person".to_string(), "Person".to_string()],
472            ))
473            .unwrap();
474        table.bind_variable("x", "Person").unwrap();
475        table
476    }
477
478    fn create_incoming_table() -> SymbolTable {
479        let mut table = SymbolTable::new();
480        table.add_domain(DomainInfo::new("Person", 150)).unwrap(); // Different cardinality
481        table
482            .add_domain(DomainInfo::new("Organization", 50))
483            .unwrap();
484        table
485            .add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
486            .unwrap();
487        table
488    }
489
490    #[test]
491    fn test_merge_union_no_conflicts() {
492        let base = create_base_table();
493        let incoming = create_incoming_table();
494
495        let merger = SchemaMerger::new(MergeStrategy::Union);
496        let result = merger.merge(&base, &incoming).unwrap();
497
498        assert_eq!(result.merged.domains.len(), 2); // Person (larger card.) + Organization
499        assert_eq!(result.merged.predicates.len(), 2); // knows + age
500                                                       // Person domain has conflict but will be resolved by taking larger cardinality
501        assert!(result.report.has_conflicts()); // Person domain conflict
502    }
503
504    #[test]
505    fn test_merge_with_domain_conflict() {
506        let mut base = SymbolTable::new();
507        base.add_domain(DomainInfo::new("Person", 100)).unwrap();
508
509        let mut incoming = SymbolTable::new();
510        incoming.add_domain(DomainInfo::new("Person", 200)).unwrap();
511
512        let merger = SchemaMerger::new(MergeStrategy::KeepFirst);
513        let result = merger.merge(&base, &incoming).unwrap();
514
515        assert_eq!(result.merged.domains.len(), 1);
516        assert_eq!(
517            result.merged.domains.get("Person").unwrap().cardinality,
518            100
519        );
520        assert!(result.report.has_conflicts());
521    }
522
523    #[test]
524    fn test_merge_keep_second() {
525        let mut base = SymbolTable::new();
526        base.add_domain(DomainInfo::new("Person", 100)).unwrap();
527
528        let mut incoming = SymbolTable::new();
529        incoming.add_domain(DomainInfo::new("Person", 200)).unwrap();
530
531        let merger = SchemaMerger::new(MergeStrategy::KeepSecond);
532        let result = merger.merge(&base, &incoming).unwrap();
533
534        assert_eq!(
535            result.merged.domains.get("Person").unwrap().cardinality,
536            200
537        );
538    }
539
540    #[test]
541    fn test_merge_fail_on_conflict() {
542        let mut base = SymbolTable::new();
543        base.add_domain(DomainInfo::new("Person", 100)).unwrap();
544
545        let mut incoming = SymbolTable::new();
546        incoming.add_domain(DomainInfo::new("Person", 200)).unwrap();
547
548        let merger = SchemaMerger::new(MergeStrategy::FailOnConflict);
549        let result = merger.merge(&base, &incoming);
550
551        assert!(result.is_err());
552    }
553
554    #[test]
555    fn test_merge_report() {
556        let base = create_base_table();
557        let incoming = create_incoming_table();
558
559        let merger = SchemaMerger::new(MergeStrategy::Union);
560        let result = merger.merge(&base, &incoming).unwrap();
561
562        let report = &result.report;
563        // Base has Person (no unique domains after merge since incoming also has Person)
564        assert_eq!(report.base_domains.len(), 0);
565        // Incoming has Organization (unique)
566        assert_eq!(report.incoming_domains.len(), 1);
567        // merged_count = base_domains (0) + incoming_domains (1) + base_predicates (1) + incoming_predicates (1) + merged_variables (1)
568        assert_eq!(report.merged_count(), 4);
569        assert_eq!(report.conflict_count(), 1); // Person domain conflict
570    }
571
572    #[test]
573    fn test_predicate_conflict_compatible() {
574        let mut base = SymbolTable::new();
575        base.add_domain(DomainInfo::new("Person", 100)).unwrap();
576        base.add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
577            .unwrap();
578
579        let mut incoming = SymbolTable::new();
580        incoming.add_domain(DomainInfo::new("Person", 100)).unwrap();
581        incoming
582            .add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
583            .unwrap();
584
585        let merger = SchemaMerger::new(MergeStrategy::Union);
586        let result = merger.merge(&base, &incoming).unwrap();
587
588        assert_eq!(result.merged.predicates.len(), 1);
589        assert_eq!(result.report.conflicting_predicates.len(), 0);
590    }
591
592    #[test]
593    fn test_variable_conflict() {
594        let mut base = SymbolTable::new();
595        base.add_domain(DomainInfo::new("Person", 100)).unwrap();
596        base.add_domain(DomainInfo::new("Agent", 50)).unwrap();
597        base.bind_variable("x", "Person").unwrap();
598
599        let mut incoming = SymbolTable::new();
600        incoming.add_domain(DomainInfo::new("Person", 100)).unwrap();
601        incoming.add_domain(DomainInfo::new("Agent", 50)).unwrap();
602        incoming.bind_variable("x", "Agent").unwrap();
603
604        let merger = SchemaMerger::new(MergeStrategy::KeepFirst);
605        let result = merger.merge(&base, &incoming).unwrap();
606
607        assert_eq!(result.merged.variables.get("x").unwrap(), "Person");
608        assert_eq!(result.report.conflicting_variables.len(), 1);
609    }
610}