Skip to main content

tensorlogic_compiler/optimize/
memory_estimation.rs

1//! Memory estimation for tensor expressions.
2//!
3//! This module provides tools for estimating the memory footprint of
4//! compiled tensor expressions. This is useful for:
5//!
6//! - Planning batch sizes
7//! - Deciding on execution strategies
8//! - GPU memory allocation
9//! - Optimization decisions
10//!
11//! # Usage
12//!
13//! ```
14//! use tensorlogic_compiler::optimize::{estimate_memory, MemoryEstimate};
15//! use tensorlogic_compiler::CompilerContext;
16//! use tensorlogic_ir::{TLExpr, Term};
17//!
18//! let mut ctx = CompilerContext::new();
19//! ctx.add_domain("Person", 1000);
20//! ctx.add_domain("Thing", 500);
21//!
22//! let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
23//! let estimate = estimate_memory(&expr, &ctx);
24//!
25//! println!("Estimated memory: {} bytes", estimate.total_bytes);
26//! println!("Peak memory: {} bytes", estimate.peak_bytes);
27//! ```
28
29use crate::CompilerContext;
30use tensorlogic_ir::TLExpr;
31
32/// Detailed memory estimate for an expression.
33#[derive(Debug, Clone, Default)]
34pub struct MemoryEstimate {
35    /// Total memory needed for all tensors (in bytes)
36    pub total_bytes: usize,
37    /// Peak memory usage during execution (in bytes)
38    pub peak_bytes: usize,
39    /// Number of intermediate tensors
40    pub intermediate_count: usize,
41    /// Maximum tensor size (in elements)
42    pub max_tensor_size: usize,
43    /// Total number of elements across all tensors
44    pub total_elements: usize,
45}
46
47impl MemoryEstimate {
48    /// Get total memory in megabytes.
49    pub fn total_mb(&self) -> f64 {
50        self.total_bytes as f64 / (1024.0 * 1024.0)
51    }
52
53    /// Get peak memory in megabytes.
54    pub fn peak_mb(&self) -> f64 {
55        self.peak_bytes as f64 / (1024.0 * 1024.0)
56    }
57
58    /// Check if this exceeds a memory limit (in bytes).
59    pub fn exceeds_limit(&self, limit_bytes: usize) -> bool {
60        self.peak_bytes > limit_bytes
61    }
62
63    /// Suggest optimal batch size given a memory budget.
64    pub fn suggest_batch_size(&self, budget_bytes: usize, current_batch: usize) -> usize {
65        if self.peak_bytes == 0 {
66            return current_batch;
67        }
68
69        let memory_per_item = self.peak_bytes / current_batch.max(1);
70        if memory_per_item == 0 {
71            return current_batch;
72        }
73
74        (budget_bytes / memory_per_item).max(1)
75    }
76}
77
78impl std::fmt::Display for MemoryEstimate {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        writeln!(f, "Memory Estimate:")?;
81        writeln!(
82            f,
83            "  Total: {:.2} MB ({} bytes)",
84            self.total_mb(),
85            self.total_bytes
86        )?;
87        writeln!(
88            f,
89            "  Peak: {:.2} MB ({} bytes)",
90            self.peak_mb(),
91            self.peak_bytes
92        )?;
93        writeln!(f, "  Intermediates: {}", self.intermediate_count)?;
94        writeln!(f, "  Max tensor size: {} elements", self.max_tensor_size)?;
95        writeln!(f, "  Total elements: {}", self.total_elements)?;
96        Ok(())
97    }
98}
99
100/// Estimate memory usage for an expression.
101///
102/// This function analyzes the expression and estimates memory usage
103/// based on domain sizes in the compiler context.
104///
105/// # Arguments
106///
107/// * `expr` - The expression to analyze
108/// * `ctx` - The compiler context containing domain sizes
109///
110/// # Returns
111///
112/// A memory estimate
113pub fn estimate_memory(expr: &TLExpr, ctx: &CompilerContext) -> MemoryEstimate {
114    let mut estimate = MemoryEstimate::default();
115    let mut current_memory = 0usize;
116
117    estimate_memory_impl(expr, ctx, &mut estimate, &mut current_memory);
118
119    // Ensure peak is at least total
120    estimate.peak_bytes = estimate.peak_bytes.max(estimate.total_bytes);
121
122    estimate
123}
124
125fn estimate_memory_impl(
126    expr: &TLExpr,
127    ctx: &CompilerContext,
128    estimate: &mut MemoryEstimate,
129    current_memory: &mut usize,
130) -> usize {
131    // Size of f64 in bytes
132    const ELEM_SIZE: usize = 8;
133
134    match expr {
135        TLExpr::Pred { args, .. } => {
136            // Estimate tensor size from argument domains
137            let mut size = 1usize;
138            for arg in args {
139                if let tensorlogic_ir::Term::Var(v) = arg {
140                    // Get domain size for this variable
141                    if let Some(domain_name) = ctx.var_to_domain.get(v) {
142                        if let Some(info) = ctx.domains.get(domain_name) {
143                            size = size.saturating_mul(info.cardinality);
144                        }
145                    } else {
146                        // Unknown domain, assume default size
147                        size = size.saturating_mul(100);
148                    }
149                }
150            }
151
152            let bytes = size.saturating_mul(ELEM_SIZE);
153            estimate.total_bytes = estimate.total_bytes.saturating_add(bytes);
154            estimate.total_elements = estimate.total_elements.saturating_add(size);
155            estimate.max_tensor_size = estimate.max_tensor_size.max(size);
156            estimate.intermediate_count += 1;
157
158            *current_memory = current_memory.saturating_add(bytes);
159            estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
160
161            bytes
162        }
163
164        TLExpr::Constant(_) => {
165            // Scalar constant: 8 bytes
166            let bytes = ELEM_SIZE;
167            estimate.total_bytes = estimate.total_bytes.saturating_add(bytes);
168            estimate.total_elements = estimate.total_elements.saturating_add(1);
169            estimate.max_tensor_size = estimate.max_tensor_size.max(1);
170            estimate.intermediate_count += 1;
171
172            *current_memory = current_memory.saturating_add(bytes);
173            estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
174
175            bytes
176        }
177
178        // Binary operations: result has same shape as operands (broadcast assumed)
179        TLExpr::Add(lhs, rhs)
180        | TLExpr::Sub(lhs, rhs)
181        | TLExpr::Mul(lhs, rhs)
182        | TLExpr::Div(lhs, rhs)
183        | TLExpr::Min(lhs, rhs)
184        | TLExpr::Max(lhs, rhs) => {
185            let lhs_bytes = estimate_memory_impl(lhs, ctx, estimate, current_memory);
186            let rhs_bytes = estimate_memory_impl(rhs, ctx, estimate, current_memory);
187
188            // Result tensor: max of operand sizes
189            let result_bytes = lhs_bytes.max(rhs_bytes);
190            estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
191            estimate.intermediate_count += 1;
192
193            *current_memory = current_memory.saturating_add(result_bytes);
194            estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
195
196            // Free intermediate tensors
197            *current_memory = current_memory.saturating_sub(lhs_bytes);
198            *current_memory = current_memory.saturating_sub(rhs_bytes);
199
200            result_bytes
201        }
202
203        // Logic operations
204        TLExpr::And(lhs, rhs) | TLExpr::Or(lhs, rhs) | TLExpr::Imply(lhs, rhs) => {
205            let lhs_bytes = estimate_memory_impl(lhs, ctx, estimate, current_memory);
206            let rhs_bytes = estimate_memory_impl(rhs, ctx, estimate, current_memory);
207
208            let result_bytes = lhs_bytes.max(rhs_bytes);
209            estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
210            estimate.intermediate_count += 1;
211
212            *current_memory = current_memory.saturating_add(result_bytes);
213            estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
214
215            *current_memory = current_memory.saturating_sub(lhs_bytes);
216            *current_memory = current_memory.saturating_sub(rhs_bytes);
217
218            result_bytes
219        }
220
221        // Comparison operations
222        TLExpr::Eq(lhs, rhs)
223        | TLExpr::Lt(lhs, rhs)
224        | TLExpr::Lte(lhs, rhs)
225        | TLExpr::Gt(lhs, rhs)
226        | TLExpr::Gte(lhs, rhs) => {
227            let lhs_bytes = estimate_memory_impl(lhs, ctx, estimate, current_memory);
228            let rhs_bytes = estimate_memory_impl(rhs, ctx, estimate, current_memory);
229
230            let result_bytes = lhs_bytes.max(rhs_bytes);
231            estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
232            estimate.intermediate_count += 1;
233
234            *current_memory = current_memory.saturating_add(result_bytes);
235            estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
236
237            *current_memory = current_memory.saturating_sub(lhs_bytes);
238            *current_memory = current_memory.saturating_sub(rhs_bytes);
239
240            result_bytes
241        }
242
243        // Unary operations: same shape as input
244        TLExpr::Not(inner)
245        | TLExpr::Abs(inner)
246        | TLExpr::Sqrt(inner)
247        | TLExpr::Exp(inner)
248        | TLExpr::Log(inner)
249        | TLExpr::Score(inner)
250        | TLExpr::Floor(inner)
251        | TLExpr::Ceil(inner)
252        | TLExpr::Round(inner)
253        | TLExpr::Sin(inner)
254        | TLExpr::Cos(inner)
255        | TLExpr::Tan(inner) => {
256            let inner_bytes = estimate_memory_impl(inner, ctx, estimate, current_memory);
257
258            // Result tensor: same size as input
259            estimate.total_bytes = estimate.total_bytes.saturating_add(inner_bytes);
260            estimate.intermediate_count += 1;
261
262            *current_memory = current_memory.saturating_add(inner_bytes);
263            estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
264
265            *current_memory = current_memory.saturating_sub(inner_bytes);
266
267            inner_bytes
268        }
269
270        TLExpr::Pow(base, exp) | TLExpr::Mod(base, exp) => {
271            let base_bytes = estimate_memory_impl(base, ctx, estimate, current_memory);
272            let exp_bytes = estimate_memory_impl(exp, ctx, estimate, current_memory);
273
274            let result_bytes = base_bytes.max(exp_bytes);
275            estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
276            estimate.intermediate_count += 1;
277
278            *current_memory = current_memory.saturating_add(result_bytes);
279            estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
280
281            *current_memory = current_memory.saturating_sub(base_bytes);
282            *current_memory = current_memory.saturating_sub(exp_bytes);
283
284            result_bytes
285        }
286
287        // Quantifiers: reduce along one dimension
288        TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
289            // Get domain size for reduction
290            let domain_size = ctx
291                .domains
292                .get(domain)
293                .map(|info| info.cardinality)
294                .unwrap_or(100);
295
296            let body_bytes = estimate_memory_impl(body, ctx, estimate, current_memory);
297
298            // Result is reduced: divide by domain size
299            let result_bytes = body_bytes / domain_size.max(1);
300            let result_bytes = result_bytes.max(ELEM_SIZE); // At least one element
301
302            estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
303            estimate.intermediate_count += 1;
304
305            // Account for domain dimension in variable
306            let _ = var; // Just to silence unused warning
307
308            *current_memory = current_memory.saturating_add(result_bytes);
309            estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
310
311            *current_memory = current_memory.saturating_sub(body_bytes);
312
313            result_bytes
314        }
315
316        TLExpr::Let { value, body, .. } => {
317            let value_bytes = estimate_memory_impl(value, ctx, estimate, current_memory);
318            let body_bytes = estimate_memory_impl(body, ctx, estimate, current_memory);
319
320            // Let keeps value alive during body evaluation
321            *current_memory = current_memory.saturating_sub(value_bytes);
322
323            body_bytes
324        }
325
326        TLExpr::IfThenElse {
327            condition,
328            then_branch,
329            else_branch,
330        } => {
331            let cond_bytes = estimate_memory_impl(condition, ctx, estimate, current_memory);
332            let then_bytes = estimate_memory_impl(then_branch, ctx, estimate, current_memory);
333            let else_bytes = estimate_memory_impl(else_branch, ctx, estimate, current_memory);
334
335            // Result is max of branches
336            let result_bytes = then_bytes.max(else_bytes);
337            estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
338            estimate.intermediate_count += 1;
339
340            *current_memory = current_memory.saturating_add(result_bytes);
341            estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
342
343            *current_memory = current_memory.saturating_sub(cond_bytes);
344            *current_memory = current_memory.saturating_sub(then_bytes);
345            *current_memory = current_memory.saturating_sub(else_bytes);
346
347            result_bytes
348        }
349
350        // Modal/Temporal: similar to quantifiers
351        TLExpr::Box(inner)
352        | TLExpr::Diamond(inner)
353        | TLExpr::Next(inner)
354        | TLExpr::Eventually(inner)
355        | TLExpr::Always(inner)
356        | TLExpr::FuzzyNot { expr: inner, .. }
357        | TLExpr::WeightedRule { rule: inner, .. } => {
358            let inner_bytes = estimate_memory_impl(inner, ctx, estimate, current_memory);
359
360            // Typically reduces one dimension
361            let result_bytes = inner_bytes;
362            estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
363            estimate.intermediate_count += 1;
364
365            *current_memory = current_memory.saturating_add(result_bytes);
366            estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
367
368            *current_memory = current_memory.saturating_sub(inner_bytes);
369
370            result_bytes
371        }
372
373        TLExpr::Until { before, after } | TLExpr::WeakUntil { before, after } => {
374            let lhs_bytes = estimate_memory_impl(before, ctx, estimate, current_memory);
375            let rhs_bytes = estimate_memory_impl(after, ctx, estimate, current_memory);
376
377            let result_bytes = lhs_bytes.max(rhs_bytes);
378            estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
379            estimate.intermediate_count += 1;
380
381            *current_memory = current_memory.saturating_add(result_bytes);
382            estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
383
384            *current_memory = current_memory.saturating_sub(lhs_bytes);
385            *current_memory = current_memory.saturating_sub(rhs_bytes);
386
387            result_bytes
388        }
389
390        TLExpr::Release { released, releaser } | TLExpr::StrongRelease { released, releaser } => {
391            let lhs_bytes = estimate_memory_impl(released, ctx, estimate, current_memory);
392            let rhs_bytes = estimate_memory_impl(releaser, ctx, estimate, current_memory);
393
394            let result_bytes = lhs_bytes.max(rhs_bytes);
395            estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
396            estimate.intermediate_count += 1;
397
398            *current_memory = current_memory.saturating_add(result_bytes);
399            estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
400
401            *current_memory = current_memory.saturating_sub(lhs_bytes);
402            *current_memory = current_memory.saturating_sub(rhs_bytes);
403
404            result_bytes
405        }
406
407        TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
408            let lhs_bytes = estimate_memory_impl(left, ctx, estimate, current_memory);
409            let rhs_bytes = estimate_memory_impl(right, ctx, estimate, current_memory);
410
411            let result_bytes = lhs_bytes.max(rhs_bytes);
412            estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
413            estimate.intermediate_count += 1;
414
415            *current_memory = current_memory.saturating_add(result_bytes);
416            estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
417
418            *current_memory = current_memory.saturating_sub(lhs_bytes);
419            *current_memory = current_memory.saturating_sub(rhs_bytes);
420
421            result_bytes
422        }
423
424        TLExpr::FuzzyImplication {
425            premise,
426            conclusion,
427            ..
428        } => {
429            let lhs_bytes = estimate_memory_impl(premise, ctx, estimate, current_memory);
430            let rhs_bytes = estimate_memory_impl(conclusion, ctx, estimate, current_memory);
431
432            let result_bytes = lhs_bytes.max(rhs_bytes);
433            estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
434            estimate.intermediate_count += 1;
435
436            *current_memory = current_memory.saturating_add(result_bytes);
437            estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
438
439            *current_memory = current_memory.saturating_sub(lhs_bytes);
440            *current_memory = current_memory.saturating_sub(rhs_bytes);
441
442            result_bytes
443        }
444
445        TLExpr::Aggregate {
446            var, domain, body, ..
447        }
448        | TLExpr::SoftExists {
449            var, domain, body, ..
450        }
451        | TLExpr::SoftForAll {
452            var, domain, body, ..
453        } => {
454            let domain_size = ctx
455                .domains
456                .get(domain)
457                .map(|info| info.cardinality)
458                .unwrap_or(100);
459
460            let body_bytes = estimate_memory_impl(body, ctx, estimate, current_memory);
461
462            let result_bytes = body_bytes / domain_size.max(1);
463            let result_bytes = result_bytes.max(ELEM_SIZE);
464
465            estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
466            estimate.intermediate_count += 1;
467
468            let _ = var;
469
470            *current_memory = current_memory.saturating_add(result_bytes);
471            estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
472
473            *current_memory = current_memory.saturating_sub(body_bytes);
474
475            result_bytes
476        }
477
478        TLExpr::ProbabilisticChoice { alternatives } => {
479            let mut max_bytes = 0;
480            for (_, expr) in alternatives {
481                let bytes = estimate_memory_impl(expr, ctx, estimate, current_memory);
482                max_bytes = max_bytes.max(bytes);
483            }
484
485            estimate.total_bytes = estimate.total_bytes.saturating_add(max_bytes);
486            estimate.intermediate_count += 1;
487
488            *current_memory = current_memory.saturating_add(max_bytes);
489            estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
490
491            max_bytes
492        }
493
494        // All other expression types (enhancements)
495        _ => {
496            const ELEM_SIZE: usize = 8;
497            ELEM_SIZE
498        }
499    }
500}
501
502/// Estimate memory for a batch of similar expressions.
503///
504/// Useful for planning batch execution.
505pub fn estimate_batch_memory(
506    expr: &TLExpr,
507    ctx: &CompilerContext,
508    batch_size: usize,
509) -> MemoryEstimate {
510    let single = estimate_memory(expr, ctx);
511
512    MemoryEstimate {
513        total_bytes: single.total_bytes.saturating_mul(batch_size),
514        peak_bytes: single.peak_bytes.saturating_mul(batch_size),
515        intermediate_count: single.intermediate_count,
516        max_tensor_size: single.max_tensor_size.saturating_mul(batch_size),
517        total_elements: single.total_elements.saturating_mul(batch_size),
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524    use tensorlogic_ir::Term;
525
526    #[test]
527    fn test_constant_memory() {
528        let ctx = CompilerContext::new();
529        let expr = TLExpr::Constant(1.0);
530        let estimate = estimate_memory(&expr, &ctx);
531
532        assert_eq!(estimate.total_bytes, 8); // One f64
533        assert_eq!(estimate.total_elements, 1);
534    }
535
536    #[test]
537    fn test_predicate_memory() {
538        let mut ctx = CompilerContext::new();
539        ctx.add_domain("Person", 100);
540        ctx.add_domain("Thing", 50);
541        ctx.bind_var("x", "Person").unwrap();
542        ctx.bind_var("y", "Thing").unwrap();
543
544        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
545        let estimate = estimate_memory(&expr, &ctx);
546
547        // 100 * 50 = 5000 elements * 8 bytes = 40000 bytes
548        assert_eq!(estimate.total_bytes, 40000);
549        assert_eq!(estimate.max_tensor_size, 5000);
550    }
551
552    #[test]
553    fn test_binary_operation_memory() {
554        let mut ctx = CompilerContext::new();
555        ctx.add_domain("D", 100);
556        ctx.bind_var("x", "D").unwrap();
557
558        let a = TLExpr::pred("a", vec![Term::var("x")]);
559        let b = TLExpr::pred("b", vec![Term::var("x")]);
560        let expr = TLExpr::add(a, b);
561
562        let estimate = estimate_memory(&expr, &ctx);
563
564        // Two inputs (100 each) + one output (100) = 300 elements
565        // Actually more due to intermediate tracking
566        assert!(estimate.total_bytes > 0);
567        assert!(estimate.intermediate_count >= 3);
568    }
569
570    #[test]
571    fn test_quantifier_reduction() {
572        let mut ctx = CompilerContext::new();
573        ctx.add_domain("Person", 100);
574        ctx.add_domain("Thing", 50);
575        ctx.bind_var("x", "Person").unwrap();
576        ctx.bind_var("y", "Thing").unwrap();
577
578        let pred = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
579        let expr = TLExpr::exists("y", "Thing", pred);
580
581        let estimate = estimate_memory(&expr, &ctx);
582
583        // Input: 100*50 = 5000 elements
584        // Output after reduction: 100 elements
585        assert!(estimate.total_bytes > 0);
586        assert!(estimate.intermediate_count >= 2);
587    }
588
589    #[test]
590    fn test_peak_memory() {
591        let mut ctx = CompilerContext::new();
592        ctx.add_domain("D", 1000);
593        ctx.bind_var("x", "D").unwrap();
594
595        // Build deep expression tree
596        let mut expr = TLExpr::pred("a", vec![Term::var("x")]);
597        for i in 0..5 {
598            let name = format!("b{}", i);
599            let other = TLExpr::pred(&name, vec![Term::var("x")]);
600            expr = TLExpr::add(expr, other);
601        }
602
603        let estimate = estimate_memory(&expr, &ctx);
604
605        // Peak should be >= some input size
606        assert!(estimate.peak_bytes > 0);
607        assert!(estimate.peak_bytes >= 1000 * 8); // At least one tensor
608    }
609
610    #[test]
611    fn test_batch_estimation() {
612        let mut ctx = CompilerContext::new();
613        ctx.add_domain("D", 100);
614        ctx.bind_var("x", "D").unwrap();
615
616        let expr = TLExpr::pred("a", vec![Term::var("x")]);
617        let single = estimate_memory(&expr, &ctx);
618        let batch = estimate_batch_memory(&expr, &ctx, 10);
619
620        assert_eq!(batch.total_bytes, single.total_bytes * 10);
621        assert_eq!(batch.total_elements, single.total_elements * 10);
622    }
623
624    #[test]
625    fn test_memory_limit_check() {
626        let mut ctx = CompilerContext::new();
627        ctx.add_domain("D", 10000);
628        ctx.bind_var("x", "D").unwrap();
629        ctx.bind_var("y", "D").unwrap();
630
631        let expr = TLExpr::pred("big", vec![Term::var("x"), Term::var("y")]);
632        let estimate = estimate_memory(&expr, &ctx);
633
634        // 10000 * 10000 * 8 = 800MB
635        let mb = 1024 * 1024;
636        assert!(estimate.exceeds_limit(100 * mb)); // Should exceed 100MB
637        assert!(!estimate.exceeds_limit(1000 * mb)); // Should not exceed 1GB
638    }
639
640    #[test]
641    fn test_batch_size_suggestion() {
642        let estimate = MemoryEstimate {
643            total_bytes: 1000,
644            peak_bytes: 1000,
645            intermediate_count: 5,
646            max_tensor_size: 100,
647            total_elements: 500,
648        };
649
650        // With 5000 byte budget and 1000 bytes per item (assuming batch=1)
651        let suggested = estimate.suggest_batch_size(5000, 1);
652        assert_eq!(suggested, 5); // Can fit 5 items
653    }
654
655    #[test]
656    fn test_display() {
657        let estimate = MemoryEstimate {
658            total_bytes: 1024 * 1024,    // 1MB
659            peak_bytes: 2 * 1024 * 1024, // 2MB
660            intermediate_count: 10,
661            max_tensor_size: 10000,
662            total_elements: 50000,
663        };
664
665        let display = format!("{}", estimate);
666        assert!(display.contains("Memory Estimate"));
667        assert!(display.contains("MB"));
668    }
669
670    #[test]
671    fn test_nested_quantifiers() {
672        let mut ctx = CompilerContext::new();
673        ctx.add_domain("A", 100);
674        ctx.add_domain("B", 50);
675        ctx.bind_var("x", "A").unwrap();
676        ctx.bind_var("y", "B").unwrap();
677
678        let pred = TLExpr::pred("rel", vec![Term::var("x"), Term::var("y")]);
679        let expr = TLExpr::exists("x", "A", TLExpr::forall("y", "B", pred));
680
681        let estimate = estimate_memory(&expr, &ctx);
682
683        // Should reduce both dimensions
684        assert!(estimate.total_bytes > 0);
685        assert!(estimate.intermediate_count >= 3);
686    }
687
688    #[test]
689    fn test_mb_conversion() {
690        let estimate = MemoryEstimate {
691            total_bytes: 1024 * 1024 * 10, // 10MB
692            peak_bytes: 1024 * 1024 * 20,  // 20MB
693            ..Default::default()
694        };
695
696        assert!((estimate.total_mb() - 10.0).abs() < 0.001);
697        assert!((estimate.peak_mb() - 20.0).abs() < 0.001);
698    }
699}