Skip to main content

tensorlogic_ir/
effect_system.rs

1//! Effect system for tracking computational effects in TensorLogic expressions.
2//!
3//! This module provides an effect system that tracks various kinds of computational
4//! effects in logical expressions and tensor operations, enabling:
5//!
6//! - **Effect tracking**: Know which operations have side effects
7//! - **Differentiability**: Track which operations support gradient computation
8//! - **Probabilistic reasoning**: Distinguish deterministic from stochastic operations
9//! - **Memory safety**: Track memory access patterns
10//! - **Effect polymorphism**: Functions parametric over effects
11//!
12//! # Examples
13//!
14//! ```
15//! use tensorlogic_ir::effect_system::{Effect, EffectSet, ComputationalEffect};
16//!
17//! // Pure computation (no side effects)
18//! let pure_effect = EffectSet::pure();
19//! assert!(pure_effect.is_pure());
20//!
21//! // Differentiable operation
22//! let diff_effect = EffectSet::new()
23//!     .with(Effect::Computational(ComputationalEffect::Pure))
24//!     .with(Effect::Differentiable);
25//!
26//! // Combine effects
27//! let combined = pure_effect.union(&diff_effect);
28//! assert!(combined.contains(&Effect::Differentiable));
29//! ```
30
31use serde::{Deserialize, Serialize};
32use std::collections::HashSet;
33use std::fmt;
34
35use crate::IrError;
36
37/// Computational purity effects.
38#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
39pub enum ComputationalEffect {
40    /// Pure computation (no side effects, referentially transparent)
41    Pure,
42    /// Impure computation (may have side effects)
43    Impure,
44    /// I/O operations (reading/writing external state)
45    IO,
46}
47
48impl fmt::Display for ComputationalEffect {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        match self {
51            ComputationalEffect::Pure => write!(f, "Pure"),
52            ComputationalEffect::Impure => write!(f, "Impure"),
53            ComputationalEffect::IO => write!(f, "IO"),
54        }
55    }
56}
57
58/// Memory access effects.
59#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub enum MemoryEffect {
61    /// Read-only memory access
62    ReadOnly,
63    /// Read-write memory access
64    ReadWrite,
65    /// Memory allocation
66    Allocating,
67}
68
69impl fmt::Display for MemoryEffect {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        match self {
72            MemoryEffect::ReadOnly => write!(f, "ReadOnly"),
73            MemoryEffect::ReadWrite => write!(f, "ReadWrite"),
74            MemoryEffect::Allocating => write!(f, "Allocating"),
75        }
76    }
77}
78
79/// Probabilistic effects.
80#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
81pub enum ProbabilisticEffect {
82    /// Deterministic computation (same inputs → same outputs)
83    Deterministic,
84    /// Stochastic computation (involves randomness)
85    Stochastic,
86}
87
88impl fmt::Display for ProbabilisticEffect {
89    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90        match self {
91            ProbabilisticEffect::Deterministic => write!(f, "Deterministic"),
92            ProbabilisticEffect::Stochastic => write!(f, "Stochastic"),
93        }
94    }
95}
96
97/// Individual effect kinds.
98#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
99pub enum Effect {
100    /// Computational purity
101    Computational(ComputationalEffect),
102    /// Memory access pattern
103    Memory(MemoryEffect),
104    /// Probabilistic behavior
105    Probabilistic(ProbabilisticEffect),
106    /// Supports automatic differentiation
107    Differentiable,
108    /// Does not support automatic differentiation
109    NonDifferentiable,
110    /// Asynchronous computation
111    Async,
112    /// Parallel computation
113    Parallel,
114    /// Custom user-defined effect
115    Custom(String),
116}
117
118impl Effect {
119    /// Check if this effect is pure
120    pub fn is_pure(&self) -> bool {
121        matches!(self, Effect::Computational(ComputationalEffect::Pure))
122    }
123
124    /// Check if this effect is impure
125    pub fn is_impure(&self) -> bool {
126        matches!(
127            self,
128            Effect::Computational(ComputationalEffect::Impure | ComputationalEffect::IO)
129        )
130    }
131
132    /// Check if this effect is differentiable
133    pub fn is_differentiable(&self) -> bool {
134        matches!(self, Effect::Differentiable)
135    }
136
137    /// Check if this effect is stochastic
138    pub fn is_stochastic(&self) -> bool {
139        matches!(self, Effect::Probabilistic(ProbabilisticEffect::Stochastic))
140    }
141}
142
143impl fmt::Display for Effect {
144    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145        match self {
146            Effect::Computational(e) => write!(f, "{}", e),
147            Effect::Memory(e) => write!(f, "{}", e),
148            Effect::Probabilistic(e) => write!(f, "{}", e),
149            Effect::Differentiable => write!(f, "Diff"),
150            Effect::NonDifferentiable => write!(f, "NonDiff"),
151            Effect::Async => write!(f, "Async"),
152            Effect::Parallel => write!(f, "Parallel"),
153            Effect::Custom(name) => write!(f, "{}", name),
154        }
155    }
156}
157
158/// Set of effects for an expression or operation.
159#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
160pub struct EffectSet {
161    effects: HashSet<Effect>,
162}
163
164impl EffectSet {
165    /// Create an empty effect set
166    pub fn new() -> Self {
167        EffectSet {
168            effects: HashSet::new(),
169        }
170    }
171
172    /// Create a pure effect set (pure + deterministic + differentiable)
173    pub fn pure() -> Self {
174        let mut effects = HashSet::new();
175        effects.insert(Effect::Computational(ComputationalEffect::Pure));
176        effects.insert(Effect::Probabilistic(ProbabilisticEffect::Deterministic));
177        effects.insert(Effect::Memory(MemoryEffect::ReadOnly));
178        EffectSet { effects }
179    }
180
181    /// Create an impure effect set
182    pub fn impure() -> Self {
183        let mut effects = HashSet::new();
184        effects.insert(Effect::Computational(ComputationalEffect::Impure));
185        EffectSet { effects }
186    }
187
188    /// Create a differentiable effect set
189    pub fn differentiable() -> Self {
190        let mut effects = HashSet::new();
191        effects.insert(Effect::Differentiable);
192        EffectSet { effects }
193    }
194
195    /// Create a stochastic effect set
196    pub fn stochastic() -> Self {
197        let mut effects = HashSet::new();
198        effects.insert(Effect::Probabilistic(ProbabilisticEffect::Stochastic));
199        EffectSet { effects }
200    }
201
202    /// Add an effect to this set
203    pub fn with(mut self, effect: Effect) -> Self {
204        self.effects.insert(effect);
205        self
206    }
207
208    /// Add multiple effects
209    pub fn with_all(mut self, effects: impl IntoIterator<Item = Effect>) -> Self {
210        self.effects.extend(effects);
211        self
212    }
213
214    /// Check if this set contains a specific effect
215    pub fn contains(&self, effect: &Effect) -> bool {
216        self.effects.contains(effect)
217    }
218
219    /// Check if this effect set is pure (contains Pure computational effect and no impure effects)
220    pub fn is_pure(&self) -> bool {
221        // Either empty or contains Pure and no impure effects
222        if self.effects.is_empty() {
223            return true;
224        }
225
226        let has_pure = self
227            .effects
228            .iter()
229            .any(|e| matches!(e, Effect::Computational(ComputationalEffect::Pure)));
230
231        let has_impure = self.effects.iter().any(|e| {
232            matches!(
233                e,
234                Effect::Computational(ComputationalEffect::Impure | ComputationalEffect::IO)
235            )
236        });
237
238        has_pure && !has_impure
239    }
240
241    /// Check if this effect set is impure
242    pub fn is_impure(&self) -> bool {
243        self.effects.iter().any(|e| e.is_impure())
244    }
245
246    /// Check if this effect set is differentiable
247    pub fn is_differentiable(&self) -> bool {
248        self.effects.iter().any(|e| e.is_differentiable())
249            && !self
250                .effects
251                .iter()
252                .any(|e| matches!(e, Effect::NonDifferentiable))
253    }
254
255    /// Check if this effect set is stochastic
256    pub fn is_stochastic(&self) -> bool {
257        self.effects.iter().any(|e| e.is_stochastic())
258    }
259
260    /// Get all effects in this set
261    pub fn effects(&self) -> impl Iterator<Item = &Effect> {
262        self.effects.iter()
263    }
264
265    /// Union of two effect sets
266    pub fn union(&self, other: &EffectSet) -> EffectSet {
267        let mut effects = self.effects.clone();
268        effects.extend(other.effects.iter().cloned());
269        EffectSet { effects }
270    }
271
272    /// Intersection of two effect sets
273    pub fn intersection(&self, other: &EffectSet) -> EffectSet {
274        let effects = self.effects.intersection(&other.effects).cloned().collect();
275        EffectSet { effects }
276    }
277
278    /// Check if this effect set is a subset of another (subtyping)
279    pub fn is_subset_of(&self, other: &EffectSet) -> bool {
280        self.effects.is_subset(&other.effects)
281    }
282
283    /// Check if two effect sets are compatible
284    pub fn is_compatible_with(&self, other: &EffectSet) -> bool {
285        // Compatible if no conflicting effects
286        !self.has_conflicts_with(other)
287    }
288
289    /// Check if there are conflicting effects
290    fn has_conflicts_with(&self, other: &EffectSet) -> bool {
291        // Pure and Impure conflict
292        if (self.contains(&Effect::Computational(ComputationalEffect::Pure)) && other.is_impure())
293            || (other.contains(&Effect::Computational(ComputationalEffect::Pure))
294                && self.is_impure())
295        {
296            return true;
297        }
298
299        // Differentiable and NonDifferentiable conflict
300        if (self.contains(&Effect::Differentiable) && other.contains(&Effect::NonDifferentiable))
301            || (other.contains(&Effect::Differentiable)
302                && self.contains(&Effect::NonDifferentiable))
303        {
304            return true;
305        }
306
307        false
308    }
309
310    /// Number of effects in this set
311    pub fn len(&self) -> usize {
312        self.effects.len()
313    }
314
315    /// Check if effect set is empty
316    pub fn is_empty(&self) -> bool {
317        self.effects.is_empty()
318    }
319}
320
321impl Default for EffectSet {
322    fn default() -> Self {
323        Self::new()
324    }
325}
326
327impl fmt::Display for EffectSet {
328    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329        if self.effects.is_empty() {
330            return write!(f, "{{}}");
331        }
332
333        write!(f, "{{")?;
334        let mut first = true;
335        for effect in &self.effects {
336            if !first {
337                write!(f, ", ")?;
338            }
339            write!(f, "{}", effect)?;
340            first = false;
341        }
342        write!(f, "}}")
343    }
344}
345
346/// Effect variable for effect polymorphism
347#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
348pub struct EffectVar(pub String);
349
350impl EffectVar {
351    pub fn new(name: impl Into<String>) -> Self {
352        EffectVar(name.into())
353    }
354}
355
356impl fmt::Display for EffectVar {
357    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358        write!(f, "ε{}", self.0)
359    }
360}
361
362/// Effect scheme for effect polymorphism (analogous to type schemes)
363#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
364pub enum EffectScheme {
365    /// Concrete effect set
366    Concrete(EffectSet),
367    /// Effect variable (for polymorphism)
368    Variable(EffectVar),
369    /// Union of effect schemes
370    Union(Box<EffectScheme>, Box<EffectScheme>),
371}
372
373impl EffectScheme {
374    /// Create a concrete effect scheme
375    pub fn concrete(effects: EffectSet) -> Self {
376        EffectScheme::Concrete(effects)
377    }
378
379    /// Create an effect variable
380    pub fn variable(name: impl Into<String>) -> Self {
381        EffectScheme::Variable(EffectVar::new(name))
382    }
383
384    /// Create a union of two effect schemes
385    pub fn union(e1: EffectScheme, e2: EffectScheme) -> Self {
386        EffectScheme::Union(Box::new(e1), Box::new(e2))
387    }
388
389    /// Substitute effect variables with concrete effect sets
390    pub fn substitute(&self, subst: &EffectSubstitution) -> EffectScheme {
391        match self {
392            EffectScheme::Concrete(effects) => EffectScheme::Concrete(effects.clone()),
393            EffectScheme::Variable(var) => {
394                if let Some(effects) = subst.get(var) {
395                    EffectScheme::Concrete(effects.clone())
396                } else {
397                    EffectScheme::Variable(var.clone())
398                }
399            }
400            EffectScheme::Union(e1, e2) => {
401                let s1 = e1.substitute(subst);
402                let s2 = e2.substitute(subst);
403                EffectScheme::union(s1, s2)
404            }
405        }
406    }
407
408    /// Evaluate to a concrete effect set (if possible)
409    pub fn evaluate(&self, subst: &EffectSubstitution) -> Result<EffectSet, IrError> {
410        match self {
411            EffectScheme::Concrete(effects) => Ok(effects.clone()),
412            EffectScheme::Variable(var) => {
413                subst
414                    .get(var)
415                    .cloned()
416                    .ok_or_else(|| IrError::UnboundVariable {
417                        var: format!("effect variable {}", var),
418                    })
419            }
420            EffectScheme::Union(e1, e2) => {
421                let effects1 = e1.evaluate(subst)?;
422                let effects2 = e2.evaluate(subst)?;
423                Ok(effects1.union(&effects2))
424            }
425        }
426    }
427}
428
429impl fmt::Display for EffectScheme {
430    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
431        match self {
432            EffectScheme::Concrete(effects) => write!(f, "{}", effects),
433            EffectScheme::Variable(var) => write!(f, "{}", var),
434            EffectScheme::Union(e1, e2) => write!(f, "({} ∪ {})", e1, e2),
435        }
436    }
437}
438
439/// Substitution mapping effect variables to effect sets
440pub type EffectSubstitution = std::collections::HashMap<EffectVar, EffectSet>;
441
442/// Effect annotation for expressions
443#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
444pub struct EffectAnnotation {
445    /// The effect scheme for this expression
446    pub scheme: EffectScheme,
447    /// Optional description
448    pub description: Option<String>,
449}
450
451impl EffectAnnotation {
452    pub fn new(scheme: EffectScheme) -> Self {
453        EffectAnnotation {
454            scheme,
455            description: None,
456        }
457    }
458
459    pub fn with_description(mut self, description: impl Into<String>) -> Self {
460        self.description = Some(description.into());
461        self
462    }
463
464    /// Create a pure effect annotation
465    pub fn pure() -> Self {
466        EffectAnnotation::new(EffectScheme::concrete(EffectSet::pure()))
467    }
468
469    /// Create a differentiable effect annotation
470    pub fn differentiable() -> Self {
471        EffectAnnotation::new(EffectScheme::concrete(EffectSet::differentiable()))
472    }
473}
474
475/// Infer effects for common operations
476pub fn infer_operation_effects(op_name: &str) -> EffectSet {
477    match op_name {
478        // Pure logical operations
479        "and" | "or" | "not" | "implies" => EffectSet::pure().with(Effect::Differentiable),
480
481        // Arithmetic operations (pure and differentiable)
482        "add" | "subtract" | "multiply" | "divide" => {
483            EffectSet::pure().with(Effect::Differentiable)
484        }
485
486        // Quantifiers (pure but may not be differentiable)
487        "exists" | "forall" => EffectSet::pure(),
488
489        // Comparisons (pure but not differentiable)
490        "equal" | "less_than" | "greater_than" => EffectSet::pure().with(Effect::NonDifferentiable),
491
492        // Sampling operations (stochastic)
493        "sample" | "random" => EffectSet::stochastic().with(Effect::NonDifferentiable),
494
495        // I/O operations
496        "read" | "write" => EffectSet::new()
497            .with(Effect::Computational(ComputationalEffect::IO))
498            .with(Effect::Memory(MemoryEffect::ReadWrite)),
499
500        // Default: conservative (impure, non-differentiable)
501        _ => EffectSet::impure().with(Effect::NonDifferentiable),
502    }
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[test]
510    fn test_effect_creation() {
511        let pure = Effect::Computational(ComputationalEffect::Pure);
512        assert!(pure.is_pure());
513        assert!(!pure.is_impure());
514
515        let impure = Effect::Computational(ComputationalEffect::Impure);
516        assert!(!impure.is_pure());
517        assert!(impure.is_impure());
518
519        let diff = Effect::Differentiable;
520        assert!(diff.is_differentiable());
521    }
522
523    #[test]
524    fn test_effect_set_pure() {
525        let pure_set = EffectSet::pure();
526        assert!(pure_set.is_pure());
527        assert!(!pure_set.is_impure());
528        assert!(pure_set.contains(&Effect::Computational(ComputationalEffect::Pure)));
529    }
530
531    #[test]
532    fn test_effect_set_differentiable() {
533        let diff_set = EffectSet::differentiable();
534        assert!(diff_set.is_differentiable());
535        assert!(diff_set.contains(&Effect::Differentiable));
536    }
537
538    #[test]
539    fn test_effect_set_union() {
540        let pure = EffectSet::pure();
541        let diff = EffectSet::differentiable();
542        let combined = pure.union(&diff);
543
544        assert!(combined.contains(&Effect::Computational(ComputationalEffect::Pure)));
545        assert!(combined.contains(&Effect::Differentiable));
546    }
547
548    #[test]
549    fn test_effect_set_intersection() {
550        let set1 = EffectSet::pure().with(Effect::Differentiable);
551        let set2 = EffectSet::differentiable();
552        let intersection = set1.intersection(&set2);
553
554        assert!(intersection.contains(&Effect::Differentiable));
555        assert!(!intersection.contains(&Effect::Computational(ComputationalEffect::Pure)));
556    }
557
558    #[test]
559    fn test_effect_set_subset() {
560        let small = EffectSet::pure();
561        let large = EffectSet::pure().with(Effect::Differentiable);
562
563        assert!(small.is_subset_of(&large));
564        assert!(!large.is_subset_of(&small));
565    }
566
567    #[test]
568    fn test_effect_conflicts() {
569        let pure = EffectSet::pure();
570        let impure = EffectSet::impure();
571
572        assert!(!pure.is_compatible_with(&impure));
573        assert!(!impure.is_compatible_with(&pure));
574    }
575
576    #[test]
577    fn test_effect_scheme_concrete() {
578        let scheme = EffectScheme::concrete(EffectSet::pure());
579        let subst = EffectSubstitution::new();
580        let effects = scheme.evaluate(&subst).unwrap();
581
582        assert!(effects.is_pure());
583    }
584
585    #[test]
586    fn test_effect_scheme_variable() {
587        let var = EffectVar::new("e1");
588        let scheme = EffectScheme::Variable(var.clone());
589
590        let mut subst = EffectSubstitution::new();
591        subst.insert(var, EffectSet::pure());
592
593        let effects = scheme.evaluate(&subst).unwrap();
594        assert!(effects.is_pure());
595    }
596
597    #[test]
598    fn test_effect_scheme_union() {
599        let scheme1 = EffectScheme::concrete(EffectSet::pure());
600        let scheme2 = EffectScheme::concrete(EffectSet::differentiable());
601        let union_scheme = EffectScheme::union(scheme1, scheme2);
602
603        let subst = EffectSubstitution::new();
604        let effects = union_scheme.evaluate(&subst).unwrap();
605
606        assert!(effects.is_pure());
607        assert!(effects.is_differentiable());
608    }
609
610    #[test]
611    fn test_effect_annotation() {
612        let annotation = EffectAnnotation::pure().with_description("Pure computation");
613
614        assert_eq!(annotation.description.as_deref(), Some("Pure computation"));
615    }
616
617    #[test]
618    fn test_infer_operation_effects() {
619        let and_effects = infer_operation_effects("and");
620        assert!(and_effects.is_pure());
621        assert!(and_effects.is_differentiable());
622
623        let sample_effects = infer_operation_effects("sample");
624        assert!(sample_effects.is_stochastic());
625
626        let io_effects = infer_operation_effects("read");
627        assert!(io_effects.is_impure());
628    }
629
630    #[test]
631    fn test_effect_set_stochastic() {
632        let stochastic = EffectSet::stochastic();
633        assert!(stochastic.is_stochastic());
634        assert!(stochastic.contains(&Effect::Probabilistic(ProbabilisticEffect::Stochastic)));
635    }
636
637    #[test]
638    fn test_memory_effects() {
639        let read_only = Effect::Memory(MemoryEffect::ReadOnly);
640        let read_write = Effect::Memory(MemoryEffect::ReadWrite);
641
642        let set1 = EffectSet::new().with(read_only);
643        let set2 = EffectSet::new().with(read_write);
644
645        assert_ne!(set1, set2);
646    }
647
648    #[test]
649    fn test_custom_effect() {
650        let custom = Effect::Custom("GPUCompute".to_string());
651        let set = EffectSet::new().with(custom.clone());
652
653        assert!(set.contains(&custom));
654    }
655
656    #[test]
657    fn test_effect_display() {
658        let pure = Effect::Computational(ComputationalEffect::Pure);
659        assert_eq!(pure.to_string(), "Pure");
660
661        let diff = Effect::Differentiable;
662        assert_eq!(diff.to_string(), "Diff");
663
664        let custom = Effect::Custom("MyEffect".to_string());
665        assert_eq!(custom.to_string(), "MyEffect");
666    }
667
668    #[test]
669    fn test_effect_set_display() {
670        let set = EffectSet::pure().with(Effect::Differentiable);
671        let display = set.to_string();
672
673        assert!(display.contains("Pure") || display.contains("Diff"));
674        assert!(display.starts_with('{'));
675        assert!(display.ends_with('}'));
676    }
677
678    #[test]
679    fn test_effect_var_display() {
680        let var = EffectVar::new("1");
681        assert_eq!(var.to_string(), "ε1");
682    }
683
684    #[test]
685    fn test_non_differentiable_conflicts() {
686        let diff = EffectSet::new().with(Effect::Differentiable);
687        let non_diff = EffectSet::new().with(Effect::NonDifferentiable);
688
689        assert!(!diff.is_compatible_with(&non_diff));
690    }
691}