1use serde::{Deserialize, Serialize};
36use std::collections::{HashMap, HashSet};
37use std::fmt;
38
39use crate::{IrError, ParametricType};
40
41#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
43pub enum Multiplicity {
44 Linear,
46 Affine,
48 Relevant,
50 Unrestricted,
52}
53
54impl Multiplicity {
55 pub fn allows(&self, n: usize) -> bool {
57 match self {
58 Multiplicity::Linear => n == 1,
59 Multiplicity::Affine => n <= 1,
60 Multiplicity::Relevant => n >= 1,
61 Multiplicity::Unrestricted => true,
62 }
63 }
64
65 pub fn is_linear(&self) -> bool {
67 matches!(self, Multiplicity::Linear)
68 }
69
70 pub fn is_unrestricted(&self) -> bool {
72 matches!(self, Multiplicity::Unrestricted)
73 }
74
75 pub fn combine(&self, other: &Multiplicity) -> Multiplicity {
77 match (self, other) {
78 (Multiplicity::Unrestricted, Multiplicity::Unrestricted) => Multiplicity::Unrestricted,
79 (Multiplicity::Linear, Multiplicity::Linear) => Multiplicity::Linear,
80 (Multiplicity::Affine, Multiplicity::Affine) => Multiplicity::Affine,
81 (Multiplicity::Relevant, Multiplicity::Relevant) => Multiplicity::Relevant,
82 (Multiplicity::Linear, _) | (_, Multiplicity::Linear) => Multiplicity::Linear,
84 (Multiplicity::Affine, _) | (_, Multiplicity::Affine) => Multiplicity::Affine,
85 (Multiplicity::Relevant, _) | (_, Multiplicity::Relevant) => Multiplicity::Relevant,
86 }
87 }
88
89 pub fn join(&self, other: &Multiplicity) -> Multiplicity {
91 match (self, other) {
92 (Multiplicity::Unrestricted, _) | (_, Multiplicity::Unrestricted) => {
93 Multiplicity::Unrestricted
94 }
95 (Multiplicity::Relevant, _) | (_, Multiplicity::Relevant) => Multiplicity::Relevant,
96 (Multiplicity::Affine, Multiplicity::Affine) => Multiplicity::Affine,
97 (Multiplicity::Linear, Multiplicity::Linear) => Multiplicity::Linear,
98 (Multiplicity::Affine, Multiplicity::Linear)
99 | (Multiplicity::Linear, Multiplicity::Affine) => Multiplicity::Affine,
100 }
101 }
102}
103
104impl fmt::Display for Multiplicity {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 match self {
107 Multiplicity::Linear => write!(f, "1"),
108 Multiplicity::Affine => write!(f, "0..1"),
109 Multiplicity::Relevant => write!(f, "1.."),
110 Multiplicity::Unrestricted => write!(f, "0.."),
111 }
112 }
113}
114
115#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
117pub struct LinearType {
118 pub base_type: ParametricType,
120 pub multiplicity: Multiplicity,
122}
123
124impl LinearType {
125 pub fn new(base_type: ParametricType, multiplicity: Multiplicity) -> Self {
127 LinearType {
128 base_type,
129 multiplicity,
130 }
131 }
132
133 pub fn linear(type_name: impl Into<String>) -> Self {
135 LinearType {
136 base_type: ParametricType::concrete(type_name),
137 multiplicity: Multiplicity::Linear,
138 }
139 }
140
141 pub fn affine(type_name: impl Into<String>) -> Self {
143 LinearType {
144 base_type: ParametricType::concrete(type_name),
145 multiplicity: Multiplicity::Affine,
146 }
147 }
148
149 pub fn relevant(type_name: impl Into<String>) -> Self {
151 LinearType {
152 base_type: ParametricType::concrete(type_name),
153 multiplicity: Multiplicity::Relevant,
154 }
155 }
156
157 pub fn unrestricted(type_name: impl Into<String>) -> Self {
159 LinearType {
160 base_type: ParametricType::concrete(type_name),
161 multiplicity: Multiplicity::Unrestricted,
162 }
163 }
164
165 pub fn is_linear(&self) -> bool {
167 self.multiplicity.is_linear()
168 }
169
170 pub fn is_unrestricted(&self) -> bool {
172 self.multiplicity.is_unrestricted()
173 }
174
175 pub fn make_unrestricted(mut self) -> Self {
177 self.multiplicity = Multiplicity::Unrestricted;
178 self
179 }
180
181 pub fn make_linear(mut self) -> Self {
183 self.multiplicity = Multiplicity::Linear;
184 self
185 }
186}
187
188impl fmt::Display for LinearType {
189 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190 write!(f, "{}<{}>", self.base_type, self.multiplicity)
191 }
192}
193
194#[derive(Clone, Debug, PartialEq, Eq)]
196pub struct Usage {
197 pub var_name: String,
199 pub use_count: usize,
201 pub expected: Multiplicity,
203}
204
205impl Usage {
206 pub fn new(var_name: impl Into<String>, expected: Multiplicity) -> Self {
207 Usage {
208 var_name: var_name.into(),
209 use_count: 0,
210 expected,
211 }
212 }
213
214 pub fn record_use(&mut self) {
216 self.use_count += 1;
217 }
218
219 pub fn is_valid(&self) -> bool {
221 self.expected.allows(self.use_count)
222 }
223
224 pub fn error_message(&self) -> Option<String> {
226 if self.is_valid() {
227 None
228 } else {
229 Some(format!(
230 "Variable '{}' has multiplicity {} but was used {} times",
231 self.var_name, self.expected, self.use_count
232 ))
233 }
234 }
235}
236
237#[derive(Clone, Debug, Default)]
239pub struct LinearContext {
240 bindings: HashMap<String, LinearType>,
242 usage: HashMap<String, Usage>,
244 consumed: HashSet<String>,
246}
247
248impl LinearContext {
249 pub fn new() -> Self {
250 Self::default()
251 }
252
253 pub fn bind(&mut self, name: impl Into<String>, linear_type: LinearType) {
255 let name = name.into();
256 let multiplicity = linear_type.multiplicity.clone();
257 self.bindings.insert(name.clone(), linear_type);
258 self.usage
259 .insert(name.clone(), Usage::new(name, multiplicity));
260 }
261
262 pub fn use_var(&mut self, name: &str) -> Result<(), IrError> {
264 if self.consumed.contains(name) {
265 return Err(IrError::LinearityViolation(format!(
266 "Variable '{}' already consumed",
267 name
268 )));
269 }
270
271 if let Some(usage) = self.usage.get_mut(name) {
272 usage.record_use();
273
274 #[allow(clippy::collapsible_if)]
276 if usage.expected.is_linear() || matches!(usage.expected, Multiplicity::Affine) {
277 if usage.use_count >= 1 {
278 self.consumed.insert(name.to_string());
279 }
280 }
281
282 Ok(())
283 } else {
284 Err(IrError::UnboundVariable {
285 var: name.to_string(),
286 })
287 }
288 }
289
290 pub fn is_linear(&self, name: &str) -> bool {
292 self.bindings
293 .get(name)
294 .map(|t| t.is_linear())
295 .unwrap_or(false)
296 }
297
298 pub fn is_consumed(&self, name: &str) -> bool {
300 self.consumed.contains(name)
301 }
302
303 pub fn get_type(&self, name: &str) -> Option<&LinearType> {
305 self.bindings.get(name)
306 }
307
308 pub fn validate(&self) -> Result<(), Vec<String>> {
310 let mut errors = Vec::new();
311
312 for usage in self.usage.values() {
313 if let Some(err) = usage.error_message() {
314 errors.push(err);
315 }
316 }
317
318 if errors.is_empty() {
319 Ok(())
320 } else {
321 Err(errors)
322 }
323 }
324
325 pub fn get_unused_required(&self) -> Vec<String> {
327 self.usage
328 .values()
329 .filter(|u| {
330 u.use_count == 0
331 && (u.expected.is_linear() || matches!(u.expected, Multiplicity::Relevant))
332 })
333 .map(|u| u.var_name.clone())
334 .collect()
335 }
336
337 pub fn merge(&self, other: &LinearContext) -> Result<LinearContext, IrError> {
339 let mut merged = LinearContext::new();
340
341 for (name, typ) in &self.bindings {
343 if let Some(other_typ) = other.bindings.get(name) {
344 if typ != other_typ {
345 return Err(IrError::InconsistentTypes {
346 var: name.clone(),
347 type1: format!("{}", typ),
348 type2: format!("{}", other_typ),
349 });
350 }
351 merged.bindings.insert(name.clone(), typ.clone());
352 }
353 }
354
355 for (name, usage1) in &self.usage {
357 if let Some(usage2) = other.usage.get(name) {
358 let min_uses = usage1.use_count.min(usage2.use_count);
361 let max_uses = usage1.use_count.max(usage2.use_count);
362
363 let use_count = match usage1.expected {
364 Multiplicity::Linear | Multiplicity::Relevant => {
365 if usage1.use_count == 0 || usage2.use_count == 0 {
367 return Err(IrError::LinearityViolation(format!(
368 "Variable '{}' must be used in both branches",
369 name
370 )));
371 }
372 min_uses
373 }
374 Multiplicity::Affine | Multiplicity::Unrestricted => max_uses,
375 };
376
377 let mut merged_usage = Usage::new(name, usage1.expected.clone());
378 merged_usage.use_count = use_count;
379 merged.usage.insert(name.clone(), merged_usage);
380 }
381 }
382
383 merged.consumed = self
385 .consumed
386 .intersection(&other.consumed)
387 .cloned()
388 .collect();
389
390 Ok(merged)
391 }
392
393 pub fn split(&mut self, vars: &[String]) -> Result<LinearContext, IrError> {
395 let mut split_ctx = LinearContext::new();
396
397 for var in vars {
398 if let Some(typ) = self.bindings.remove(var) {
399 if typ.is_linear() {
400 split_ctx.bind(var, typ);
402 self.consumed.insert(var.clone());
403 } else if typ.is_unrestricted() {
404 split_ctx.bind(var, typ.clone());
406 self.bindings.insert(var.clone(), typ);
407 } else {
408 return Err(IrError::LinearityViolation(format!(
409 "Cannot split variable '{}' with multiplicity {}",
410 var, typ.multiplicity
411 )));
412 }
413 }
414 }
415
416 Ok(split_ctx)
417 }
418}
419
420#[derive(Clone, Debug)]
422pub struct LinearityChecker {
423 context: LinearContext,
424 errors: Vec<String>,
425}
426
427impl LinearityChecker {
428 pub fn new() -> Self {
429 LinearityChecker {
430 context: LinearContext::new(),
431 errors: Vec::new(),
432 }
433 }
434
435 pub fn bind(&mut self, name: impl Into<String>, linear_type: LinearType) {
437 self.context.bind(name, linear_type);
438 }
439
440 pub fn use_var(&mut self, name: &str) {
442 if let Err(e) = self.context.use_var(name) {
443 self.errors.push(format!("{}", e));
444 }
445 }
446
447 pub fn check(&self) -> Result<(), Vec<String>> {
449 let mut all_errors = self.errors.clone();
450
451 if let Err(mut usage_errors) = self.context.validate() {
452 all_errors.append(&mut usage_errors);
453 }
454
455 if all_errors.is_empty() {
456 Ok(())
457 } else {
458 Err(all_errors)
459 }
460 }
461
462 pub fn context(&self) -> &LinearContext {
464 &self.context
465 }
466
467 pub fn context_mut(&mut self) -> &mut LinearContext {
469 &mut self.context
470 }
471}
472
473impl Default for LinearityChecker {
474 fn default() -> Self {
475 Self::new()
476 }
477}
478
479#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
481pub enum Capability {
482 Read,
484 Write,
486 Execute,
488 Own,
490}
491
492#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
494pub struct LinearResource {
495 pub resource_type: LinearType,
497 pub capabilities: HashSet<Capability>,
499}
500
501impl LinearResource {
502 pub fn new(resource_type: LinearType, capabilities: HashSet<Capability>) -> Self {
503 LinearResource {
504 resource_type,
505 capabilities,
506 }
507 }
508
509 pub fn has_capability(&self, cap: &Capability) -> bool {
511 self.capabilities.contains(cap)
512 }
513
514 pub fn read_only(resource_type: LinearType) -> Self {
516 let mut caps = HashSet::new();
517 caps.insert(Capability::Read);
518 LinearResource::new(resource_type, caps)
519 }
520
521 pub fn read_write(resource_type: LinearType) -> Self {
523 let mut caps = HashSet::new();
524 caps.insert(Capability::Read);
525 caps.insert(Capability::Write);
526 LinearResource::new(resource_type, caps)
527 }
528
529 pub fn owned(resource_type: LinearType) -> Self {
531 let mut caps = HashSet::new();
532 caps.insert(Capability::Read);
533 caps.insert(Capability::Write);
534 caps.insert(Capability::Own);
535 LinearResource::new(resource_type, caps)
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542
543 #[test]
544 fn test_multiplicity_allows() {
545 assert!(Multiplicity::Linear.allows(1));
546 assert!(!Multiplicity::Linear.allows(0));
547 assert!(!Multiplicity::Linear.allows(2));
548
549 assert!(Multiplicity::Affine.allows(0));
550 assert!(Multiplicity::Affine.allows(1));
551 assert!(!Multiplicity::Affine.allows(2));
552
553 assert!(!Multiplicity::Relevant.allows(0));
554 assert!(Multiplicity::Relevant.allows(1));
555 assert!(Multiplicity::Relevant.allows(2));
556
557 assert!(Multiplicity::Unrestricted.allows(0));
558 assert!(Multiplicity::Unrestricted.allows(1));
559 assert!(Multiplicity::Unrestricted.allows(100));
560 }
561
562 #[test]
563 fn test_multiplicity_combine() {
564 assert_eq!(
565 Multiplicity::Linear.combine(&Multiplicity::Linear),
566 Multiplicity::Linear
567 );
568 assert_eq!(
569 Multiplicity::Unrestricted.combine(&Multiplicity::Unrestricted),
570 Multiplicity::Unrestricted
571 );
572 assert_eq!(
573 Multiplicity::Linear.combine(&Multiplicity::Unrestricted),
574 Multiplicity::Linear
575 );
576 }
577
578 #[test]
579 fn test_linear_type_creation() {
580 let linear_tensor = LinearType::linear("Tensor");
581 assert!(linear_tensor.is_linear());
582 assert!(!linear_tensor.is_unrestricted());
583
584 let unrestricted_int = LinearType::unrestricted("Int");
585 assert!(!unrestricted_int.is_linear());
586 assert!(unrestricted_int.is_unrestricted());
587 }
588
589 #[test]
590 fn test_linear_context_basic() {
591 let mut ctx = LinearContext::new();
592 let tensor_type = LinearType::linear("Tensor");
593
594 ctx.bind("x", tensor_type);
595 assert!(ctx.is_linear("x"));
596 assert!(!ctx.is_consumed("x"));
597
598 assert!(ctx.use_var("x").is_ok());
600 assert!(ctx.is_consumed("x"));
601
602 assert!(ctx.use_var("x").is_err());
604 }
605
606 #[test]
607 fn test_affine_type_usage() {
608 let mut ctx = LinearContext::new();
609 let affine_type = LinearType::affine("File");
610
611 ctx.bind("f", affine_type);
612
613 assert!(ctx.validate().is_ok());
615
616 assert!(ctx.use_var("f").is_ok());
618 assert!(ctx.validate().is_ok());
619 }
620
621 #[test]
622 fn test_relevant_type_usage() {
623 let mut ctx = LinearContext::new();
624 let relevant_type = LinearType::relevant("Resource");
625
626 ctx.bind("r", relevant_type);
627
628 assert!(ctx.validate().is_err());
630
631 let mut ctx2 = LinearContext::new();
632 ctx2.bind("r", LinearType::relevant("Resource"));
633 assert!(ctx2.use_var("r").is_ok());
634 assert!(ctx2.use_var("r").is_ok()); assert!(ctx2.validate().is_ok());
636 }
637
638 #[test]
639 fn test_unrestricted_type_usage() {
640 let mut ctx = LinearContext::new();
641 let unrestricted_type = LinearType::unrestricted("Int");
642
643 ctx.bind("x", unrestricted_type);
644
645 for _ in 0..10 {
647 assert!(ctx.use_var("x").is_ok());
648 }
649 assert!(ctx.validate().is_ok());
650 }
651
652 #[test]
653 fn test_linearity_checker() {
654 let mut checker = LinearityChecker::new();
655
656 checker.bind("x", LinearType::linear("Tensor"));
657 checker.bind("y", LinearType::unrestricted("Int"));
658
659 checker.use_var("x");
661
662 checker.use_var("y");
664 checker.use_var("y");
665
666 assert!(checker.check().is_ok());
668 }
669
670 #[test]
671 fn test_linearity_checker_violation() {
672 let mut checker = LinearityChecker::new();
673
674 checker.bind("x", LinearType::linear("Tensor"));
675
676 checker.use_var("x");
678 checker.use_var("x");
679
680 assert!(checker.check().is_err());
681 }
682
683 #[test]
684 fn test_context_merge() {
685 let mut ctx1 = LinearContext::new();
686 let mut ctx2 = LinearContext::new();
687
688 ctx1.bind("x", LinearType::unrestricted("Int"));
690 ctx2.bind("x", LinearType::unrestricted("Int"));
691
692 ctx1.use_var("x").unwrap();
694 ctx2.use_var("x").unwrap();
695 ctx2.use_var("x").unwrap();
696
697 let merged = ctx1.merge(&ctx2);
699 assert!(merged.is_ok());
700 }
701
702 #[test]
703 fn test_linear_resource_capabilities() {
704 let tensor_type = LinearType::linear("Tensor");
705 let resource = LinearResource::read_only(tensor_type);
706
707 assert!(resource.has_capability(&Capability::Read));
708 assert!(!resource.has_capability(&Capability::Write));
709 assert!(!resource.has_capability(&Capability::Own));
710 }
711
712 #[test]
713 fn test_get_unused_required() {
714 let mut ctx = LinearContext::new();
715
716 ctx.bind("x", LinearType::linear("Tensor"));
717 ctx.bind("y", LinearType::unrestricted("Int"));
718 ctx.bind("z", LinearType::relevant("Resource"));
719
720 let unused = ctx.get_unused_required();
722 assert_eq!(unused.len(), 2);
723 assert!(unused.contains(&"x".to_string()));
724 assert!(unused.contains(&"z".to_string()));
725 }
726
727 #[test]
728 fn test_context_split() {
729 let mut ctx = LinearContext::new();
730
731 ctx.bind("x", LinearType::linear("Tensor"));
732 ctx.bind("y", LinearType::unrestricted("Int"));
733
734 let split = ctx.split(&["x".to_string()]);
736 assert!(split.is_ok());
737
738 let split_ctx = split.unwrap();
739 assert!(split_ctx.get_type("x").is_some());
740 assert!(ctx.is_consumed("x"));
741
742 assert!(ctx.get_type("y").is_some());
744 assert!(!ctx.is_consumed("y"));
745 }
746
747 #[test]
748 fn test_linear_type_display() {
749 let linear = LinearType::linear("Tensor");
750 assert_eq!(linear.to_string(), "Tensor<1>");
751
752 let affine = LinearType::affine("File");
753 assert_eq!(affine.to_string(), "File<0..1>");
754
755 let unrestricted = LinearType::unrestricted("Int");
756 assert_eq!(unrestricted.to_string(), "Int<0..>");
757 }
758}