tensorlogic_adapters/
lazy.rs

1//! Lazy loading support for huge symbol tables.
2//!
3//! This module provides on-demand loading of domains, predicates, and other
4//! schema elements, enabling efficient handling of very large schemas that
5//! don't fit comfortably in memory.
6//!
7//! # Examples
8//!
9//! ```rust
10//! use tensorlogic_adapters::{LazySymbolTable, FileSchemaLoader};
11//! use std::sync::Arc;
12//!
13//! // Create a lazy symbol table with on-demand loading
14//! let loader = Arc::new(FileSchemaLoader::new("/tmp/schema"));
15//! let lazy_table = LazySymbolTable::new(loader);
16//! ```
17
18use anyhow::Result;
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet};
21use std::path::PathBuf;
22use std::sync::{Arc, RwLock};
23
24use crate::{DomainInfo, PredicateInfo, SymbolTable};
25
26/// Strategy for loading schema elements.
27#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
28pub enum LoadStrategy {
29    /// Load all elements eagerly (no lazy loading).
30    Eager,
31    /// Load elements on first access.
32    OnDemand,
33    /// Preload frequently accessed elements.
34    Predictive {
35        /// Threshold for "frequent" access.
36        access_threshold: usize,
37    },
38    /// Load elements in batches.
39    Batched {
40        /// Batch size for loading.
41        batch_size: usize,
42    },
43}
44
45impl Default for LoadStrategy {
46    fn default() -> Self {
47        Self::OnDemand
48    }
49}
50
51/// A loader trait for fetching schema elements on demand.
52///
53/// Implementations can load from files, databases, or remote services.
54pub trait SchemaLoader: Send + Sync {
55    /// Load a domain by name.
56    fn load_domain(&self, name: &str) -> Result<DomainInfo>;
57
58    /// Load a predicate by name.
59    fn load_predicate(&self, name: &str) -> Result<PredicateInfo>;
60
61    /// Check if a domain exists without loading it.
62    fn has_domain(&self, name: &str) -> bool;
63
64    /// Check if a predicate exists without loading it.
65    fn has_predicate(&self, name: &str) -> bool;
66
67    /// List all available domain names.
68    fn list_domains(&self) -> Result<Vec<String>>;
69
70    /// List all available predicate names.
71    fn list_predicates(&self) -> Result<Vec<String>>;
72
73    /// Load a batch of domains by name.
74    fn load_domains_batch(&self, names: &[String]) -> Result<Vec<DomainInfo>> {
75        names.iter().map(|n| self.load_domain(n)).collect()
76    }
77
78    /// Load a batch of predicates by name.
79    fn load_predicates_batch(&self, names: &[String]) -> Result<Vec<PredicateInfo>> {
80        names.iter().map(|n| self.load_predicate(n)).collect()
81    }
82}
83
84/// File-based schema loader that reads from a directory structure.
85///
86/// Expected directory layout:
87/// ```text
88/// schema_dir/
89///   domains/
90///     domain1.json
91///     domain2.json
92///   predicates/
93///     pred1.json
94///     pred2.json
95/// ```
96#[derive(Clone, Debug)]
97pub struct FileSchemaLoader {
98    /// Base directory for schema files.
99    base_dir: PathBuf,
100}
101
102impl FileSchemaLoader {
103    /// Create a new file-based schema loader.
104    ///
105    /// # Examples
106    ///
107    /// ```rust
108    /// use tensorlogic_adapters::FileSchemaLoader;
109    ///
110    /// let loader = FileSchemaLoader::new("/path/to/schema");
111    /// ```
112    pub fn new(base_dir: impl Into<PathBuf>) -> Self {
113        Self {
114            base_dir: base_dir.into(),
115        }
116    }
117
118    fn domain_path(&self, name: &str) -> PathBuf {
119        self.base_dir.join("domains").join(format!("{}.json", name))
120    }
121
122    fn predicate_path(&self, name: &str) -> PathBuf {
123        self.base_dir
124            .join("predicates")
125            .join(format!("{}.json", name))
126    }
127}
128
129impl SchemaLoader for FileSchemaLoader {
130    fn load_domain(&self, name: &str) -> Result<DomainInfo> {
131        let path = self.domain_path(name);
132        let content = std::fs::read_to_string(path)?;
133        let domain: DomainInfo = serde_json::from_str(&content)?;
134        Ok(domain)
135    }
136
137    fn load_predicate(&self, name: &str) -> Result<PredicateInfo> {
138        let path = self.predicate_path(name);
139        let content = std::fs::read_to_string(path)?;
140        let predicate: PredicateInfo = serde_json::from_str(&content)?;
141        Ok(predicate)
142    }
143
144    fn has_domain(&self, name: &str) -> bool {
145        self.domain_path(name).exists()
146    }
147
148    fn has_predicate(&self, name: &str) -> bool {
149        self.predicate_path(name).exists()
150    }
151
152    fn list_domains(&self) -> Result<Vec<String>> {
153        let domains_dir = self.base_dir.join("domains");
154        if !domains_dir.exists() {
155            return Ok(Vec::new());
156        }
157
158        let mut names = Vec::new();
159        for entry in std::fs::read_dir(domains_dir)? {
160            let entry = entry?;
161            if let Some(name) = entry.path().file_stem() {
162                names.push(name.to_string_lossy().to_string());
163            }
164        }
165        Ok(names)
166    }
167
168    fn list_predicates(&self) -> Result<Vec<String>> {
169        let predicates_dir = self.base_dir.join("predicates");
170        if !predicates_dir.exists() {
171            return Ok(Vec::new());
172        }
173
174        let mut names = Vec::new();
175        for entry in std::fs::read_dir(predicates_dir)? {
176            let entry = entry?;
177            if let Some(name) = entry.path().file_stem() {
178                names.push(name.to_string_lossy().to_string());
179            }
180        }
181        Ok(names)
182    }
183}
184
185/// Statistics about lazy loading behavior.
186#[derive(Clone, Debug, Default)]
187pub struct LazyLoadStats {
188    /// Number of domain loads.
189    pub domain_loads: usize,
190    /// Number of predicate loads.
191    pub predicate_loads: usize,
192    /// Number of cache hits.
193    pub cache_hits: usize,
194    /// Number of cache misses.
195    pub cache_misses: usize,
196    /// Number of batch loads.
197    pub batch_loads: usize,
198}
199
200impl LazyLoadStats {
201    /// Get the cache hit rate.
202    pub fn hit_rate(&self) -> f64 {
203        let total = self.cache_hits + self.cache_misses;
204        if total == 0 {
205            0.0
206        } else {
207            self.cache_hits as f64 / total as f64
208        }
209    }
210}
211
212/// A symbol table with lazy loading support.
213///
214/// Elements are loaded on-demand from a SchemaLoader, reducing memory
215/// usage for large schemas.
216pub struct LazySymbolTable {
217    /// Eagerly loaded symbol table (acts as cache).
218    loaded: Arc<RwLock<SymbolTable>>,
219    /// Loader for on-demand fetching.
220    loader: Arc<dyn SchemaLoader>,
221    /// Loading strategy.
222    strategy: LoadStrategy,
223    /// Statistics.
224    stats: Arc<RwLock<LazyLoadStats>>,
225    /// Set of loaded domain names.
226    loaded_domains: Arc<RwLock<HashSet<String>>>,
227    /// Set of loaded predicate names.
228    loaded_predicates: Arc<RwLock<HashSet<String>>>,
229    /// Access counts for predictive loading.
230    access_counts: Arc<RwLock<HashMap<String, usize>>>,
231}
232
233impl LazySymbolTable {
234    /// Create a new lazy symbol table.
235    ///
236    /// # Examples
237    ///
238    /// ```rust
239    /// use tensorlogic_adapters::{LazySymbolTable, FileSchemaLoader};
240    /// use std::sync::Arc;
241    ///
242    /// let loader = Arc::new(FileSchemaLoader::new("/tmp/schema"));
243    /// let lazy_table = LazySymbolTable::new(loader);
244    /// ```
245    pub fn new(loader: Arc<dyn SchemaLoader>) -> Self {
246        Self {
247            loaded: Arc::new(RwLock::new(SymbolTable::new())),
248            loader,
249            strategy: LoadStrategy::default(),
250            stats: Arc::new(RwLock::new(LazyLoadStats::default())),
251            loaded_domains: Arc::new(RwLock::new(HashSet::new())),
252            loaded_predicates: Arc::new(RwLock::new(HashSet::new())),
253            access_counts: Arc::new(RwLock::new(HashMap::new())),
254        }
255    }
256
257    /// Create a lazy table with a specific loading strategy.
258    pub fn with_strategy(loader: Arc<dyn SchemaLoader>, strategy: LoadStrategy) -> Self {
259        let mut table = Self::new(loader);
260        table.strategy = strategy;
261        table
262    }
263
264    /// Get a domain, loading it if necessary.
265    pub fn get_domain(&self, name: &str) -> Result<Option<DomainInfo>> {
266        // Check if already loaded
267        {
268            let loaded_set = self.loaded_domains.read().unwrap();
269            if loaded_set.contains(name) {
270                let table = self.loaded.read().unwrap();
271                let mut stats = self.stats.write().unwrap();
272                stats.cache_hits += 1;
273                return Ok(table.get_domain(name).cloned());
274            }
275        }
276
277        // Check if domain exists
278        if !self.loader.has_domain(name) {
279            let mut stats = self.stats.write().unwrap();
280            stats.cache_misses += 1;
281            return Ok(None);
282        }
283
284        // Load domain
285        self.load_domain_internal(name)?;
286
287        let table = self.loaded.read().unwrap();
288        Ok(table.get_domain(name).cloned())
289    }
290
291    /// Get a predicate, loading it if necessary.
292    pub fn get_predicate(&self, name: &str) -> Result<Option<PredicateInfo>> {
293        // Check if already loaded
294        {
295            let loaded_set = self.loaded_predicates.read().unwrap();
296            if loaded_set.contains(name) {
297                let table = self.loaded.read().unwrap();
298                let mut stats = self.stats.write().unwrap();
299                stats.cache_hits += 1;
300                return Ok(table.get_predicate(name).cloned());
301            }
302        }
303
304        // Check if predicate exists
305        if !self.loader.has_predicate(name) {
306            let mut stats = self.stats.write().unwrap();
307            stats.cache_misses += 1;
308            return Ok(None);
309        }
310
311        // Load predicate
312        self.load_predicate_internal(name)?;
313
314        let table = self.loaded.read().unwrap();
315        Ok(table.get_predicate(name).cloned())
316    }
317
318    /// List all available domains (without loading them).
319    pub fn list_domains(&self) -> Result<Vec<String>> {
320        self.loader.list_domains()
321    }
322
323    /// List all available predicates (without loading them).
324    pub fn list_predicates(&self) -> Result<Vec<String>> {
325        self.loader.list_predicates()
326    }
327
328    /// Preload a batch of domains.
329    pub fn preload_domains(&self, names: &[String]) -> Result<()> {
330        let domains = self.loader.load_domains_batch(names)?;
331        let mut table = self.loaded.write().unwrap();
332        let mut loaded_set = self.loaded_domains.write().unwrap();
333        let mut stats = self.stats.write().unwrap();
334
335        for domain in domains {
336            let name = domain.name.clone();
337            table.add_domain(domain).map_err(|e| anyhow::anyhow!(e))?;
338            loaded_set.insert(name);
339            stats.domain_loads += 1;
340        }
341        stats.batch_loads += 1;
342
343        Ok(())
344    }
345
346    /// Preload a batch of predicates.
347    pub fn preload_predicates(&self, names: &[String]) -> Result<()> {
348        let predicates = self.loader.load_predicates_batch(names)?;
349        let mut table = self.loaded.write().unwrap();
350        let mut loaded_set = self.loaded_predicates.write().unwrap();
351        let mut stats = self.stats.write().unwrap();
352
353        for predicate in predicates {
354            let name = predicate.name.clone();
355            table
356                .add_predicate(predicate)
357                .map_err(|e| anyhow::anyhow!(e))?;
358            loaded_set.insert(name);
359            stats.predicate_loads += 1;
360        }
361        stats.batch_loads += 1;
362
363        Ok(())
364    }
365
366    /// Get loading statistics.
367    pub fn stats(&self) -> LazyLoadStats {
368        self.stats.read().unwrap().clone()
369    }
370
371    /// Clear the cache and force reload.
372    pub fn clear_cache(&self) {
373        let mut table = self.loaded.write().unwrap();
374        *table = SymbolTable::new();
375        self.loaded_domains.write().unwrap().clear();
376        self.loaded_predicates.write().unwrap().clear();
377        self.access_counts.write().unwrap().clear();
378    }
379
380    /// Get the number of loaded domains.
381    pub fn loaded_domain_count(&self) -> usize {
382        self.loaded_domains.read().unwrap().len()
383    }
384
385    /// Get the number of loaded predicates.
386    pub fn loaded_predicate_count(&self) -> usize {
387        self.loaded_predicates.read().unwrap().len()
388    }
389
390    /// Get a read-only reference to the loaded symbol table.
391    ///
392    /// Note: This only includes loaded elements.
393    pub fn as_symbol_table(&self) -> Arc<RwLock<SymbolTable>> {
394        Arc::clone(&self.loaded)
395    }
396
397    fn load_domain_internal(&self, name: &str) -> Result<()> {
398        let domain = self.loader.load_domain(name)?;
399        let mut table = self.loaded.write().unwrap();
400        let mut loaded_set = self.loaded_domains.write().unwrap();
401        let mut stats = self.stats.write().unwrap();
402
403        table.add_domain(domain).map_err(|e| anyhow::anyhow!(e))?;
404        loaded_set.insert(name.to_string());
405        stats.domain_loads += 1;
406        stats.cache_misses += 1;
407
408        // Track access for predictive loading
409        {
410            let mut counts = self.access_counts.write().unwrap();
411            *counts.entry(name.to_string()).or_insert(0) += 1;
412        }
413
414        Ok(())
415    }
416
417    fn load_predicate_internal(&self, name: &str) -> Result<()> {
418        let predicate = self.loader.load_predicate(name)?;
419        let mut table = self.loaded.write().unwrap();
420        let mut loaded_set = self.loaded_predicates.write().unwrap();
421        let mut stats = self.stats.write().unwrap();
422
423        table
424            .add_predicate(predicate)
425            .map_err(|e| anyhow::anyhow!(e))?;
426        loaded_set.insert(name.to_string());
427        stats.predicate_loads += 1;
428        stats.cache_misses += 1;
429
430        // Track access for predictive loading
431        {
432            let mut counts = self.access_counts.write().unwrap();
433            *counts.entry(name.to_string()).or_insert(0) += 1;
434        }
435
436        Ok(())
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443    use std::collections::HashMap;
444
445    // Mock loader for testing
446    struct MockLoader {
447        domains: HashMap<String, DomainInfo>,
448        predicates: HashMap<String, PredicateInfo>,
449    }
450
451    impl MockLoader {
452        fn new() -> Self {
453            let mut domains = HashMap::new();
454            domains.insert("Person".to_string(), DomainInfo::new("Person", 100));
455            domains.insert("Location".to_string(), DomainInfo::new("Location", 50));
456
457            let mut predicates = HashMap::new();
458            predicates.insert(
459                "at".to_string(),
460                PredicateInfo::new("at", vec!["Person".to_string(), "Location".to_string()]),
461            );
462
463            Self {
464                domains,
465                predicates,
466            }
467        }
468    }
469
470    impl SchemaLoader for MockLoader {
471        fn load_domain(&self, name: &str) -> Result<DomainInfo> {
472            self.domains
473                .get(name)
474                .cloned()
475                .ok_or_else(|| anyhow::anyhow!("Domain not found: {}", name))
476        }
477
478        fn load_predicate(&self, name: &str) -> Result<PredicateInfo> {
479            self.predicates
480                .get(name)
481                .cloned()
482                .ok_or_else(|| anyhow::anyhow!("Predicate not found: {}", name))
483        }
484
485        fn has_domain(&self, name: &str) -> bool {
486            self.domains.contains_key(name)
487        }
488
489        fn has_predicate(&self, name: &str) -> bool {
490            self.predicates.contains_key(name)
491        }
492
493        fn list_domains(&self) -> Result<Vec<String>> {
494            Ok(self.domains.keys().cloned().collect())
495        }
496
497        fn list_predicates(&self) -> Result<Vec<String>> {
498            Ok(self.predicates.keys().cloned().collect())
499        }
500    }
501
502    #[test]
503    fn test_lazy_load_domain() {
504        let loader = Arc::new(MockLoader::new());
505        let lazy_table = LazySymbolTable::new(loader);
506
507        let domain = lazy_table.get_domain("Person").unwrap();
508        assert!(domain.is_some());
509        assert_eq!(domain.unwrap().name, "Person");
510    }
511
512    #[test]
513    fn test_lazy_load_predicate() {
514        let loader = Arc::new(MockLoader::new());
515        let lazy_table = LazySymbolTable::new(loader);
516
517        // First load predicates' domains
518        lazy_table.get_domain("Person").unwrap();
519        lazy_table.get_domain("Location").unwrap();
520
521        let predicate = lazy_table.get_predicate("at").unwrap();
522        assert!(predicate.is_some());
523        assert_eq!(predicate.unwrap().name, "at");
524    }
525
526    #[test]
527    fn test_cache_hits() {
528        let loader = Arc::new(MockLoader::new());
529        let lazy_table = LazySymbolTable::new(loader);
530
531        // First access (miss)
532        lazy_table.get_domain("Person").unwrap();
533
534        // Second access (hit)
535        lazy_table.get_domain("Person").unwrap();
536
537        let stats = lazy_table.stats();
538        assert_eq!(stats.cache_hits, 1);
539        assert_eq!(stats.cache_misses, 1);
540    }
541
542    #[test]
543    fn test_list_domains() {
544        let loader = Arc::new(MockLoader::new());
545        let lazy_table = LazySymbolTable::new(loader);
546
547        let domains = lazy_table.list_domains().unwrap();
548        assert_eq!(domains.len(), 2);
549        assert!(domains.contains(&"Person".to_string()));
550        assert!(domains.contains(&"Location".to_string()));
551    }
552
553    #[test]
554    fn test_preload_domains() {
555        let loader = Arc::new(MockLoader::new());
556        let lazy_table = LazySymbolTable::new(loader);
557
558        let names = vec!["Person".to_string(), "Location".to_string()];
559        lazy_table.preload_domains(&names).unwrap();
560
561        assert_eq!(lazy_table.loaded_domain_count(), 2);
562
563        let stats = lazy_table.stats();
564        assert_eq!(stats.batch_loads, 1);
565    }
566
567    #[test]
568    fn test_clear_cache() {
569        let loader = Arc::new(MockLoader::new());
570        let lazy_table = LazySymbolTable::new(loader);
571
572        lazy_table.get_domain("Person").unwrap();
573        assert_eq!(lazy_table.loaded_domain_count(), 1);
574
575        lazy_table.clear_cache();
576        assert_eq!(lazy_table.loaded_domain_count(), 0);
577    }
578
579    #[test]
580    fn test_load_strategy() {
581        let loader = Arc::new(MockLoader::new());
582        let strategy = LoadStrategy::Predictive {
583            access_threshold: 5,
584        };
585        let lazy_table = LazySymbolTable::with_strategy(loader, strategy);
586
587        lazy_table.get_domain("Person").unwrap();
588        assert_eq!(lazy_table.loaded_domain_count(), 1);
589    }
590
591    #[test]
592    fn test_hit_rate() {
593        let mut stats = LazyLoadStats::default();
594        assert_eq!(stats.hit_rate(), 0.0);
595
596        stats.cache_hits = 8;
597        stats.cache_misses = 2;
598        assert!((stats.hit_rate() - 0.8).abs() < 0.01);
599    }
600}