Skip to main content

shape_ast/parser/
functions.rs

1//! Function and annotation parsing for Shape
2//!
3//! This module handles parsing of:
4//! - Function definitions with parameters and return types
5//! - Function parameters with default values
6//! - Annotations (@warmup, @strategy, etc.)
7
8use crate::ast::{
9    Annotation, BuiltinFunctionDecl, ForeignFunctionDef, FunctionDef, FunctionParameter,
10    NativeAbiBinding,
11};
12use crate::error::Result;
13use pest::iterators::Pair;
14
15use super::expressions;
16use super::statements;
17use super::string_literals::parse_string_literal;
18use super::types;
19use super::types::parse_type_annotation;
20use super::{Rule, pair_span};
21
22/// Parse annotations
23pub fn parse_annotations(pair: Pair<Rule>) -> Result<Vec<Annotation>> {
24    let mut annotations = vec![];
25
26    for annotation_pair in pair.into_inner() {
27        if annotation_pair.as_rule() == Rule::annotation {
28            annotations.push(parse_annotation(annotation_pair)?);
29        }
30    }
31
32    Ok(annotations)
33}
34
35/// Parse a single annotation
36pub fn parse_annotation(pair: Pair<Rule>) -> Result<Annotation> {
37    let span = pair_span(&pair);
38    let mut name = String::new();
39    let mut args = Vec::new();
40
41    for inner_pair in pair.into_inner() {
42        match inner_pair.as_rule() {
43            Rule::annotation_name | Rule::ident => {
44                name = inner_pair.as_str().to_string();
45            }
46            Rule::annotation_args => {
47                for arg_pair in inner_pair.into_inner() {
48                    if arg_pair.as_rule() == Rule::expression {
49                        args.push(expressions::parse_expression(arg_pair)?);
50                    }
51                }
52            }
53            Rule::expression => {
54                args.push(expressions::parse_expression(inner_pair)?);
55            }
56            _ => {}
57        }
58    }
59
60    Ok(Annotation { name, args, span })
61}
62
63/// Parse a function parameter
64pub fn parse_function_param(pair: Pair<Rule>) -> Result<FunctionParameter> {
65    let mut pattern = None;
66    let mut is_const = false;
67    let mut is_reference = false;
68    let mut is_mut_reference = false;
69    let mut is_out = false;
70    let mut type_annotation = None;
71    let mut default_value = None;
72
73    for inner_pair in pair.into_inner() {
74        match inner_pair.as_rule() {
75            Rule::param_const_keyword => {
76                is_const = true;
77            }
78            Rule::param_ref_keyword => {
79                is_reference = true;
80                // Check for &mut: param_ref_keyword contains optional param_mut_keyword
81                for child in inner_pair.into_inner() {
82                    if child.as_rule() == Rule::param_mut_keyword {
83                        is_mut_reference = true;
84                    }
85                }
86            }
87            Rule::param_out_keyword => {
88                is_out = true;
89            }
90            Rule::destructure_pattern => {
91                pattern = Some(super::items::parse_pattern(inner_pair)?);
92            }
93            Rule::type_annotation => {
94                type_annotation = Some(parse_type_annotation(inner_pair)?);
95            }
96            Rule::expression => {
97                default_value = Some(expressions::parse_expression(inner_pair)?);
98            }
99            _ => {}
100        }
101    }
102
103    let pattern = pattern.ok_or_else(|| crate::error::ShapeError::ParseError {
104        message: "expected pattern in function parameter".to_string(),
105        location: None,
106    })?;
107
108    Ok(FunctionParameter {
109        pattern,
110        is_const,
111        is_reference,
112        is_mut_reference,
113        is_out,
114        type_annotation,
115        default_value,
116    })
117}
118
119/// Parse a function definition
120pub fn parse_function_def(pair: Pair<Rule>) -> Result<FunctionDef> {
121    let mut name = String::new();
122    let mut name_span = crate::ast::Span::DUMMY;
123    let mut type_params = None;
124    let mut params = vec![];
125    let mut return_type = None;
126    let mut where_clause = None;
127    let mut body = vec![];
128    let mut annotations = vec![];
129    let mut is_async = false;
130    let mut is_comptime = false;
131
132    // Parse all parts sequentially (can't use find() as it consumes the iterator)
133    for inner_pair in pair.into_inner() {
134        match inner_pair.as_rule() {
135            Rule::annotations => {
136                annotations = parse_annotations(inner_pair)?;
137            }
138            Rule::async_keyword => {
139                is_async = true;
140            }
141            Rule::comptime_keyword => {
142                is_comptime = true;
143            }
144            Rule::ident => {
145                if name.is_empty() {
146                    name = inner_pair.as_str().to_string();
147                    name_span = pair_span(&inner_pair);
148                }
149            }
150            Rule::type_params => {
151                type_params = Some(types::parse_type_params(inner_pair)?);
152            }
153            Rule::function_params => {
154                for param_pair in inner_pair.into_inner() {
155                    if param_pair.as_rule() == Rule::function_param {
156                        params.push(parse_function_param(param_pair)?);
157                    }
158                }
159            }
160            Rule::return_type => {
161                // Skip the "->" and get the type annotation
162                if let Some(type_pair) = inner_pair.into_inner().next() {
163                    return_type = Some(parse_type_annotation(type_pair)?);
164                }
165            }
166            Rule::where_clause => {
167                where_clause = Some(parse_where_clause(inner_pair)?);
168            }
169            Rule::function_body => {
170                // Parse all statements in the function body
171                body = statements::parse_statements(inner_pair.into_inner())?;
172            }
173            _ => {}
174        }
175    }
176
177    Ok(FunctionDef {
178        name,
179        name_span,
180        declaring_module_path: None,
181        doc_comment: None,
182        type_params,
183        params,
184        return_type,
185        where_clause,
186        body,
187        annotations,
188        is_async,
189        is_comptime,
190    })
191}
192
193/// Parse a declaration-only builtin function definition.
194///
195/// Grammar:
196/// `builtin fn name<T>(params...) -> ReturnType;`
197pub fn parse_builtin_function_decl(pair: Pair<Rule>) -> Result<BuiltinFunctionDecl> {
198    let mut name = String::new();
199    let mut name_span = crate::ast::Span::DUMMY;
200    let mut type_params = None;
201    let mut params = vec![];
202    let mut return_type = None;
203
204    for inner_pair in pair.into_inner() {
205        match inner_pair.as_rule() {
206            Rule::ident => {
207                if name.is_empty() {
208                    name = inner_pair.as_str().to_string();
209                    name_span = pair_span(&inner_pair);
210                }
211            }
212            Rule::type_params => {
213                type_params = Some(types::parse_type_params(inner_pair)?);
214            }
215            Rule::function_params => {
216                for param_pair in inner_pair.into_inner() {
217                    if param_pair.as_rule() == Rule::function_param {
218                        params.push(parse_function_param(param_pair)?);
219                    }
220                }
221            }
222            Rule::return_type => {
223                if let Some(type_pair) = inner_pair.into_inner().next() {
224                    return_type = Some(parse_type_annotation(type_pair)?);
225                }
226            }
227            _ => {}
228        }
229    }
230
231    let return_type = return_type.ok_or_else(|| crate::error::ShapeError::ParseError {
232        message: "builtin function declaration requires an explicit return type".to_string(),
233        location: None,
234    })?;
235
236    Ok(BuiltinFunctionDecl {
237        name,
238        name_span,
239        doc_comment: None,
240        type_params,
241        params,
242        return_type,
243    })
244}
245
246/// Parse a foreign function definition: `fn python analyze(data: DataTable) -> number { ... }`
247pub fn parse_foreign_function_def(pair: Pair<Rule>) -> Result<ForeignFunctionDef> {
248    let mut language = String::new();
249    let mut language_span = crate::ast::Span::DUMMY;
250    let mut name = String::new();
251    let mut name_span = crate::ast::Span::DUMMY;
252    let mut type_params = None;
253    let mut params = vec![];
254    let mut return_type = None;
255    let mut body_text = String::new();
256    let mut body_span = crate::ast::Span::DUMMY;
257    let mut annotations = vec![];
258    let mut is_async = false;
259
260    for inner_pair in pair.into_inner() {
261        match inner_pair.as_rule() {
262            Rule::annotations => {
263                annotations = parse_annotations(inner_pair)?;
264            }
265            Rule::async_keyword => {
266                is_async = true;
267            }
268            Rule::function_keyword => {}
269            Rule::foreign_language_id => {
270                language = inner_pair.as_str().to_string();
271                language_span = pair_span(&inner_pair);
272            }
273            Rule::ident => {
274                if name.is_empty() {
275                    name = inner_pair.as_str().to_string();
276                    name_span = pair_span(&inner_pair);
277                }
278            }
279            Rule::type_params => {
280                type_params = Some(types::parse_type_params(inner_pair)?);
281            }
282            Rule::function_params => {
283                for param_pair in inner_pair.into_inner() {
284                    if param_pair.as_rule() == Rule::function_param {
285                        params.push(parse_function_param(param_pair)?);
286                    }
287                }
288            }
289            Rule::return_type => {
290                if let Some(type_pair) = inner_pair.into_inner().next() {
291                    return_type = Some(parse_type_annotation(type_pair)?);
292                }
293            }
294            Rule::foreign_body => {
295                body_span = pair_span(&inner_pair);
296                body_text = dedent_foreign_body(inner_pair.as_str());
297            }
298            _ => {}
299        }
300    }
301
302    Ok(ForeignFunctionDef {
303        language,
304        language_span,
305        name,
306        name_span,
307        doc_comment: None,
308        type_params,
309        params,
310        return_type,
311        body_text,
312        body_span,
313        annotations,
314        is_async,
315        native_abi: None,
316    })
317}
318
319/// Parse a native ABI declaration:
320/// `extern "C" fn name(args...) -> Ret from "library" [as "symbol"];`
321pub fn parse_extern_native_function_def(pair: Pair<Rule>) -> Result<ForeignFunctionDef> {
322    let mut abi = String::new();
323    let mut abi_span = crate::ast::Span::DUMMY;
324    let mut name = String::new();
325    let mut name_span = crate::ast::Span::DUMMY;
326    let mut type_params = None;
327    let mut params = Vec::new();
328    let mut return_type = None;
329    let mut library: Option<String> = None;
330    let mut symbol: Option<String> = None;
331    let mut annotations = Vec::new();
332    let mut is_async = false;
333
334    for inner_pair in pair.into_inner() {
335        match inner_pair.as_rule() {
336            Rule::annotations => {
337                annotations = parse_annotations(inner_pair)?;
338            }
339            Rule::async_keyword => {
340                is_async = true;
341            }
342            Rule::extern_abi => {
343                abi_span = pair_span(&inner_pair);
344                abi = parse_extern_abi(inner_pair)?;
345            }
346            Rule::function_keyword => {}
347            Rule::ident => {
348                if name.is_empty() {
349                    name = inner_pair.as_str().to_string();
350                    name_span = pair_span(&inner_pair);
351                }
352            }
353            Rule::type_params => {
354                type_params = Some(types::parse_type_params(inner_pair)?);
355            }
356            Rule::function_params => {
357                for param_pair in inner_pair.into_inner() {
358                    if param_pair.as_rule() == Rule::function_param {
359                        params.push(parse_function_param(param_pair)?);
360                    }
361                }
362            }
363            Rule::return_type => {
364                if let Some(type_pair) = inner_pair.into_inner().next() {
365                    return_type = Some(parse_type_annotation(type_pair)?);
366                }
367            }
368            Rule::extern_native_link => {
369                for link_part in inner_pair.into_inner() {
370                    match link_part.as_rule() {
371                        Rule::extern_native_library => {
372                            library = Some(parse_string_literal(link_part.as_str())?);
373                        }
374                        Rule::extern_native_symbol => {
375                            symbol = Some(parse_string_literal(link_part.as_str())?);
376                        }
377                        _ => {}
378                    }
379                }
380            }
381            _ => {}
382        }
383    }
384
385    let library = library.ok_or_else(|| crate::error::ShapeError::ParseError {
386        message: "extern native declaration requires `from \"library\"`".to_string(),
387        location: None,
388    })?;
389
390    if abi.trim() != "C" {
391        return Err(crate::error::ShapeError::ParseError {
392            message: format!(
393                "unsupported extern ABI '{}': only \"C\" is currently supported",
394                abi
395            ),
396            location: None,
397        });
398    }
399
400    let symbol = symbol.unwrap_or_else(|| name.clone());
401
402    Ok(ForeignFunctionDef {
403        // Keep foreign-language compatibility for downstream compilation/runtime
404        // while carrying explicit native ABI metadata.
405        language: "native".to_string(),
406        language_span: abi_span,
407        name,
408        name_span,
409        doc_comment: None,
410        type_params,
411        params,
412        return_type,
413        body_text: String::new(),
414        body_span: crate::ast::Span::DUMMY,
415        annotations,
416        is_async,
417        native_abi: Some(NativeAbiBinding {
418            abi,
419            library,
420            symbol,
421            package_key: None,
422        }),
423    })
424}
425
426pub(crate) fn parse_extern_abi(pair: Pair<Rule>) -> Result<String> {
427    let inner = pair
428        .into_inner()
429        .next()
430        .ok_or_else(|| crate::error::ShapeError::ParseError {
431            message: "extern declaration is missing ABI name".to_string(),
432            location: None,
433        })?;
434
435    match inner.as_rule() {
436        Rule::string => parse_string_literal(inner.as_str()),
437        Rule::ident => Ok(inner.as_str().to_string()),
438        _ => Err(crate::error::ShapeError::ParseError {
439            message: format!("unsupported extern ABI token: {:?}", inner.as_rule()),
440            location: None,
441        }),
442    }
443}
444
445/// Strip common leading whitespace from foreign body text.
446///
447/// Similar to Python's `textwrap.dedent`. This is critical for Python blocks
448/// since the body is indented inside Shape code but needs to be dedented
449/// for the foreign language runtime.
450///
451/// Note: The Pest parser's implicit WHITESPACE rule consumes the newline and
452/// leading whitespace between `{` and the first token of `foreign_body`. This
453/// means the first line has its leading whitespace eaten by the parser, while
454/// subsequent lines retain their original indentation. We compute `min_indent`
455/// from lines after the first, then strip that amount only from those lines.
456/// The first line is kept as-is.
457fn dedent_foreign_body(text: &str) -> String {
458    let lines: Vec<&str> = text.lines().collect();
459    if lines.is_empty() {
460        return String::new();
461    }
462    if lines.len() == 1 {
463        return lines[0].trim_start().to_string();
464    }
465
466    // Compute min_indent from lines after the first, since the parser already
467    // consumed the first line's leading whitespace.
468    let min_indent = lines
469        .iter()
470        .skip(1)
471        .filter(|line| !line.trim().is_empty())
472        .map(|line| line.len() - line.trim_start().len())
473        .min()
474        .unwrap_or(0);
475
476    // First line: keep as-is (parser already stripped its whitespace).
477    // Subsequent lines: strip min_indent characters.
478    let mut result = Vec::with_capacity(lines.len());
479    result.push(lines[0]);
480    for line in &lines[1..] {
481        if line.len() >= min_indent {
482            result.push(&line[min_indent..]);
483        } else {
484            result.push(line.trim());
485        }
486    }
487    result.join("\n")
488}
489
490/// Parse a where clause: `where T: Bound1 + Bound2, U: Bound3`
491pub fn parse_where_clause(pair: Pair<Rule>) -> Result<Vec<crate::ast::types::WherePredicate>> {
492    let mut predicates = Vec::new();
493    for child in pair.into_inner() {
494        if child.as_rule() == Rule::where_predicate {
495            predicates.push(parse_where_predicate(child)?);
496        }
497    }
498    Ok(predicates)
499}
500
501fn parse_where_predicate(pair: Pair<Rule>) -> Result<crate::ast::types::WherePredicate> {
502    let mut inner = pair.into_inner();
503
504    let name_pair = inner
505        .next()
506        .ok_or_else(|| crate::error::ShapeError::ParseError {
507            message: "expected type parameter name in where predicate".to_string(),
508            location: None,
509        })?;
510    let type_name = name_pair.as_str().to_string();
511
512    let mut bounds = Vec::new();
513    for remaining in inner {
514        if remaining.as_rule() == Rule::trait_bound_list {
515            for bound_ident in remaining.into_inner() {
516                if bound_ident.as_rule() == Rule::ident {
517                    bounds.push(bound_ident.as_str().to_string());
518                }
519            }
520        }
521    }
522
523    Ok(crate::ast::types::WherePredicate { type_name, bounds })
524}