Skip to main content

seqc/typechecker/
state.rs

1//! Constructors, accessors, and bookkeeping for TypeChecker.
2use std::collections::HashMap;
3
4use crate::call_graph::CallGraph;
5use crate::types::{Effect, StackType, Type, UnionTypeInfo, VariantInfo};
6
7use super::{TypeChecker, format_line_prefix};
8
9impl TypeChecker {
10    pub fn new() -> Self {
11        TypeChecker {
12            env: HashMap::new(),
13            unions: HashMap::new(),
14            fresh_counter: std::cell::Cell::new(0),
15            quotation_types: std::cell::RefCell::new(HashMap::new()),
16            expected_quotation_type: std::cell::RefCell::new(None),
17            current_word: std::cell::RefCell::new(None),
18            statement_top_types: std::cell::RefCell::new(HashMap::new()),
19            call_graph: None,
20            current_aux_stack: std::cell::RefCell::new(StackType::Empty),
21            aux_max_depths: std::cell::RefCell::new(HashMap::new()),
22            quotation_aux_depths: std::cell::RefCell::new(HashMap::new()),
23            quotation_id_stack: std::cell::RefCell::new(Vec::new()),
24            resolved_sugar: std::cell::RefCell::new(HashMap::new()),
25        }
26    }
27
28    /// Set the call graph for mutual recursion detection.
29    ///
30    /// When set, the type checker can detect divergent branches caused by
31    /// mutual recursion (e.g., even/odd pattern) in addition to direct recursion.
32    pub fn set_call_graph(&mut self, call_graph: CallGraph) {
33        self.call_graph = Some(call_graph);
34    }
35
36    /// Get line info prefix for error messages (e.g., "at line 42: " or "")
37    pub(super) fn line_prefix(&self) -> String {
38        self.current_word
39            .borrow()
40            .as_ref()
41            .and_then(|(_, line)| line.map(format_line_prefix))
42            .unwrap_or_default()
43    }
44
45    /// Look up a union type by name
46    pub fn get_union(&self, name: &str) -> Option<&UnionTypeInfo> {
47        self.unions.get(name)
48    }
49
50    /// Get all registered union types
51    pub fn get_unions(&self) -> &HashMap<String, UnionTypeInfo> {
52        &self.unions
53    }
54
55    /// Find variant info by name across all unions
56    ///
57    /// Returns (union_name, variant_info) for the variant
58    pub(super) fn find_variant(&self, variant_name: &str) -> Option<(&str, &VariantInfo)> {
59        for (union_name, union_info) in &self.unions {
60            for variant in &union_info.variants {
61                if variant.name == variant_name {
62                    return Some((union_name.as_str(), variant));
63                }
64            }
65        }
66        None
67    }
68
69    /// Register external word effects (e.g., from included modules or FFI).
70    ///
71    /// All external words must have explicit stack effects for type safety.
72    pub fn register_external_words(&mut self, words: &[(&str, &Effect)]) {
73        for (name, effect) in words {
74            self.env.insert(name.to_string(), (*effect).clone());
75        }
76    }
77
78    /// Register external union type names (e.g., from included modules).
79    ///
80    /// This allows field types in union definitions to reference types from includes.
81    /// We only register the name as a valid type; we don't need full variant info
82    /// since the actual union definition lives in the included file.
83    pub fn register_external_unions(&mut self, union_names: &[&str]) {
84        for name in union_names {
85            // Insert a placeholder union with no variants
86            // This makes is_valid_type_name() return true for this type
87            self.unions.insert(
88                name.to_string(),
89                UnionTypeInfo {
90                    name: name.to_string(),
91                    variants: vec![],
92                },
93            );
94        }
95    }
96
97    /// Extract the type map (quotation ID -> inferred type)
98    ///
99    /// This should be called after check_program() to get the inferred types
100    /// for all quotations in the program. The map is used by codegen to generate
101    /// appropriate code for Quotations vs Closures.
102    pub fn take_quotation_types(&self) -> HashMap<usize, Type> {
103        self.quotation_types.replace(HashMap::new())
104    }
105
106    /// Extract per-statement type info for codegen optimization (Issue #186)
107    /// Returns map of (word_name, statement_index) -> top-of-stack type
108    pub fn take_statement_top_types(&self) -> HashMap<(String, usize), Type> {
109        self.statement_top_types.replace(HashMap::new())
110    }
111
112    /// Extract resolved arithmetic sugar for codegen
113    /// Maps (line, column) -> concrete operation name
114    pub fn take_resolved_sugar(&self) -> HashMap<(usize, usize), String> {
115        self.resolved_sugar.replace(HashMap::new())
116    }
117
118    /// Extract per-word aux stack max depths for codegen alloca sizing (Issue #350)
119    pub fn take_aux_max_depths(&self) -> HashMap<String, usize> {
120        self.aux_max_depths.replace(HashMap::new())
121    }
122
123    /// Extract per-quotation aux stack max depths for codegen alloca sizing (Issue #393)
124    /// Maps quotation_id -> max_depth
125    pub fn take_quotation_aux_depths(&self) -> HashMap<usize, usize> {
126        self.quotation_aux_depths.replace(HashMap::new())
127    }
128
129    /// Count the number of concrete types in a StackType (for aux depth tracking)
130    pub(super) fn capture_statement_type(
131        &self,
132        word_name: &str,
133        stmt_index: usize,
134        stack: &StackType,
135    ) {
136        if let Some(top_type) = Self::get_trivially_copyable_top(stack) {
137            self.statement_top_types
138                .borrow_mut()
139                .insert((word_name.to_string(), stmt_index), top_type);
140        }
141    }
142}