Skip to main content

tensorlogic_ir/
linear.rs

1//! Linear type system for resource management in TensorLogic.
2//!
3//! This module implements linear types (also known as affine types), where values
4//! must be used exactly once. This is crucial for:
5//!
6//! - **Memory management**: Ensuring tensors are properly deallocated
7//! - **In-place operations**: Tracking when tensors can be mutated safely
8//! - **Resource tracking**: Managing GPU memory, file handles, etc.
9//! - **Side effect control**: Ensuring operations execute in the correct order
10//!
11//! # Examples
12//!
13//! ```
14//! use tensorlogic_ir::linear::{LinearType, Multiplicity, LinearContext};
15//!
16//! // Linear type: must be used exactly once
17//! let tensor_handle = LinearType::linear("TensorHandle");
18//!
19//! // Unrestricted type: can be used multiple times
20//! let int_type = LinearType::unrestricted("Int");
21//!
22//! // Check multiplicity constraints
23//! let mut ctx = LinearContext::new();
24//! ctx.bind("x", tensor_handle);
25//! assert!(ctx.is_linear("x"));
26//! ```
27//!
28//! # Multiplicity System
29//!
30//! - **Linear (1)**: Must be used exactly once
31//! - **Affine (0..1)**: Must be used at most once
32//! - **Relevant (1..)**: Must be used at least once
33//! - **Unrestricted (0..)**: Can be used any number of times
34
35use serde::{Deserialize, Serialize};
36use std::collections::{HashMap, HashSet};
37use std::fmt;
38
39use crate::{IrError, ParametricType};
40
41/// Multiplicity: how many times a value can be used.
42#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
43pub enum Multiplicity {
44    /// Linear: must be used exactly once (1)
45    Linear,
46    /// Affine: must be used at most once (0..1)
47    Affine,
48    /// Relevant: must be used at least once (1..)
49    Relevant,
50    /// Unrestricted: can be used any number of times (0..)
51    Unrestricted,
52}
53
54impl Multiplicity {
55    /// Check if a value with this multiplicity can be used n times
56    pub fn allows(&self, n: usize) -> bool {
57        match self {
58            Multiplicity::Linear => n == 1,
59            Multiplicity::Affine => n <= 1,
60            Multiplicity::Relevant => n >= 1,
61            Multiplicity::Unrestricted => true,
62        }
63    }
64
65    /// Check if this is linear (exactly once)
66    pub fn is_linear(&self) -> bool {
67        matches!(self, Multiplicity::Linear)
68    }
69
70    /// Check if this is unrestricted (any number of times)
71    pub fn is_unrestricted(&self) -> bool {
72        matches!(self, Multiplicity::Unrestricted)
73    }
74
75    /// Combine multiplicities (for products/tuples)
76    pub fn combine(&self, other: &Multiplicity) -> Multiplicity {
77        match (self, other) {
78            (Multiplicity::Unrestricted, Multiplicity::Unrestricted) => Multiplicity::Unrestricted,
79            (Multiplicity::Linear, Multiplicity::Linear) => Multiplicity::Linear,
80            (Multiplicity::Affine, Multiplicity::Affine) => Multiplicity::Affine,
81            (Multiplicity::Relevant, Multiplicity::Relevant) => Multiplicity::Relevant,
82            // Most restrictive wins
83            (Multiplicity::Linear, _) | (_, Multiplicity::Linear) => Multiplicity::Linear,
84            (Multiplicity::Affine, _) | (_, Multiplicity::Affine) => Multiplicity::Affine,
85            (Multiplicity::Relevant, _) | (_, Multiplicity::Relevant) => Multiplicity::Relevant,
86        }
87    }
88
89    /// Join multiplicities (for sums/unions)
90    pub fn join(&self, other: &Multiplicity) -> Multiplicity {
91        match (self, other) {
92            (Multiplicity::Unrestricted, _) | (_, Multiplicity::Unrestricted) => {
93                Multiplicity::Unrestricted
94            }
95            (Multiplicity::Relevant, _) | (_, Multiplicity::Relevant) => Multiplicity::Relevant,
96            (Multiplicity::Affine, Multiplicity::Affine) => Multiplicity::Affine,
97            (Multiplicity::Linear, Multiplicity::Linear) => Multiplicity::Linear,
98            (Multiplicity::Affine, Multiplicity::Linear)
99            | (Multiplicity::Linear, Multiplicity::Affine) => Multiplicity::Affine,
100        }
101    }
102}
103
104impl fmt::Display for Multiplicity {
105    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106        match self {
107            Multiplicity::Linear => write!(f, "1"),
108            Multiplicity::Affine => write!(f, "0..1"),
109            Multiplicity::Relevant => write!(f, "1.."),
110            Multiplicity::Unrestricted => write!(f, "0.."),
111        }
112    }
113}
114
115/// Linear type: type with multiplicity constraints.
116#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
117pub struct LinearType {
118    /// Base type
119    pub base_type: ParametricType,
120    /// Multiplicity constraint
121    pub multiplicity: Multiplicity,
122}
123
124impl LinearType {
125    /// Create a new linear type
126    pub fn new(base_type: ParametricType, multiplicity: Multiplicity) -> Self {
127        LinearType {
128            base_type,
129            multiplicity,
130        }
131    }
132
133    /// Create a linear type (must be used exactly once)
134    pub fn linear(type_name: impl Into<String>) -> Self {
135        LinearType {
136            base_type: ParametricType::concrete(type_name),
137            multiplicity: Multiplicity::Linear,
138        }
139    }
140
141    /// Create an affine type (at most once)
142    pub fn affine(type_name: impl Into<String>) -> Self {
143        LinearType {
144            base_type: ParametricType::concrete(type_name),
145            multiplicity: Multiplicity::Affine,
146        }
147    }
148
149    /// Create a relevant type (at least once)
150    pub fn relevant(type_name: impl Into<String>) -> Self {
151        LinearType {
152            base_type: ParametricType::concrete(type_name),
153            multiplicity: Multiplicity::Relevant,
154        }
155    }
156
157    /// Create an unrestricted type (any number of times)
158    pub fn unrestricted(type_name: impl Into<String>) -> Self {
159        LinearType {
160            base_type: ParametricType::concrete(type_name),
161            multiplicity: Multiplicity::Unrestricted,
162        }
163    }
164
165    /// Check if this is a linear type
166    pub fn is_linear(&self) -> bool {
167        self.multiplicity.is_linear()
168    }
169
170    /// Check if this is unrestricted
171    pub fn is_unrestricted(&self) -> bool {
172        self.multiplicity.is_unrestricted()
173    }
174
175    /// Convert to unrestricted (for copying)
176    pub fn make_unrestricted(mut self) -> Self {
177        self.multiplicity = Multiplicity::Unrestricted;
178        self
179    }
180
181    /// Convert to linear
182    pub fn make_linear(mut self) -> Self {
183        self.multiplicity = Multiplicity::Linear;
184        self
185    }
186}
187
188impl fmt::Display for LinearType {
189    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190        write!(f, "{}<{}>", self.base_type, self.multiplicity)
191    }
192}
193
194/// Usage tracking for linear variables.
195#[derive(Clone, Debug, PartialEq, Eq)]
196pub struct Usage {
197    /// Variable name
198    pub var_name: String,
199    /// Number of times used
200    pub use_count: usize,
201    /// Expected multiplicity
202    pub expected: Multiplicity,
203}
204
205impl Usage {
206    pub fn new(var_name: impl Into<String>, expected: Multiplicity) -> Self {
207        Usage {
208            var_name: var_name.into(),
209            use_count: 0,
210            expected,
211        }
212    }
213
214    /// Record a use
215    pub fn record_use(&mut self) {
216        self.use_count += 1;
217    }
218
219    /// Check if usage is valid
220    pub fn is_valid(&self) -> bool {
221        self.expected.allows(self.use_count)
222    }
223
224    /// Get error message if invalid
225    pub fn error_message(&self) -> Option<String> {
226        if self.is_valid() {
227            None
228        } else {
229            Some(format!(
230                "Variable '{}' has multiplicity {} but was used {} times",
231                self.var_name, self.expected, self.use_count
232            ))
233        }
234    }
235}
236
237/// Linear typing context for tracking variable usage.
238#[derive(Clone, Debug, Default)]
239pub struct LinearContext {
240    /// Variable bindings with their linear types
241    bindings: HashMap<String, LinearType>,
242    /// Usage tracking
243    usage: HashMap<String, Usage>,
244    /// Consumed variables (used and invalidated)
245    consumed: HashSet<String>,
246}
247
248impl LinearContext {
249    pub fn new() -> Self {
250        Self::default()
251    }
252
253    /// Bind a variable with a linear type
254    pub fn bind(&mut self, name: impl Into<String>, linear_type: LinearType) {
255        let name = name.into();
256        let multiplicity = linear_type.multiplicity.clone();
257        self.bindings.insert(name.clone(), linear_type);
258        self.usage
259            .insert(name.clone(), Usage::new(name, multiplicity));
260    }
261
262    /// Use a variable (increment use count)
263    pub fn use_var(&mut self, name: &str) -> Result<(), IrError> {
264        if self.consumed.contains(name) {
265            return Err(IrError::LinearityViolation(format!(
266                "Variable '{}' already consumed",
267                name
268            )));
269        }
270
271        if let Some(usage) = self.usage.get_mut(name) {
272            usage.record_use();
273
274            // If linear or affine, mark as consumed after use
275            #[allow(clippy::collapsible_if)]
276            if usage.expected.is_linear() || matches!(usage.expected, Multiplicity::Affine) {
277                if usage.use_count >= 1 {
278                    self.consumed.insert(name.to_string());
279                }
280            }
281
282            Ok(())
283        } else {
284            Err(IrError::UnboundVariable {
285                var: name.to_string(),
286            })
287        }
288    }
289
290    /// Check if a variable is linear
291    pub fn is_linear(&self, name: &str) -> bool {
292        self.bindings
293            .get(name)
294            .map(|t| t.is_linear())
295            .unwrap_or(false)
296    }
297
298    /// Check if a variable is consumed
299    pub fn is_consumed(&self, name: &str) -> bool {
300        self.consumed.contains(name)
301    }
302
303    /// Get the linear type of a variable
304    pub fn get_type(&self, name: &str) -> Option<&LinearType> {
305        self.bindings.get(name)
306    }
307
308    /// Validate all usage counts at the end of scope
309    pub fn validate(&self) -> Result<(), Vec<String>> {
310        let mut errors = Vec::new();
311
312        for usage in self.usage.values() {
313            if let Some(err) = usage.error_message() {
314                errors.push(err);
315            }
316        }
317
318        if errors.is_empty() {
319            Ok(())
320        } else {
321            Err(errors)
322        }
323    }
324
325    /// Get all unused variables with relevant or linear multiplicity
326    pub fn get_unused_required(&self) -> Vec<String> {
327        self.usage
328            .values()
329            .filter(|u| {
330                u.use_count == 0
331                    && (u.expected.is_linear() || matches!(u.expected, Multiplicity::Relevant))
332            })
333            .map(|u| u.var_name.clone())
334            .collect()
335    }
336
337    /// Merge two contexts (for branching control flow)
338    pub fn merge(&self, other: &LinearContext) -> Result<LinearContext, IrError> {
339        let mut merged = LinearContext::new();
340
341        // Merge bindings
342        for (name, typ) in &self.bindings {
343            if let Some(other_typ) = other.bindings.get(name) {
344                if typ != other_typ {
345                    return Err(IrError::InconsistentTypes {
346                        var: name.clone(),
347                        type1: format!("{}", typ),
348                        type2: format!("{}", other_typ),
349                    });
350                }
351                merged.bindings.insert(name.clone(), typ.clone());
352            }
353        }
354
355        // Merge usage: both branches must satisfy constraints
356        for (name, usage1) in &self.usage {
357            if let Some(usage2) = other.usage.get(name) {
358                // For linear/relevant: both branches must use the variable
359                // For affine/unrestricted: either branch can use it
360                let min_uses = usage1.use_count.min(usage2.use_count);
361                let max_uses = usage1.use_count.max(usage2.use_count);
362
363                let use_count = match usage1.expected {
364                    Multiplicity::Linear | Multiplicity::Relevant => {
365                        // Both branches must use it
366                        if usage1.use_count == 0 || usage2.use_count == 0 {
367                            return Err(IrError::LinearityViolation(format!(
368                                "Variable '{}' must be used in both branches",
369                                name
370                            )));
371                        }
372                        min_uses
373                    }
374                    Multiplicity::Affine | Multiplicity::Unrestricted => max_uses,
375                };
376
377                let mut merged_usage = Usage::new(name, usage1.expected.clone());
378                merged_usage.use_count = use_count;
379                merged.usage.insert(name.clone(), merged_usage);
380            }
381        }
382
383        // Merge consumed sets
384        merged.consumed = self
385            .consumed
386            .intersection(&other.consumed)
387            .cloned()
388            .collect();
389
390        Ok(merged)
391    }
392
393    /// Split context for parallel use (e.g., function arguments)
394    pub fn split(&mut self, vars: &[String]) -> Result<LinearContext, IrError> {
395        let mut split_ctx = LinearContext::new();
396
397        for var in vars {
398            if let Some(typ) = self.bindings.remove(var) {
399                if typ.is_linear() {
400                    // Linear types can be moved
401                    split_ctx.bind(var, typ);
402                    self.consumed.insert(var.clone());
403                } else if typ.is_unrestricted() {
404                    // Unrestricted types can be copied
405                    split_ctx.bind(var, typ.clone());
406                    self.bindings.insert(var.clone(), typ);
407                } else {
408                    return Err(IrError::LinearityViolation(format!(
409                        "Cannot split variable '{}' with multiplicity {}",
410                        var, typ.multiplicity
411                    )));
412                }
413            }
414        }
415
416        Ok(split_ctx)
417    }
418}
419
420/// Linearity checker for expressions.
421#[derive(Clone, Debug)]
422pub struct LinearityChecker {
423    context: LinearContext,
424    errors: Vec<String>,
425}
426
427impl LinearityChecker {
428    pub fn new() -> Self {
429        LinearityChecker {
430            context: LinearContext::new(),
431            errors: Vec::new(),
432        }
433    }
434
435    /// Add a linear variable binding
436    pub fn bind(&mut self, name: impl Into<String>, linear_type: LinearType) {
437        self.context.bind(name, linear_type);
438    }
439
440    /// Record a variable use
441    pub fn use_var(&mut self, name: &str) {
442        if let Err(e) = self.context.use_var(name) {
443            self.errors.push(format!("{}", e));
444        }
445    }
446
447    /// Check if all linearity constraints are satisfied
448    pub fn check(&self) -> Result<(), Vec<String>> {
449        let mut all_errors = self.errors.clone();
450
451        if let Err(mut usage_errors) = self.context.validate() {
452            all_errors.append(&mut usage_errors);
453        }
454
455        if all_errors.is_empty() {
456            Ok(())
457        } else {
458            Err(all_errors)
459        }
460    }
461
462    /// Get the current context
463    pub fn context(&self) -> &LinearContext {
464        &self.context
465    }
466
467    /// Get a mutable reference to the context
468    pub fn context_mut(&mut self) -> &mut LinearContext {
469        &mut self.context
470    }
471}
472
473impl Default for LinearityChecker {
474    fn default() -> Self {
475        Self::new()
476    }
477}
478
479/// Capability: describes what operations are allowed on a linear resource.
480#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
481pub enum Capability {
482    /// Read access
483    Read,
484    /// Write access
485    Write,
486    /// Execute access
487    Execute,
488    /// Own (can deallocate)
489    Own,
490}
491
492/// Linear resource with capabilities.
493#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
494pub struct LinearResource {
495    /// Resource type
496    pub resource_type: LinearType,
497    /// Allowed capabilities
498    pub capabilities: HashSet<Capability>,
499}
500
501impl LinearResource {
502    pub fn new(resource_type: LinearType, capabilities: HashSet<Capability>) -> Self {
503        LinearResource {
504            resource_type,
505            capabilities,
506        }
507    }
508
509    /// Check if a capability is allowed
510    pub fn has_capability(&self, cap: &Capability) -> bool {
511        self.capabilities.contains(cap)
512    }
513
514    /// Create a read-only resource
515    pub fn read_only(resource_type: LinearType) -> Self {
516        let mut caps = HashSet::new();
517        caps.insert(Capability::Read);
518        LinearResource::new(resource_type, caps)
519    }
520
521    /// Create a read-write resource
522    pub fn read_write(resource_type: LinearType) -> Self {
523        let mut caps = HashSet::new();
524        caps.insert(Capability::Read);
525        caps.insert(Capability::Write);
526        LinearResource::new(resource_type, caps)
527    }
528
529    /// Create an owned resource (full access)
530    pub fn owned(resource_type: LinearType) -> Self {
531        let mut caps = HashSet::new();
532        caps.insert(Capability::Read);
533        caps.insert(Capability::Write);
534        caps.insert(Capability::Own);
535        LinearResource::new(resource_type, caps)
536    }
537}
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542
543    #[test]
544    fn test_multiplicity_allows() {
545        assert!(Multiplicity::Linear.allows(1));
546        assert!(!Multiplicity::Linear.allows(0));
547        assert!(!Multiplicity::Linear.allows(2));
548
549        assert!(Multiplicity::Affine.allows(0));
550        assert!(Multiplicity::Affine.allows(1));
551        assert!(!Multiplicity::Affine.allows(2));
552
553        assert!(!Multiplicity::Relevant.allows(0));
554        assert!(Multiplicity::Relevant.allows(1));
555        assert!(Multiplicity::Relevant.allows(2));
556
557        assert!(Multiplicity::Unrestricted.allows(0));
558        assert!(Multiplicity::Unrestricted.allows(1));
559        assert!(Multiplicity::Unrestricted.allows(100));
560    }
561
562    #[test]
563    fn test_multiplicity_combine() {
564        assert_eq!(
565            Multiplicity::Linear.combine(&Multiplicity::Linear),
566            Multiplicity::Linear
567        );
568        assert_eq!(
569            Multiplicity::Unrestricted.combine(&Multiplicity::Unrestricted),
570            Multiplicity::Unrestricted
571        );
572        assert_eq!(
573            Multiplicity::Linear.combine(&Multiplicity::Unrestricted),
574            Multiplicity::Linear
575        );
576    }
577
578    #[test]
579    fn test_linear_type_creation() {
580        let linear_tensor = LinearType::linear("Tensor");
581        assert!(linear_tensor.is_linear());
582        assert!(!linear_tensor.is_unrestricted());
583
584        let unrestricted_int = LinearType::unrestricted("Int");
585        assert!(!unrestricted_int.is_linear());
586        assert!(unrestricted_int.is_unrestricted());
587    }
588
589    #[test]
590    fn test_linear_context_basic() {
591        let mut ctx = LinearContext::new();
592        let tensor_type = LinearType::linear("Tensor");
593
594        ctx.bind("x", tensor_type);
595        assert!(ctx.is_linear("x"));
596        assert!(!ctx.is_consumed("x"));
597
598        // Use once - should be OK
599        assert!(ctx.use_var("x").is_ok());
600        assert!(ctx.is_consumed("x"));
601
602        // Use again - should fail
603        assert!(ctx.use_var("x").is_err());
604    }
605
606    #[test]
607    fn test_affine_type_usage() {
608        let mut ctx = LinearContext::new();
609        let affine_type = LinearType::affine("File");
610
611        ctx.bind("f", affine_type);
612
613        // Using 0 times is OK for affine
614        assert!(ctx.validate().is_ok());
615
616        // Using 1 time is OK
617        assert!(ctx.use_var("f").is_ok());
618        assert!(ctx.validate().is_ok());
619    }
620
621    #[test]
622    fn test_relevant_type_usage() {
623        let mut ctx = LinearContext::new();
624        let relevant_type = LinearType::relevant("Resource");
625
626        ctx.bind("r", relevant_type);
627
628        // Not using is NOT OK for relevant
629        assert!(ctx.validate().is_err());
630
631        let mut ctx2 = LinearContext::new();
632        ctx2.bind("r", LinearType::relevant("Resource"));
633        assert!(ctx2.use_var("r").is_ok());
634        assert!(ctx2.use_var("r").is_ok()); // Can use multiple times
635        assert!(ctx2.validate().is_ok());
636    }
637
638    #[test]
639    fn test_unrestricted_type_usage() {
640        let mut ctx = LinearContext::new();
641        let unrestricted_type = LinearType::unrestricted("Int");
642
643        ctx.bind("x", unrestricted_type);
644
645        // Can use any number of times
646        for _ in 0..10 {
647            assert!(ctx.use_var("x").is_ok());
648        }
649        assert!(ctx.validate().is_ok());
650    }
651
652    #[test]
653    fn test_linearity_checker() {
654        let mut checker = LinearityChecker::new();
655
656        checker.bind("x", LinearType::linear("Tensor"));
657        checker.bind("y", LinearType::unrestricted("Int"));
658
659        // Use x once
660        checker.use_var("x");
661
662        // Use y multiple times
663        checker.use_var("y");
664        checker.use_var("y");
665
666        // Should pass
667        assert!(checker.check().is_ok());
668    }
669
670    #[test]
671    fn test_linearity_checker_violation() {
672        let mut checker = LinearityChecker::new();
673
674        checker.bind("x", LinearType::linear("Tensor"));
675
676        // Use x twice - should fail
677        checker.use_var("x");
678        checker.use_var("x");
679
680        assert!(checker.check().is_err());
681    }
682
683    #[test]
684    fn test_context_merge() {
685        let mut ctx1 = LinearContext::new();
686        let mut ctx2 = LinearContext::new();
687
688        // Both contexts have same unrestricted binding
689        ctx1.bind("x", LinearType::unrestricted("Int"));
690        ctx2.bind("x", LinearType::unrestricted("Int"));
691
692        // Use in different amounts
693        ctx1.use_var("x").unwrap();
694        ctx2.use_var("x").unwrap();
695        ctx2.use_var("x").unwrap();
696
697        // Merge should succeed
698        let merged = ctx1.merge(&ctx2);
699        assert!(merged.is_ok());
700    }
701
702    #[test]
703    fn test_linear_resource_capabilities() {
704        let tensor_type = LinearType::linear("Tensor");
705        let resource = LinearResource::read_only(tensor_type);
706
707        assert!(resource.has_capability(&Capability::Read));
708        assert!(!resource.has_capability(&Capability::Write));
709        assert!(!resource.has_capability(&Capability::Own));
710    }
711
712    #[test]
713    fn test_get_unused_required() {
714        let mut ctx = LinearContext::new();
715
716        ctx.bind("x", LinearType::linear("Tensor"));
717        ctx.bind("y", LinearType::unrestricted("Int"));
718        ctx.bind("z", LinearType::relevant("Resource"));
719
720        // x and z are required but unused
721        let unused = ctx.get_unused_required();
722        assert_eq!(unused.len(), 2);
723        assert!(unused.contains(&"x".to_string()));
724        assert!(unused.contains(&"z".to_string()));
725    }
726
727    #[test]
728    fn test_context_split() {
729        let mut ctx = LinearContext::new();
730
731        ctx.bind("x", LinearType::linear("Tensor"));
732        ctx.bind("y", LinearType::unrestricted("Int"));
733
734        // Split off x
735        let split = ctx.split(&["x".to_string()]);
736        assert!(split.is_ok());
737
738        let split_ctx = split.unwrap();
739        assert!(split_ctx.get_type("x").is_some());
740        assert!(ctx.is_consumed("x"));
741
742        // y should still be in both
743        assert!(ctx.get_type("y").is_some());
744        assert!(!ctx.is_consumed("y"));
745    }
746
747    #[test]
748    fn test_linear_type_display() {
749        let linear = LinearType::linear("Tensor");
750        assert_eq!(linear.to_string(), "Tensor<1>");
751
752        let affine = LinearType::affine("File");
753        assert_eq!(affine.to_string(), "File<0..1>");
754
755        let unrestricted = LinearType::unrestricted("Int");
756        assert_eq!(unrestricted.to_string(), "Int<0..>");
757    }
758}