1use anyhow::Result;
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet};
21use std::path::PathBuf;
22use std::sync::{Arc, RwLock};
23
24use crate::{DomainInfo, PredicateInfo, SymbolTable};
25
26#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
28pub enum LoadStrategy {
29 Eager,
31 #[default]
33 OnDemand,
34 Predictive {
36 access_threshold: usize,
38 },
39 Batched {
41 batch_size: usize,
43 },
44}
45
46pub trait SchemaLoader: Send + Sync {
50 fn load_domain(&self, name: &str) -> Result<DomainInfo>;
52
53 fn load_predicate(&self, name: &str) -> Result<PredicateInfo>;
55
56 fn has_domain(&self, name: &str) -> bool;
58
59 fn has_predicate(&self, name: &str) -> bool;
61
62 fn list_domains(&self) -> Result<Vec<String>>;
64
65 fn list_predicates(&self) -> Result<Vec<String>>;
67
68 fn load_domains_batch(&self, names: &[String]) -> Result<Vec<DomainInfo>> {
70 names.iter().map(|n| self.load_domain(n)).collect()
71 }
72
73 fn load_predicates_batch(&self, names: &[String]) -> Result<Vec<PredicateInfo>> {
75 names.iter().map(|n| self.load_predicate(n)).collect()
76 }
77}
78
79#[derive(Clone, Debug)]
92pub struct FileSchemaLoader {
93 base_dir: PathBuf,
95}
96
97impl FileSchemaLoader {
98 pub fn new(base_dir: impl Into<PathBuf>) -> Self {
108 Self {
109 base_dir: base_dir.into(),
110 }
111 }
112
113 fn domain_path(&self, name: &str) -> PathBuf {
114 self.base_dir.join("domains").join(format!("{}.json", name))
115 }
116
117 fn predicate_path(&self, name: &str) -> PathBuf {
118 self.base_dir
119 .join("predicates")
120 .join(format!("{}.json", name))
121 }
122}
123
124impl SchemaLoader for FileSchemaLoader {
125 fn load_domain(&self, name: &str) -> Result<DomainInfo> {
126 let path = self.domain_path(name);
127 let content = std::fs::read_to_string(path)?;
128 let domain: DomainInfo = serde_json::from_str(&content)?;
129 Ok(domain)
130 }
131
132 fn load_predicate(&self, name: &str) -> Result<PredicateInfo> {
133 let path = self.predicate_path(name);
134 let content = std::fs::read_to_string(path)?;
135 let predicate: PredicateInfo = serde_json::from_str(&content)?;
136 Ok(predicate)
137 }
138
139 fn has_domain(&self, name: &str) -> bool {
140 self.domain_path(name).exists()
141 }
142
143 fn has_predicate(&self, name: &str) -> bool {
144 self.predicate_path(name).exists()
145 }
146
147 fn list_domains(&self) -> Result<Vec<String>> {
148 let domains_dir = self.base_dir.join("domains");
149 if !domains_dir.exists() {
150 return Ok(Vec::new());
151 }
152
153 let mut names = Vec::new();
154 for entry in std::fs::read_dir(domains_dir)? {
155 let entry = entry?;
156 if let Some(name) = entry.path().file_stem() {
157 names.push(name.to_string_lossy().to_string());
158 }
159 }
160 Ok(names)
161 }
162
163 fn list_predicates(&self) -> Result<Vec<String>> {
164 let predicates_dir = self.base_dir.join("predicates");
165 if !predicates_dir.exists() {
166 return Ok(Vec::new());
167 }
168
169 let mut names = Vec::new();
170 for entry in std::fs::read_dir(predicates_dir)? {
171 let entry = entry?;
172 if let Some(name) = entry.path().file_stem() {
173 names.push(name.to_string_lossy().to_string());
174 }
175 }
176 Ok(names)
177 }
178}
179
180#[derive(Clone, Debug, Default)]
182pub struct LazyLoadStats {
183 pub domain_loads: usize,
185 pub predicate_loads: usize,
187 pub cache_hits: usize,
189 pub cache_misses: usize,
191 pub batch_loads: usize,
193}
194
195impl LazyLoadStats {
196 pub fn hit_rate(&self) -> f64 {
198 let total = self.cache_hits + self.cache_misses;
199 if total == 0 {
200 0.0
201 } else {
202 self.cache_hits as f64 / total as f64
203 }
204 }
205}
206
207pub struct LazySymbolTable {
212 loaded: Arc<RwLock<SymbolTable>>,
214 loader: Arc<dyn SchemaLoader>,
216 strategy: LoadStrategy,
218 stats: Arc<RwLock<LazyLoadStats>>,
220 loaded_domains: Arc<RwLock<HashSet<String>>>,
222 loaded_predicates: Arc<RwLock<HashSet<String>>>,
224 access_counts: Arc<RwLock<HashMap<String, usize>>>,
226}
227
228impl LazySymbolTable {
229 pub fn new(loader: Arc<dyn SchemaLoader>) -> Self {
241 Self {
242 loaded: Arc::new(RwLock::new(SymbolTable::new())),
243 loader,
244 strategy: LoadStrategy::default(),
245 stats: Arc::new(RwLock::new(LazyLoadStats::default())),
246 loaded_domains: Arc::new(RwLock::new(HashSet::new())),
247 loaded_predicates: Arc::new(RwLock::new(HashSet::new())),
248 access_counts: Arc::new(RwLock::new(HashMap::new())),
249 }
250 }
251
252 pub fn with_strategy(loader: Arc<dyn SchemaLoader>, strategy: LoadStrategy) -> Self {
254 let mut table = Self::new(loader);
255 table.strategy = strategy;
256 table
257 }
258
259 pub fn get_domain(&self, name: &str) -> Result<Option<DomainInfo>> {
261 {
263 let loaded_set = self
264 .loaded_domains
265 .read()
266 .expect("lock should not be poisoned");
267 if loaded_set.contains(name) {
268 let table = self.loaded.read().expect("lock should not be poisoned");
269 let mut stats = self.stats.write().expect("lock should not be poisoned");
270 stats.cache_hits += 1;
271 return Ok(table.get_domain(name).cloned());
272 }
273 }
274
275 if !self.loader.has_domain(name) {
277 let mut stats = self.stats.write().expect("lock should not be poisoned");
278 stats.cache_misses += 1;
279 return Ok(None);
280 }
281
282 self.load_domain_internal(name)?;
284
285 let table = self.loaded.read().expect("lock should not be poisoned");
286 Ok(table.get_domain(name).cloned())
287 }
288
289 pub fn get_predicate(&self, name: &str) -> Result<Option<PredicateInfo>> {
291 {
293 let loaded_set = self
294 .loaded_predicates
295 .read()
296 .expect("lock should not be poisoned");
297 if loaded_set.contains(name) {
298 let table = self.loaded.read().expect("lock should not be poisoned");
299 let mut stats = self.stats.write().expect("lock should not be poisoned");
300 stats.cache_hits += 1;
301 return Ok(table.get_predicate(name).cloned());
302 }
303 }
304
305 if !self.loader.has_predicate(name) {
307 let mut stats = self.stats.write().expect("lock should not be poisoned");
308 stats.cache_misses += 1;
309 return Ok(None);
310 }
311
312 self.load_predicate_internal(name)?;
314
315 let table = self.loaded.read().expect("lock should not be poisoned");
316 Ok(table.get_predicate(name).cloned())
317 }
318
319 pub fn list_domains(&self) -> Result<Vec<String>> {
321 self.loader.list_domains()
322 }
323
324 pub fn list_predicates(&self) -> Result<Vec<String>> {
326 self.loader.list_predicates()
327 }
328
329 pub fn preload_domains(&self, names: &[String]) -> Result<()> {
331 let domains = self.loader.load_domains_batch(names)?;
332 let mut table = self.loaded.write().expect("lock should not be poisoned");
333 let mut loaded_set = self
334 .loaded_domains
335 .write()
336 .expect("lock should not be poisoned");
337 let mut stats = self.stats.write().expect("lock should not be poisoned");
338
339 for domain in domains {
340 let name = domain.name.clone();
341 table.add_domain(domain).map_err(|e| anyhow::anyhow!(e))?;
342 loaded_set.insert(name);
343 stats.domain_loads += 1;
344 }
345 stats.batch_loads += 1;
346
347 Ok(())
348 }
349
350 pub fn preload_predicates(&self, names: &[String]) -> Result<()> {
352 let predicates = self.loader.load_predicates_batch(names)?;
353 let mut table = self.loaded.write().expect("lock should not be poisoned");
354 let mut loaded_set = self
355 .loaded_predicates
356 .write()
357 .expect("lock should not be poisoned");
358 let mut stats = self.stats.write().expect("lock should not be poisoned");
359
360 for predicate in predicates {
361 let name = predicate.name.clone();
362 table
363 .add_predicate(predicate)
364 .map_err(|e| anyhow::anyhow!(e))?;
365 loaded_set.insert(name);
366 stats.predicate_loads += 1;
367 }
368 stats.batch_loads += 1;
369
370 Ok(())
371 }
372
373 pub fn stats(&self) -> LazyLoadStats {
375 self.stats
376 .read()
377 .expect("lock should not be poisoned")
378 .clone()
379 }
380
381 pub fn clear_cache(&self) {
383 let mut table = self.loaded.write().expect("lock should not be poisoned");
384 *table = SymbolTable::new();
385 self.loaded_domains
386 .write()
387 .expect("lock should not be poisoned")
388 .clear();
389 self.loaded_predicates
390 .write()
391 .expect("lock should not be poisoned")
392 .clear();
393 self.access_counts
394 .write()
395 .expect("lock should not be poisoned")
396 .clear();
397 }
398
399 pub fn loaded_domain_count(&self) -> usize {
401 self.loaded_domains
402 .read()
403 .expect("lock should not be poisoned")
404 .len()
405 }
406
407 pub fn loaded_predicate_count(&self) -> usize {
409 self.loaded_predicates
410 .read()
411 .expect("lock should not be poisoned")
412 .len()
413 }
414
415 pub fn as_symbol_table(&self) -> Arc<RwLock<SymbolTable>> {
419 Arc::clone(&self.loaded)
420 }
421
422 fn load_domain_internal(&self, name: &str) -> Result<()> {
423 let domain = self.loader.load_domain(name)?;
424 let mut table = self.loaded.write().expect("lock should not be poisoned");
425 let mut loaded_set = self
426 .loaded_domains
427 .write()
428 .expect("lock should not be poisoned");
429 let mut stats = self.stats.write().expect("lock should not be poisoned");
430
431 table.add_domain(domain).map_err(|e| anyhow::anyhow!(e))?;
432 loaded_set.insert(name.to_string());
433 stats.domain_loads += 1;
434 stats.cache_misses += 1;
435
436 {
438 let mut counts = self
439 .access_counts
440 .write()
441 .expect("lock should not be poisoned");
442 *counts.entry(name.to_string()).or_insert(0) += 1;
443 }
444
445 Ok(())
446 }
447
448 fn load_predicate_internal(&self, name: &str) -> Result<()> {
449 let predicate = self.loader.load_predicate(name)?;
450 let mut table = self.loaded.write().expect("lock should not be poisoned");
451 let mut loaded_set = self
452 .loaded_predicates
453 .write()
454 .expect("lock should not be poisoned");
455 let mut stats = self.stats.write().expect("lock should not be poisoned");
456
457 table
458 .add_predicate(predicate)
459 .map_err(|e| anyhow::anyhow!(e))?;
460 loaded_set.insert(name.to_string());
461 stats.predicate_loads += 1;
462 stats.cache_misses += 1;
463
464 {
466 let mut counts = self
467 .access_counts
468 .write()
469 .expect("lock should not be poisoned");
470 *counts.entry(name.to_string()).or_insert(0) += 1;
471 }
472
473 Ok(())
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use std::collections::HashMap;
481
482 struct MockLoader {
484 domains: HashMap<String, DomainInfo>,
485 predicates: HashMap<String, PredicateInfo>,
486 }
487
488 impl MockLoader {
489 fn new() -> Self {
490 let mut domains = HashMap::new();
491 domains.insert("Person".to_string(), DomainInfo::new("Person", 100));
492 domains.insert("Location".to_string(), DomainInfo::new("Location", 50));
493
494 let mut predicates = HashMap::new();
495 predicates.insert(
496 "at".to_string(),
497 PredicateInfo::new("at", vec!["Person".to_string(), "Location".to_string()]),
498 );
499
500 Self {
501 domains,
502 predicates,
503 }
504 }
505 }
506
507 impl SchemaLoader for MockLoader {
508 fn load_domain(&self, name: &str) -> Result<DomainInfo> {
509 self.domains
510 .get(name)
511 .cloned()
512 .ok_or_else(|| anyhow::anyhow!("Domain not found: {}", name))
513 }
514
515 fn load_predicate(&self, name: &str) -> Result<PredicateInfo> {
516 self.predicates
517 .get(name)
518 .cloned()
519 .ok_or_else(|| anyhow::anyhow!("Predicate not found: {}", name))
520 }
521
522 fn has_domain(&self, name: &str) -> bool {
523 self.domains.contains_key(name)
524 }
525
526 fn has_predicate(&self, name: &str) -> bool {
527 self.predicates.contains_key(name)
528 }
529
530 fn list_domains(&self) -> Result<Vec<String>> {
531 Ok(self.domains.keys().cloned().collect())
532 }
533
534 fn list_predicates(&self) -> Result<Vec<String>> {
535 Ok(self.predicates.keys().cloned().collect())
536 }
537 }
538
539 #[test]
540 fn test_lazy_load_domain() {
541 let loader = Arc::new(MockLoader::new());
542 let lazy_table = LazySymbolTable::new(loader);
543
544 let domain = lazy_table.get_domain("Person").expect("unwrap");
545 assert!(domain.is_some());
546 assert_eq!(domain.expect("unwrap").name, "Person");
547 }
548
549 #[test]
550 fn test_lazy_load_predicate() {
551 let loader = Arc::new(MockLoader::new());
552 let lazy_table = LazySymbolTable::new(loader);
553
554 lazy_table.get_domain("Person").expect("unwrap");
556 lazy_table.get_domain("Location").expect("unwrap");
557
558 let predicate = lazy_table.get_predicate("at").expect("unwrap");
559 assert!(predicate.is_some());
560 assert_eq!(predicate.expect("unwrap").name, "at");
561 }
562
563 #[test]
564 fn test_cache_hits() {
565 let loader = Arc::new(MockLoader::new());
566 let lazy_table = LazySymbolTable::new(loader);
567
568 lazy_table.get_domain("Person").expect("unwrap");
570
571 lazy_table.get_domain("Person").expect("unwrap");
573
574 let stats = lazy_table.stats();
575 assert_eq!(stats.cache_hits, 1);
576 assert_eq!(stats.cache_misses, 1);
577 }
578
579 #[test]
580 fn test_list_domains() {
581 let loader = Arc::new(MockLoader::new());
582 let lazy_table = LazySymbolTable::new(loader);
583
584 let domains = lazy_table.list_domains().expect("unwrap");
585 assert_eq!(domains.len(), 2);
586 assert!(domains.contains(&"Person".to_string()));
587 assert!(domains.contains(&"Location".to_string()));
588 }
589
590 #[test]
591 fn test_preload_domains() {
592 let loader = Arc::new(MockLoader::new());
593 let lazy_table = LazySymbolTable::new(loader);
594
595 let names = vec!["Person".to_string(), "Location".to_string()];
596 lazy_table.preload_domains(&names).expect("unwrap");
597
598 assert_eq!(lazy_table.loaded_domain_count(), 2);
599
600 let stats = lazy_table.stats();
601 assert_eq!(stats.batch_loads, 1);
602 }
603
604 #[test]
605 fn test_clear_cache() {
606 let loader = Arc::new(MockLoader::new());
607 let lazy_table = LazySymbolTable::new(loader);
608
609 lazy_table.get_domain("Person").expect("unwrap");
610 assert_eq!(lazy_table.loaded_domain_count(), 1);
611
612 lazy_table.clear_cache();
613 assert_eq!(lazy_table.loaded_domain_count(), 0);
614 }
615
616 #[test]
617 fn test_load_strategy() {
618 let loader = Arc::new(MockLoader::new());
619 let strategy = LoadStrategy::Predictive {
620 access_threshold: 5,
621 };
622 let lazy_table = LazySymbolTable::with_strategy(loader, strategy);
623
624 lazy_table.get_domain("Person").expect("unwrap");
625 assert_eq!(lazy_table.loaded_domain_count(), 1);
626 }
627
628 #[test]
629 fn test_hit_rate() {
630 let mut stats = LazyLoadStats::default();
631 assert_eq!(stats.hit_rate(), 0.0);
632
633 stats.cache_hits = 8;
634 stats.cache_misses = 2;
635 assert!((stats.hit_rate() - 0.8).abs() < 0.01);
636 }
637}