telltale_language/extensions/
timeout.rs1use super::{
7 CodegenContext, ExtensionRegistry, ExtensionValidationError, GrammarExtension, ParseContext,
8 ParseError, ProjectionContext, ProtocolExtension, StatementParser,
9};
10use crate::ast::{LocalType, Role};
11use crate::compiler::projection::ProjectionError;
12use std::any::{Any, TypeId};
13use std::time::Duration;
14
15#[derive(Debug)]
17pub struct TimeoutGrammarExtension;
18
19impl GrammarExtension for TimeoutGrammarExtension {
20 fn grammar_rules(&self) -> &'static str {
21 r#"
22timeout_stmt = { "timeout" ~ timeout_duration ~ timeout_roles ~ "{" ~ protocol_body ~ "}" }
23timeout_duration = { integer ~ time_unit? }
24time_unit = { "ms" | "s" | "m" | "h" }
25timeout_roles = { "(" ~ role_list ~ ")" | role_ref }
26"#
27 }
28
29 fn statement_rules(&self) -> Vec<&'static str> {
30 vec!["timeout_stmt"]
31 }
32
33 fn priority(&self) -> u32 {
34 200 }
36
37 fn extension_id(&self) -> &'static str {
38 "timeout"
39 }
40}
41
42#[derive(Debug)]
44pub struct TimeoutStatementParser;
45
46impl StatementParser for TimeoutStatementParser {
47 fn can_parse(&self, rule_name: &str) -> bool {
48 rule_name == "timeout_stmt"
49 }
50
51 fn supported_rules(&self) -> Vec<String> {
52 vec!["timeout_stmt".to_string()]
53 }
54
55 fn parse_statement(
56 &self,
57 rule_name: &str,
58 _content: &str,
59 context: &ParseContext,
60 ) -> Result<Box<dyn ProtocolExtension>, ParseError> {
61 if rule_name != "timeout_stmt" {
62 return Err(ParseError::InvalidSyntax {
63 details: format!("Expected timeout_stmt, got {}", rule_name),
64 });
65 }
66
67 let timeout_protocol = self.parse_timeout_content(_content, context)?;
70 Ok(Box::new(timeout_protocol))
71 }
72}
73
74impl TimeoutStatementParser {
75 fn parse_timeout_content(
76 &self,
77 content: &str,
78 context: &ParseContext,
79 ) -> Result<TimeoutProtocol, ParseError> {
80 let duration_ms = self.extract_duration(content)?;
82 let roles = self.extract_roles(content, context)?;
83
84 Ok(TimeoutProtocol {
86 duration: Duration::from_millis(duration_ms),
87 role_names: roles.iter().map(|r| r.name().to_string()).collect(),
88 body_repr: "End".to_string(),
89 })
90 }
91
92 fn extract_duration(&self, content: &str) -> Result<u64, ParseError> {
93 let duration_str = content
96 .split_whitespace()
97 .find(|s| s.chars().all(|c| c.is_ascii_digit()))
98 .ok_or_else(|| ParseError::InvalidSyntax {
99 details: "Could not find timeout duration".to_string(),
100 })?;
101
102 duration_str.parse().map_err(|_| ParseError::InvalidSyntax {
103 details: "Invalid timeout duration format".to_string(),
104 })
105 }
106
107 fn extract_roles(
108 &self,
109 _content: &str,
110 context: &ParseContext,
111 ) -> Result<Vec<Role>, ParseError> {
112 Ok(context.declared_roles.to_vec())
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct TimeoutProtocol {
121 pub duration: Duration,
122 pub role_names: Vec<String>, pub body_repr: String,
126}
127
128impl ProtocolExtension for TimeoutProtocol {
129 fn type_name(&self) -> &'static str {
130 "TimeoutProtocol"
131 }
132
133 fn mentions_role(&self, role: &Role) -> bool {
134 self.role_names
135 .iter()
136 .any(|name| name == &role.name().to_string())
137 }
138
139 fn validate(&self, all_roles: &[Role]) -> Result<(), ExtensionValidationError> {
140 for role_name in &self.role_names {
142 if !all_roles.iter().any(|r| &r.name().to_string() == role_name) {
143 return Err(ExtensionValidationError::UndeclaredRole {
144 role: role_name.clone(),
145 });
146 }
147 }
148
149 if self.duration.is_zero() {
151 return Err(ExtensionValidationError::InvalidStructure {
152 reason: "Timeout duration cannot be zero".to_string(),
153 });
154 }
155
156 if self.duration > Duration::from_secs(3600) {
157 return Err(ExtensionValidationError::InvalidStructure {
158 reason: "Timeout duration too long (max 1 hour)".to_string(),
159 });
160 }
161
162 Ok(())
163 }
164
165 fn project(
166 &self,
167 role: &Role,
168 _context: &ProjectionContext,
169 ) -> Result<LocalType, ProjectionError> {
170 if self
171 .role_names
172 .iter()
173 .any(|name| name == &role.name().to_string())
174 {
175 Ok(LocalType::Timeout {
178 duration: self.duration,
179 body: Box::new(LocalType::End),
180 })
181 } else {
182 Ok(LocalType::End)
184 }
185 }
186
187 fn generate_code(&self, _context: &CodegenContext) -> proc_macro2::TokenStream {
188 let duration_ms = u64::try_from(self.duration.as_millis()).unwrap_or(u64::MAX);
189 let _role_names = &self.role_names;
190
191 quote::quote! {
192 .with_timeout(
194 Duration::from_millis(#duration_ms),
195 )
197 }
198 }
199
200 fn as_any(&self) -> &dyn Any {
201 self
202 }
203
204 fn as_any_mut(&mut self) -> &mut dyn Any {
205 self
206 }
207
208 fn type_id(&self) -> TypeId {
209 TypeId::of::<Self>()
210 }
211
212 fn clone_box(&self) -> Box<dyn ProtocolExtension> {
213 Box::new(self.clone())
214 }
215}
216
217pub fn register_timeout_extension(
223 registry: &mut ExtensionRegistry,
224) -> Result<(), crate::extensions::ParseError> {
225 registry.register_grammar(TimeoutGrammarExtension)?;
226 registry.register_parser(TimeoutStatementParser, "timeout".to_string());
227 Ok(())
228}
229
230impl LocalType {
232 pub fn timeout(duration: Duration, body: LocalType) -> Self {
233 Self::Timeout {
234 duration,
235 body: Box::new(body),
236 }
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_timeout_grammar_extension() {
246 let ext = TimeoutGrammarExtension;
247 assert_eq!(ext.extension_id(), "timeout");
248 assert!(ext.statement_rules().contains(&"timeout_stmt"));
249 assert!(ext.grammar_rules().contains("timeout_stmt"));
250 }
251
252 #[test]
253 fn test_timeout_statement_parser() {
254 let parser = TimeoutStatementParser;
255 assert!(parser.can_parse("timeout_stmt"));
256 assert!(!parser.can_parse("unknown_stmt"));
257 }
258
259 #[test]
260 fn test_timeout_protocol() {
261 let timeout_protocol = TimeoutProtocol {
262 duration: Duration::from_millis(5000),
263 role_names: vec!["Alice".to_string()],
264 body_repr: "End".to_string(),
265 };
266
267 assert_eq!(timeout_protocol.type_name(), "TimeoutProtocol");
268
269 use proc_macro2::Span;
270 let alice = Role::new(proc_macro2::Ident::new("Alice", Span::call_site())).unwrap();
271 let bob = Role::new(proc_macro2::Ident::new("Bob", Span::call_site())).unwrap();
272
273 assert!(timeout_protocol.mentions_role(&alice));
274 assert!(!timeout_protocol.mentions_role(&bob));
275 }
276
277 #[test]
278 fn test_timeout_validation() {
279 use proc_macro2::Span;
280 let roles = vec![Role::new(proc_macro2::Ident::new("Alice", Span::call_site())).unwrap()];
281
282 let valid_timeout = TimeoutProtocol {
283 duration: Duration::from_millis(5000),
284 role_names: roles.iter().map(|r| r.name().to_string()).collect(),
285 body_repr: "End".to_string(),
286 };
287
288 assert!(valid_timeout.validate(&roles).is_ok());
289
290 let invalid_timeout = TimeoutProtocol {
292 duration: Duration::ZERO,
293 role_names: roles.iter().map(|r| r.name().to_string()).collect(),
294 body_repr: "End".to_string(),
295 };
296
297 assert!(invalid_timeout.validate(&roles).is_err());
298 }
299
300 #[test]
301 fn test_extension_registration() {
302 let mut registry = ExtensionRegistry::new();
303 register_timeout_extension(&mut registry).expect("extension should register");
304
305 assert!(registry.can_handle("timeout_stmt"));
306 }
307}