1use crate::{DomainInfo, PredicateInfo, SymbolTable};
7use anyhow::{bail, Result};
8use std::collections::HashSet;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum MergeStrategy {
13 KeepFirst,
15 KeepSecond,
17 FailOnConflict,
19 Union,
21 Intersection,
23}
24
25#[derive(Debug, Clone)]
27pub struct MergeResult {
28 pub merged: SymbolTable,
30 pub report: MergeReport,
32}
33
34#[derive(Debug, Clone)]
36pub struct MergeReport {
37 pub base_domains: Vec<String>,
39 pub incoming_domains: Vec<String>,
41 pub conflicting_domains: Vec<DomainConflict>,
43 pub base_predicates: Vec<String>,
45 pub incoming_predicates: Vec<String>,
47 pub conflicting_predicates: Vec<PredicateConflict>,
49 pub merged_variables: Vec<String>,
51 pub conflicting_variables: Vec<VariableConflict>,
53 pub strategy: MergeStrategy,
55}
56
57impl MergeReport {
58 pub fn new(strategy: MergeStrategy) -> Self {
60 Self {
61 base_domains: Vec::new(),
62 incoming_domains: Vec::new(),
63 conflicting_domains: Vec::new(),
64 base_predicates: Vec::new(),
65 incoming_predicates: Vec::new(),
66 conflicting_predicates: Vec::new(),
67 merged_variables: Vec::new(),
68 conflicting_variables: Vec::new(),
69 strategy,
70 }
71 }
72
73 pub fn has_conflicts(&self) -> bool {
75 !self.conflicting_domains.is_empty()
76 || !self.conflicting_predicates.is_empty()
77 || !self.conflicting_variables.is_empty()
78 }
79
80 pub fn conflict_count(&self) -> usize {
82 self.conflicting_domains.len()
83 + self.conflicting_predicates.len()
84 + self.conflicting_variables.len()
85 }
86
87 pub fn merged_count(&self) -> usize {
89 self.base_domains.len()
90 + self.incoming_domains.len()
91 + self.base_predicates.len()
92 + self.incoming_predicates.len()
93 + self.merged_variables.len()
94 }
95}
96
97#[derive(Debug, Clone)]
99pub struct DomainConflict {
100 pub name: String,
102 pub base: DomainInfo,
104 pub incoming: DomainInfo,
106 pub resolution: MergeConflictResolution,
108}
109
110#[derive(Debug, Clone)]
112pub struct PredicateConflict {
113 pub name: String,
115 pub base: PredicateInfo,
117 pub incoming: PredicateInfo,
119 pub resolution: MergeConflictResolution,
121}
122
123#[derive(Debug, Clone)]
125pub struct VariableConflict {
126 pub name: String,
128 pub base_domain: String,
130 pub incoming_domain: String,
132 pub resolution: MergeConflictResolution,
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum MergeConflictResolution {
139 KeptBase,
141 KeptIncoming,
143 Failed,
145 Merged,
147}
148
149pub struct SchemaMerger {
151 strategy: MergeStrategy,
152}
153
154impl SchemaMerger {
155 pub fn new(strategy: MergeStrategy) -> Self {
157 Self { strategy }
158 }
159
160 pub fn merge(&self, base: &SymbolTable, incoming: &SymbolTable) -> Result<MergeResult> {
179 let mut merged = SymbolTable::new();
180 let mut report = MergeReport::new(self.strategy);
181
182 self.merge_domains(base, incoming, &mut merged, &mut report)?;
184
185 self.merge_predicates(base, incoming, &mut merged, &mut report)?;
187
188 self.merge_variables(base, incoming, &mut merged, &mut report)?;
190
191 Ok(MergeResult { merged, report })
192 }
193
194 fn merge_domains(
195 &self,
196 base: &SymbolTable,
197 incoming: &SymbolTable,
198 merged: &mut SymbolTable,
199 report: &mut MergeReport,
200 ) -> Result<()> {
201 let base_keys: HashSet<&String> = base.domains.keys().collect();
202 let incoming_keys: HashSet<&String> = incoming.domains.keys().collect();
203
204 for key in base_keys.difference(&incoming_keys) {
206 let domain = base
207 .domains
208 .get(*key)
209 .expect("key from HashMap iteration is always present");
210 merged.add_domain(domain.clone())?;
211 report.base_domains.push(key.to_string());
212 }
213
214 for key in incoming_keys.difference(&base_keys) {
216 let domain = incoming
217 .domains
218 .get(*key)
219 .expect("key from HashMap iteration is always present");
220 merged.add_domain(domain.clone())?;
221 report.incoming_domains.push(key.to_string());
222 }
223
224 for key in base_keys.intersection(&incoming_keys) {
226 let base_domain = base
227 .domains
228 .get(*key)
229 .expect("key from HashMap iteration is always present");
230 let incoming_domain = incoming
231 .domains
232 .get(*key)
233 .expect("key from HashMap iteration is always present");
234
235 let (domain, resolution) =
236 self.resolve_domain_conflict(base_domain, incoming_domain)?;
237
238 merged.add_domain(domain)?;
239
240 if resolution != MergeConflictResolution::Merged {
241 report.conflicting_domains.push(DomainConflict {
242 name: key.to_string(),
243 base: base_domain.clone(),
244 incoming: incoming_domain.clone(),
245 resolution,
246 });
247 }
248 }
249
250 Ok(())
251 }
252
253 fn merge_predicates(
254 &self,
255 base: &SymbolTable,
256 incoming: &SymbolTable,
257 merged: &mut SymbolTable,
258 report: &mut MergeReport,
259 ) -> Result<()> {
260 let base_keys: HashSet<&String> = base.predicates.keys().collect();
261 let incoming_keys: HashSet<&String> = incoming.predicates.keys().collect();
262
263 for key in base_keys.difference(&incoming_keys) {
265 let predicate = base
266 .predicates
267 .get(*key)
268 .expect("key from HashMap iteration is always present");
269 merged.add_predicate(predicate.clone())?;
270 report.base_predicates.push(key.to_string());
271 }
272
273 for key in incoming_keys.difference(&base_keys) {
275 let predicate = incoming
276 .predicates
277 .get(*key)
278 .expect("key from HashMap iteration is always present");
279 merged.add_predicate(predicate.clone())?;
280 report.incoming_predicates.push(key.to_string());
281 }
282
283 for key in base_keys.intersection(&incoming_keys) {
285 let base_pred = base
286 .predicates
287 .get(*key)
288 .expect("key from HashMap iteration is always present");
289 let incoming_pred = incoming
290 .predicates
291 .get(*key)
292 .expect("key from HashMap iteration is always present");
293
294 let (predicate, resolution) =
295 self.resolve_predicate_conflict(base_pred, incoming_pred)?;
296
297 merged.add_predicate(predicate)?;
298
299 if resolution != MergeConflictResolution::Merged {
300 report.conflicting_predicates.push(PredicateConflict {
301 name: key.to_string(),
302 base: base_pred.clone(),
303 incoming: incoming_pred.clone(),
304 resolution,
305 });
306 }
307 }
308
309 Ok(())
310 }
311
312 fn merge_variables(
313 &self,
314 base: &SymbolTable,
315 incoming: &SymbolTable,
316 merged: &mut SymbolTable,
317 report: &mut MergeReport,
318 ) -> Result<()> {
319 let base_keys: HashSet<&String> = base.variables.keys().collect();
320 let incoming_keys: HashSet<&String> = incoming.variables.keys().collect();
321
322 for key in base_keys.difference(&incoming_keys) {
324 let domain = base
325 .variables
326 .get(*key)
327 .expect("key from HashMap iteration is always present");
328 merged.bind_variable(key.to_string(), domain.clone())?;
329 report.merged_variables.push(key.to_string());
330 }
331
332 for key in incoming_keys.difference(&base_keys) {
334 let domain = incoming
335 .variables
336 .get(*key)
337 .expect("key from HashMap iteration is always present");
338 merged.bind_variable(key.to_string(), domain.clone())?;
339 report.merged_variables.push(key.to_string());
340 }
341
342 for key in base_keys.intersection(&incoming_keys) {
344 let base_domain = base
345 .variables
346 .get(*key)
347 .expect("key from HashMap iteration is always present");
348 let incoming_domain = incoming
349 .variables
350 .get(*key)
351 .expect("key from HashMap iteration is always present");
352
353 let (domain, resolution) =
354 self.resolve_variable_conflict(base_domain, incoming_domain)?;
355
356 merged.bind_variable(key.to_string(), domain)?;
357
358 if resolution != MergeConflictResolution::Merged {
359 report.conflicting_variables.push(VariableConflict {
360 name: key.to_string(),
361 base_domain: base_domain.clone(),
362 incoming_domain: incoming_domain.clone(),
363 resolution,
364 });
365 }
366 }
367
368 Ok(())
369 }
370
371 fn resolve_domain_conflict(
372 &self,
373 base: &DomainInfo,
374 incoming: &DomainInfo,
375 ) -> Result<(DomainInfo, MergeConflictResolution)> {
376 match self.strategy {
377 MergeStrategy::KeepFirst => Ok((base.clone(), MergeConflictResolution::KeptBase)),
378 MergeStrategy::KeepSecond => {
379 Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
380 }
381 MergeStrategy::FailOnConflict => {
382 bail!(
383 "Domain conflict for '{}': cardinality {} vs {}",
384 base.name,
385 base.cardinality,
386 incoming.cardinality
387 )
388 }
389 MergeStrategy::Union => {
390 if base.cardinality >= incoming.cardinality {
392 Ok((base.clone(), MergeConflictResolution::KeptBase))
393 } else {
394 Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
395 }
396 }
397 MergeStrategy::Intersection => {
398 if base.cardinality <= incoming.cardinality {
400 Ok((base.clone(), MergeConflictResolution::KeptBase))
401 } else {
402 Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
403 }
404 }
405 }
406 }
407
408 fn resolve_predicate_conflict(
409 &self,
410 base: &PredicateInfo,
411 incoming: &PredicateInfo,
412 ) -> Result<(PredicateInfo, MergeConflictResolution)> {
413 let compatible = base.arg_domains == incoming.arg_domains;
415
416 match self.strategy {
417 MergeStrategy::KeepFirst => Ok((base.clone(), MergeConflictResolution::KeptBase)),
418 MergeStrategy::KeepSecond => {
419 Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
420 }
421 MergeStrategy::FailOnConflict => {
422 bail!(
423 "Predicate conflict for '{}': {:?} vs {:?}",
424 base.name,
425 base.arg_domains,
426 incoming.arg_domains
427 )
428 }
429 MergeStrategy::Union => {
430 if compatible {
431 Ok((base.clone(), MergeConflictResolution::Merged))
432 } else {
433 bail!(
434 "Incompatible predicate signatures for '{}': {:?} vs {:?}",
435 base.name,
436 base.arg_domains,
437 incoming.arg_domains
438 )
439 }
440 }
441 MergeStrategy::Intersection => {
442 if compatible {
443 Ok((base.clone(), MergeConflictResolution::Merged))
444 } else {
445 bail!(
446 "Incompatible predicate signatures for '{}': {:?} vs {:?}",
447 base.name,
448 base.arg_domains,
449 incoming.arg_domains
450 )
451 }
452 }
453 }
454 }
455
456 fn resolve_variable_conflict(
457 &self,
458 base_domain: &str,
459 incoming_domain: &str,
460 ) -> Result<(String, MergeConflictResolution)> {
461 match self.strategy {
462 MergeStrategy::KeepFirst => {
463 Ok((base_domain.to_string(), MergeConflictResolution::KeptBase))
464 }
465 MergeStrategy::KeepSecond => Ok((
466 incoming_domain.to_string(),
467 MergeConflictResolution::KeptIncoming,
468 )),
469 MergeStrategy::FailOnConflict => {
470 bail!(
471 "Variable domain conflict: '{}' vs '{}'",
472 base_domain,
473 incoming_domain
474 )
475 }
476 MergeStrategy::Union | MergeStrategy::Intersection => {
477 if base_domain == incoming_domain {
478 Ok((base_domain.to_string(), MergeConflictResolution::Merged))
479 } else {
480 bail!(
481 "Incompatible variable domains: '{}' vs '{}'",
482 base_domain,
483 incoming_domain
484 )
485 }
486 }
487 }
488 }
489}
490
491impl Default for SchemaMerger {
492 fn default() -> Self {
493 Self::new(MergeStrategy::Union)
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 fn create_base_table() -> SymbolTable {
502 let mut table = SymbolTable::new();
503 table
504 .add_domain(DomainInfo::new("Person", 100))
505 .expect("unwrap");
506 table
507 .add_predicate(PredicateInfo::new(
508 "knows",
509 vec!["Person".to_string(), "Person".to_string()],
510 ))
511 .expect("unwrap");
512 table.bind_variable("x", "Person").expect("unwrap");
513 table
514 }
515
516 fn create_incoming_table() -> SymbolTable {
517 let mut table = SymbolTable::new();
518 table
519 .add_domain(DomainInfo::new("Person", 150))
520 .expect("unwrap"); table
522 .add_domain(DomainInfo::new("Organization", 50))
523 .expect("unwrap");
524 table
525 .add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
526 .expect("unwrap");
527 table
528 }
529
530 #[test]
531 fn test_merge_union_no_conflicts() {
532 let base = create_base_table();
533 let incoming = create_incoming_table();
534
535 let merger = SchemaMerger::new(MergeStrategy::Union);
536 let result = merger.merge(&base, &incoming).expect("unwrap");
537
538 assert_eq!(result.merged.domains.len(), 2); assert_eq!(result.merged.predicates.len(), 2); assert!(result.report.has_conflicts()); }
543
544 #[test]
545 fn test_merge_with_domain_conflict() {
546 let mut base = SymbolTable::new();
547 base.add_domain(DomainInfo::new("Person", 100))
548 .expect("unwrap");
549
550 let mut incoming = SymbolTable::new();
551 incoming
552 .add_domain(DomainInfo::new("Person", 200))
553 .expect("unwrap");
554
555 let merger = SchemaMerger::new(MergeStrategy::KeepFirst);
556 let result = merger.merge(&base, &incoming).expect("unwrap");
557
558 assert_eq!(result.merged.domains.len(), 1);
559 assert_eq!(
560 result
561 .merged
562 .domains
563 .get("Person")
564 .expect("unwrap")
565 .cardinality,
566 100
567 );
568 assert!(result.report.has_conflicts());
569 }
570
571 #[test]
572 fn test_merge_keep_second() {
573 let mut base = SymbolTable::new();
574 base.add_domain(DomainInfo::new("Person", 100))
575 .expect("unwrap");
576
577 let mut incoming = SymbolTable::new();
578 incoming
579 .add_domain(DomainInfo::new("Person", 200))
580 .expect("unwrap");
581
582 let merger = SchemaMerger::new(MergeStrategy::KeepSecond);
583 let result = merger.merge(&base, &incoming).expect("unwrap");
584
585 assert_eq!(
586 result
587 .merged
588 .domains
589 .get("Person")
590 .expect("unwrap")
591 .cardinality,
592 200
593 );
594 }
595
596 #[test]
597 fn test_merge_fail_on_conflict() {
598 let mut base = SymbolTable::new();
599 base.add_domain(DomainInfo::new("Person", 100))
600 .expect("unwrap");
601
602 let mut incoming = SymbolTable::new();
603 incoming
604 .add_domain(DomainInfo::new("Person", 200))
605 .expect("unwrap");
606
607 let merger = SchemaMerger::new(MergeStrategy::FailOnConflict);
608 let result = merger.merge(&base, &incoming);
609
610 assert!(result.is_err());
611 }
612
613 #[test]
614 fn test_merge_report() {
615 let base = create_base_table();
616 let incoming = create_incoming_table();
617
618 let merger = SchemaMerger::new(MergeStrategy::Union);
619 let result = merger.merge(&base, &incoming).expect("unwrap");
620
621 let report = &result.report;
622 assert_eq!(report.base_domains.len(), 0);
624 assert_eq!(report.incoming_domains.len(), 1);
626 assert_eq!(report.merged_count(), 4);
628 assert_eq!(report.conflict_count(), 1); }
630
631 #[test]
632 fn test_predicate_conflict_compatible() {
633 let mut base = SymbolTable::new();
634 base.add_domain(DomainInfo::new("Person", 100))
635 .expect("unwrap");
636 base.add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
637 .expect("unwrap");
638
639 let mut incoming = SymbolTable::new();
640 incoming
641 .add_domain(DomainInfo::new("Person", 100))
642 .expect("unwrap");
643 incoming
644 .add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
645 .expect("unwrap");
646
647 let merger = SchemaMerger::new(MergeStrategy::Union);
648 let result = merger.merge(&base, &incoming).expect("unwrap");
649
650 assert_eq!(result.merged.predicates.len(), 1);
651 assert_eq!(result.report.conflicting_predicates.len(), 0);
652 }
653
654 #[test]
655 fn test_variable_conflict() {
656 let mut base = SymbolTable::new();
657 base.add_domain(DomainInfo::new("Person", 100))
658 .expect("unwrap");
659 base.add_domain(DomainInfo::new("Agent", 50))
660 .expect("unwrap");
661 base.bind_variable("x", "Person").expect("unwrap");
662
663 let mut incoming = SymbolTable::new();
664 incoming
665 .add_domain(DomainInfo::new("Person", 100))
666 .expect("unwrap");
667 incoming
668 .add_domain(DomainInfo::new("Agent", 50))
669 .expect("unwrap");
670 incoming.bind_variable("x", "Agent").expect("unwrap");
671
672 let merger = SchemaMerger::new(MergeStrategy::KeepFirst);
673 let result = merger.merge(&base, &incoming).expect("unwrap");
674
675 assert_eq!(result.merged.variables.get("x").expect("unwrap"), "Person");
676 assert_eq!(result.report.conflicting_variables.len(), 1);
677 }
678}