Skip to main content

tensorlogic_adapters/
lazy.rs

1//! Lazy loading support for huge symbol tables.
2//!
3//! This module provides on-demand loading of domains, predicates, and other
4//! schema elements, enabling efficient handling of very large schemas that
5//! don't fit comfortably in memory.
6//!
7//! # Examples
8//!
9//! ```rust
10//! use tensorlogic_adapters::{LazySymbolTable, FileSchemaLoader};
11//! use std::sync::Arc;
12//!
13//! // Create a lazy symbol table with on-demand loading
14//! let loader = Arc::new(FileSchemaLoader::new("/tmp/schema"));
15//! let lazy_table = LazySymbolTable::new(loader);
16//! ```
17
18use anyhow::Result;
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet};
21use std::path::PathBuf;
22use std::sync::{Arc, RwLock};
23
24use crate::{DomainInfo, PredicateInfo, SymbolTable};
25
26/// Strategy for loading schema elements.
27#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
28pub enum LoadStrategy {
29    /// Load all elements eagerly (no lazy loading).
30    Eager,
31    /// Load elements on first access.
32    #[default]
33    OnDemand,
34    /// Preload frequently accessed elements.
35    Predictive {
36        /// Threshold for "frequent" access.
37        access_threshold: usize,
38    },
39    /// Load elements in batches.
40    Batched {
41        /// Batch size for loading.
42        batch_size: usize,
43    },
44}
45
46/// A loader trait for fetching schema elements on demand.
47///
48/// Implementations can load from files, databases, or remote services.
49pub trait SchemaLoader: Send + Sync {
50    /// Load a domain by name.
51    fn load_domain(&self, name: &str) -> Result<DomainInfo>;
52
53    /// Load a predicate by name.
54    fn load_predicate(&self, name: &str) -> Result<PredicateInfo>;
55
56    /// Check if a domain exists without loading it.
57    fn has_domain(&self, name: &str) -> bool;
58
59    /// Check if a predicate exists without loading it.
60    fn has_predicate(&self, name: &str) -> bool;
61
62    /// List all available domain names.
63    fn list_domains(&self) -> Result<Vec<String>>;
64
65    /// List all available predicate names.
66    fn list_predicates(&self) -> Result<Vec<String>>;
67
68    /// Load a batch of domains by name.
69    fn load_domains_batch(&self, names: &[String]) -> Result<Vec<DomainInfo>> {
70        names.iter().map(|n| self.load_domain(n)).collect()
71    }
72
73    /// Load a batch of predicates by name.
74    fn load_predicates_batch(&self, names: &[String]) -> Result<Vec<PredicateInfo>> {
75        names.iter().map(|n| self.load_predicate(n)).collect()
76    }
77}
78
79/// File-based schema loader that reads from a directory structure.
80///
81/// Expected directory layout:
82/// ```text
83/// schema_dir/
84///   domains/
85///     domain1.json
86///     domain2.json
87///   predicates/
88///     pred1.json
89///     pred2.json
90/// ```
91#[derive(Clone, Debug)]
92pub struct FileSchemaLoader {
93    /// Base directory for schema files.
94    base_dir: PathBuf,
95}
96
97impl FileSchemaLoader {
98    /// Create a new file-based schema loader.
99    ///
100    /// # Examples
101    ///
102    /// ```rust
103    /// use tensorlogic_adapters::FileSchemaLoader;
104    ///
105    /// let loader = FileSchemaLoader::new("/path/to/schema");
106    /// ```
107    pub fn new(base_dir: impl Into<PathBuf>) -> Self {
108        Self {
109            base_dir: base_dir.into(),
110        }
111    }
112
113    fn domain_path(&self, name: &str) -> PathBuf {
114        self.base_dir.join("domains").join(format!("{}.json", name))
115    }
116
117    fn predicate_path(&self, name: &str) -> PathBuf {
118        self.base_dir
119            .join("predicates")
120            .join(format!("{}.json", name))
121    }
122}
123
124impl SchemaLoader for FileSchemaLoader {
125    fn load_domain(&self, name: &str) -> Result<DomainInfo> {
126        let path = self.domain_path(name);
127        let content = std::fs::read_to_string(path)?;
128        let domain: DomainInfo = serde_json::from_str(&content)?;
129        Ok(domain)
130    }
131
132    fn load_predicate(&self, name: &str) -> Result<PredicateInfo> {
133        let path = self.predicate_path(name);
134        let content = std::fs::read_to_string(path)?;
135        let predicate: PredicateInfo = serde_json::from_str(&content)?;
136        Ok(predicate)
137    }
138
139    fn has_domain(&self, name: &str) -> bool {
140        self.domain_path(name).exists()
141    }
142
143    fn has_predicate(&self, name: &str) -> bool {
144        self.predicate_path(name).exists()
145    }
146
147    fn list_domains(&self) -> Result<Vec<String>> {
148        let domains_dir = self.base_dir.join("domains");
149        if !domains_dir.exists() {
150            return Ok(Vec::new());
151        }
152
153        let mut names = Vec::new();
154        for entry in std::fs::read_dir(domains_dir)? {
155            let entry = entry?;
156            if let Some(name) = entry.path().file_stem() {
157                names.push(name.to_string_lossy().to_string());
158            }
159        }
160        Ok(names)
161    }
162
163    fn list_predicates(&self) -> Result<Vec<String>> {
164        let predicates_dir = self.base_dir.join("predicates");
165        if !predicates_dir.exists() {
166            return Ok(Vec::new());
167        }
168
169        let mut names = Vec::new();
170        for entry in std::fs::read_dir(predicates_dir)? {
171            let entry = entry?;
172            if let Some(name) = entry.path().file_stem() {
173                names.push(name.to_string_lossy().to_string());
174            }
175        }
176        Ok(names)
177    }
178}
179
180/// Statistics about lazy loading behavior.
181#[derive(Clone, Debug, Default)]
182pub struct LazyLoadStats {
183    /// Number of domain loads.
184    pub domain_loads: usize,
185    /// Number of predicate loads.
186    pub predicate_loads: usize,
187    /// Number of cache hits.
188    pub cache_hits: usize,
189    /// Number of cache misses.
190    pub cache_misses: usize,
191    /// Number of batch loads.
192    pub batch_loads: usize,
193}
194
195impl LazyLoadStats {
196    /// Get the cache hit rate.
197    pub fn hit_rate(&self) -> f64 {
198        let total = self.cache_hits + self.cache_misses;
199        if total == 0 {
200            0.0
201        } else {
202            self.cache_hits as f64 / total as f64
203        }
204    }
205}
206
207/// A symbol table with lazy loading support.
208///
209/// Elements are loaded on-demand from a SchemaLoader, reducing memory
210/// usage for large schemas.
211pub struct LazySymbolTable {
212    /// Eagerly loaded symbol table (acts as cache).
213    loaded: Arc<RwLock<SymbolTable>>,
214    /// Loader for on-demand fetching.
215    loader: Arc<dyn SchemaLoader>,
216    /// Loading strategy.
217    strategy: LoadStrategy,
218    /// Statistics.
219    stats: Arc<RwLock<LazyLoadStats>>,
220    /// Set of loaded domain names.
221    loaded_domains: Arc<RwLock<HashSet<String>>>,
222    /// Set of loaded predicate names.
223    loaded_predicates: Arc<RwLock<HashSet<String>>>,
224    /// Access counts for predictive loading.
225    access_counts: Arc<RwLock<HashMap<String, usize>>>,
226}
227
228impl LazySymbolTable {
229    /// Create a new lazy symbol table.
230    ///
231    /// # Examples
232    ///
233    /// ```rust
234    /// use tensorlogic_adapters::{LazySymbolTable, FileSchemaLoader};
235    /// use std::sync::Arc;
236    ///
237    /// let loader = Arc::new(FileSchemaLoader::new("/tmp/schema"));
238    /// let lazy_table = LazySymbolTable::new(loader);
239    /// ```
240    pub fn new(loader: Arc<dyn SchemaLoader>) -> Self {
241        Self {
242            loaded: Arc::new(RwLock::new(SymbolTable::new())),
243            loader,
244            strategy: LoadStrategy::default(),
245            stats: Arc::new(RwLock::new(LazyLoadStats::default())),
246            loaded_domains: Arc::new(RwLock::new(HashSet::new())),
247            loaded_predicates: Arc::new(RwLock::new(HashSet::new())),
248            access_counts: Arc::new(RwLock::new(HashMap::new())),
249        }
250    }
251
252    /// Create a lazy table with a specific loading strategy.
253    pub fn with_strategy(loader: Arc<dyn SchemaLoader>, strategy: LoadStrategy) -> Self {
254        let mut table = Self::new(loader);
255        table.strategy = strategy;
256        table
257    }
258
259    /// Get a domain, loading it if necessary.
260    pub fn get_domain(&self, name: &str) -> Result<Option<DomainInfo>> {
261        // Check if already loaded
262        {
263            let loaded_set = self
264                .loaded_domains
265                .read()
266                .expect("lock should not be poisoned");
267            if loaded_set.contains(name) {
268                let table = self.loaded.read().expect("lock should not be poisoned");
269                let mut stats = self.stats.write().expect("lock should not be poisoned");
270                stats.cache_hits += 1;
271                return Ok(table.get_domain(name).cloned());
272            }
273        }
274
275        // Check if domain exists
276        if !self.loader.has_domain(name) {
277            let mut stats = self.stats.write().expect("lock should not be poisoned");
278            stats.cache_misses += 1;
279            return Ok(None);
280        }
281
282        // Load domain
283        self.load_domain_internal(name)?;
284
285        let table = self.loaded.read().expect("lock should not be poisoned");
286        Ok(table.get_domain(name).cloned())
287    }
288
289    /// Get a predicate, loading it if necessary.
290    pub fn get_predicate(&self, name: &str) -> Result<Option<PredicateInfo>> {
291        // Check if already loaded
292        {
293            let loaded_set = self
294                .loaded_predicates
295                .read()
296                .expect("lock should not be poisoned");
297            if loaded_set.contains(name) {
298                let table = self.loaded.read().expect("lock should not be poisoned");
299                let mut stats = self.stats.write().expect("lock should not be poisoned");
300                stats.cache_hits += 1;
301                return Ok(table.get_predicate(name).cloned());
302            }
303        }
304
305        // Check if predicate exists
306        if !self.loader.has_predicate(name) {
307            let mut stats = self.stats.write().expect("lock should not be poisoned");
308            stats.cache_misses += 1;
309            return Ok(None);
310        }
311
312        // Load predicate
313        self.load_predicate_internal(name)?;
314
315        let table = self.loaded.read().expect("lock should not be poisoned");
316        Ok(table.get_predicate(name).cloned())
317    }
318
319    /// List all available domains (without loading them).
320    pub fn list_domains(&self) -> Result<Vec<String>> {
321        self.loader.list_domains()
322    }
323
324    /// List all available predicates (without loading them).
325    pub fn list_predicates(&self) -> Result<Vec<String>> {
326        self.loader.list_predicates()
327    }
328
329    /// Preload a batch of domains.
330    pub fn preload_domains(&self, names: &[String]) -> Result<()> {
331        let domains = self.loader.load_domains_batch(names)?;
332        let mut table = self.loaded.write().expect("lock should not be poisoned");
333        let mut loaded_set = self
334            .loaded_domains
335            .write()
336            .expect("lock should not be poisoned");
337        let mut stats = self.stats.write().expect("lock should not be poisoned");
338
339        for domain in domains {
340            let name = domain.name.clone();
341            table.add_domain(domain).map_err(|e| anyhow::anyhow!(e))?;
342            loaded_set.insert(name);
343            stats.domain_loads += 1;
344        }
345        stats.batch_loads += 1;
346
347        Ok(())
348    }
349
350    /// Preload a batch of predicates.
351    pub fn preload_predicates(&self, names: &[String]) -> Result<()> {
352        let predicates = self.loader.load_predicates_batch(names)?;
353        let mut table = self.loaded.write().expect("lock should not be poisoned");
354        let mut loaded_set = self
355            .loaded_predicates
356            .write()
357            .expect("lock should not be poisoned");
358        let mut stats = self.stats.write().expect("lock should not be poisoned");
359
360        for predicate in predicates {
361            let name = predicate.name.clone();
362            table
363                .add_predicate(predicate)
364                .map_err(|e| anyhow::anyhow!(e))?;
365            loaded_set.insert(name);
366            stats.predicate_loads += 1;
367        }
368        stats.batch_loads += 1;
369
370        Ok(())
371    }
372
373    /// Get loading statistics.
374    pub fn stats(&self) -> LazyLoadStats {
375        self.stats
376            .read()
377            .expect("lock should not be poisoned")
378            .clone()
379    }
380
381    /// Clear the cache and force reload.
382    pub fn clear_cache(&self) {
383        let mut table = self.loaded.write().expect("lock should not be poisoned");
384        *table = SymbolTable::new();
385        self.loaded_domains
386            .write()
387            .expect("lock should not be poisoned")
388            .clear();
389        self.loaded_predicates
390            .write()
391            .expect("lock should not be poisoned")
392            .clear();
393        self.access_counts
394            .write()
395            .expect("lock should not be poisoned")
396            .clear();
397    }
398
399    /// Get the number of loaded domains.
400    pub fn loaded_domain_count(&self) -> usize {
401        self.loaded_domains
402            .read()
403            .expect("lock should not be poisoned")
404            .len()
405    }
406
407    /// Get the number of loaded predicates.
408    pub fn loaded_predicate_count(&self) -> usize {
409        self.loaded_predicates
410            .read()
411            .expect("lock should not be poisoned")
412            .len()
413    }
414
415    /// Get a read-only reference to the loaded symbol table.
416    ///
417    /// Note: This only includes loaded elements.
418    pub fn as_symbol_table(&self) -> Arc<RwLock<SymbolTable>> {
419        Arc::clone(&self.loaded)
420    }
421
422    fn load_domain_internal(&self, name: &str) -> Result<()> {
423        let domain = self.loader.load_domain(name)?;
424        let mut table = self.loaded.write().expect("lock should not be poisoned");
425        let mut loaded_set = self
426            .loaded_domains
427            .write()
428            .expect("lock should not be poisoned");
429        let mut stats = self.stats.write().expect("lock should not be poisoned");
430
431        table.add_domain(domain).map_err(|e| anyhow::anyhow!(e))?;
432        loaded_set.insert(name.to_string());
433        stats.domain_loads += 1;
434        stats.cache_misses += 1;
435
436        // Track access for predictive loading
437        {
438            let mut counts = self
439                .access_counts
440                .write()
441                .expect("lock should not be poisoned");
442            *counts.entry(name.to_string()).or_insert(0) += 1;
443        }
444
445        Ok(())
446    }
447
448    fn load_predicate_internal(&self, name: &str) -> Result<()> {
449        let predicate = self.loader.load_predicate(name)?;
450        let mut table = self.loaded.write().expect("lock should not be poisoned");
451        let mut loaded_set = self
452            .loaded_predicates
453            .write()
454            .expect("lock should not be poisoned");
455        let mut stats = self.stats.write().expect("lock should not be poisoned");
456
457        table
458            .add_predicate(predicate)
459            .map_err(|e| anyhow::anyhow!(e))?;
460        loaded_set.insert(name.to_string());
461        stats.predicate_loads += 1;
462        stats.cache_misses += 1;
463
464        // Track access for predictive loading
465        {
466            let mut counts = self
467                .access_counts
468                .write()
469                .expect("lock should not be poisoned");
470            *counts.entry(name.to_string()).or_insert(0) += 1;
471        }
472
473        Ok(())
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480    use std::collections::HashMap;
481
482    // Mock loader for testing
483    struct MockLoader {
484        domains: HashMap<String, DomainInfo>,
485        predicates: HashMap<String, PredicateInfo>,
486    }
487
488    impl MockLoader {
489        fn new() -> Self {
490            let mut domains = HashMap::new();
491            domains.insert("Person".to_string(), DomainInfo::new("Person", 100));
492            domains.insert("Location".to_string(), DomainInfo::new("Location", 50));
493
494            let mut predicates = HashMap::new();
495            predicates.insert(
496                "at".to_string(),
497                PredicateInfo::new("at", vec!["Person".to_string(), "Location".to_string()]),
498            );
499
500            Self {
501                domains,
502                predicates,
503            }
504        }
505    }
506
507    impl SchemaLoader for MockLoader {
508        fn load_domain(&self, name: &str) -> Result<DomainInfo> {
509            self.domains
510                .get(name)
511                .cloned()
512                .ok_or_else(|| anyhow::anyhow!("Domain not found: {}", name))
513        }
514
515        fn load_predicate(&self, name: &str) -> Result<PredicateInfo> {
516            self.predicates
517                .get(name)
518                .cloned()
519                .ok_or_else(|| anyhow::anyhow!("Predicate not found: {}", name))
520        }
521
522        fn has_domain(&self, name: &str) -> bool {
523            self.domains.contains_key(name)
524        }
525
526        fn has_predicate(&self, name: &str) -> bool {
527            self.predicates.contains_key(name)
528        }
529
530        fn list_domains(&self) -> Result<Vec<String>> {
531            Ok(self.domains.keys().cloned().collect())
532        }
533
534        fn list_predicates(&self) -> Result<Vec<String>> {
535            Ok(self.predicates.keys().cloned().collect())
536        }
537    }
538
539    #[test]
540    fn test_lazy_load_domain() {
541        let loader = Arc::new(MockLoader::new());
542        let lazy_table = LazySymbolTable::new(loader);
543
544        let domain = lazy_table.get_domain("Person").expect("unwrap");
545        assert!(domain.is_some());
546        assert_eq!(domain.expect("unwrap").name, "Person");
547    }
548
549    #[test]
550    fn test_lazy_load_predicate() {
551        let loader = Arc::new(MockLoader::new());
552        let lazy_table = LazySymbolTable::new(loader);
553
554        // First load predicates' domains
555        lazy_table.get_domain("Person").expect("unwrap");
556        lazy_table.get_domain("Location").expect("unwrap");
557
558        let predicate = lazy_table.get_predicate("at").expect("unwrap");
559        assert!(predicate.is_some());
560        assert_eq!(predicate.expect("unwrap").name, "at");
561    }
562
563    #[test]
564    fn test_cache_hits() {
565        let loader = Arc::new(MockLoader::new());
566        let lazy_table = LazySymbolTable::new(loader);
567
568        // First access (miss)
569        lazy_table.get_domain("Person").expect("unwrap");
570
571        // Second access (hit)
572        lazy_table.get_domain("Person").expect("unwrap");
573
574        let stats = lazy_table.stats();
575        assert_eq!(stats.cache_hits, 1);
576        assert_eq!(stats.cache_misses, 1);
577    }
578
579    #[test]
580    fn test_list_domains() {
581        let loader = Arc::new(MockLoader::new());
582        let lazy_table = LazySymbolTable::new(loader);
583
584        let domains = lazy_table.list_domains().expect("unwrap");
585        assert_eq!(domains.len(), 2);
586        assert!(domains.contains(&"Person".to_string()));
587        assert!(domains.contains(&"Location".to_string()));
588    }
589
590    #[test]
591    fn test_preload_domains() {
592        let loader = Arc::new(MockLoader::new());
593        let lazy_table = LazySymbolTable::new(loader);
594
595        let names = vec!["Person".to_string(), "Location".to_string()];
596        lazy_table.preload_domains(&names).expect("unwrap");
597
598        assert_eq!(lazy_table.loaded_domain_count(), 2);
599
600        let stats = lazy_table.stats();
601        assert_eq!(stats.batch_loads, 1);
602    }
603
604    #[test]
605    fn test_clear_cache() {
606        let loader = Arc::new(MockLoader::new());
607        let lazy_table = LazySymbolTable::new(loader);
608
609        lazy_table.get_domain("Person").expect("unwrap");
610        assert_eq!(lazy_table.loaded_domain_count(), 1);
611
612        lazy_table.clear_cache();
613        assert_eq!(lazy_table.loaded_domain_count(), 0);
614    }
615
616    #[test]
617    fn test_load_strategy() {
618        let loader = Arc::new(MockLoader::new());
619        let strategy = LoadStrategy::Predictive {
620            access_threshold: 5,
621        };
622        let lazy_table = LazySymbolTable::with_strategy(loader, strategy);
623
624        lazy_table.get_domain("Person").expect("unwrap");
625        assert_eq!(lazy_table.loaded_domain_count(), 1);
626    }
627
628    #[test]
629    fn test_hit_rate() {
630        let mut stats = LazyLoadStats::default();
631        assert_eq!(stats.hit_rate(), 0.0);
632
633        stats.cache_hits = 8;
634        stats.cache_misses = 2;
635        assert!((stats.hit_rate() - 0.8).abs() < 0.01);
636    }
637}