1use anyhow::Result;
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet};
21use std::path::PathBuf;
22use std::sync::{Arc, RwLock};
23
24use crate::{DomainInfo, PredicateInfo, SymbolTable};
25
26#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
28pub enum LoadStrategy {
29 Eager,
31 OnDemand,
33 Predictive {
35 access_threshold: usize,
37 },
38 Batched {
40 batch_size: usize,
42 },
43}
44
45impl Default for LoadStrategy {
46 fn default() -> Self {
47 Self::OnDemand
48 }
49}
50
51pub trait SchemaLoader: Send + Sync {
55 fn load_domain(&self, name: &str) -> Result<DomainInfo>;
57
58 fn load_predicate(&self, name: &str) -> Result<PredicateInfo>;
60
61 fn has_domain(&self, name: &str) -> bool;
63
64 fn has_predicate(&self, name: &str) -> bool;
66
67 fn list_domains(&self) -> Result<Vec<String>>;
69
70 fn list_predicates(&self) -> Result<Vec<String>>;
72
73 fn load_domains_batch(&self, names: &[String]) -> Result<Vec<DomainInfo>> {
75 names.iter().map(|n| self.load_domain(n)).collect()
76 }
77
78 fn load_predicates_batch(&self, names: &[String]) -> Result<Vec<PredicateInfo>> {
80 names.iter().map(|n| self.load_predicate(n)).collect()
81 }
82}
83
84#[derive(Clone, Debug)]
97pub struct FileSchemaLoader {
98 base_dir: PathBuf,
100}
101
102impl FileSchemaLoader {
103 pub fn new(base_dir: impl Into<PathBuf>) -> Self {
113 Self {
114 base_dir: base_dir.into(),
115 }
116 }
117
118 fn domain_path(&self, name: &str) -> PathBuf {
119 self.base_dir.join("domains").join(format!("{}.json", name))
120 }
121
122 fn predicate_path(&self, name: &str) -> PathBuf {
123 self.base_dir
124 .join("predicates")
125 .join(format!("{}.json", name))
126 }
127}
128
129impl SchemaLoader for FileSchemaLoader {
130 fn load_domain(&self, name: &str) -> Result<DomainInfo> {
131 let path = self.domain_path(name);
132 let content = std::fs::read_to_string(path)?;
133 let domain: DomainInfo = serde_json::from_str(&content)?;
134 Ok(domain)
135 }
136
137 fn load_predicate(&self, name: &str) -> Result<PredicateInfo> {
138 let path = self.predicate_path(name);
139 let content = std::fs::read_to_string(path)?;
140 let predicate: PredicateInfo = serde_json::from_str(&content)?;
141 Ok(predicate)
142 }
143
144 fn has_domain(&self, name: &str) -> bool {
145 self.domain_path(name).exists()
146 }
147
148 fn has_predicate(&self, name: &str) -> bool {
149 self.predicate_path(name).exists()
150 }
151
152 fn list_domains(&self) -> Result<Vec<String>> {
153 let domains_dir = self.base_dir.join("domains");
154 if !domains_dir.exists() {
155 return Ok(Vec::new());
156 }
157
158 let mut names = Vec::new();
159 for entry in std::fs::read_dir(domains_dir)? {
160 let entry = entry?;
161 if let Some(name) = entry.path().file_stem() {
162 names.push(name.to_string_lossy().to_string());
163 }
164 }
165 Ok(names)
166 }
167
168 fn list_predicates(&self) -> Result<Vec<String>> {
169 let predicates_dir = self.base_dir.join("predicates");
170 if !predicates_dir.exists() {
171 return Ok(Vec::new());
172 }
173
174 let mut names = Vec::new();
175 for entry in std::fs::read_dir(predicates_dir)? {
176 let entry = entry?;
177 if let Some(name) = entry.path().file_stem() {
178 names.push(name.to_string_lossy().to_string());
179 }
180 }
181 Ok(names)
182 }
183}
184
185#[derive(Clone, Debug, Default)]
187pub struct LazyLoadStats {
188 pub domain_loads: usize,
190 pub predicate_loads: usize,
192 pub cache_hits: usize,
194 pub cache_misses: usize,
196 pub batch_loads: usize,
198}
199
200impl LazyLoadStats {
201 pub fn hit_rate(&self) -> f64 {
203 let total = self.cache_hits + self.cache_misses;
204 if total == 0 {
205 0.0
206 } else {
207 self.cache_hits as f64 / total as f64
208 }
209 }
210}
211
212pub struct LazySymbolTable {
217 loaded: Arc<RwLock<SymbolTable>>,
219 loader: Arc<dyn SchemaLoader>,
221 strategy: LoadStrategy,
223 stats: Arc<RwLock<LazyLoadStats>>,
225 loaded_domains: Arc<RwLock<HashSet<String>>>,
227 loaded_predicates: Arc<RwLock<HashSet<String>>>,
229 access_counts: Arc<RwLock<HashMap<String, usize>>>,
231}
232
233impl LazySymbolTable {
234 pub fn new(loader: Arc<dyn SchemaLoader>) -> Self {
246 Self {
247 loaded: Arc::new(RwLock::new(SymbolTable::new())),
248 loader,
249 strategy: LoadStrategy::default(),
250 stats: Arc::new(RwLock::new(LazyLoadStats::default())),
251 loaded_domains: Arc::new(RwLock::new(HashSet::new())),
252 loaded_predicates: Arc::new(RwLock::new(HashSet::new())),
253 access_counts: Arc::new(RwLock::new(HashMap::new())),
254 }
255 }
256
257 pub fn with_strategy(loader: Arc<dyn SchemaLoader>, strategy: LoadStrategy) -> Self {
259 let mut table = Self::new(loader);
260 table.strategy = strategy;
261 table
262 }
263
264 pub fn get_domain(&self, name: &str) -> Result<Option<DomainInfo>> {
266 {
268 let loaded_set = self.loaded_domains.read().unwrap();
269 if loaded_set.contains(name) {
270 let table = self.loaded.read().unwrap();
271 let mut stats = self.stats.write().unwrap();
272 stats.cache_hits += 1;
273 return Ok(table.get_domain(name).cloned());
274 }
275 }
276
277 if !self.loader.has_domain(name) {
279 let mut stats = self.stats.write().unwrap();
280 stats.cache_misses += 1;
281 return Ok(None);
282 }
283
284 self.load_domain_internal(name)?;
286
287 let table = self.loaded.read().unwrap();
288 Ok(table.get_domain(name).cloned())
289 }
290
291 pub fn get_predicate(&self, name: &str) -> Result<Option<PredicateInfo>> {
293 {
295 let loaded_set = self.loaded_predicates.read().unwrap();
296 if loaded_set.contains(name) {
297 let table = self.loaded.read().unwrap();
298 let mut stats = self.stats.write().unwrap();
299 stats.cache_hits += 1;
300 return Ok(table.get_predicate(name).cloned());
301 }
302 }
303
304 if !self.loader.has_predicate(name) {
306 let mut stats = self.stats.write().unwrap();
307 stats.cache_misses += 1;
308 return Ok(None);
309 }
310
311 self.load_predicate_internal(name)?;
313
314 let table = self.loaded.read().unwrap();
315 Ok(table.get_predicate(name).cloned())
316 }
317
318 pub fn list_domains(&self) -> Result<Vec<String>> {
320 self.loader.list_domains()
321 }
322
323 pub fn list_predicates(&self) -> Result<Vec<String>> {
325 self.loader.list_predicates()
326 }
327
328 pub fn preload_domains(&self, names: &[String]) -> Result<()> {
330 let domains = self.loader.load_domains_batch(names)?;
331 let mut table = self.loaded.write().unwrap();
332 let mut loaded_set = self.loaded_domains.write().unwrap();
333 let mut stats = self.stats.write().unwrap();
334
335 for domain in domains {
336 let name = domain.name.clone();
337 table.add_domain(domain).map_err(|e| anyhow::anyhow!(e))?;
338 loaded_set.insert(name);
339 stats.domain_loads += 1;
340 }
341 stats.batch_loads += 1;
342
343 Ok(())
344 }
345
346 pub fn preload_predicates(&self, names: &[String]) -> Result<()> {
348 let predicates = self.loader.load_predicates_batch(names)?;
349 let mut table = self.loaded.write().unwrap();
350 let mut loaded_set = self.loaded_predicates.write().unwrap();
351 let mut stats = self.stats.write().unwrap();
352
353 for predicate in predicates {
354 let name = predicate.name.clone();
355 table
356 .add_predicate(predicate)
357 .map_err(|e| anyhow::anyhow!(e))?;
358 loaded_set.insert(name);
359 stats.predicate_loads += 1;
360 }
361 stats.batch_loads += 1;
362
363 Ok(())
364 }
365
366 pub fn stats(&self) -> LazyLoadStats {
368 self.stats.read().unwrap().clone()
369 }
370
371 pub fn clear_cache(&self) {
373 let mut table = self.loaded.write().unwrap();
374 *table = SymbolTable::new();
375 self.loaded_domains.write().unwrap().clear();
376 self.loaded_predicates.write().unwrap().clear();
377 self.access_counts.write().unwrap().clear();
378 }
379
380 pub fn loaded_domain_count(&self) -> usize {
382 self.loaded_domains.read().unwrap().len()
383 }
384
385 pub fn loaded_predicate_count(&self) -> usize {
387 self.loaded_predicates.read().unwrap().len()
388 }
389
390 pub fn as_symbol_table(&self) -> Arc<RwLock<SymbolTable>> {
394 Arc::clone(&self.loaded)
395 }
396
397 fn load_domain_internal(&self, name: &str) -> Result<()> {
398 let domain = self.loader.load_domain(name)?;
399 let mut table = self.loaded.write().unwrap();
400 let mut loaded_set = self.loaded_domains.write().unwrap();
401 let mut stats = self.stats.write().unwrap();
402
403 table.add_domain(domain).map_err(|e| anyhow::anyhow!(e))?;
404 loaded_set.insert(name.to_string());
405 stats.domain_loads += 1;
406 stats.cache_misses += 1;
407
408 {
410 let mut counts = self.access_counts.write().unwrap();
411 *counts.entry(name.to_string()).or_insert(0) += 1;
412 }
413
414 Ok(())
415 }
416
417 fn load_predicate_internal(&self, name: &str) -> Result<()> {
418 let predicate = self.loader.load_predicate(name)?;
419 let mut table = self.loaded.write().unwrap();
420 let mut loaded_set = self.loaded_predicates.write().unwrap();
421 let mut stats = self.stats.write().unwrap();
422
423 table
424 .add_predicate(predicate)
425 .map_err(|e| anyhow::anyhow!(e))?;
426 loaded_set.insert(name.to_string());
427 stats.predicate_loads += 1;
428 stats.cache_misses += 1;
429
430 {
432 let mut counts = self.access_counts.write().unwrap();
433 *counts.entry(name.to_string()).or_insert(0) += 1;
434 }
435
436 Ok(())
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use std::collections::HashMap;
444
445 struct MockLoader {
447 domains: HashMap<String, DomainInfo>,
448 predicates: HashMap<String, PredicateInfo>,
449 }
450
451 impl MockLoader {
452 fn new() -> Self {
453 let mut domains = HashMap::new();
454 domains.insert("Person".to_string(), DomainInfo::new("Person", 100));
455 domains.insert("Location".to_string(), DomainInfo::new("Location", 50));
456
457 let mut predicates = HashMap::new();
458 predicates.insert(
459 "at".to_string(),
460 PredicateInfo::new("at", vec!["Person".to_string(), "Location".to_string()]),
461 );
462
463 Self {
464 domains,
465 predicates,
466 }
467 }
468 }
469
470 impl SchemaLoader for MockLoader {
471 fn load_domain(&self, name: &str) -> Result<DomainInfo> {
472 self.domains
473 .get(name)
474 .cloned()
475 .ok_or_else(|| anyhow::anyhow!("Domain not found: {}", name))
476 }
477
478 fn load_predicate(&self, name: &str) -> Result<PredicateInfo> {
479 self.predicates
480 .get(name)
481 .cloned()
482 .ok_or_else(|| anyhow::anyhow!("Predicate not found: {}", name))
483 }
484
485 fn has_domain(&self, name: &str) -> bool {
486 self.domains.contains_key(name)
487 }
488
489 fn has_predicate(&self, name: &str) -> bool {
490 self.predicates.contains_key(name)
491 }
492
493 fn list_domains(&self) -> Result<Vec<String>> {
494 Ok(self.domains.keys().cloned().collect())
495 }
496
497 fn list_predicates(&self) -> Result<Vec<String>> {
498 Ok(self.predicates.keys().cloned().collect())
499 }
500 }
501
502 #[test]
503 fn test_lazy_load_domain() {
504 let loader = Arc::new(MockLoader::new());
505 let lazy_table = LazySymbolTable::new(loader);
506
507 let domain = lazy_table.get_domain("Person").unwrap();
508 assert!(domain.is_some());
509 assert_eq!(domain.unwrap().name, "Person");
510 }
511
512 #[test]
513 fn test_lazy_load_predicate() {
514 let loader = Arc::new(MockLoader::new());
515 let lazy_table = LazySymbolTable::new(loader);
516
517 lazy_table.get_domain("Person").unwrap();
519 lazy_table.get_domain("Location").unwrap();
520
521 let predicate = lazy_table.get_predicate("at").unwrap();
522 assert!(predicate.is_some());
523 assert_eq!(predicate.unwrap().name, "at");
524 }
525
526 #[test]
527 fn test_cache_hits() {
528 let loader = Arc::new(MockLoader::new());
529 let lazy_table = LazySymbolTable::new(loader);
530
531 lazy_table.get_domain("Person").unwrap();
533
534 lazy_table.get_domain("Person").unwrap();
536
537 let stats = lazy_table.stats();
538 assert_eq!(stats.cache_hits, 1);
539 assert_eq!(stats.cache_misses, 1);
540 }
541
542 #[test]
543 fn test_list_domains() {
544 let loader = Arc::new(MockLoader::new());
545 let lazy_table = LazySymbolTable::new(loader);
546
547 let domains = lazy_table.list_domains().unwrap();
548 assert_eq!(domains.len(), 2);
549 assert!(domains.contains(&"Person".to_string()));
550 assert!(domains.contains(&"Location".to_string()));
551 }
552
553 #[test]
554 fn test_preload_domains() {
555 let loader = Arc::new(MockLoader::new());
556 let lazy_table = LazySymbolTable::new(loader);
557
558 let names = vec!["Person".to_string(), "Location".to_string()];
559 lazy_table.preload_domains(&names).unwrap();
560
561 assert_eq!(lazy_table.loaded_domain_count(), 2);
562
563 let stats = lazy_table.stats();
564 assert_eq!(stats.batch_loads, 1);
565 }
566
567 #[test]
568 fn test_clear_cache() {
569 let loader = Arc::new(MockLoader::new());
570 let lazy_table = LazySymbolTable::new(loader);
571
572 lazy_table.get_domain("Person").unwrap();
573 assert_eq!(lazy_table.loaded_domain_count(), 1);
574
575 lazy_table.clear_cache();
576 assert_eq!(lazy_table.loaded_domain_count(), 0);
577 }
578
579 #[test]
580 fn test_load_strategy() {
581 let loader = Arc::new(MockLoader::new());
582 let strategy = LoadStrategy::Predictive {
583 access_threshold: 5,
584 };
585 let lazy_table = LazySymbolTable::with_strategy(loader, strategy);
586
587 lazy_table.get_domain("Person").unwrap();
588 assert_eq!(lazy_table.loaded_domain_count(), 1);
589 }
590
591 #[test]
592 fn test_hit_rate() {
593 let mut stats = LazyLoadStats::default();
594 assert_eq!(stats.hit_rate(), 0.0);
595
596 stats.cache_hits = 8;
597 stats.cache_misses = 2;
598 assert!((stats.hit_rate() - 0.8).abs() < 0.01);
599 }
600}