Skip to main content

shape_ast/parser/
functions.rs

1//! Function and annotation parsing for Shape
2//!
3//! This module handles parsing of:
4//! - Function definitions with parameters and return types
5//! - Function parameters with default values
6//! - Annotations (@warmup, @strategy, etc.)
7
8use crate::ast::{
9    Annotation, BuiltinFunctionDecl, ForeignFunctionDef, FunctionDef, FunctionParameter,
10    NativeAbiBinding,
11};
12use crate::error::Result;
13use pest::iterators::Pair;
14
15use super::expressions;
16use super::statements;
17use super::string_literals::parse_string_literal;
18use super::types;
19use super::types::parse_type_annotation;
20use super::{Rule, pair_span};
21
22/// Parse annotations
23pub fn parse_annotations(pair: Pair<Rule>) -> Result<Vec<Annotation>> {
24    let mut annotations = vec![];
25
26    for annotation_pair in pair.into_inner() {
27        if annotation_pair.as_rule() == Rule::annotation {
28            annotations.push(parse_annotation(annotation_pair)?);
29        }
30    }
31
32    Ok(annotations)
33}
34
35/// Parse a single annotation
36pub fn parse_annotation(pair: Pair<Rule>) -> Result<Annotation> {
37    let span = pair_span(&pair);
38    let mut name = String::new();
39    let mut args = Vec::new();
40
41    for inner_pair in pair.into_inner() {
42        match inner_pair.as_rule() {
43            Rule::annotation_ref => {
44                name = inner_pair.as_str().to_string();
45            }
46            Rule::annotation_name | Rule::ident => {
47                name = inner_pair.as_str().to_string();
48            }
49            Rule::annotation_args => {
50                for arg_pair in inner_pair.into_inner() {
51                    if arg_pair.as_rule() == Rule::expression {
52                        args.push(expressions::parse_expression(arg_pair)?);
53                    }
54                }
55            }
56            Rule::expression => {
57                args.push(expressions::parse_expression(inner_pair)?);
58            }
59            _ => {}
60        }
61    }
62
63    Ok(Annotation { name, args, span })
64}
65
66/// Parse a function parameter
67pub fn parse_function_param(pair: Pair<Rule>) -> Result<FunctionParameter> {
68    let mut pattern = None;
69    let mut is_const = false;
70    let mut is_reference = false;
71    let mut is_mut_reference = false;
72    let mut is_out = false;
73    let mut type_annotation = None;
74    let mut default_value = None;
75
76    for inner_pair in pair.into_inner() {
77        match inner_pair.as_rule() {
78            Rule::param_const_keyword => {
79                is_const = true;
80            }
81            Rule::param_ref_keyword => {
82                is_reference = true;
83                // Check for &mut: param_ref_keyword contains optional param_mut_keyword
84                for child in inner_pair.into_inner() {
85                    if child.as_rule() == Rule::param_mut_keyword {
86                        is_mut_reference = true;
87                    }
88                }
89            }
90            Rule::param_out_keyword => {
91                is_out = true;
92            }
93            Rule::destructure_pattern => {
94                pattern = Some(super::items::parse_pattern(inner_pair)?);
95            }
96            Rule::type_annotation => {
97                type_annotation = Some(parse_type_annotation(inner_pair)?);
98            }
99            Rule::expression => {
100                default_value = Some(expressions::parse_expression(inner_pair)?);
101            }
102            _ => {}
103        }
104    }
105
106    let pattern = pattern.ok_or_else(|| crate::error::ShapeError::ParseError {
107        message: "expected pattern in function parameter".to_string(),
108        location: None,
109    })?;
110
111    Ok(FunctionParameter {
112        pattern,
113        is_const,
114        is_reference,
115        is_mut_reference,
116        is_out,
117        type_annotation,
118        default_value,
119    })
120}
121
122/// Parse a function definition
123pub fn parse_function_def(pair: Pair<Rule>) -> Result<FunctionDef> {
124    let mut name = String::new();
125    let mut name_span = crate::ast::Span::DUMMY;
126    let mut type_params = None;
127    let mut params = vec![];
128    let mut return_type = None;
129    let mut where_clause = None;
130    let mut body = vec![];
131    let mut annotations = vec![];
132    let mut is_async = false;
133    let mut is_comptime = false;
134
135    // Parse all parts sequentially (can't use find() as it consumes the iterator)
136    for inner_pair in pair.into_inner() {
137        match inner_pair.as_rule() {
138            Rule::annotations => {
139                annotations = parse_annotations(inner_pair)?;
140            }
141            Rule::async_keyword => {
142                is_async = true;
143            }
144            Rule::comptime_keyword => {
145                is_comptime = true;
146            }
147            Rule::ident => {
148                if name.is_empty() {
149                    name = inner_pair.as_str().to_string();
150                    name_span = pair_span(&inner_pair);
151                }
152            }
153            Rule::type_params => {
154                type_params = Some(types::parse_type_params(inner_pair)?);
155            }
156            Rule::function_params => {
157                for param_pair in inner_pair.into_inner() {
158                    if param_pair.as_rule() == Rule::function_param {
159                        params.push(parse_function_param(param_pair)?);
160                    }
161                }
162            }
163            Rule::return_type => {
164                // Skip the "->" and get the type annotation
165                if let Some(type_pair) = inner_pair.into_inner().next() {
166                    return_type = Some(parse_type_annotation(type_pair)?);
167                }
168            }
169            Rule::where_clause => {
170                where_clause = Some(parse_where_clause(inner_pair)?);
171            }
172            Rule::function_body => {
173                // Parse all statements in the function body
174                body = statements::parse_statements(inner_pair.into_inner())?;
175            }
176            _ => {}
177        }
178    }
179
180    Ok(FunctionDef {
181        name,
182        name_span,
183        declaring_module_path: None,
184        doc_comment: None,
185        type_params,
186        params,
187        return_type,
188        where_clause,
189        body,
190        annotations,
191        is_async,
192        is_comptime,
193    })
194}
195
196/// Parse a declaration-only builtin function definition.
197///
198/// Grammar:
199/// `builtin fn name<T>(params...) -> ReturnType;`
200pub fn parse_builtin_function_decl(pair: Pair<Rule>) -> Result<BuiltinFunctionDecl> {
201    let mut name = String::new();
202    let mut name_span = crate::ast::Span::DUMMY;
203    let mut type_params = None;
204    let mut params = vec![];
205    let mut return_type = None;
206
207    for inner_pair in pair.into_inner() {
208        match inner_pair.as_rule() {
209            Rule::ident => {
210                if name.is_empty() {
211                    name = inner_pair.as_str().to_string();
212                    name_span = pair_span(&inner_pair);
213                }
214            }
215            Rule::type_params => {
216                type_params = Some(types::parse_type_params(inner_pair)?);
217            }
218            Rule::function_params => {
219                for param_pair in inner_pair.into_inner() {
220                    if param_pair.as_rule() == Rule::function_param {
221                        params.push(parse_function_param(param_pair)?);
222                    }
223                }
224            }
225            Rule::return_type => {
226                if let Some(type_pair) = inner_pair.into_inner().next() {
227                    return_type = Some(parse_type_annotation(type_pair)?);
228                }
229            }
230            _ => {}
231        }
232    }
233
234    let return_type = return_type.ok_or_else(|| crate::error::ShapeError::ParseError {
235        message: "builtin function declaration requires an explicit return type".to_string(),
236        location: None,
237    })?;
238
239    Ok(BuiltinFunctionDecl {
240        name,
241        name_span,
242        doc_comment: None,
243        type_params,
244        params,
245        return_type,
246    })
247}
248
249/// Parse a foreign function definition: `fn python analyze(data: DataTable) -> number { ... }`
250pub fn parse_foreign_function_def(pair: Pair<Rule>) -> Result<ForeignFunctionDef> {
251    let mut language = String::new();
252    let mut language_span = crate::ast::Span::DUMMY;
253    let mut name = String::new();
254    let mut name_span = crate::ast::Span::DUMMY;
255    let mut type_params = None;
256    let mut params = vec![];
257    let mut return_type = None;
258    let mut body_text = String::new();
259    let mut body_span = crate::ast::Span::DUMMY;
260    let mut annotations = vec![];
261    let mut is_async = false;
262
263    for inner_pair in pair.into_inner() {
264        match inner_pair.as_rule() {
265            Rule::annotations => {
266                annotations = parse_annotations(inner_pair)?;
267            }
268            Rule::async_keyword => {
269                is_async = true;
270            }
271            Rule::function_keyword => {}
272            Rule::foreign_language_id => {
273                language = inner_pair.as_str().to_string();
274                language_span = pair_span(&inner_pair);
275            }
276            Rule::ident => {
277                if name.is_empty() {
278                    name = inner_pair.as_str().to_string();
279                    name_span = pair_span(&inner_pair);
280                }
281            }
282            Rule::type_params => {
283                type_params = Some(types::parse_type_params(inner_pair)?);
284            }
285            Rule::function_params => {
286                for param_pair in inner_pair.into_inner() {
287                    if param_pair.as_rule() == Rule::function_param {
288                        params.push(parse_function_param(param_pair)?);
289                    }
290                }
291            }
292            Rule::return_type => {
293                if let Some(type_pair) = inner_pair.into_inner().next() {
294                    return_type = Some(parse_type_annotation(type_pair)?);
295                }
296            }
297            Rule::foreign_body => {
298                body_span = pair_span(&inner_pair);
299                body_text = dedent_foreign_body(inner_pair.as_str());
300            }
301            _ => {}
302        }
303    }
304
305    Ok(ForeignFunctionDef {
306        language,
307        language_span,
308        name,
309        name_span,
310        doc_comment: None,
311        type_params,
312        params,
313        return_type,
314        body_text,
315        body_span,
316        annotations,
317        is_async,
318        native_abi: None,
319    })
320}
321
322/// Parse a native ABI declaration:
323/// `extern "C" fn name(args...) -> Ret from "library" [as "symbol"];`
324pub fn parse_extern_native_function_def(pair: Pair<Rule>) -> Result<ForeignFunctionDef> {
325    let mut abi = String::new();
326    let mut abi_span = crate::ast::Span::DUMMY;
327    let mut name = String::new();
328    let mut name_span = crate::ast::Span::DUMMY;
329    let mut type_params = None;
330    let mut params = Vec::new();
331    let mut return_type = None;
332    let mut library: Option<String> = None;
333    let mut symbol: Option<String> = None;
334    let mut annotations = Vec::new();
335    let mut is_async = false;
336
337    for inner_pair in pair.into_inner() {
338        match inner_pair.as_rule() {
339            Rule::annotations => {
340                annotations = parse_annotations(inner_pair)?;
341            }
342            Rule::async_keyword => {
343                is_async = true;
344            }
345            Rule::extern_abi => {
346                abi_span = pair_span(&inner_pair);
347                abi = parse_extern_abi(inner_pair)?;
348            }
349            Rule::function_keyword => {}
350            Rule::ident => {
351                if name.is_empty() {
352                    name = inner_pair.as_str().to_string();
353                    name_span = pair_span(&inner_pair);
354                }
355            }
356            Rule::type_params => {
357                type_params = Some(types::parse_type_params(inner_pair)?);
358            }
359            Rule::function_params => {
360                for param_pair in inner_pair.into_inner() {
361                    if param_pair.as_rule() == Rule::function_param {
362                        params.push(parse_function_param(param_pair)?);
363                    }
364                }
365            }
366            Rule::return_type => {
367                if let Some(type_pair) = inner_pair.into_inner().next() {
368                    return_type = Some(parse_type_annotation(type_pair)?);
369                }
370            }
371            Rule::extern_native_link => {
372                for link_part in inner_pair.into_inner() {
373                    match link_part.as_rule() {
374                        Rule::extern_native_library => {
375                            library = Some(parse_string_literal(link_part.as_str())?);
376                        }
377                        Rule::extern_native_symbol => {
378                            symbol = Some(parse_string_literal(link_part.as_str())?);
379                        }
380                        _ => {}
381                    }
382                }
383            }
384            _ => {}
385        }
386    }
387
388    let library = library.ok_or_else(|| crate::error::ShapeError::ParseError {
389        message: "extern native declaration requires `from \"library\"`".to_string(),
390        location: None,
391    })?;
392
393    if abi.trim() != "C" {
394        return Err(crate::error::ShapeError::ParseError {
395            message: format!(
396                "unsupported extern ABI '{}': only \"C\" is currently supported",
397                abi
398            ),
399            location: None,
400        });
401    }
402
403    let symbol = symbol.unwrap_or_else(|| name.clone());
404
405    Ok(ForeignFunctionDef {
406        // Keep foreign-language compatibility for downstream compilation/runtime
407        // while carrying explicit native ABI metadata.
408        language: "native".to_string(),
409        language_span: abi_span,
410        name,
411        name_span,
412        doc_comment: None,
413        type_params,
414        params,
415        return_type,
416        body_text: String::new(),
417        body_span: crate::ast::Span::DUMMY,
418        annotations,
419        is_async,
420        native_abi: Some(NativeAbiBinding {
421            abi,
422            library,
423            symbol,
424            package_key: None,
425        }),
426    })
427}
428
429pub(crate) fn parse_extern_abi(pair: Pair<Rule>) -> Result<String> {
430    let inner = pair
431        .into_inner()
432        .next()
433        .ok_or_else(|| crate::error::ShapeError::ParseError {
434            message: "extern declaration is missing ABI name".to_string(),
435            location: None,
436        })?;
437
438    match inner.as_rule() {
439        Rule::string => parse_string_literal(inner.as_str()),
440        Rule::ident => Ok(inner.as_str().to_string()),
441        _ => Err(crate::error::ShapeError::ParseError {
442            message: format!("unsupported extern ABI token: {:?}", inner.as_rule()),
443            location: None,
444        }),
445    }
446}
447
448/// Strip common leading whitespace from foreign body text.
449///
450/// Similar to Python's `textwrap.dedent`. This is critical for Python blocks
451/// since the body is indented inside Shape code but needs to be dedented
452/// for the foreign language runtime.
453///
454/// Note: The Pest parser's implicit WHITESPACE rule consumes the newline and
455/// leading whitespace between `{` and the first token of `foreign_body`. This
456/// means the first line has its leading whitespace eaten by the parser, while
457/// subsequent lines retain their original indentation. We compute `min_indent`
458/// from lines after the first, then strip that amount only from those lines.
459/// The first line is kept as-is.
460fn dedent_foreign_body(text: &str) -> String {
461    let lines: Vec<&str> = text.lines().collect();
462    if lines.is_empty() {
463        return String::new();
464    }
465    if lines.len() == 1 {
466        return lines[0].trim_start().to_string();
467    }
468
469    // Compute min_indent from lines after the first, since the parser already
470    // consumed the first line's leading whitespace.
471    let min_indent = lines
472        .iter()
473        .skip(1)
474        .filter(|line| !line.trim().is_empty())
475        .map(|line| line.len() - line.trim_start().len())
476        .min()
477        .unwrap_or(0);
478
479    // First line: keep as-is (parser already stripped its whitespace).
480    // Subsequent lines: strip min_indent characters.
481    let mut result = Vec::with_capacity(lines.len());
482    result.push(lines[0]);
483    for line in &lines[1..] {
484        if line.len() >= min_indent {
485            result.push(&line[min_indent..]);
486        } else {
487            result.push(line.trim());
488        }
489    }
490    result.join("\n")
491}
492
493/// Parse a where clause: `where T: Bound1 + Bound2, U: Bound3`
494pub fn parse_where_clause(pair: Pair<Rule>) -> Result<Vec<crate::ast::types::WherePredicate>> {
495    let mut predicates = Vec::new();
496    for child in pair.into_inner() {
497        if child.as_rule() == Rule::where_predicate {
498            predicates.push(parse_where_predicate(child)?);
499        }
500    }
501    Ok(predicates)
502}
503
504fn parse_where_predicate(pair: Pair<Rule>) -> Result<crate::ast::types::WherePredicate> {
505    let mut inner = pair.into_inner();
506
507    let name_pair = inner
508        .next()
509        .ok_or_else(|| crate::error::ShapeError::ParseError {
510            message: "expected type parameter name in where predicate".to_string(),
511            location: None,
512        })?;
513    let type_name = name_pair.as_str().to_string();
514
515    let mut bounds = Vec::new();
516    for remaining in inner {
517        if remaining.as_rule() == Rule::trait_bound_list {
518            for bound_ident in remaining.into_inner() {
519                if bound_ident.as_rule() == Rule::qualified_ident {
520                    bounds.push(bound_ident.as_str().into());
521                }
522            }
523        }
524    }
525
526    Ok(crate::ast::types::WherePredicate { type_name, bounds })
527}