1use crate::CompilerContext;
30use tensorlogic_ir::TLExpr;
31
32#[derive(Debug, Clone, Default)]
34pub struct MemoryEstimate {
35 pub total_bytes: usize,
37 pub peak_bytes: usize,
39 pub intermediate_count: usize,
41 pub max_tensor_size: usize,
43 pub total_elements: usize,
45}
46
47impl MemoryEstimate {
48 pub fn total_mb(&self) -> f64 {
50 self.total_bytes as f64 / (1024.0 * 1024.0)
51 }
52
53 pub fn peak_mb(&self) -> f64 {
55 self.peak_bytes as f64 / (1024.0 * 1024.0)
56 }
57
58 pub fn exceeds_limit(&self, limit_bytes: usize) -> bool {
60 self.peak_bytes > limit_bytes
61 }
62
63 pub fn suggest_batch_size(&self, budget_bytes: usize, current_batch: usize) -> usize {
65 if self.peak_bytes == 0 {
66 return current_batch;
67 }
68
69 let memory_per_item = self.peak_bytes / current_batch.max(1);
70 if memory_per_item == 0 {
71 return current_batch;
72 }
73
74 (budget_bytes / memory_per_item).max(1)
75 }
76}
77
78impl std::fmt::Display for MemoryEstimate {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 writeln!(f, "Memory Estimate:")?;
81 writeln!(
82 f,
83 " Total: {:.2} MB ({} bytes)",
84 self.total_mb(),
85 self.total_bytes
86 )?;
87 writeln!(
88 f,
89 " Peak: {:.2} MB ({} bytes)",
90 self.peak_mb(),
91 self.peak_bytes
92 )?;
93 writeln!(f, " Intermediates: {}", self.intermediate_count)?;
94 writeln!(f, " Max tensor size: {} elements", self.max_tensor_size)?;
95 writeln!(f, " Total elements: {}", self.total_elements)?;
96 Ok(())
97 }
98}
99
100pub fn estimate_memory(expr: &TLExpr, ctx: &CompilerContext) -> MemoryEstimate {
114 let mut estimate = MemoryEstimate::default();
115 let mut current_memory = 0usize;
116
117 estimate_memory_impl(expr, ctx, &mut estimate, &mut current_memory);
118
119 estimate.peak_bytes = estimate.peak_bytes.max(estimate.total_bytes);
121
122 estimate
123}
124
125fn estimate_memory_impl(
126 expr: &TLExpr,
127 ctx: &CompilerContext,
128 estimate: &mut MemoryEstimate,
129 current_memory: &mut usize,
130) -> usize {
131 const ELEM_SIZE: usize = 8;
133
134 match expr {
135 TLExpr::Pred { args, .. } => {
136 let mut size = 1usize;
138 for arg in args {
139 if let tensorlogic_ir::Term::Var(v) = arg {
140 if let Some(domain_name) = ctx.var_to_domain.get(v) {
142 if let Some(info) = ctx.domains.get(domain_name) {
143 size = size.saturating_mul(info.cardinality);
144 }
145 } else {
146 size = size.saturating_mul(100);
148 }
149 }
150 }
151
152 let bytes = size.saturating_mul(ELEM_SIZE);
153 estimate.total_bytes = estimate.total_bytes.saturating_add(bytes);
154 estimate.total_elements = estimate.total_elements.saturating_add(size);
155 estimate.max_tensor_size = estimate.max_tensor_size.max(size);
156 estimate.intermediate_count += 1;
157
158 *current_memory = current_memory.saturating_add(bytes);
159 estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
160
161 bytes
162 }
163
164 TLExpr::Constant(_) => {
165 let bytes = ELEM_SIZE;
167 estimate.total_bytes = estimate.total_bytes.saturating_add(bytes);
168 estimate.total_elements = estimate.total_elements.saturating_add(1);
169 estimate.max_tensor_size = estimate.max_tensor_size.max(1);
170 estimate.intermediate_count += 1;
171
172 *current_memory = current_memory.saturating_add(bytes);
173 estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
174
175 bytes
176 }
177
178 TLExpr::Add(lhs, rhs)
180 | TLExpr::Sub(lhs, rhs)
181 | TLExpr::Mul(lhs, rhs)
182 | TLExpr::Div(lhs, rhs)
183 | TLExpr::Min(lhs, rhs)
184 | TLExpr::Max(lhs, rhs) => {
185 let lhs_bytes = estimate_memory_impl(lhs, ctx, estimate, current_memory);
186 let rhs_bytes = estimate_memory_impl(rhs, ctx, estimate, current_memory);
187
188 let result_bytes = lhs_bytes.max(rhs_bytes);
190 estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
191 estimate.intermediate_count += 1;
192
193 *current_memory = current_memory.saturating_add(result_bytes);
194 estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
195
196 *current_memory = current_memory.saturating_sub(lhs_bytes);
198 *current_memory = current_memory.saturating_sub(rhs_bytes);
199
200 result_bytes
201 }
202
203 TLExpr::And(lhs, rhs) | TLExpr::Or(lhs, rhs) | TLExpr::Imply(lhs, rhs) => {
205 let lhs_bytes = estimate_memory_impl(lhs, ctx, estimate, current_memory);
206 let rhs_bytes = estimate_memory_impl(rhs, ctx, estimate, current_memory);
207
208 let result_bytes = lhs_bytes.max(rhs_bytes);
209 estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
210 estimate.intermediate_count += 1;
211
212 *current_memory = current_memory.saturating_add(result_bytes);
213 estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
214
215 *current_memory = current_memory.saturating_sub(lhs_bytes);
216 *current_memory = current_memory.saturating_sub(rhs_bytes);
217
218 result_bytes
219 }
220
221 TLExpr::Eq(lhs, rhs)
223 | TLExpr::Lt(lhs, rhs)
224 | TLExpr::Lte(lhs, rhs)
225 | TLExpr::Gt(lhs, rhs)
226 | TLExpr::Gte(lhs, rhs) => {
227 let lhs_bytes = estimate_memory_impl(lhs, ctx, estimate, current_memory);
228 let rhs_bytes = estimate_memory_impl(rhs, ctx, estimate, current_memory);
229
230 let result_bytes = lhs_bytes.max(rhs_bytes);
231 estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
232 estimate.intermediate_count += 1;
233
234 *current_memory = current_memory.saturating_add(result_bytes);
235 estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
236
237 *current_memory = current_memory.saturating_sub(lhs_bytes);
238 *current_memory = current_memory.saturating_sub(rhs_bytes);
239
240 result_bytes
241 }
242
243 TLExpr::Not(inner)
245 | TLExpr::Abs(inner)
246 | TLExpr::Sqrt(inner)
247 | TLExpr::Exp(inner)
248 | TLExpr::Log(inner)
249 | TLExpr::Score(inner)
250 | TLExpr::Floor(inner)
251 | TLExpr::Ceil(inner)
252 | TLExpr::Round(inner)
253 | TLExpr::Sin(inner)
254 | TLExpr::Cos(inner)
255 | TLExpr::Tan(inner) => {
256 let inner_bytes = estimate_memory_impl(inner, ctx, estimate, current_memory);
257
258 estimate.total_bytes = estimate.total_bytes.saturating_add(inner_bytes);
260 estimate.intermediate_count += 1;
261
262 *current_memory = current_memory.saturating_add(inner_bytes);
263 estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
264
265 *current_memory = current_memory.saturating_sub(inner_bytes);
266
267 inner_bytes
268 }
269
270 TLExpr::Pow(base, exp) | TLExpr::Mod(base, exp) => {
271 let base_bytes = estimate_memory_impl(base, ctx, estimate, current_memory);
272 let exp_bytes = estimate_memory_impl(exp, ctx, estimate, current_memory);
273
274 let result_bytes = base_bytes.max(exp_bytes);
275 estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
276 estimate.intermediate_count += 1;
277
278 *current_memory = current_memory.saturating_add(result_bytes);
279 estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
280
281 *current_memory = current_memory.saturating_sub(base_bytes);
282 *current_memory = current_memory.saturating_sub(exp_bytes);
283
284 result_bytes
285 }
286
287 TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
289 let domain_size = ctx
291 .domains
292 .get(domain)
293 .map(|info| info.cardinality)
294 .unwrap_or(100);
295
296 let body_bytes = estimate_memory_impl(body, ctx, estimate, current_memory);
297
298 let result_bytes = body_bytes / domain_size.max(1);
300 let result_bytes = result_bytes.max(ELEM_SIZE); estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
303 estimate.intermediate_count += 1;
304
305 let _ = var; *current_memory = current_memory.saturating_add(result_bytes);
309 estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
310
311 *current_memory = current_memory.saturating_sub(body_bytes);
312
313 result_bytes
314 }
315
316 TLExpr::Let { value, body, .. } => {
317 let value_bytes = estimate_memory_impl(value, ctx, estimate, current_memory);
318 let body_bytes = estimate_memory_impl(body, ctx, estimate, current_memory);
319
320 *current_memory = current_memory.saturating_sub(value_bytes);
322
323 body_bytes
324 }
325
326 TLExpr::IfThenElse {
327 condition,
328 then_branch,
329 else_branch,
330 } => {
331 let cond_bytes = estimate_memory_impl(condition, ctx, estimate, current_memory);
332 let then_bytes = estimate_memory_impl(then_branch, ctx, estimate, current_memory);
333 let else_bytes = estimate_memory_impl(else_branch, ctx, estimate, current_memory);
334
335 let result_bytes = then_bytes.max(else_bytes);
337 estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
338 estimate.intermediate_count += 1;
339
340 *current_memory = current_memory.saturating_add(result_bytes);
341 estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
342
343 *current_memory = current_memory.saturating_sub(cond_bytes);
344 *current_memory = current_memory.saturating_sub(then_bytes);
345 *current_memory = current_memory.saturating_sub(else_bytes);
346
347 result_bytes
348 }
349
350 TLExpr::Box(inner)
352 | TLExpr::Diamond(inner)
353 | TLExpr::Next(inner)
354 | TLExpr::Eventually(inner)
355 | TLExpr::Always(inner)
356 | TLExpr::FuzzyNot { expr: inner, .. }
357 | TLExpr::WeightedRule { rule: inner, .. } => {
358 let inner_bytes = estimate_memory_impl(inner, ctx, estimate, current_memory);
359
360 let result_bytes = inner_bytes;
362 estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
363 estimate.intermediate_count += 1;
364
365 *current_memory = current_memory.saturating_add(result_bytes);
366 estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
367
368 *current_memory = current_memory.saturating_sub(inner_bytes);
369
370 result_bytes
371 }
372
373 TLExpr::Until { before, after } | TLExpr::WeakUntil { before, after } => {
374 let lhs_bytes = estimate_memory_impl(before, ctx, estimate, current_memory);
375 let rhs_bytes = estimate_memory_impl(after, ctx, estimate, current_memory);
376
377 let result_bytes = lhs_bytes.max(rhs_bytes);
378 estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
379 estimate.intermediate_count += 1;
380
381 *current_memory = current_memory.saturating_add(result_bytes);
382 estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
383
384 *current_memory = current_memory.saturating_sub(lhs_bytes);
385 *current_memory = current_memory.saturating_sub(rhs_bytes);
386
387 result_bytes
388 }
389
390 TLExpr::Release { released, releaser } | TLExpr::StrongRelease { released, releaser } => {
391 let lhs_bytes = estimate_memory_impl(released, ctx, estimate, current_memory);
392 let rhs_bytes = estimate_memory_impl(releaser, ctx, estimate, current_memory);
393
394 let result_bytes = lhs_bytes.max(rhs_bytes);
395 estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
396 estimate.intermediate_count += 1;
397
398 *current_memory = current_memory.saturating_add(result_bytes);
399 estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
400
401 *current_memory = current_memory.saturating_sub(lhs_bytes);
402 *current_memory = current_memory.saturating_sub(rhs_bytes);
403
404 result_bytes
405 }
406
407 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
408 let lhs_bytes = estimate_memory_impl(left, ctx, estimate, current_memory);
409 let rhs_bytes = estimate_memory_impl(right, ctx, estimate, current_memory);
410
411 let result_bytes = lhs_bytes.max(rhs_bytes);
412 estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
413 estimate.intermediate_count += 1;
414
415 *current_memory = current_memory.saturating_add(result_bytes);
416 estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
417
418 *current_memory = current_memory.saturating_sub(lhs_bytes);
419 *current_memory = current_memory.saturating_sub(rhs_bytes);
420
421 result_bytes
422 }
423
424 TLExpr::FuzzyImplication {
425 premise,
426 conclusion,
427 ..
428 } => {
429 let lhs_bytes = estimate_memory_impl(premise, ctx, estimate, current_memory);
430 let rhs_bytes = estimate_memory_impl(conclusion, ctx, estimate, current_memory);
431
432 let result_bytes = lhs_bytes.max(rhs_bytes);
433 estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
434 estimate.intermediate_count += 1;
435
436 *current_memory = current_memory.saturating_add(result_bytes);
437 estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
438
439 *current_memory = current_memory.saturating_sub(lhs_bytes);
440 *current_memory = current_memory.saturating_sub(rhs_bytes);
441
442 result_bytes
443 }
444
445 TLExpr::Aggregate {
446 var, domain, body, ..
447 }
448 | TLExpr::SoftExists {
449 var, domain, body, ..
450 }
451 | TLExpr::SoftForAll {
452 var, domain, body, ..
453 } => {
454 let domain_size = ctx
455 .domains
456 .get(domain)
457 .map(|info| info.cardinality)
458 .unwrap_or(100);
459
460 let body_bytes = estimate_memory_impl(body, ctx, estimate, current_memory);
461
462 let result_bytes = body_bytes / domain_size.max(1);
463 let result_bytes = result_bytes.max(ELEM_SIZE);
464
465 estimate.total_bytes = estimate.total_bytes.saturating_add(result_bytes);
466 estimate.intermediate_count += 1;
467
468 let _ = var;
469
470 *current_memory = current_memory.saturating_add(result_bytes);
471 estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
472
473 *current_memory = current_memory.saturating_sub(body_bytes);
474
475 result_bytes
476 }
477
478 TLExpr::ProbabilisticChoice { alternatives } => {
479 let mut max_bytes = 0;
480 for (_, expr) in alternatives {
481 let bytes = estimate_memory_impl(expr, ctx, estimate, current_memory);
482 max_bytes = max_bytes.max(bytes);
483 }
484
485 estimate.total_bytes = estimate.total_bytes.saturating_add(max_bytes);
486 estimate.intermediate_count += 1;
487
488 *current_memory = current_memory.saturating_add(max_bytes);
489 estimate.peak_bytes = estimate.peak_bytes.max(*current_memory);
490
491 max_bytes
492 }
493
494 _ => {
496 const ELEM_SIZE: usize = 8;
497 ELEM_SIZE
498 }
499 }
500}
501
502pub fn estimate_batch_memory(
506 expr: &TLExpr,
507 ctx: &CompilerContext,
508 batch_size: usize,
509) -> MemoryEstimate {
510 let single = estimate_memory(expr, ctx);
511
512 MemoryEstimate {
513 total_bytes: single.total_bytes.saturating_mul(batch_size),
514 peak_bytes: single.peak_bytes.saturating_mul(batch_size),
515 intermediate_count: single.intermediate_count,
516 max_tensor_size: single.max_tensor_size.saturating_mul(batch_size),
517 total_elements: single.total_elements.saturating_mul(batch_size),
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524 use tensorlogic_ir::Term;
525
526 #[test]
527 fn test_constant_memory() {
528 let ctx = CompilerContext::new();
529 let expr = TLExpr::Constant(1.0);
530 let estimate = estimate_memory(&expr, &ctx);
531
532 assert_eq!(estimate.total_bytes, 8); assert_eq!(estimate.total_elements, 1);
534 }
535
536 #[test]
537 fn test_predicate_memory() {
538 let mut ctx = CompilerContext::new();
539 ctx.add_domain("Person", 100);
540 ctx.add_domain("Thing", 50);
541 ctx.bind_var("x", "Person").unwrap();
542 ctx.bind_var("y", "Thing").unwrap();
543
544 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
545 let estimate = estimate_memory(&expr, &ctx);
546
547 assert_eq!(estimate.total_bytes, 40000);
549 assert_eq!(estimate.max_tensor_size, 5000);
550 }
551
552 #[test]
553 fn test_binary_operation_memory() {
554 let mut ctx = CompilerContext::new();
555 ctx.add_domain("D", 100);
556 ctx.bind_var("x", "D").unwrap();
557
558 let a = TLExpr::pred("a", vec![Term::var("x")]);
559 let b = TLExpr::pred("b", vec![Term::var("x")]);
560 let expr = TLExpr::add(a, b);
561
562 let estimate = estimate_memory(&expr, &ctx);
563
564 assert!(estimate.total_bytes > 0);
567 assert!(estimate.intermediate_count >= 3);
568 }
569
570 #[test]
571 fn test_quantifier_reduction() {
572 let mut ctx = CompilerContext::new();
573 ctx.add_domain("Person", 100);
574 ctx.add_domain("Thing", 50);
575 ctx.bind_var("x", "Person").unwrap();
576 ctx.bind_var("y", "Thing").unwrap();
577
578 let pred = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
579 let expr = TLExpr::exists("y", "Thing", pred);
580
581 let estimate = estimate_memory(&expr, &ctx);
582
583 assert!(estimate.total_bytes > 0);
586 assert!(estimate.intermediate_count >= 2);
587 }
588
589 #[test]
590 fn test_peak_memory() {
591 let mut ctx = CompilerContext::new();
592 ctx.add_domain("D", 1000);
593 ctx.bind_var("x", "D").unwrap();
594
595 let mut expr = TLExpr::pred("a", vec![Term::var("x")]);
597 for i in 0..5 {
598 let name = format!("b{}", i);
599 let other = TLExpr::pred(&name, vec![Term::var("x")]);
600 expr = TLExpr::add(expr, other);
601 }
602
603 let estimate = estimate_memory(&expr, &ctx);
604
605 assert!(estimate.peak_bytes > 0);
607 assert!(estimate.peak_bytes >= 1000 * 8); }
609
610 #[test]
611 fn test_batch_estimation() {
612 let mut ctx = CompilerContext::new();
613 ctx.add_domain("D", 100);
614 ctx.bind_var("x", "D").unwrap();
615
616 let expr = TLExpr::pred("a", vec![Term::var("x")]);
617 let single = estimate_memory(&expr, &ctx);
618 let batch = estimate_batch_memory(&expr, &ctx, 10);
619
620 assert_eq!(batch.total_bytes, single.total_bytes * 10);
621 assert_eq!(batch.total_elements, single.total_elements * 10);
622 }
623
624 #[test]
625 fn test_memory_limit_check() {
626 let mut ctx = CompilerContext::new();
627 ctx.add_domain("D", 10000);
628 ctx.bind_var("x", "D").unwrap();
629 ctx.bind_var("y", "D").unwrap();
630
631 let expr = TLExpr::pred("big", vec![Term::var("x"), Term::var("y")]);
632 let estimate = estimate_memory(&expr, &ctx);
633
634 let mb = 1024 * 1024;
636 assert!(estimate.exceeds_limit(100 * mb)); assert!(!estimate.exceeds_limit(1000 * mb)); }
639
640 #[test]
641 fn test_batch_size_suggestion() {
642 let estimate = MemoryEstimate {
643 total_bytes: 1000,
644 peak_bytes: 1000,
645 intermediate_count: 5,
646 max_tensor_size: 100,
647 total_elements: 500,
648 };
649
650 let suggested = estimate.suggest_batch_size(5000, 1);
652 assert_eq!(suggested, 5); }
654
655 #[test]
656 fn test_display() {
657 let estimate = MemoryEstimate {
658 total_bytes: 1024 * 1024, peak_bytes: 2 * 1024 * 1024, intermediate_count: 10,
661 max_tensor_size: 10000,
662 total_elements: 50000,
663 };
664
665 let display = format!("{}", estimate);
666 assert!(display.contains("Memory Estimate"));
667 assert!(display.contains("MB"));
668 }
669
670 #[test]
671 fn test_nested_quantifiers() {
672 let mut ctx = CompilerContext::new();
673 ctx.add_domain("A", 100);
674 ctx.add_domain("B", 50);
675 ctx.bind_var("x", "A").unwrap();
676 ctx.bind_var("y", "B").unwrap();
677
678 let pred = TLExpr::pred("rel", vec![Term::var("x"), Term::var("y")]);
679 let expr = TLExpr::exists("x", "A", TLExpr::forall("y", "B", pred));
680
681 let estimate = estimate_memory(&expr, &ctx);
682
683 assert!(estimate.total_bytes > 0);
685 assert!(estimate.intermediate_count >= 3);
686 }
687
688 #[test]
689 fn test_mb_conversion() {
690 let estimate = MemoryEstimate {
691 total_bytes: 1024 * 1024 * 10, peak_bytes: 1024 * 1024 * 20, ..Default::default()
694 };
695
696 assert!((estimate.total_mb() - 10.0).abs() < 0.001);
697 assert!((estimate.peak_mb() - 20.0).abs() < 0.001);
698 }
699}