Skip to main content

shape_ast/parser/
functions.rs

1//! Function and annotation parsing for Shape
2//!
3//! This module handles parsing of:
4//! - Function definitions with parameters and return types
5//! - Function parameters with default values
6//! - Annotations (@warmup, @strategy, etc.)
7
8use crate::ast::{
9    Annotation, BuiltinFunctionDecl, ForeignFunctionDef, FunctionDef, FunctionParameter,
10    NativeAbiBinding,
11};
12use crate::error::Result;
13use pest::iterators::Pair;
14
15use super::expressions;
16use super::statements;
17use super::string_literals::parse_string_literal;
18use super::types;
19use super::types::parse_type_annotation;
20use super::{Rule, pair_span};
21
22/// Parse annotations
23pub fn parse_annotations(pair: Pair<Rule>) -> Result<Vec<Annotation>> {
24    let mut annotations = vec![];
25
26    for annotation_pair in pair.into_inner() {
27        if annotation_pair.as_rule() == Rule::annotation {
28            annotations.push(parse_annotation(annotation_pair)?);
29        }
30    }
31
32    Ok(annotations)
33}
34
35/// Parse a single annotation
36pub fn parse_annotation(pair: Pair<Rule>) -> Result<Annotation> {
37    let span = pair_span(&pair);
38    let mut name = String::new();
39    let mut args = Vec::new();
40
41    for inner_pair in pair.into_inner() {
42        match inner_pair.as_rule() {
43            Rule::annotation_name | Rule::ident => {
44                name = inner_pair.as_str().to_string();
45            }
46            Rule::annotation_args => {
47                for arg_pair in inner_pair.into_inner() {
48                    if arg_pair.as_rule() == Rule::expression {
49                        args.push(expressions::parse_expression(arg_pair)?);
50                    }
51                }
52            }
53            Rule::expression => {
54                args.push(expressions::parse_expression(inner_pair)?);
55            }
56            _ => {}
57        }
58    }
59
60    Ok(Annotation { name, args, span })
61}
62
63/// Parse a function parameter
64pub fn parse_function_param(pair: Pair<Rule>) -> Result<FunctionParameter> {
65    let mut pattern = None;
66    let mut is_const = false;
67    let mut is_reference = false;
68    let mut is_mut_reference = false;
69    let mut is_out = false;
70    let mut type_annotation = None;
71    let mut default_value = None;
72
73    for inner_pair in pair.into_inner() {
74        match inner_pair.as_rule() {
75            Rule::param_const_keyword => {
76                is_const = true;
77            }
78            Rule::param_ref_keyword => {
79                is_reference = true;
80                // Check for &mut: param_ref_keyword contains optional param_mut_keyword
81                for child in inner_pair.into_inner() {
82                    if child.as_rule() == Rule::param_mut_keyword {
83                        is_mut_reference = true;
84                    }
85                }
86            }
87            Rule::param_out_keyword => {
88                is_out = true;
89            }
90            Rule::destructure_pattern => {
91                pattern = Some(super::items::parse_pattern(inner_pair)?);
92            }
93            Rule::type_annotation => {
94                type_annotation = Some(parse_type_annotation(inner_pair)?);
95            }
96            Rule::expression => {
97                default_value = Some(expressions::parse_expression(inner_pair)?);
98            }
99            _ => {}
100        }
101    }
102
103    let pattern = pattern.ok_or_else(|| crate::error::ShapeError::ParseError {
104        message: "expected pattern in function parameter".to_string(),
105        location: None,
106    })?;
107
108    Ok(FunctionParameter {
109        pattern,
110        is_const,
111        is_reference,
112        is_mut_reference,
113        is_out,
114        type_annotation,
115        default_value,
116    })
117}
118
119/// Parse a function definition
120pub fn parse_function_def(pair: Pair<Rule>) -> Result<FunctionDef> {
121    let mut name = String::new();
122    let mut name_span = crate::ast::Span::DUMMY;
123    let mut type_params = None;
124    let mut params = vec![];
125    let mut return_type = None;
126    let mut where_clause = None;
127    let mut body = vec![];
128    let mut annotations = vec![];
129    let mut is_async = false;
130    let mut is_comptime = false;
131
132    // Parse all parts sequentially (can't use find() as it consumes the iterator)
133    for inner_pair in pair.into_inner() {
134        match inner_pair.as_rule() {
135            Rule::annotations => {
136                annotations = parse_annotations(inner_pair)?;
137            }
138            Rule::async_keyword => {
139                is_async = true;
140            }
141            Rule::comptime_keyword => {
142                is_comptime = true;
143            }
144            Rule::ident => {
145                if name.is_empty() {
146                    name = inner_pair.as_str().to_string();
147                    name_span = pair_span(&inner_pair);
148                }
149            }
150            Rule::type_params => {
151                type_params = Some(types::parse_type_params(inner_pair)?);
152            }
153            Rule::function_params => {
154                for param_pair in inner_pair.into_inner() {
155                    if param_pair.as_rule() == Rule::function_param {
156                        params.push(parse_function_param(param_pair)?);
157                    }
158                }
159            }
160            Rule::return_type => {
161                // Skip the "->" and get the type annotation
162                if let Some(type_pair) = inner_pair.into_inner().next() {
163                    return_type = Some(parse_type_annotation(type_pair)?);
164                }
165            }
166            Rule::where_clause => {
167                where_clause = Some(parse_where_clause(inner_pair)?);
168            }
169            Rule::function_body => {
170                // Parse all statements in the function body
171                body = statements::parse_statements(inner_pair.into_inner())?;
172            }
173            _ => {}
174        }
175    }
176
177    Ok(FunctionDef {
178        name,
179        name_span,
180        doc_comment: None,
181        type_params,
182        params,
183        return_type,
184        where_clause,
185        body,
186        annotations,
187        is_async,
188        is_comptime,
189    })
190}
191
192/// Parse a declaration-only builtin function definition.
193///
194/// Grammar:
195/// `builtin fn name<T>(params...) -> ReturnType;`
196pub fn parse_builtin_function_decl(pair: Pair<Rule>) -> Result<BuiltinFunctionDecl> {
197    let mut name = String::new();
198    let mut name_span = crate::ast::Span::DUMMY;
199    let mut type_params = None;
200    let mut params = vec![];
201    let mut return_type = None;
202
203    for inner_pair in pair.into_inner() {
204        match inner_pair.as_rule() {
205            Rule::ident => {
206                if name.is_empty() {
207                    name = inner_pair.as_str().to_string();
208                    name_span = pair_span(&inner_pair);
209                }
210            }
211            Rule::type_params => {
212                type_params = Some(types::parse_type_params(inner_pair)?);
213            }
214            Rule::function_params => {
215                for param_pair in inner_pair.into_inner() {
216                    if param_pair.as_rule() == Rule::function_param {
217                        params.push(parse_function_param(param_pair)?);
218                    }
219                }
220            }
221            Rule::return_type => {
222                if let Some(type_pair) = inner_pair.into_inner().next() {
223                    return_type = Some(parse_type_annotation(type_pair)?);
224                }
225            }
226            _ => {}
227        }
228    }
229
230    let return_type = return_type.ok_or_else(|| crate::error::ShapeError::ParseError {
231        message: "builtin function declaration requires an explicit return type".to_string(),
232        location: None,
233    })?;
234
235    Ok(BuiltinFunctionDecl {
236        name,
237        name_span,
238        doc_comment: None,
239        type_params,
240        params,
241        return_type,
242    })
243}
244
245/// Parse a foreign function definition: `fn python analyze(data: DataTable) -> number { ... }`
246pub fn parse_foreign_function_def(pair: Pair<Rule>) -> Result<ForeignFunctionDef> {
247    let mut language = String::new();
248    let mut language_span = crate::ast::Span::DUMMY;
249    let mut name = String::new();
250    let mut name_span = crate::ast::Span::DUMMY;
251    let mut type_params = None;
252    let mut params = vec![];
253    let mut return_type = None;
254    let mut body_text = String::new();
255    let mut body_span = crate::ast::Span::DUMMY;
256    let mut annotations = vec![];
257    let mut is_async = false;
258
259    for inner_pair in pair.into_inner() {
260        match inner_pair.as_rule() {
261            Rule::annotations => {
262                annotations = parse_annotations(inner_pair)?;
263            }
264            Rule::async_keyword => {
265                is_async = true;
266            }
267            Rule::function_keyword => {}
268            Rule::foreign_language_id => {
269                language = inner_pair.as_str().to_string();
270                language_span = pair_span(&inner_pair);
271            }
272            Rule::ident => {
273                if name.is_empty() {
274                    name = inner_pair.as_str().to_string();
275                    name_span = pair_span(&inner_pair);
276                }
277            }
278            Rule::type_params => {
279                type_params = Some(types::parse_type_params(inner_pair)?);
280            }
281            Rule::function_params => {
282                for param_pair in inner_pair.into_inner() {
283                    if param_pair.as_rule() == Rule::function_param {
284                        params.push(parse_function_param(param_pair)?);
285                    }
286                }
287            }
288            Rule::return_type => {
289                if let Some(type_pair) = inner_pair.into_inner().next() {
290                    return_type = Some(parse_type_annotation(type_pair)?);
291                }
292            }
293            Rule::foreign_body => {
294                body_span = pair_span(&inner_pair);
295                body_text = dedent_foreign_body(inner_pair.as_str());
296            }
297            _ => {}
298        }
299    }
300
301    Ok(ForeignFunctionDef {
302        language,
303        language_span,
304        name,
305        name_span,
306        doc_comment: None,
307        type_params,
308        params,
309        return_type,
310        body_text,
311        body_span,
312        annotations,
313        is_async,
314        native_abi: None,
315    })
316}
317
318/// Parse a native ABI declaration:
319/// `extern "C" fn name(args...) -> Ret from "library" [as "symbol"];`
320pub fn parse_extern_native_function_def(pair: Pair<Rule>) -> Result<ForeignFunctionDef> {
321    let mut abi = String::new();
322    let mut abi_span = crate::ast::Span::DUMMY;
323    let mut name = String::new();
324    let mut name_span = crate::ast::Span::DUMMY;
325    let mut type_params = None;
326    let mut params = Vec::new();
327    let mut return_type = None;
328    let mut library: Option<String> = None;
329    let mut symbol: Option<String> = None;
330    let mut annotations = Vec::new();
331    let mut is_async = false;
332
333    for inner_pair in pair.into_inner() {
334        match inner_pair.as_rule() {
335            Rule::annotations => {
336                annotations = parse_annotations(inner_pair)?;
337            }
338            Rule::async_keyword => {
339                is_async = true;
340            }
341            Rule::extern_abi => {
342                abi_span = pair_span(&inner_pair);
343                abi = parse_extern_abi(inner_pair)?;
344            }
345            Rule::function_keyword => {}
346            Rule::ident => {
347                if name.is_empty() {
348                    name = inner_pair.as_str().to_string();
349                    name_span = pair_span(&inner_pair);
350                }
351            }
352            Rule::type_params => {
353                type_params = Some(types::parse_type_params(inner_pair)?);
354            }
355            Rule::function_params => {
356                for param_pair in inner_pair.into_inner() {
357                    if param_pair.as_rule() == Rule::function_param {
358                        params.push(parse_function_param(param_pair)?);
359                    }
360                }
361            }
362            Rule::return_type => {
363                if let Some(type_pair) = inner_pair.into_inner().next() {
364                    return_type = Some(parse_type_annotation(type_pair)?);
365                }
366            }
367            Rule::extern_native_link => {
368                for link_part in inner_pair.into_inner() {
369                    match link_part.as_rule() {
370                        Rule::extern_native_library => {
371                            library = Some(parse_string_literal(link_part.as_str())?);
372                        }
373                        Rule::extern_native_symbol => {
374                            symbol = Some(parse_string_literal(link_part.as_str())?);
375                        }
376                        _ => {}
377                    }
378                }
379            }
380            _ => {}
381        }
382    }
383
384    let library = library.ok_or_else(|| crate::error::ShapeError::ParseError {
385        message: "extern native declaration requires `from \"library\"`".to_string(),
386        location: None,
387    })?;
388
389    if abi.trim() != "C" {
390        return Err(crate::error::ShapeError::ParseError {
391            message: format!(
392                "unsupported extern ABI '{}': only \"C\" is currently supported",
393                abi
394            ),
395            location: None,
396        });
397    }
398
399    let symbol = symbol.unwrap_or_else(|| name.clone());
400
401    Ok(ForeignFunctionDef {
402        // Keep foreign-language compatibility for downstream compilation/runtime
403        // while carrying explicit native ABI metadata.
404        language: "native".to_string(),
405        language_span: abi_span,
406        name,
407        name_span,
408        doc_comment: None,
409        type_params,
410        params,
411        return_type,
412        body_text: String::new(),
413        body_span: crate::ast::Span::DUMMY,
414        annotations,
415        is_async,
416        native_abi: Some(NativeAbiBinding {
417            abi,
418            library,
419            symbol,
420            package_key: None,
421        }),
422    })
423}
424
425pub(crate) fn parse_extern_abi(pair: Pair<Rule>) -> Result<String> {
426    let inner = pair
427        .into_inner()
428        .next()
429        .ok_or_else(|| crate::error::ShapeError::ParseError {
430            message: "extern declaration is missing ABI name".to_string(),
431            location: None,
432        })?;
433
434    match inner.as_rule() {
435        Rule::string => parse_string_literal(inner.as_str()),
436        Rule::ident => Ok(inner.as_str().to_string()),
437        _ => Err(crate::error::ShapeError::ParseError {
438            message: format!("unsupported extern ABI token: {:?}", inner.as_rule()),
439            location: None,
440        }),
441    }
442}
443
444/// Strip common leading whitespace from foreign body text.
445///
446/// Similar to Python's `textwrap.dedent`. This is critical for Python blocks
447/// since the body is indented inside Shape code but needs to be dedented
448/// for the foreign language runtime.
449///
450/// Note: The Pest parser's implicit WHITESPACE rule consumes the newline and
451/// leading whitespace between `{` and the first token of `foreign_body`. This
452/// means the first line has its leading whitespace eaten by the parser, while
453/// subsequent lines retain their original indentation. We compute `min_indent`
454/// from lines after the first, then strip that amount only from those lines.
455/// The first line is kept as-is.
456fn dedent_foreign_body(text: &str) -> String {
457    let lines: Vec<&str> = text.lines().collect();
458    if lines.is_empty() {
459        return String::new();
460    }
461    if lines.len() == 1 {
462        return lines[0].trim_start().to_string();
463    }
464
465    // Compute min_indent from lines after the first, since the parser already
466    // consumed the first line's leading whitespace.
467    let min_indent = lines
468        .iter()
469        .skip(1)
470        .filter(|line| !line.trim().is_empty())
471        .map(|line| line.len() - line.trim_start().len())
472        .min()
473        .unwrap_or(0);
474
475    // First line: keep as-is (parser already stripped its whitespace).
476    // Subsequent lines: strip min_indent characters.
477    let mut result = Vec::with_capacity(lines.len());
478    result.push(lines[0]);
479    for line in &lines[1..] {
480        if line.len() >= min_indent {
481            result.push(&line[min_indent..]);
482        } else {
483            result.push(line.trim());
484        }
485    }
486    result.join("\n")
487}
488
489/// Parse a where clause: `where T: Bound1 + Bound2, U: Bound3`
490pub fn parse_where_clause(pair: Pair<Rule>) -> Result<Vec<crate::ast::types::WherePredicate>> {
491    let mut predicates = Vec::new();
492    for child in pair.into_inner() {
493        if child.as_rule() == Rule::where_predicate {
494            predicates.push(parse_where_predicate(child)?);
495        }
496    }
497    Ok(predicates)
498}
499
500fn parse_where_predicate(pair: Pair<Rule>) -> Result<crate::ast::types::WherePredicate> {
501    let mut inner = pair.into_inner();
502
503    let name_pair = inner
504        .next()
505        .ok_or_else(|| crate::error::ShapeError::ParseError {
506            message: "expected type parameter name in where predicate".to_string(),
507            location: None,
508        })?;
509    let type_name = name_pair.as_str().to_string();
510
511    let mut bounds = Vec::new();
512    for remaining in inner {
513        if remaining.as_rule() == Rule::trait_bound_list {
514            for bound_ident in remaining.into_inner() {
515                if bound_ident.as_rule() == Rule::ident {
516                    bounds.push(bound_ident.as_str().to_string());
517                }
518            }
519        }
520    }
521
522    Ok(crate::ast::types::WherePredicate { type_name, bounds })
523}