Skip to main content

shape_runtime/
type_methods.rs

1//! Type method registry for storing user-defined methods on types
2//!
3//! This module provides the infrastructure for storing and retrieving
4//! methods that have been added to types via the `extend` statement.
5
6use crate::type_system::annotation_to_string;
7use shape_ast::ast::{MethodDef, TypeName};
8use shape_value::ValueWord;
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11
12/// Registry for storing type methods
13#[derive(Debug, Clone)]
14pub struct TypeMethodRegistry {
15    /// Methods stored by type name and method name
16    /// The Vec allows for method overloading
17    methods: Arc<RwLock<HashMap<String, HashMap<String, Vec<MethodDef>>>>>,
18}
19
20impl TypeMethodRegistry {
21    /// Create a new empty registry
22    pub fn new() -> Self {
23        Self {
24            methods: Arc::new(RwLock::new(HashMap::new())),
25        }
26    }
27
28    /// Register a method for a type
29    pub fn register_method(&self, type_name: &TypeName, method: MethodDef) {
30        let mut methods = self.methods.write().unwrap();
31
32        // Get the type name as a string
33        let type_str = match type_name {
34            TypeName::Simple(name) => name.to_string(),
35            TypeName::Generic { name, type_args } => {
36                // Convert generic types with their full signature
37                // e.g., "Table<Row>", "Vec<Number>"
38                if type_args.is_empty() {
39                    name.to_string()
40                } else {
41                    // Convert type arguments to strings
42                    let type_arg_strs: Vec<String> =
43                        type_args.iter().map(annotation_to_string).collect();
44                    format!("{}<{}>", name, type_arg_strs.join(", "))
45                }
46            }
47        };
48
49        // Get or create the method map for this type
50        let type_methods = methods.entry(type_str.clone()).or_default();
51
52        // Add the method to the overload list
53        type_methods
54            .entry(method.name.clone())
55            .or_default()
56            .push(method);
57    }
58
59    /// Get all methods for a type with a given name
60    pub fn get_methods(&self, type_name: &str, method_name: &str) -> Option<Vec<MethodDef>> {
61        let methods = self.methods.read().unwrap();
62        methods
63            .get(type_name)
64            .and_then(|type_methods| type_methods.get(method_name))
65            .cloned()
66    }
67
68    /// Get the type name for a value
69    pub fn get_value_type_name(value: &ValueWord) -> String {
70        value.type_name().to_string()
71    }
72
73    /// Get all methods for a type
74    pub fn get_all_methods(&self, type_name: &str) -> Vec<MethodDef> {
75        let methods = self.methods.read().unwrap();
76
77        methods
78            .get(type_name)
79            .map(|type_methods| type_methods.values().flatten().cloned().collect())
80            .unwrap_or_default()
81    }
82
83    /// Check if a type has any methods registered
84    pub fn has_type(&self, type_name: &str) -> bool {
85        let methods = self.methods.read().unwrap();
86        methods.contains_key(type_name)
87    }
88
89    /// Get all registered type names
90    pub fn get_registered_types(&self) -> Vec<String> {
91        let methods = self.methods.read().unwrap();
92        methods.keys().cloned().collect()
93    }
94
95    /// Get a debug string representation of the registry state
96    pub fn debug_state(&self) -> String {
97        let methods = self.methods.read().unwrap();
98        let mut output = String::new();
99
100        output.push_str("TypeMethodRegistry State:\n");
101        output.push_str(&format!("  Total registered types: {}\n", methods.len()));
102
103        if methods.is_empty() {
104            output.push_str("  (No types registered)\n");
105        } else {
106            for (type_name, type_methods) in methods.iter() {
107                output.push_str(&format!("  Type: {}\n", type_name));
108                for (method_name, overloads) in type_methods.iter() {
109                    output.push_str(&format!(
110                        "    Method: {} ({} overloads)\n",
111                        method_name,
112                        overloads.len()
113                    ));
114                    for (i, overload) in overloads.iter().enumerate() {
115                        output.push_str(&format!(
116                            "      Overload {}: {} params",
117                            i + 1,
118                            overload.params.len()
119                        ));
120                        if overload.when_clause.is_some() {
121                            output.push_str(" (with when clause)");
122                        }
123                        output.push('\n');
124                    }
125                }
126            }
127        }
128
129        output
130    }
131}
132
133impl Default for TypeMethodRegistry {
134    fn default() -> Self {
135        Self::new()
136    }
137}