telltale_language/extensions/
timeout.rs1use super::{
7 CodegenContext, ExtensionRegistry, ExtensionValidationError, GrammarExtension, ParseContext,
8 ParseError, ProjectionContext, ProtocolExtension, StatementParser,
9};
10use crate::ast::{LocalType, Role};
11use crate::compiler::projection::ProjectionError;
12use std::any::{Any, TypeId};
13use std::time::Duration;
14
15#[derive(Debug)]
17pub struct TimeoutGrammarExtension;
18
19impl GrammarExtension for TimeoutGrammarExtension {
20 fn grammar_rules(&self) -> &'static str {
21 r#"
22timeout_ext_stmt = { "timeout" ~ timeout_ext_duration ~ timeout_ext_roles ~ "{" ~ protocol_body ~ "}" }
23timeout_ext_duration = { integer ~ timeout_ext_time_unit? }
24timeout_ext_time_unit = { "ms" | "s" | "m" | "h" }
25timeout_ext_roles = { "(" ~ role_list ~ ")" | role_ref }
26"#
27 }
28
29 fn statement_rules(&self) -> Vec<&'static str> {
30 vec!["timeout_ext_stmt"]
31 }
32
33 fn priority(&self) -> u32 {
34 200 }
36
37 fn extension_id(&self) -> &'static str {
38 "timeout"
39 }
40}
41
42#[derive(Debug)]
44pub struct TimeoutStatementParser;
45
46impl StatementParser for TimeoutStatementParser {
47 fn can_parse(&self, rule_name: &str) -> bool {
48 rule_name == "timeout_ext_stmt"
49 }
50
51 fn supported_rules(&self) -> Vec<String> {
52 vec!["timeout_ext_stmt".to_string()]
53 }
54
55 fn parse_statement(
56 &self,
57 rule_name: &str,
58 _content: &str,
59 context: &ParseContext,
60 ) -> Result<Box<dyn ProtocolExtension>, ParseError> {
61 if rule_name != "timeout_ext_stmt" {
62 return Err(ParseError::InvalidSyntax {
63 details: format!("Expected timeout_ext_stmt, got {}", rule_name),
64 });
65 }
66
67 let timeout_protocol = self.parse_timeout_content(_content, context)?;
70 Ok(Box::new(timeout_protocol))
71 }
72}
73
74impl TimeoutStatementParser {
75 fn parse_timeout_content(
76 &self,
77 content: &str,
78 context: &ParseContext,
79 ) -> Result<TimeoutProtocol, ParseError> {
80 let duration_ms = self.extract_duration(content)?;
82 let roles = self.extract_roles(content, context)?;
83
84 Ok(TimeoutProtocol {
86 duration: Duration::from_millis(duration_ms),
87 role_names: roles.iter().map(|r| r.name().to_string()).collect(),
88 body_repr: "End".to_string(),
89 })
90 }
91
92 fn extract_duration(&self, content: &str) -> Result<u64, ParseError> {
93 let duration_str = content
96 .split_whitespace()
97 .find(|s| s.chars().all(|c| c.is_ascii_digit()))
98 .ok_or_else(|| ParseError::InvalidSyntax {
99 details: "Could not find timeout duration".to_string(),
100 })?;
101
102 duration_str.parse().map_err(|_| ParseError::InvalidSyntax {
103 details: "Invalid timeout duration format".to_string(),
104 })
105 }
106
107 fn extract_roles(
108 &self,
109 _content: &str,
110 context: &ParseContext,
111 ) -> Result<Vec<Role>, ParseError> {
112 Ok(context.declared_roles.to_vec())
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct TimeoutProtocol {
121 pub duration: Duration,
122 pub role_names: Vec<String>, pub body_repr: String,
126}
127
128impl ProtocolExtension for TimeoutProtocol {
129 fn type_name(&self) -> &'static str {
130 "TimeoutProtocol"
131 }
132
133 fn mentions_role(&self, role: &Role) -> bool {
134 self.role_names
135 .iter()
136 .any(|name| name == &role.name().to_string())
137 }
138
139 fn validate(&self, all_roles: &[Role]) -> Result<(), ExtensionValidationError> {
140 for role_name in &self.role_names {
142 if !all_roles.iter().any(|r| &r.name().to_string() == role_name) {
143 return Err(ExtensionValidationError::UndeclaredRole {
144 role: role_name.clone(),
145 });
146 }
147 }
148
149 if self.duration.is_zero() {
151 return Err(ExtensionValidationError::InvalidStructure {
152 reason: "Timeout duration cannot be zero".to_string(),
153 });
154 }
155
156 if self.duration > Duration::from_secs(3600) {
157 return Err(ExtensionValidationError::InvalidStructure {
158 reason: "Timeout duration too long (max 1 hour)".to_string(),
159 });
160 }
161
162 Ok(())
163 }
164
165 fn project(
166 &self,
167 role: &Role,
168 _context: &ProjectionContext,
169 ) -> Result<LocalType, ProjectionError> {
170 if self
171 .role_names
172 .iter()
173 .any(|name| name == &role.name().to_string())
174 {
175 Ok(LocalType::Timeout {
178 duration: self.duration,
179 body: Box::new(LocalType::End),
180 on_timeout: Box::new(LocalType::End),
181 on_cancel: None,
182 })
183 } else {
184 Ok(LocalType::End)
186 }
187 }
188
189 fn generate_code(&self, _context: &CodegenContext) -> proc_macro2::TokenStream {
190 let duration_ms = u64::try_from(self.duration.as_millis()).unwrap_or(u64::MAX);
191 let _role_names = &self.role_names;
192
193 quote::quote! {
194 .with_timeout(
196 Duration::from_millis(#duration_ms),
197 )
199 }
200 }
201
202 fn as_any(&self) -> &dyn Any {
203 self
204 }
205
206 fn as_any_mut(&mut self) -> &mut dyn Any {
207 self
208 }
209
210 fn type_id(&self) -> TypeId {
211 TypeId::of::<Self>()
212 }
213
214 fn clone_box(&self) -> Box<dyn ProtocolExtension> {
215 Box::new(self.clone())
216 }
217}
218
219pub fn register_timeout_extension(
225 registry: &mut ExtensionRegistry,
226) -> Result<(), crate::extensions::ParseError> {
227 registry.register_grammar(TimeoutGrammarExtension)?;
228 registry.register_parser(TimeoutStatementParser, "timeout".to_string());
229 Ok(())
230}
231
232impl LocalType {
234 pub fn timeout(duration: Duration, body: LocalType) -> Self {
235 Self::Timeout {
236 duration,
237 body: Box::new(body),
238 on_timeout: Box::new(LocalType::End),
239 on_cancel: None,
240 }
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[test]
249 fn test_timeout_grammar_extension() {
250 let ext = TimeoutGrammarExtension;
251 assert_eq!(ext.extension_id(), "timeout");
252 assert!(ext.statement_rules().contains(&"timeout_ext_stmt"));
253 assert!(ext.grammar_rules().contains("timeout_ext_stmt"));
254 }
255
256 #[test]
257 fn test_timeout_statement_parser() {
258 let parser = TimeoutStatementParser;
259 assert!(parser.can_parse("timeout_ext_stmt"));
260 assert!(!parser.can_parse("unknown_stmt"));
261 }
262
263 #[test]
264 fn test_timeout_protocol() {
265 let timeout_protocol = TimeoutProtocol {
266 duration: Duration::from_millis(5000),
267 role_names: vec!["Alice".to_string()],
268 body_repr: "End".to_string(),
269 };
270
271 assert_eq!(timeout_protocol.type_name(), "TimeoutProtocol");
272
273 use proc_macro2::Span;
274 let alice = Role::new(proc_macro2::Ident::new("Alice", Span::call_site())).unwrap();
275 let bob = Role::new(proc_macro2::Ident::new("Bob", Span::call_site())).unwrap();
276
277 assert!(timeout_protocol.mentions_role(&alice));
278 assert!(!timeout_protocol.mentions_role(&bob));
279 }
280
281 #[test]
282 fn test_timeout_validation() {
283 use proc_macro2::Span;
284 let roles = vec![Role::new(proc_macro2::Ident::new("Alice", Span::call_site())).unwrap()];
285
286 let valid_timeout = TimeoutProtocol {
287 duration: Duration::from_millis(5000),
288 role_names: roles.iter().map(|r| r.name().to_string()).collect(),
289 body_repr: "End".to_string(),
290 };
291
292 assert!(valid_timeout.validate(&roles).is_ok());
293
294 let invalid_timeout = TimeoutProtocol {
296 duration: Duration::ZERO,
297 role_names: roles.iter().map(|r| r.name().to_string()).collect(),
298 body_repr: "End".to_string(),
299 };
300
301 assert!(invalid_timeout.validate(&roles).is_err());
302 }
303
304 #[test]
305 fn test_extension_registration() {
306 let mut registry = ExtensionRegistry::new();
307 register_timeout_extension(&mut registry).expect("extension should register");
308
309 assert!(registry.can_handle("timeout_ext_stmt"));
310 }
311}