1use serde::{Deserialize, Serialize};
32use std::collections::HashSet;
33use std::fmt;
34
35use crate::IrError;
36
37#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
39pub enum ComputationalEffect {
40 Pure,
42 Impure,
44 IO,
46}
47
48impl fmt::Display for ComputationalEffect {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 match self {
51 ComputationalEffect::Pure => write!(f, "Pure"),
52 ComputationalEffect::Impure => write!(f, "Impure"),
53 ComputationalEffect::IO => write!(f, "IO"),
54 }
55 }
56}
57
58#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub enum MemoryEffect {
61 ReadOnly,
63 ReadWrite,
65 Allocating,
67}
68
69impl fmt::Display for MemoryEffect {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 match self {
72 MemoryEffect::ReadOnly => write!(f, "ReadOnly"),
73 MemoryEffect::ReadWrite => write!(f, "ReadWrite"),
74 MemoryEffect::Allocating => write!(f, "Allocating"),
75 }
76 }
77}
78
79#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
81pub enum ProbabilisticEffect {
82 Deterministic,
84 Stochastic,
86}
87
88impl fmt::Display for ProbabilisticEffect {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 match self {
91 ProbabilisticEffect::Deterministic => write!(f, "Deterministic"),
92 ProbabilisticEffect::Stochastic => write!(f, "Stochastic"),
93 }
94 }
95}
96
97#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
99pub enum Effect {
100 Computational(ComputationalEffect),
102 Memory(MemoryEffect),
104 Probabilistic(ProbabilisticEffect),
106 Differentiable,
108 NonDifferentiable,
110 Async,
112 Parallel,
114 Custom(String),
116}
117
118impl Effect {
119 pub fn is_pure(&self) -> bool {
121 matches!(self, Effect::Computational(ComputationalEffect::Pure))
122 }
123
124 pub fn is_impure(&self) -> bool {
126 matches!(
127 self,
128 Effect::Computational(ComputationalEffect::Impure | ComputationalEffect::IO)
129 )
130 }
131
132 pub fn is_differentiable(&self) -> bool {
134 matches!(self, Effect::Differentiable)
135 }
136
137 pub fn is_stochastic(&self) -> bool {
139 matches!(self, Effect::Probabilistic(ProbabilisticEffect::Stochastic))
140 }
141}
142
143impl fmt::Display for Effect {
144 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145 match self {
146 Effect::Computational(e) => write!(f, "{}", e),
147 Effect::Memory(e) => write!(f, "{}", e),
148 Effect::Probabilistic(e) => write!(f, "{}", e),
149 Effect::Differentiable => write!(f, "Diff"),
150 Effect::NonDifferentiable => write!(f, "NonDiff"),
151 Effect::Async => write!(f, "Async"),
152 Effect::Parallel => write!(f, "Parallel"),
153 Effect::Custom(name) => write!(f, "{}", name),
154 }
155 }
156}
157
158#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
160pub struct EffectSet {
161 effects: HashSet<Effect>,
162}
163
164impl EffectSet {
165 pub fn new() -> Self {
167 EffectSet {
168 effects: HashSet::new(),
169 }
170 }
171
172 pub fn pure() -> Self {
174 let mut effects = HashSet::new();
175 effects.insert(Effect::Computational(ComputationalEffect::Pure));
176 effects.insert(Effect::Probabilistic(ProbabilisticEffect::Deterministic));
177 effects.insert(Effect::Memory(MemoryEffect::ReadOnly));
178 EffectSet { effects }
179 }
180
181 pub fn impure() -> Self {
183 let mut effects = HashSet::new();
184 effects.insert(Effect::Computational(ComputationalEffect::Impure));
185 EffectSet { effects }
186 }
187
188 pub fn differentiable() -> Self {
190 let mut effects = HashSet::new();
191 effects.insert(Effect::Differentiable);
192 EffectSet { effects }
193 }
194
195 pub fn stochastic() -> Self {
197 let mut effects = HashSet::new();
198 effects.insert(Effect::Probabilistic(ProbabilisticEffect::Stochastic));
199 EffectSet { effects }
200 }
201
202 pub fn with(mut self, effect: Effect) -> Self {
204 self.effects.insert(effect);
205 self
206 }
207
208 pub fn with_all(mut self, effects: impl IntoIterator<Item = Effect>) -> Self {
210 self.effects.extend(effects);
211 self
212 }
213
214 pub fn contains(&self, effect: &Effect) -> bool {
216 self.effects.contains(effect)
217 }
218
219 pub fn is_pure(&self) -> bool {
221 if self.effects.is_empty() {
223 return true;
224 }
225
226 let has_pure = self
227 .effects
228 .iter()
229 .any(|e| matches!(e, Effect::Computational(ComputationalEffect::Pure)));
230
231 let has_impure = self.effects.iter().any(|e| {
232 matches!(
233 e,
234 Effect::Computational(ComputationalEffect::Impure | ComputationalEffect::IO)
235 )
236 });
237
238 has_pure && !has_impure
239 }
240
241 pub fn is_impure(&self) -> bool {
243 self.effects.iter().any(|e| e.is_impure())
244 }
245
246 pub fn is_differentiable(&self) -> bool {
248 self.effects.iter().any(|e| e.is_differentiable())
249 && !self
250 .effects
251 .iter()
252 .any(|e| matches!(e, Effect::NonDifferentiable))
253 }
254
255 pub fn is_stochastic(&self) -> bool {
257 self.effects.iter().any(|e| e.is_stochastic())
258 }
259
260 pub fn effects(&self) -> impl Iterator<Item = &Effect> {
262 self.effects.iter()
263 }
264
265 pub fn union(&self, other: &EffectSet) -> EffectSet {
267 let mut effects = self.effects.clone();
268 effects.extend(other.effects.iter().cloned());
269 EffectSet { effects }
270 }
271
272 pub fn intersection(&self, other: &EffectSet) -> EffectSet {
274 let effects = self.effects.intersection(&other.effects).cloned().collect();
275 EffectSet { effects }
276 }
277
278 pub fn is_subset_of(&self, other: &EffectSet) -> bool {
280 self.effects.is_subset(&other.effects)
281 }
282
283 pub fn is_compatible_with(&self, other: &EffectSet) -> bool {
285 !self.has_conflicts_with(other)
287 }
288
289 fn has_conflicts_with(&self, other: &EffectSet) -> bool {
291 if (self.contains(&Effect::Computational(ComputationalEffect::Pure)) && other.is_impure())
293 || (other.contains(&Effect::Computational(ComputationalEffect::Pure))
294 && self.is_impure())
295 {
296 return true;
297 }
298
299 if (self.contains(&Effect::Differentiable) && other.contains(&Effect::NonDifferentiable))
301 || (other.contains(&Effect::Differentiable)
302 && self.contains(&Effect::NonDifferentiable))
303 {
304 return true;
305 }
306
307 false
308 }
309
310 pub fn len(&self) -> usize {
312 self.effects.len()
313 }
314
315 pub fn is_empty(&self) -> bool {
317 self.effects.is_empty()
318 }
319}
320
321impl Default for EffectSet {
322 fn default() -> Self {
323 Self::new()
324 }
325}
326
327impl fmt::Display for EffectSet {
328 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329 if self.effects.is_empty() {
330 return write!(f, "{{}}");
331 }
332
333 write!(f, "{{")?;
334 let mut first = true;
335 for effect in &self.effects {
336 if !first {
337 write!(f, ", ")?;
338 }
339 write!(f, "{}", effect)?;
340 first = false;
341 }
342 write!(f, "}}")
343 }
344}
345
346#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
348pub struct EffectVar(pub String);
349
350impl EffectVar {
351 pub fn new(name: impl Into<String>) -> Self {
352 EffectVar(name.into())
353 }
354}
355
356impl fmt::Display for EffectVar {
357 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358 write!(f, "ε{}", self.0)
359 }
360}
361
362#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
364pub enum EffectScheme {
365 Concrete(EffectSet),
367 Variable(EffectVar),
369 Union(Box<EffectScheme>, Box<EffectScheme>),
371}
372
373impl EffectScheme {
374 pub fn concrete(effects: EffectSet) -> Self {
376 EffectScheme::Concrete(effects)
377 }
378
379 pub fn variable(name: impl Into<String>) -> Self {
381 EffectScheme::Variable(EffectVar::new(name))
382 }
383
384 pub fn union(e1: EffectScheme, e2: EffectScheme) -> Self {
386 EffectScheme::Union(Box::new(e1), Box::new(e2))
387 }
388
389 pub fn substitute(&self, subst: &EffectSubstitution) -> EffectScheme {
391 match self {
392 EffectScheme::Concrete(effects) => EffectScheme::Concrete(effects.clone()),
393 EffectScheme::Variable(var) => {
394 if let Some(effects) = subst.get(var) {
395 EffectScheme::Concrete(effects.clone())
396 } else {
397 EffectScheme::Variable(var.clone())
398 }
399 }
400 EffectScheme::Union(e1, e2) => {
401 let s1 = e1.substitute(subst);
402 let s2 = e2.substitute(subst);
403 EffectScheme::union(s1, s2)
404 }
405 }
406 }
407
408 pub fn evaluate(&self, subst: &EffectSubstitution) -> Result<EffectSet, IrError> {
410 match self {
411 EffectScheme::Concrete(effects) => Ok(effects.clone()),
412 EffectScheme::Variable(var) => {
413 subst
414 .get(var)
415 .cloned()
416 .ok_or_else(|| IrError::UnboundVariable {
417 var: format!("effect variable {}", var),
418 })
419 }
420 EffectScheme::Union(e1, e2) => {
421 let effects1 = e1.evaluate(subst)?;
422 let effects2 = e2.evaluate(subst)?;
423 Ok(effects1.union(&effects2))
424 }
425 }
426 }
427}
428
429impl fmt::Display for EffectScheme {
430 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
431 match self {
432 EffectScheme::Concrete(effects) => write!(f, "{}", effects),
433 EffectScheme::Variable(var) => write!(f, "{}", var),
434 EffectScheme::Union(e1, e2) => write!(f, "({} ∪ {})", e1, e2),
435 }
436 }
437}
438
439pub type EffectSubstitution = std::collections::HashMap<EffectVar, EffectSet>;
441
442#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
444pub struct EffectAnnotation {
445 pub scheme: EffectScheme,
447 pub description: Option<String>,
449}
450
451impl EffectAnnotation {
452 pub fn new(scheme: EffectScheme) -> Self {
453 EffectAnnotation {
454 scheme,
455 description: None,
456 }
457 }
458
459 pub fn with_description(mut self, description: impl Into<String>) -> Self {
460 self.description = Some(description.into());
461 self
462 }
463
464 pub fn pure() -> Self {
466 EffectAnnotation::new(EffectScheme::concrete(EffectSet::pure()))
467 }
468
469 pub fn differentiable() -> Self {
471 EffectAnnotation::new(EffectScheme::concrete(EffectSet::differentiable()))
472 }
473}
474
475pub fn infer_operation_effects(op_name: &str) -> EffectSet {
477 match op_name {
478 "and" | "or" | "not" | "implies" => EffectSet::pure().with(Effect::Differentiable),
480
481 "add" | "subtract" | "multiply" | "divide" => {
483 EffectSet::pure().with(Effect::Differentiable)
484 }
485
486 "exists" | "forall" => EffectSet::pure(),
488
489 "equal" | "less_than" | "greater_than" => EffectSet::pure().with(Effect::NonDifferentiable),
491
492 "sample" | "random" => EffectSet::stochastic().with(Effect::NonDifferentiable),
494
495 "read" | "write" => EffectSet::new()
497 .with(Effect::Computational(ComputationalEffect::IO))
498 .with(Effect::Memory(MemoryEffect::ReadWrite)),
499
500 _ => EffectSet::impure().with(Effect::NonDifferentiable),
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn test_effect_creation() {
511 let pure = Effect::Computational(ComputationalEffect::Pure);
512 assert!(pure.is_pure());
513 assert!(!pure.is_impure());
514
515 let impure = Effect::Computational(ComputationalEffect::Impure);
516 assert!(!impure.is_pure());
517 assert!(impure.is_impure());
518
519 let diff = Effect::Differentiable;
520 assert!(diff.is_differentiable());
521 }
522
523 #[test]
524 fn test_effect_set_pure() {
525 let pure_set = EffectSet::pure();
526 assert!(pure_set.is_pure());
527 assert!(!pure_set.is_impure());
528 assert!(pure_set.contains(&Effect::Computational(ComputationalEffect::Pure)));
529 }
530
531 #[test]
532 fn test_effect_set_differentiable() {
533 let diff_set = EffectSet::differentiable();
534 assert!(diff_set.is_differentiable());
535 assert!(diff_set.contains(&Effect::Differentiable));
536 }
537
538 #[test]
539 fn test_effect_set_union() {
540 let pure = EffectSet::pure();
541 let diff = EffectSet::differentiable();
542 let combined = pure.union(&diff);
543
544 assert!(combined.contains(&Effect::Computational(ComputationalEffect::Pure)));
545 assert!(combined.contains(&Effect::Differentiable));
546 }
547
548 #[test]
549 fn test_effect_set_intersection() {
550 let set1 = EffectSet::pure().with(Effect::Differentiable);
551 let set2 = EffectSet::differentiable();
552 let intersection = set1.intersection(&set2);
553
554 assert!(intersection.contains(&Effect::Differentiable));
555 assert!(!intersection.contains(&Effect::Computational(ComputationalEffect::Pure)));
556 }
557
558 #[test]
559 fn test_effect_set_subset() {
560 let small = EffectSet::pure();
561 let large = EffectSet::pure().with(Effect::Differentiable);
562
563 assert!(small.is_subset_of(&large));
564 assert!(!large.is_subset_of(&small));
565 }
566
567 #[test]
568 fn test_effect_conflicts() {
569 let pure = EffectSet::pure();
570 let impure = EffectSet::impure();
571
572 assert!(!pure.is_compatible_with(&impure));
573 assert!(!impure.is_compatible_with(&pure));
574 }
575
576 #[test]
577 fn test_effect_scheme_concrete() {
578 let scheme = EffectScheme::concrete(EffectSet::pure());
579 let subst = EffectSubstitution::new();
580 let effects = scheme.evaluate(&subst).unwrap();
581
582 assert!(effects.is_pure());
583 }
584
585 #[test]
586 fn test_effect_scheme_variable() {
587 let var = EffectVar::new("e1");
588 let scheme = EffectScheme::Variable(var.clone());
589
590 let mut subst = EffectSubstitution::new();
591 subst.insert(var, EffectSet::pure());
592
593 let effects = scheme.evaluate(&subst).unwrap();
594 assert!(effects.is_pure());
595 }
596
597 #[test]
598 fn test_effect_scheme_union() {
599 let scheme1 = EffectScheme::concrete(EffectSet::pure());
600 let scheme2 = EffectScheme::concrete(EffectSet::differentiable());
601 let union_scheme = EffectScheme::union(scheme1, scheme2);
602
603 let subst = EffectSubstitution::new();
604 let effects = union_scheme.evaluate(&subst).unwrap();
605
606 assert!(effects.is_pure());
607 assert!(effects.is_differentiable());
608 }
609
610 #[test]
611 fn test_effect_annotation() {
612 let annotation = EffectAnnotation::pure().with_description("Pure computation");
613
614 assert_eq!(annotation.description.as_deref(), Some("Pure computation"));
615 }
616
617 #[test]
618 fn test_infer_operation_effects() {
619 let and_effects = infer_operation_effects("and");
620 assert!(and_effects.is_pure());
621 assert!(and_effects.is_differentiable());
622
623 let sample_effects = infer_operation_effects("sample");
624 assert!(sample_effects.is_stochastic());
625
626 let io_effects = infer_operation_effects("read");
627 assert!(io_effects.is_impure());
628 }
629
630 #[test]
631 fn test_effect_set_stochastic() {
632 let stochastic = EffectSet::stochastic();
633 assert!(stochastic.is_stochastic());
634 assert!(stochastic.contains(&Effect::Probabilistic(ProbabilisticEffect::Stochastic)));
635 }
636
637 #[test]
638 fn test_memory_effects() {
639 let read_only = Effect::Memory(MemoryEffect::ReadOnly);
640 let read_write = Effect::Memory(MemoryEffect::ReadWrite);
641
642 let set1 = EffectSet::new().with(read_only);
643 let set2 = EffectSet::new().with(read_write);
644
645 assert_ne!(set1, set2);
646 }
647
648 #[test]
649 fn test_custom_effect() {
650 let custom = Effect::Custom("GPUCompute".to_string());
651 let set = EffectSet::new().with(custom.clone());
652
653 assert!(set.contains(&custom));
654 }
655
656 #[test]
657 fn test_effect_display() {
658 let pure = Effect::Computational(ComputationalEffect::Pure);
659 assert_eq!(pure.to_string(), "Pure");
660
661 let diff = Effect::Differentiable;
662 assert_eq!(diff.to_string(), "Diff");
663
664 let custom = Effect::Custom("MyEffect".to_string());
665 assert_eq!(custom.to_string(), "MyEffect");
666 }
667
668 #[test]
669 fn test_effect_set_display() {
670 let set = EffectSet::pure().with(Effect::Differentiable);
671 let display = set.to_string();
672
673 assert!(display.contains("Pure") || display.contains("Diff"));
674 assert!(display.starts_with('{'));
675 assert!(display.ends_with('}'));
676 }
677
678 #[test]
679 fn test_effect_var_display() {
680 let var = EffectVar::new("1");
681 assert_eq!(var.to_string(), "ε1");
682 }
683
684 #[test]
685 fn test_non_differentiable_conflicts() {
686 let diff = EffectSet::new().with(Effect::Differentiable);
687 let non_diff = EffectSet::new().with(Effect::NonDifferentiable);
688
689 assert!(!diff.is_compatible_with(&non_diff));
690 }
691}