1use anyhow::Result;
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet};
21use std::path::PathBuf;
22use std::sync::{Arc, RwLock};
23
24use crate::{DomainInfo, PredicateInfo, SymbolTable};
25
26#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
28pub enum LoadStrategy {
29 Eager,
31 #[default]
33 OnDemand,
34 Predictive {
36 access_threshold: usize,
38 },
39 Batched {
41 batch_size: usize,
43 },
44}
45
46pub trait SchemaLoader: Send + Sync {
50 fn load_domain(&self, name: &str) -> Result<DomainInfo>;
52
53 fn load_predicate(&self, name: &str) -> Result<PredicateInfo>;
55
56 fn has_domain(&self, name: &str) -> bool;
58
59 fn has_predicate(&self, name: &str) -> bool;
61
62 fn list_domains(&self) -> Result<Vec<String>>;
64
65 fn list_predicates(&self) -> Result<Vec<String>>;
67
68 fn load_domains_batch(&self, names: &[String]) -> Result<Vec<DomainInfo>> {
70 names.iter().map(|n| self.load_domain(n)).collect()
71 }
72
73 fn load_predicates_batch(&self, names: &[String]) -> Result<Vec<PredicateInfo>> {
75 names.iter().map(|n| self.load_predicate(n)).collect()
76 }
77}
78
79#[derive(Clone, Debug)]
92pub struct FileSchemaLoader {
93 base_dir: PathBuf,
95}
96
97impl FileSchemaLoader {
98 pub fn new(base_dir: impl Into<PathBuf>) -> Self {
108 Self {
109 base_dir: base_dir.into(),
110 }
111 }
112
113 fn domain_path(&self, name: &str) -> PathBuf {
114 self.base_dir.join("domains").join(format!("{}.json", name))
115 }
116
117 fn predicate_path(&self, name: &str) -> PathBuf {
118 self.base_dir
119 .join("predicates")
120 .join(format!("{}.json", name))
121 }
122}
123
124impl SchemaLoader for FileSchemaLoader {
125 fn load_domain(&self, name: &str) -> Result<DomainInfo> {
126 let path = self.domain_path(name);
127 let content = std::fs::read_to_string(path)?;
128 let domain: DomainInfo = serde_json::from_str(&content)?;
129 Ok(domain)
130 }
131
132 fn load_predicate(&self, name: &str) -> Result<PredicateInfo> {
133 let path = self.predicate_path(name);
134 let content = std::fs::read_to_string(path)?;
135 let predicate: PredicateInfo = serde_json::from_str(&content)?;
136 Ok(predicate)
137 }
138
139 fn has_domain(&self, name: &str) -> bool {
140 self.domain_path(name).exists()
141 }
142
143 fn has_predicate(&self, name: &str) -> bool {
144 self.predicate_path(name).exists()
145 }
146
147 fn list_domains(&self) -> Result<Vec<String>> {
148 let domains_dir = self.base_dir.join("domains");
149 if !domains_dir.exists() {
150 return Ok(Vec::new());
151 }
152
153 let mut names = Vec::new();
154 for entry in std::fs::read_dir(domains_dir)? {
155 let entry = entry?;
156 if let Some(name) = entry.path().file_stem() {
157 names.push(name.to_string_lossy().to_string());
158 }
159 }
160 Ok(names)
161 }
162
163 fn list_predicates(&self) -> Result<Vec<String>> {
164 let predicates_dir = self.base_dir.join("predicates");
165 if !predicates_dir.exists() {
166 return Ok(Vec::new());
167 }
168
169 let mut names = Vec::new();
170 for entry in std::fs::read_dir(predicates_dir)? {
171 let entry = entry?;
172 if let Some(name) = entry.path().file_stem() {
173 names.push(name.to_string_lossy().to_string());
174 }
175 }
176 Ok(names)
177 }
178}
179
180#[derive(Clone, Debug, Default)]
182pub struct LazyLoadStats {
183 pub domain_loads: usize,
185 pub predicate_loads: usize,
187 pub cache_hits: usize,
189 pub cache_misses: usize,
191 pub batch_loads: usize,
193}
194
195impl LazyLoadStats {
196 pub fn hit_rate(&self) -> f64 {
198 let total = self.cache_hits + self.cache_misses;
199 if total == 0 {
200 0.0
201 } else {
202 self.cache_hits as f64 / total as f64
203 }
204 }
205}
206
207pub struct LazySymbolTable {
212 loaded: Arc<RwLock<SymbolTable>>,
214 loader: Arc<dyn SchemaLoader>,
216 strategy: LoadStrategy,
218 stats: Arc<RwLock<LazyLoadStats>>,
220 loaded_domains: Arc<RwLock<HashSet<String>>>,
222 loaded_predicates: Arc<RwLock<HashSet<String>>>,
224 access_counts: Arc<RwLock<HashMap<String, usize>>>,
226}
227
228impl LazySymbolTable {
229 pub fn new(loader: Arc<dyn SchemaLoader>) -> Self {
241 Self {
242 loaded: Arc::new(RwLock::new(SymbolTable::new())),
243 loader,
244 strategy: LoadStrategy::default(),
245 stats: Arc::new(RwLock::new(LazyLoadStats::default())),
246 loaded_domains: Arc::new(RwLock::new(HashSet::new())),
247 loaded_predicates: Arc::new(RwLock::new(HashSet::new())),
248 access_counts: Arc::new(RwLock::new(HashMap::new())),
249 }
250 }
251
252 pub fn with_strategy(loader: Arc<dyn SchemaLoader>, strategy: LoadStrategy) -> Self {
254 let mut table = Self::new(loader);
255 table.strategy = strategy;
256 table
257 }
258
259 pub fn get_domain(&self, name: &str) -> Result<Option<DomainInfo>> {
261 {
263 let loaded_set = self.loaded_domains.read().unwrap();
264 if loaded_set.contains(name) {
265 let table = self.loaded.read().unwrap();
266 let mut stats = self.stats.write().unwrap();
267 stats.cache_hits += 1;
268 return Ok(table.get_domain(name).cloned());
269 }
270 }
271
272 if !self.loader.has_domain(name) {
274 let mut stats = self.stats.write().unwrap();
275 stats.cache_misses += 1;
276 return Ok(None);
277 }
278
279 self.load_domain_internal(name)?;
281
282 let table = self.loaded.read().unwrap();
283 Ok(table.get_domain(name).cloned())
284 }
285
286 pub fn get_predicate(&self, name: &str) -> Result<Option<PredicateInfo>> {
288 {
290 let loaded_set = self.loaded_predicates.read().unwrap();
291 if loaded_set.contains(name) {
292 let table = self.loaded.read().unwrap();
293 let mut stats = self.stats.write().unwrap();
294 stats.cache_hits += 1;
295 return Ok(table.get_predicate(name).cloned());
296 }
297 }
298
299 if !self.loader.has_predicate(name) {
301 let mut stats = self.stats.write().unwrap();
302 stats.cache_misses += 1;
303 return Ok(None);
304 }
305
306 self.load_predicate_internal(name)?;
308
309 let table = self.loaded.read().unwrap();
310 Ok(table.get_predicate(name).cloned())
311 }
312
313 pub fn list_domains(&self) -> Result<Vec<String>> {
315 self.loader.list_domains()
316 }
317
318 pub fn list_predicates(&self) -> Result<Vec<String>> {
320 self.loader.list_predicates()
321 }
322
323 pub fn preload_domains(&self, names: &[String]) -> Result<()> {
325 let domains = self.loader.load_domains_batch(names)?;
326 let mut table = self.loaded.write().unwrap();
327 let mut loaded_set = self.loaded_domains.write().unwrap();
328 let mut stats = self.stats.write().unwrap();
329
330 for domain in domains {
331 let name = domain.name.clone();
332 table.add_domain(domain).map_err(|e| anyhow::anyhow!(e))?;
333 loaded_set.insert(name);
334 stats.domain_loads += 1;
335 }
336 stats.batch_loads += 1;
337
338 Ok(())
339 }
340
341 pub fn preload_predicates(&self, names: &[String]) -> Result<()> {
343 let predicates = self.loader.load_predicates_batch(names)?;
344 let mut table = self.loaded.write().unwrap();
345 let mut loaded_set = self.loaded_predicates.write().unwrap();
346 let mut stats = self.stats.write().unwrap();
347
348 for predicate in predicates {
349 let name = predicate.name.clone();
350 table
351 .add_predicate(predicate)
352 .map_err(|e| anyhow::anyhow!(e))?;
353 loaded_set.insert(name);
354 stats.predicate_loads += 1;
355 }
356 stats.batch_loads += 1;
357
358 Ok(())
359 }
360
361 pub fn stats(&self) -> LazyLoadStats {
363 self.stats.read().unwrap().clone()
364 }
365
366 pub fn clear_cache(&self) {
368 let mut table = self.loaded.write().unwrap();
369 *table = SymbolTable::new();
370 self.loaded_domains.write().unwrap().clear();
371 self.loaded_predicates.write().unwrap().clear();
372 self.access_counts.write().unwrap().clear();
373 }
374
375 pub fn loaded_domain_count(&self) -> usize {
377 self.loaded_domains.read().unwrap().len()
378 }
379
380 pub fn loaded_predicate_count(&self) -> usize {
382 self.loaded_predicates.read().unwrap().len()
383 }
384
385 pub fn as_symbol_table(&self) -> Arc<RwLock<SymbolTable>> {
389 Arc::clone(&self.loaded)
390 }
391
392 fn load_domain_internal(&self, name: &str) -> Result<()> {
393 let domain = self.loader.load_domain(name)?;
394 let mut table = self.loaded.write().unwrap();
395 let mut loaded_set = self.loaded_domains.write().unwrap();
396 let mut stats = self.stats.write().unwrap();
397
398 table.add_domain(domain).map_err(|e| anyhow::anyhow!(e))?;
399 loaded_set.insert(name.to_string());
400 stats.domain_loads += 1;
401 stats.cache_misses += 1;
402
403 {
405 let mut counts = self.access_counts.write().unwrap();
406 *counts.entry(name.to_string()).or_insert(0) += 1;
407 }
408
409 Ok(())
410 }
411
412 fn load_predicate_internal(&self, name: &str) -> Result<()> {
413 let predicate = self.loader.load_predicate(name)?;
414 let mut table = self.loaded.write().unwrap();
415 let mut loaded_set = self.loaded_predicates.write().unwrap();
416 let mut stats = self.stats.write().unwrap();
417
418 table
419 .add_predicate(predicate)
420 .map_err(|e| anyhow::anyhow!(e))?;
421 loaded_set.insert(name.to_string());
422 stats.predicate_loads += 1;
423 stats.cache_misses += 1;
424
425 {
427 let mut counts = self.access_counts.write().unwrap();
428 *counts.entry(name.to_string()).or_insert(0) += 1;
429 }
430
431 Ok(())
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use std::collections::HashMap;
439
440 struct MockLoader {
442 domains: HashMap<String, DomainInfo>,
443 predicates: HashMap<String, PredicateInfo>,
444 }
445
446 impl MockLoader {
447 fn new() -> Self {
448 let mut domains = HashMap::new();
449 domains.insert("Person".to_string(), DomainInfo::new("Person", 100));
450 domains.insert("Location".to_string(), DomainInfo::new("Location", 50));
451
452 let mut predicates = HashMap::new();
453 predicates.insert(
454 "at".to_string(),
455 PredicateInfo::new("at", vec!["Person".to_string(), "Location".to_string()]),
456 );
457
458 Self {
459 domains,
460 predicates,
461 }
462 }
463 }
464
465 impl SchemaLoader for MockLoader {
466 fn load_domain(&self, name: &str) -> Result<DomainInfo> {
467 self.domains
468 .get(name)
469 .cloned()
470 .ok_or_else(|| anyhow::anyhow!("Domain not found: {}", name))
471 }
472
473 fn load_predicate(&self, name: &str) -> Result<PredicateInfo> {
474 self.predicates
475 .get(name)
476 .cloned()
477 .ok_or_else(|| anyhow::anyhow!("Predicate not found: {}", name))
478 }
479
480 fn has_domain(&self, name: &str) -> bool {
481 self.domains.contains_key(name)
482 }
483
484 fn has_predicate(&self, name: &str) -> bool {
485 self.predicates.contains_key(name)
486 }
487
488 fn list_domains(&self) -> Result<Vec<String>> {
489 Ok(self.domains.keys().cloned().collect())
490 }
491
492 fn list_predicates(&self) -> Result<Vec<String>> {
493 Ok(self.predicates.keys().cloned().collect())
494 }
495 }
496
497 #[test]
498 fn test_lazy_load_domain() {
499 let loader = Arc::new(MockLoader::new());
500 let lazy_table = LazySymbolTable::new(loader);
501
502 let domain = lazy_table.get_domain("Person").unwrap();
503 assert!(domain.is_some());
504 assert_eq!(domain.unwrap().name, "Person");
505 }
506
507 #[test]
508 fn test_lazy_load_predicate() {
509 let loader = Arc::new(MockLoader::new());
510 let lazy_table = LazySymbolTable::new(loader);
511
512 lazy_table.get_domain("Person").unwrap();
514 lazy_table.get_domain("Location").unwrap();
515
516 let predicate = lazy_table.get_predicate("at").unwrap();
517 assert!(predicate.is_some());
518 assert_eq!(predicate.unwrap().name, "at");
519 }
520
521 #[test]
522 fn test_cache_hits() {
523 let loader = Arc::new(MockLoader::new());
524 let lazy_table = LazySymbolTable::new(loader);
525
526 lazy_table.get_domain("Person").unwrap();
528
529 lazy_table.get_domain("Person").unwrap();
531
532 let stats = lazy_table.stats();
533 assert_eq!(stats.cache_hits, 1);
534 assert_eq!(stats.cache_misses, 1);
535 }
536
537 #[test]
538 fn test_list_domains() {
539 let loader = Arc::new(MockLoader::new());
540 let lazy_table = LazySymbolTable::new(loader);
541
542 let domains = lazy_table.list_domains().unwrap();
543 assert_eq!(domains.len(), 2);
544 assert!(domains.contains(&"Person".to_string()));
545 assert!(domains.contains(&"Location".to_string()));
546 }
547
548 #[test]
549 fn test_preload_domains() {
550 let loader = Arc::new(MockLoader::new());
551 let lazy_table = LazySymbolTable::new(loader);
552
553 let names = vec!["Person".to_string(), "Location".to_string()];
554 lazy_table.preload_domains(&names).unwrap();
555
556 assert_eq!(lazy_table.loaded_domain_count(), 2);
557
558 let stats = lazy_table.stats();
559 assert_eq!(stats.batch_loads, 1);
560 }
561
562 #[test]
563 fn test_clear_cache() {
564 let loader = Arc::new(MockLoader::new());
565 let lazy_table = LazySymbolTable::new(loader);
566
567 lazy_table.get_domain("Person").unwrap();
568 assert_eq!(lazy_table.loaded_domain_count(), 1);
569
570 lazy_table.clear_cache();
571 assert_eq!(lazy_table.loaded_domain_count(), 0);
572 }
573
574 #[test]
575 fn test_load_strategy() {
576 let loader = Arc::new(MockLoader::new());
577 let strategy = LoadStrategy::Predictive {
578 access_threshold: 5,
579 };
580 let lazy_table = LazySymbolTable::with_strategy(loader, strategy);
581
582 lazy_table.get_domain("Person").unwrap();
583 assert_eq!(lazy_table.loaded_domain_count(), 1);
584 }
585
586 #[test]
587 fn test_hit_rate() {
588 let mut stats = LazyLoadStats::default();
589 assert_eq!(stats.hit_rate(), 0.0);
590
591 stats.cache_hits = 8;
592 stats.cache_misses = 2;
593 assert!((stats.hit_rate() - 0.8).abs() < 0.01);
594 }
595}