Skip to main content

tensorlogic_adapters/
lazy.rs

1//! Lazy loading support for huge symbol tables.
2//!
3//! This module provides on-demand loading of domains, predicates, and other
4//! schema elements, enabling efficient handling of very large schemas that
5//! don't fit comfortably in memory.
6//!
7//! # Examples
8//!
9//! ```rust
10//! use tensorlogic_adapters::{LazySymbolTable, FileSchemaLoader};
11//! use std::sync::Arc;
12//!
13//! // Create a lazy symbol table with on-demand loading
14//! let loader = Arc::new(FileSchemaLoader::new("/tmp/schema"));
15//! let lazy_table = LazySymbolTable::new(loader);
16//! ```
17
18use anyhow::Result;
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet};
21use std::path::PathBuf;
22use std::sync::{Arc, RwLock};
23
24use crate::{DomainInfo, PredicateInfo, SymbolTable};
25
26/// Strategy for loading schema elements.
27#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
28pub enum LoadStrategy {
29    /// Load all elements eagerly (no lazy loading).
30    Eager,
31    /// Load elements on first access.
32    #[default]
33    OnDemand,
34    /// Preload frequently accessed elements.
35    Predictive {
36        /// Threshold for "frequent" access.
37        access_threshold: usize,
38    },
39    /// Load elements in batches.
40    Batched {
41        /// Batch size for loading.
42        batch_size: usize,
43    },
44}
45
46/// A loader trait for fetching schema elements on demand.
47///
48/// Implementations can load from files, databases, or remote services.
49pub trait SchemaLoader: Send + Sync {
50    /// Load a domain by name.
51    fn load_domain(&self, name: &str) -> Result<DomainInfo>;
52
53    /// Load a predicate by name.
54    fn load_predicate(&self, name: &str) -> Result<PredicateInfo>;
55
56    /// Check if a domain exists without loading it.
57    fn has_domain(&self, name: &str) -> bool;
58
59    /// Check if a predicate exists without loading it.
60    fn has_predicate(&self, name: &str) -> bool;
61
62    /// List all available domain names.
63    fn list_domains(&self) -> Result<Vec<String>>;
64
65    /// List all available predicate names.
66    fn list_predicates(&self) -> Result<Vec<String>>;
67
68    /// Load a batch of domains by name.
69    fn load_domains_batch(&self, names: &[String]) -> Result<Vec<DomainInfo>> {
70        names.iter().map(|n| self.load_domain(n)).collect()
71    }
72
73    /// Load a batch of predicates by name.
74    fn load_predicates_batch(&self, names: &[String]) -> Result<Vec<PredicateInfo>> {
75        names.iter().map(|n| self.load_predicate(n)).collect()
76    }
77}
78
79/// File-based schema loader that reads from a directory structure.
80///
81/// Expected directory layout:
82/// ```text
83/// schema_dir/
84///   domains/
85///     domain1.json
86///     domain2.json
87///   predicates/
88///     pred1.json
89///     pred2.json
90/// ```
91#[derive(Clone, Debug)]
92pub struct FileSchemaLoader {
93    /// Base directory for schema files.
94    base_dir: PathBuf,
95}
96
97impl FileSchemaLoader {
98    /// Create a new file-based schema loader.
99    ///
100    /// # Examples
101    ///
102    /// ```rust
103    /// use tensorlogic_adapters::FileSchemaLoader;
104    ///
105    /// let loader = FileSchemaLoader::new("/path/to/schema");
106    /// ```
107    pub fn new(base_dir: impl Into<PathBuf>) -> Self {
108        Self {
109            base_dir: base_dir.into(),
110        }
111    }
112
113    fn domain_path(&self, name: &str) -> PathBuf {
114        self.base_dir.join("domains").join(format!("{}.json", name))
115    }
116
117    fn predicate_path(&self, name: &str) -> PathBuf {
118        self.base_dir
119            .join("predicates")
120            .join(format!("{}.json", name))
121    }
122}
123
124impl SchemaLoader for FileSchemaLoader {
125    fn load_domain(&self, name: &str) -> Result<DomainInfo> {
126        let path = self.domain_path(name);
127        let content = std::fs::read_to_string(path)?;
128        let domain: DomainInfo = serde_json::from_str(&content)?;
129        Ok(domain)
130    }
131
132    fn load_predicate(&self, name: &str) -> Result<PredicateInfo> {
133        let path = self.predicate_path(name);
134        let content = std::fs::read_to_string(path)?;
135        let predicate: PredicateInfo = serde_json::from_str(&content)?;
136        Ok(predicate)
137    }
138
139    fn has_domain(&self, name: &str) -> bool {
140        self.domain_path(name).exists()
141    }
142
143    fn has_predicate(&self, name: &str) -> bool {
144        self.predicate_path(name).exists()
145    }
146
147    fn list_domains(&self) -> Result<Vec<String>> {
148        let domains_dir = self.base_dir.join("domains");
149        if !domains_dir.exists() {
150            return Ok(Vec::new());
151        }
152
153        let mut names = Vec::new();
154        for entry in std::fs::read_dir(domains_dir)? {
155            let entry = entry?;
156            if let Some(name) = entry.path().file_stem() {
157                names.push(name.to_string_lossy().to_string());
158            }
159        }
160        Ok(names)
161    }
162
163    fn list_predicates(&self) -> Result<Vec<String>> {
164        let predicates_dir = self.base_dir.join("predicates");
165        if !predicates_dir.exists() {
166            return Ok(Vec::new());
167        }
168
169        let mut names = Vec::new();
170        for entry in std::fs::read_dir(predicates_dir)? {
171            let entry = entry?;
172            if let Some(name) = entry.path().file_stem() {
173                names.push(name.to_string_lossy().to_string());
174            }
175        }
176        Ok(names)
177    }
178}
179
180/// Statistics about lazy loading behavior.
181#[derive(Clone, Debug, Default)]
182pub struct LazyLoadStats {
183    /// Number of domain loads.
184    pub domain_loads: usize,
185    /// Number of predicate loads.
186    pub predicate_loads: usize,
187    /// Number of cache hits.
188    pub cache_hits: usize,
189    /// Number of cache misses.
190    pub cache_misses: usize,
191    /// Number of batch loads.
192    pub batch_loads: usize,
193}
194
195impl LazyLoadStats {
196    /// Get the cache hit rate.
197    pub fn hit_rate(&self) -> f64 {
198        let total = self.cache_hits + self.cache_misses;
199        if total == 0 {
200            0.0
201        } else {
202            self.cache_hits as f64 / total as f64
203        }
204    }
205}
206
207/// A symbol table with lazy loading support.
208///
209/// Elements are loaded on-demand from a SchemaLoader, reducing memory
210/// usage for large schemas.
211pub struct LazySymbolTable {
212    /// Eagerly loaded symbol table (acts as cache).
213    loaded: Arc<RwLock<SymbolTable>>,
214    /// Loader for on-demand fetching.
215    loader: Arc<dyn SchemaLoader>,
216    /// Loading strategy.
217    strategy: LoadStrategy,
218    /// Statistics.
219    stats: Arc<RwLock<LazyLoadStats>>,
220    /// Set of loaded domain names.
221    loaded_domains: Arc<RwLock<HashSet<String>>>,
222    /// Set of loaded predicate names.
223    loaded_predicates: Arc<RwLock<HashSet<String>>>,
224    /// Access counts for predictive loading.
225    access_counts: Arc<RwLock<HashMap<String, usize>>>,
226}
227
228impl LazySymbolTable {
229    /// Create a new lazy symbol table.
230    ///
231    /// # Examples
232    ///
233    /// ```rust
234    /// use tensorlogic_adapters::{LazySymbolTable, FileSchemaLoader};
235    /// use std::sync::Arc;
236    ///
237    /// let loader = Arc::new(FileSchemaLoader::new("/tmp/schema"));
238    /// let lazy_table = LazySymbolTable::new(loader);
239    /// ```
240    pub fn new(loader: Arc<dyn SchemaLoader>) -> Self {
241        Self {
242            loaded: Arc::new(RwLock::new(SymbolTable::new())),
243            loader,
244            strategy: LoadStrategy::default(),
245            stats: Arc::new(RwLock::new(LazyLoadStats::default())),
246            loaded_domains: Arc::new(RwLock::new(HashSet::new())),
247            loaded_predicates: Arc::new(RwLock::new(HashSet::new())),
248            access_counts: Arc::new(RwLock::new(HashMap::new())),
249        }
250    }
251
252    /// Create a lazy table with a specific loading strategy.
253    pub fn with_strategy(loader: Arc<dyn SchemaLoader>, strategy: LoadStrategy) -> Self {
254        let mut table = Self::new(loader);
255        table.strategy = strategy;
256        table
257    }
258
259    /// Get a domain, loading it if necessary.
260    pub fn get_domain(&self, name: &str) -> Result<Option<DomainInfo>> {
261        // Check if already loaded
262        {
263            let loaded_set = self.loaded_domains.read().unwrap();
264            if loaded_set.contains(name) {
265                let table = self.loaded.read().unwrap();
266                let mut stats = self.stats.write().unwrap();
267                stats.cache_hits += 1;
268                return Ok(table.get_domain(name).cloned());
269            }
270        }
271
272        // Check if domain exists
273        if !self.loader.has_domain(name) {
274            let mut stats = self.stats.write().unwrap();
275            stats.cache_misses += 1;
276            return Ok(None);
277        }
278
279        // Load domain
280        self.load_domain_internal(name)?;
281
282        let table = self.loaded.read().unwrap();
283        Ok(table.get_domain(name).cloned())
284    }
285
286    /// Get a predicate, loading it if necessary.
287    pub fn get_predicate(&self, name: &str) -> Result<Option<PredicateInfo>> {
288        // Check if already loaded
289        {
290            let loaded_set = self.loaded_predicates.read().unwrap();
291            if loaded_set.contains(name) {
292                let table = self.loaded.read().unwrap();
293                let mut stats = self.stats.write().unwrap();
294                stats.cache_hits += 1;
295                return Ok(table.get_predicate(name).cloned());
296            }
297        }
298
299        // Check if predicate exists
300        if !self.loader.has_predicate(name) {
301            let mut stats = self.stats.write().unwrap();
302            stats.cache_misses += 1;
303            return Ok(None);
304        }
305
306        // Load predicate
307        self.load_predicate_internal(name)?;
308
309        let table = self.loaded.read().unwrap();
310        Ok(table.get_predicate(name).cloned())
311    }
312
313    /// List all available domains (without loading them).
314    pub fn list_domains(&self) -> Result<Vec<String>> {
315        self.loader.list_domains()
316    }
317
318    /// List all available predicates (without loading them).
319    pub fn list_predicates(&self) -> Result<Vec<String>> {
320        self.loader.list_predicates()
321    }
322
323    /// Preload a batch of domains.
324    pub fn preload_domains(&self, names: &[String]) -> Result<()> {
325        let domains = self.loader.load_domains_batch(names)?;
326        let mut table = self.loaded.write().unwrap();
327        let mut loaded_set = self.loaded_domains.write().unwrap();
328        let mut stats = self.stats.write().unwrap();
329
330        for domain in domains {
331            let name = domain.name.clone();
332            table.add_domain(domain).map_err(|e| anyhow::anyhow!(e))?;
333            loaded_set.insert(name);
334            stats.domain_loads += 1;
335        }
336        stats.batch_loads += 1;
337
338        Ok(())
339    }
340
341    /// Preload a batch of predicates.
342    pub fn preload_predicates(&self, names: &[String]) -> Result<()> {
343        let predicates = self.loader.load_predicates_batch(names)?;
344        let mut table = self.loaded.write().unwrap();
345        let mut loaded_set = self.loaded_predicates.write().unwrap();
346        let mut stats = self.stats.write().unwrap();
347
348        for predicate in predicates {
349            let name = predicate.name.clone();
350            table
351                .add_predicate(predicate)
352                .map_err(|e| anyhow::anyhow!(e))?;
353            loaded_set.insert(name);
354            stats.predicate_loads += 1;
355        }
356        stats.batch_loads += 1;
357
358        Ok(())
359    }
360
361    /// Get loading statistics.
362    pub fn stats(&self) -> LazyLoadStats {
363        self.stats.read().unwrap().clone()
364    }
365
366    /// Clear the cache and force reload.
367    pub fn clear_cache(&self) {
368        let mut table = self.loaded.write().unwrap();
369        *table = SymbolTable::new();
370        self.loaded_domains.write().unwrap().clear();
371        self.loaded_predicates.write().unwrap().clear();
372        self.access_counts.write().unwrap().clear();
373    }
374
375    /// Get the number of loaded domains.
376    pub fn loaded_domain_count(&self) -> usize {
377        self.loaded_domains.read().unwrap().len()
378    }
379
380    /// Get the number of loaded predicates.
381    pub fn loaded_predicate_count(&self) -> usize {
382        self.loaded_predicates.read().unwrap().len()
383    }
384
385    /// Get a read-only reference to the loaded symbol table.
386    ///
387    /// Note: This only includes loaded elements.
388    pub fn as_symbol_table(&self) -> Arc<RwLock<SymbolTable>> {
389        Arc::clone(&self.loaded)
390    }
391
392    fn load_domain_internal(&self, name: &str) -> Result<()> {
393        let domain = self.loader.load_domain(name)?;
394        let mut table = self.loaded.write().unwrap();
395        let mut loaded_set = self.loaded_domains.write().unwrap();
396        let mut stats = self.stats.write().unwrap();
397
398        table.add_domain(domain).map_err(|e| anyhow::anyhow!(e))?;
399        loaded_set.insert(name.to_string());
400        stats.domain_loads += 1;
401        stats.cache_misses += 1;
402
403        // Track access for predictive loading
404        {
405            let mut counts = self.access_counts.write().unwrap();
406            *counts.entry(name.to_string()).or_insert(0) += 1;
407        }
408
409        Ok(())
410    }
411
412    fn load_predicate_internal(&self, name: &str) -> Result<()> {
413        let predicate = self.loader.load_predicate(name)?;
414        let mut table = self.loaded.write().unwrap();
415        let mut loaded_set = self.loaded_predicates.write().unwrap();
416        let mut stats = self.stats.write().unwrap();
417
418        table
419            .add_predicate(predicate)
420            .map_err(|e| anyhow::anyhow!(e))?;
421        loaded_set.insert(name.to_string());
422        stats.predicate_loads += 1;
423        stats.cache_misses += 1;
424
425        // Track access for predictive loading
426        {
427            let mut counts = self.access_counts.write().unwrap();
428            *counts.entry(name.to_string()).or_insert(0) += 1;
429        }
430
431        Ok(())
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use std::collections::HashMap;
439
440    // Mock loader for testing
441    struct MockLoader {
442        domains: HashMap<String, DomainInfo>,
443        predicates: HashMap<String, PredicateInfo>,
444    }
445
446    impl MockLoader {
447        fn new() -> Self {
448            let mut domains = HashMap::new();
449            domains.insert("Person".to_string(), DomainInfo::new("Person", 100));
450            domains.insert("Location".to_string(), DomainInfo::new("Location", 50));
451
452            let mut predicates = HashMap::new();
453            predicates.insert(
454                "at".to_string(),
455                PredicateInfo::new("at", vec!["Person".to_string(), "Location".to_string()]),
456            );
457
458            Self {
459                domains,
460                predicates,
461            }
462        }
463    }
464
465    impl SchemaLoader for MockLoader {
466        fn load_domain(&self, name: &str) -> Result<DomainInfo> {
467            self.domains
468                .get(name)
469                .cloned()
470                .ok_or_else(|| anyhow::anyhow!("Domain not found: {}", name))
471        }
472
473        fn load_predicate(&self, name: &str) -> Result<PredicateInfo> {
474            self.predicates
475                .get(name)
476                .cloned()
477                .ok_or_else(|| anyhow::anyhow!("Predicate not found: {}", name))
478        }
479
480        fn has_domain(&self, name: &str) -> bool {
481            self.domains.contains_key(name)
482        }
483
484        fn has_predicate(&self, name: &str) -> bool {
485            self.predicates.contains_key(name)
486        }
487
488        fn list_domains(&self) -> Result<Vec<String>> {
489            Ok(self.domains.keys().cloned().collect())
490        }
491
492        fn list_predicates(&self) -> Result<Vec<String>> {
493            Ok(self.predicates.keys().cloned().collect())
494        }
495    }
496
497    #[test]
498    fn test_lazy_load_domain() {
499        let loader = Arc::new(MockLoader::new());
500        let lazy_table = LazySymbolTable::new(loader);
501
502        let domain = lazy_table.get_domain("Person").unwrap();
503        assert!(domain.is_some());
504        assert_eq!(domain.unwrap().name, "Person");
505    }
506
507    #[test]
508    fn test_lazy_load_predicate() {
509        let loader = Arc::new(MockLoader::new());
510        let lazy_table = LazySymbolTable::new(loader);
511
512        // First load predicates' domains
513        lazy_table.get_domain("Person").unwrap();
514        lazy_table.get_domain("Location").unwrap();
515
516        let predicate = lazy_table.get_predicate("at").unwrap();
517        assert!(predicate.is_some());
518        assert_eq!(predicate.unwrap().name, "at");
519    }
520
521    #[test]
522    fn test_cache_hits() {
523        let loader = Arc::new(MockLoader::new());
524        let lazy_table = LazySymbolTable::new(loader);
525
526        // First access (miss)
527        lazy_table.get_domain("Person").unwrap();
528
529        // Second access (hit)
530        lazy_table.get_domain("Person").unwrap();
531
532        let stats = lazy_table.stats();
533        assert_eq!(stats.cache_hits, 1);
534        assert_eq!(stats.cache_misses, 1);
535    }
536
537    #[test]
538    fn test_list_domains() {
539        let loader = Arc::new(MockLoader::new());
540        let lazy_table = LazySymbolTable::new(loader);
541
542        let domains = lazy_table.list_domains().unwrap();
543        assert_eq!(domains.len(), 2);
544        assert!(domains.contains(&"Person".to_string()));
545        assert!(domains.contains(&"Location".to_string()));
546    }
547
548    #[test]
549    fn test_preload_domains() {
550        let loader = Arc::new(MockLoader::new());
551        let lazy_table = LazySymbolTable::new(loader);
552
553        let names = vec!["Person".to_string(), "Location".to_string()];
554        lazy_table.preload_domains(&names).unwrap();
555
556        assert_eq!(lazy_table.loaded_domain_count(), 2);
557
558        let stats = lazy_table.stats();
559        assert_eq!(stats.batch_loads, 1);
560    }
561
562    #[test]
563    fn test_clear_cache() {
564        let loader = Arc::new(MockLoader::new());
565        let lazy_table = LazySymbolTable::new(loader);
566
567        lazy_table.get_domain("Person").unwrap();
568        assert_eq!(lazy_table.loaded_domain_count(), 1);
569
570        lazy_table.clear_cache();
571        assert_eq!(lazy_table.loaded_domain_count(), 0);
572    }
573
574    #[test]
575    fn test_load_strategy() {
576        let loader = Arc::new(MockLoader::new());
577        let strategy = LoadStrategy::Predictive {
578            access_threshold: 5,
579        };
580        let lazy_table = LazySymbolTable::with_strategy(loader, strategy);
581
582        lazy_table.get_domain("Person").unwrap();
583        assert_eq!(lazy_table.loaded_domain_count(), 1);
584    }
585
586    #[test]
587    fn test_hit_rate() {
588        let mut stats = LazyLoadStats::default();
589        assert_eq!(stats.hit_rate(), 0.0);
590
591        stats.cache_hits = 8;
592        stats.cache_misses = 2;
593        assert!((stats.hit_rate() - 0.8).abs() < 0.01);
594    }
595}