1use crate::{DomainInfo, PredicateInfo, SymbolTable};
7use anyhow::{bail, Result};
8use std::collections::HashSet;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum MergeStrategy {
13 KeepFirst,
15 KeepSecond,
17 FailOnConflict,
19 Union,
21 Intersection,
23}
24
25#[derive(Debug, Clone)]
27pub struct MergeResult {
28 pub merged: SymbolTable,
30 pub report: MergeReport,
32}
33
34#[derive(Debug, Clone)]
36pub struct MergeReport {
37 pub base_domains: Vec<String>,
39 pub incoming_domains: Vec<String>,
41 pub conflicting_domains: Vec<DomainConflict>,
43 pub base_predicates: Vec<String>,
45 pub incoming_predicates: Vec<String>,
47 pub conflicting_predicates: Vec<PredicateConflict>,
49 pub merged_variables: Vec<String>,
51 pub conflicting_variables: Vec<VariableConflict>,
53 pub strategy: MergeStrategy,
55}
56
57impl MergeReport {
58 pub fn new(strategy: MergeStrategy) -> Self {
60 Self {
61 base_domains: Vec::new(),
62 incoming_domains: Vec::new(),
63 conflicting_domains: Vec::new(),
64 base_predicates: Vec::new(),
65 incoming_predicates: Vec::new(),
66 conflicting_predicates: Vec::new(),
67 merged_variables: Vec::new(),
68 conflicting_variables: Vec::new(),
69 strategy,
70 }
71 }
72
73 pub fn has_conflicts(&self) -> bool {
75 !self.conflicting_domains.is_empty()
76 || !self.conflicting_predicates.is_empty()
77 || !self.conflicting_variables.is_empty()
78 }
79
80 pub fn conflict_count(&self) -> usize {
82 self.conflicting_domains.len()
83 + self.conflicting_predicates.len()
84 + self.conflicting_variables.len()
85 }
86
87 pub fn merged_count(&self) -> usize {
89 self.base_domains.len()
90 + self.incoming_domains.len()
91 + self.base_predicates.len()
92 + self.incoming_predicates.len()
93 + self.merged_variables.len()
94 }
95}
96
97#[derive(Debug, Clone)]
99pub struct DomainConflict {
100 pub name: String,
102 pub base: DomainInfo,
104 pub incoming: DomainInfo,
106 pub resolution: MergeConflictResolution,
108}
109
110#[derive(Debug, Clone)]
112pub struct PredicateConflict {
113 pub name: String,
115 pub base: PredicateInfo,
117 pub incoming: PredicateInfo,
119 pub resolution: MergeConflictResolution,
121}
122
123#[derive(Debug, Clone)]
125pub struct VariableConflict {
126 pub name: String,
128 pub base_domain: String,
130 pub incoming_domain: String,
132 pub resolution: MergeConflictResolution,
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum MergeConflictResolution {
139 KeptBase,
141 KeptIncoming,
143 Failed,
145 Merged,
147}
148
149pub struct SchemaMerger {
151 strategy: MergeStrategy,
152}
153
154impl SchemaMerger {
155 pub fn new(strategy: MergeStrategy) -> Self {
157 Self { strategy }
158 }
159
160 pub fn merge(&self, base: &SymbolTable, incoming: &SymbolTable) -> Result<MergeResult> {
179 let mut merged = SymbolTable::new();
180 let mut report = MergeReport::new(self.strategy);
181
182 self.merge_domains(base, incoming, &mut merged, &mut report)?;
184
185 self.merge_predicates(base, incoming, &mut merged, &mut report)?;
187
188 self.merge_variables(base, incoming, &mut merged, &mut report)?;
190
191 Ok(MergeResult { merged, report })
192 }
193
194 fn merge_domains(
195 &self,
196 base: &SymbolTable,
197 incoming: &SymbolTable,
198 merged: &mut SymbolTable,
199 report: &mut MergeReport,
200 ) -> Result<()> {
201 let base_keys: HashSet<&String> = base.domains.keys().collect();
202 let incoming_keys: HashSet<&String> = incoming.domains.keys().collect();
203
204 for key in base_keys.difference(&incoming_keys) {
206 let domain = base.domains.get(*key).unwrap();
207 merged.add_domain(domain.clone())?;
208 report.base_domains.push(key.to_string());
209 }
210
211 for key in incoming_keys.difference(&base_keys) {
213 let domain = incoming.domains.get(*key).unwrap();
214 merged.add_domain(domain.clone())?;
215 report.incoming_domains.push(key.to_string());
216 }
217
218 for key in base_keys.intersection(&incoming_keys) {
220 let base_domain = base.domains.get(*key).unwrap();
221 let incoming_domain = incoming.domains.get(*key).unwrap();
222
223 let (domain, resolution) =
224 self.resolve_domain_conflict(base_domain, incoming_domain)?;
225
226 merged.add_domain(domain)?;
227
228 if resolution != MergeConflictResolution::Merged {
229 report.conflicting_domains.push(DomainConflict {
230 name: key.to_string(),
231 base: base_domain.clone(),
232 incoming: incoming_domain.clone(),
233 resolution,
234 });
235 }
236 }
237
238 Ok(())
239 }
240
241 fn merge_predicates(
242 &self,
243 base: &SymbolTable,
244 incoming: &SymbolTable,
245 merged: &mut SymbolTable,
246 report: &mut MergeReport,
247 ) -> Result<()> {
248 let base_keys: HashSet<&String> = base.predicates.keys().collect();
249 let incoming_keys: HashSet<&String> = incoming.predicates.keys().collect();
250
251 for key in base_keys.difference(&incoming_keys) {
253 let predicate = base.predicates.get(*key).unwrap();
254 merged.add_predicate(predicate.clone())?;
255 report.base_predicates.push(key.to_string());
256 }
257
258 for key in incoming_keys.difference(&base_keys) {
260 let predicate = incoming.predicates.get(*key).unwrap();
261 merged.add_predicate(predicate.clone())?;
262 report.incoming_predicates.push(key.to_string());
263 }
264
265 for key in base_keys.intersection(&incoming_keys) {
267 let base_pred = base.predicates.get(*key).unwrap();
268 let incoming_pred = incoming.predicates.get(*key).unwrap();
269
270 let (predicate, resolution) =
271 self.resolve_predicate_conflict(base_pred, incoming_pred)?;
272
273 merged.add_predicate(predicate)?;
274
275 if resolution != MergeConflictResolution::Merged {
276 report.conflicting_predicates.push(PredicateConflict {
277 name: key.to_string(),
278 base: base_pred.clone(),
279 incoming: incoming_pred.clone(),
280 resolution,
281 });
282 }
283 }
284
285 Ok(())
286 }
287
288 fn merge_variables(
289 &self,
290 base: &SymbolTable,
291 incoming: &SymbolTable,
292 merged: &mut SymbolTable,
293 report: &mut MergeReport,
294 ) -> Result<()> {
295 let base_keys: HashSet<&String> = base.variables.keys().collect();
296 let incoming_keys: HashSet<&String> = incoming.variables.keys().collect();
297
298 for key in base_keys.difference(&incoming_keys) {
300 let domain = base.variables.get(*key).unwrap();
301 merged.bind_variable(key.to_string(), domain.clone())?;
302 report.merged_variables.push(key.to_string());
303 }
304
305 for key in incoming_keys.difference(&base_keys) {
307 let domain = incoming.variables.get(*key).unwrap();
308 merged.bind_variable(key.to_string(), domain.clone())?;
309 report.merged_variables.push(key.to_string());
310 }
311
312 for key in base_keys.intersection(&incoming_keys) {
314 let base_domain = base.variables.get(*key).unwrap();
315 let incoming_domain = incoming.variables.get(*key).unwrap();
316
317 let (domain, resolution) =
318 self.resolve_variable_conflict(base_domain, incoming_domain)?;
319
320 merged.bind_variable(key.to_string(), domain)?;
321
322 if resolution != MergeConflictResolution::Merged {
323 report.conflicting_variables.push(VariableConflict {
324 name: key.to_string(),
325 base_domain: base_domain.clone(),
326 incoming_domain: incoming_domain.clone(),
327 resolution,
328 });
329 }
330 }
331
332 Ok(())
333 }
334
335 fn resolve_domain_conflict(
336 &self,
337 base: &DomainInfo,
338 incoming: &DomainInfo,
339 ) -> Result<(DomainInfo, MergeConflictResolution)> {
340 match self.strategy {
341 MergeStrategy::KeepFirst => Ok((base.clone(), MergeConflictResolution::KeptBase)),
342 MergeStrategy::KeepSecond => {
343 Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
344 }
345 MergeStrategy::FailOnConflict => {
346 bail!(
347 "Domain conflict for '{}': cardinality {} vs {}",
348 base.name,
349 base.cardinality,
350 incoming.cardinality
351 )
352 }
353 MergeStrategy::Union => {
354 if base.cardinality >= incoming.cardinality {
356 Ok((base.clone(), MergeConflictResolution::KeptBase))
357 } else {
358 Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
359 }
360 }
361 MergeStrategy::Intersection => {
362 if base.cardinality <= incoming.cardinality {
364 Ok((base.clone(), MergeConflictResolution::KeptBase))
365 } else {
366 Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
367 }
368 }
369 }
370 }
371
372 fn resolve_predicate_conflict(
373 &self,
374 base: &PredicateInfo,
375 incoming: &PredicateInfo,
376 ) -> Result<(PredicateInfo, MergeConflictResolution)> {
377 let compatible = base.arg_domains == incoming.arg_domains;
379
380 match self.strategy {
381 MergeStrategy::KeepFirst => Ok((base.clone(), MergeConflictResolution::KeptBase)),
382 MergeStrategy::KeepSecond => {
383 Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
384 }
385 MergeStrategy::FailOnConflict => {
386 bail!(
387 "Predicate conflict for '{}': {:?} vs {:?}",
388 base.name,
389 base.arg_domains,
390 incoming.arg_domains
391 )
392 }
393 MergeStrategy::Union => {
394 if compatible {
395 Ok((base.clone(), MergeConflictResolution::Merged))
396 } else {
397 bail!(
398 "Incompatible predicate signatures for '{}': {:?} vs {:?}",
399 base.name,
400 base.arg_domains,
401 incoming.arg_domains
402 )
403 }
404 }
405 MergeStrategy::Intersection => {
406 if compatible {
407 Ok((base.clone(), MergeConflictResolution::Merged))
408 } else {
409 bail!(
410 "Incompatible predicate signatures for '{}': {:?} vs {:?}",
411 base.name,
412 base.arg_domains,
413 incoming.arg_domains
414 )
415 }
416 }
417 }
418 }
419
420 fn resolve_variable_conflict(
421 &self,
422 base_domain: &str,
423 incoming_domain: &str,
424 ) -> Result<(String, MergeConflictResolution)> {
425 match self.strategy {
426 MergeStrategy::KeepFirst => {
427 Ok((base_domain.to_string(), MergeConflictResolution::KeptBase))
428 }
429 MergeStrategy::KeepSecond => Ok((
430 incoming_domain.to_string(),
431 MergeConflictResolution::KeptIncoming,
432 )),
433 MergeStrategy::FailOnConflict => {
434 bail!(
435 "Variable domain conflict: '{}' vs '{}'",
436 base_domain,
437 incoming_domain
438 )
439 }
440 MergeStrategy::Union | MergeStrategy::Intersection => {
441 if base_domain == incoming_domain {
442 Ok((base_domain.to_string(), MergeConflictResolution::Merged))
443 } else {
444 bail!(
445 "Incompatible variable domains: '{}' vs '{}'",
446 base_domain,
447 incoming_domain
448 )
449 }
450 }
451 }
452 }
453}
454
455impl Default for SchemaMerger {
456 fn default() -> Self {
457 Self::new(MergeStrategy::Union)
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 fn create_base_table() -> SymbolTable {
466 let mut table = SymbolTable::new();
467 table.add_domain(DomainInfo::new("Person", 100)).unwrap();
468 table
469 .add_predicate(PredicateInfo::new(
470 "knows",
471 vec!["Person".to_string(), "Person".to_string()],
472 ))
473 .unwrap();
474 table.bind_variable("x", "Person").unwrap();
475 table
476 }
477
478 fn create_incoming_table() -> SymbolTable {
479 let mut table = SymbolTable::new();
480 table.add_domain(DomainInfo::new("Person", 150)).unwrap(); table
482 .add_domain(DomainInfo::new("Organization", 50))
483 .unwrap();
484 table
485 .add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
486 .unwrap();
487 table
488 }
489
490 #[test]
491 fn test_merge_union_no_conflicts() {
492 let base = create_base_table();
493 let incoming = create_incoming_table();
494
495 let merger = SchemaMerger::new(MergeStrategy::Union);
496 let result = merger.merge(&base, &incoming).unwrap();
497
498 assert_eq!(result.merged.domains.len(), 2); assert_eq!(result.merged.predicates.len(), 2); assert!(result.report.has_conflicts()); }
503
504 #[test]
505 fn test_merge_with_domain_conflict() {
506 let mut base = SymbolTable::new();
507 base.add_domain(DomainInfo::new("Person", 100)).unwrap();
508
509 let mut incoming = SymbolTable::new();
510 incoming.add_domain(DomainInfo::new("Person", 200)).unwrap();
511
512 let merger = SchemaMerger::new(MergeStrategy::KeepFirst);
513 let result = merger.merge(&base, &incoming).unwrap();
514
515 assert_eq!(result.merged.domains.len(), 1);
516 assert_eq!(
517 result.merged.domains.get("Person").unwrap().cardinality,
518 100
519 );
520 assert!(result.report.has_conflicts());
521 }
522
523 #[test]
524 fn test_merge_keep_second() {
525 let mut base = SymbolTable::new();
526 base.add_domain(DomainInfo::new("Person", 100)).unwrap();
527
528 let mut incoming = SymbolTable::new();
529 incoming.add_domain(DomainInfo::new("Person", 200)).unwrap();
530
531 let merger = SchemaMerger::new(MergeStrategy::KeepSecond);
532 let result = merger.merge(&base, &incoming).unwrap();
533
534 assert_eq!(
535 result.merged.domains.get("Person").unwrap().cardinality,
536 200
537 );
538 }
539
540 #[test]
541 fn test_merge_fail_on_conflict() {
542 let mut base = SymbolTable::new();
543 base.add_domain(DomainInfo::new("Person", 100)).unwrap();
544
545 let mut incoming = SymbolTable::new();
546 incoming.add_domain(DomainInfo::new("Person", 200)).unwrap();
547
548 let merger = SchemaMerger::new(MergeStrategy::FailOnConflict);
549 let result = merger.merge(&base, &incoming);
550
551 assert!(result.is_err());
552 }
553
554 #[test]
555 fn test_merge_report() {
556 let base = create_base_table();
557 let incoming = create_incoming_table();
558
559 let merger = SchemaMerger::new(MergeStrategy::Union);
560 let result = merger.merge(&base, &incoming).unwrap();
561
562 let report = &result.report;
563 assert_eq!(report.base_domains.len(), 0);
565 assert_eq!(report.incoming_domains.len(), 1);
567 assert_eq!(report.merged_count(), 4);
569 assert_eq!(report.conflict_count(), 1); }
571
572 #[test]
573 fn test_predicate_conflict_compatible() {
574 let mut base = SymbolTable::new();
575 base.add_domain(DomainInfo::new("Person", 100)).unwrap();
576 base.add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
577 .unwrap();
578
579 let mut incoming = SymbolTable::new();
580 incoming.add_domain(DomainInfo::new("Person", 100)).unwrap();
581 incoming
582 .add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
583 .unwrap();
584
585 let merger = SchemaMerger::new(MergeStrategy::Union);
586 let result = merger.merge(&base, &incoming).unwrap();
587
588 assert_eq!(result.merged.predicates.len(), 1);
589 assert_eq!(result.report.conflicting_predicates.len(), 0);
590 }
591
592 #[test]
593 fn test_variable_conflict() {
594 let mut base = SymbolTable::new();
595 base.add_domain(DomainInfo::new("Person", 100)).unwrap();
596 base.add_domain(DomainInfo::new("Agent", 50)).unwrap();
597 base.bind_variable("x", "Person").unwrap();
598
599 let mut incoming = SymbolTable::new();
600 incoming.add_domain(DomainInfo::new("Person", 100)).unwrap();
601 incoming.add_domain(DomainInfo::new("Agent", 50)).unwrap();
602 incoming.bind_variable("x", "Agent").unwrap();
603
604 let merger = SchemaMerger::new(MergeStrategy::KeepFirst);
605 let result = merger.merge(&base, &incoming).unwrap();
606
607 assert_eq!(result.merged.variables.get("x").unwrap(), "Person");
608 assert_eq!(result.report.conflicting_variables.len(), 1);
609 }
610}