syn_solidity/item/
function.rs

1use crate::{
2    Block, FunctionAttribute, FunctionAttributes, Mutability, ParameterList, Parameters, SolIdent,
3    Spanned, Stmt, Type, VariableDeclaration, VariableDefinition, Visibility, kw,
4};
5use proc_macro2::Span;
6use std::{
7    fmt,
8    hash::{Hash, Hasher},
9    num::NonZeroU16,
10};
11use syn::{
12    Attribute, Error, Result, Token, parenthesized,
13    parse::{Parse, ParseStream},
14    token::{Brace, Paren},
15};
16
17/// A function, constructor, fallback, receive, or modifier definition:
18/// `function helloWorld() external pure returns(string memory);`.
19///
20/// Solidity reference:
21/// <https://docs.soliditylang.org/en/latest/grammar.html#a4.SolidityParser.functionDefinition>
22#[derive(Clone)]
23pub struct ItemFunction {
24    /// The `syn` attributes of the function.
25    pub attrs: Vec<Attribute>,
26    pub kind: FunctionKind,
27    pub name: Option<SolIdent>,
28    /// Parens are optional for modifiers:
29    /// <https://docs.soliditylang.org/en/latest/grammar.html#a4.SolidityParser.modifierDefinition>
30    pub paren_token: Option<Paren>,
31    pub parameters: ParameterList,
32    /// The Solidity attributes of the function.
33    pub attributes: FunctionAttributes,
34    /// The optional return types of the function.
35    pub returns: Option<Returns>,
36    pub body: FunctionBody,
37}
38
39impl fmt::Display for ItemFunction {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        f.write_str(self.kind.as_str())?;
42        if let Some(name) = &self.name {
43            f.write_str(" ")?;
44            name.fmt(f)?;
45        }
46        write!(f, "({})", self.parameters)?;
47
48        if !self.attributes.is_empty() {
49            write!(f, " {}", self.attributes)?;
50        }
51
52        if let Some(returns) = &self.returns {
53            write!(f, " {returns}")?;
54        }
55
56        if !self.body.is_empty() {
57            f.write_str(" ")?;
58        }
59        f.write_str(self.body.as_str())
60    }
61}
62
63impl fmt::Debug for ItemFunction {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        f.debug_struct("ItemFunction")
66            .field("attrs", &self.attrs)
67            .field("kind", &self.kind)
68            .field("name", &self.name)
69            .field("arguments", &self.parameters)
70            .field("attributes", &self.attributes)
71            .field("returns", &self.returns)
72            .field("body", &self.body)
73            .finish()
74    }
75}
76
77impl Parse for ItemFunction {
78    fn parse(input: ParseStream<'_>) -> Result<Self> {
79        let attrs = input.call(Attribute::parse_outer)?;
80        let kind: FunctionKind = input.parse()?;
81        let name = input.call(SolIdent::parse_opt)?;
82
83        let (paren_token, parameters) = if kind.is_modifier() && !input.peek(Paren) {
84            (None, ParameterList::new())
85        } else {
86            let content;
87            (Some(parenthesized!(content in input)), content.parse()?)
88        };
89
90        let attributes = input.parse()?;
91        let returns = input.call(Returns::parse_opt)?;
92        let body = input.parse()?;
93
94        Ok(Self { attrs, kind, name, paren_token, parameters, attributes, returns, body })
95    }
96}
97
98impl Spanned for ItemFunction {
99    fn span(&self) -> Span {
100        if let Some(name) = &self.name { name.span() } else { self.kind.span() }
101    }
102
103    fn set_span(&mut self, span: Span) {
104        self.kind.set_span(span);
105        if let Some(name) = &mut self.name {
106            name.set_span(span);
107        }
108    }
109}
110
111impl ItemFunction {
112    /// Create a new function of the given kind.
113    pub fn new(kind: FunctionKind, name: Option<SolIdent>) -> Self {
114        let span = name.as_ref().map_or_else(|| kind.span(), |name| name.span());
115        Self {
116            attrs: Vec::new(),
117            kind,
118            name,
119            paren_token: Some(Paren(span)),
120            parameters: Parameters::new(),
121            attributes: FunctionAttributes::new(),
122            returns: None,
123            body: FunctionBody::Empty(Token![;](span)),
124        }
125    }
126
127    /// Create a new function with the given name and arguments.
128    ///
129    /// Note that:
130    /// - the type is not validated
131    /// - structs/array of structs in return position are not expanded
132    /// - the body is not set
133    ///
134    /// The attributes are set to `public view`.
135    ///
136    /// See [the Solidity documentation][ref] for more details on how getters
137    /// are generated.
138    ///
139    /// [ref]: https://docs.soliditylang.org/en/latest/contracts.html#getter-functions
140    pub fn new_getter(name: SolIdent, ty: Type) -> Self {
141        let span = name.span();
142        let kind = FunctionKind::new_function(span);
143        let mut function = Self::new(kind, Some(name.clone()));
144
145        // `public view`
146        function.attributes.0 = vec![
147            FunctionAttribute::Visibility(Visibility::new_public(span)),
148            FunctionAttribute::Mutability(Mutability::new_view(span)),
149        ];
150
151        // Recurse into mappings and arrays to generate arguments and the return type.
152        // If the return type is simple, the return value name is set to the variable name.
153        let mut ty = ty;
154        let mut return_name = None;
155        let mut first = true;
156        loop {
157            match ty {
158                // mapping(k => v) -> arguments += k, ty = v
159                Type::Mapping(map) => {
160                    let key = VariableDeclaration::new_with(*map.key, None, map.key_name);
161                    function.parameters.push(key);
162                    return_name = map.value_name;
163                    ty = *map.value;
164                }
165                // inner[] -> arguments += uint256, ty = inner
166                Type::Array(array) => {
167                    let uint256 = Type::Uint(span, NonZeroU16::new(256));
168                    function.parameters.push(VariableDeclaration::new(uint256));
169                    ty = *array.ty;
170                }
171                _ => {
172                    if first {
173                        return_name = Some(name);
174                    }
175                    break;
176                }
177            }
178            first = false;
179        }
180        let mut returns = ParameterList::new();
181        returns.push(VariableDeclaration::new_with(ty, None, return_name));
182        function.returns = Some(Returns::new(span, returns));
183
184        function
185    }
186
187    /// Creates a new function from a variable definition.
188    ///
189    /// The function will have the same name and the variable type's will be the
190    /// return type. The variable attributes are ignored, and instead will
191    /// always generate `public returns`.
192    ///
193    /// See [`new_getter`](Self::new_getter) for more details.
194    pub fn from_variable_definition(var: VariableDefinition) -> Self {
195        let mut function = Self::new_getter(var.name, var.ty);
196        function.attrs = var.attrs;
197        function
198    }
199
200    /// Returns the name of the function.
201    ///
202    /// # Panics
203    ///
204    /// Panics if the function has no name. This is the case when `kind` is not
205    /// `Function`.
206    #[track_caller]
207    pub fn name(&self) -> &SolIdent {
208        match &self.name {
209            Some(name) => name,
210            None => panic!("function has no name: {self:?}"),
211        }
212    }
213
214    /// Returns true if the function returns nothing.
215    pub fn is_void(&self) -> bool {
216        match &self.returns {
217            None => true,
218            Some(returns) => returns.returns.is_empty(),
219        }
220    }
221
222    /// Returns true if the function has a body.
223    pub fn has_implementation(&self) -> bool {
224        matches!(self.body, FunctionBody::Block(_))
225    }
226
227    /// Returns the function's arguments tuple type.
228    pub fn call_type(&self) -> Type {
229        Type::Tuple(self.parameters.types().cloned().collect())
230    }
231
232    /// Returns the function's return tuple type.
233    pub fn return_type(&self) -> Option<Type> {
234        self.returns.as_ref().map(|returns| Type::Tuple(returns.returns.types().cloned().collect()))
235    }
236
237    /// Returns a reference to the function's body, if any.
238    pub fn body(&self) -> Option<&[Stmt]> {
239        match &self.body {
240            FunctionBody::Block(block) => Some(&block.stmts),
241            _ => None,
242        }
243    }
244
245    /// Returns a mutable reference to the function's body, if any.
246    pub fn body_mut(&mut self) -> Option<&mut Vec<Stmt>> {
247        match &mut self.body {
248            FunctionBody::Block(block) => Some(&mut block.stmts),
249            _ => None,
250        }
251    }
252
253    #[allow(clippy::result_large_err)]
254    pub fn into_body(self) -> std::result::Result<Vec<Stmt>, Self> {
255        match self.body {
256            FunctionBody::Block(block) => Ok(block.stmts),
257            _ => Err(self),
258        }
259    }
260}
261
262kw_enum! {
263    /// The kind of function.
264    pub enum FunctionKind {
265        Constructor(kw::constructor),
266        Function(kw::function),
267        Fallback(kw::fallback),
268        Receive(kw::receive),
269        Modifier(kw::modifier),
270    }
271}
272
273/// The `returns` attribute of a function.
274#[derive(Clone)]
275pub struct Returns {
276    pub returns_token: kw::returns,
277    pub paren_token: Paren,
278    /// The returns of the function. This cannot be parsed empty.
279    pub returns: ParameterList,
280}
281
282impl PartialEq for Returns {
283    fn eq(&self, other: &Self) -> bool {
284        self.returns == other.returns
285    }
286}
287
288impl Eq for Returns {}
289
290impl Hash for Returns {
291    fn hash<H: Hasher>(&self, state: &mut H) {
292        self.returns.hash(state);
293    }
294}
295
296impl fmt::Display for Returns {
297    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298        f.write_str("returns (")?;
299        self.returns.fmt(f)?;
300        f.write_str(")")
301    }
302}
303
304impl fmt::Debug for Returns {
305    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
306        f.debug_tuple("Returns").field(&self.returns).finish()
307    }
308}
309
310impl Parse for Returns {
311    fn parse(input: ParseStream<'_>) -> Result<Self> {
312        let content;
313        let this = Self {
314            returns_token: input.parse()?,
315            paren_token: parenthesized!(content in input),
316            returns: content.parse()?,
317        };
318        if this.returns.is_empty() {
319            Err(Error::new(this.paren_token.span.join(), "expected at least one return type"))
320        } else {
321            Ok(this)
322        }
323    }
324}
325
326impl Spanned for Returns {
327    fn span(&self) -> Span {
328        let span = self.returns_token.span;
329        span.join(self.paren_token.span.join()).unwrap_or(span)
330    }
331
332    fn set_span(&mut self, span: Span) {
333        self.returns_token.span = span;
334        self.paren_token = Paren(span);
335    }
336}
337
338impl Returns {
339    pub fn new(span: Span, returns: ParameterList) -> Self {
340        Self { returns_token: kw::returns(span), paren_token: Paren(span), returns }
341    }
342
343    pub fn parse_opt(input: ParseStream<'_>) -> Result<Option<Self>> {
344        if input.peek(kw::returns) { input.parse().map(Some) } else { Ok(None) }
345    }
346}
347
348/// The body of a function.
349#[derive(Clone)]
350pub enum FunctionBody {
351    /// A function without implementation.
352    Empty(Token![;]),
353    /// A function body delimited by curly braces.
354    Block(Block),
355}
356
357impl fmt::Display for FunctionBody {
358    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
359        f.write_str(self.as_str())
360    }
361}
362
363impl fmt::Debug for FunctionBody {
364    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365        f.write_str("FunctionBody::")?;
366        match self {
367            Self::Empty(_) => f.write_str("Empty"),
368            Self::Block(block) => block.fmt(f),
369        }
370    }
371}
372
373impl Parse for FunctionBody {
374    fn parse(input: ParseStream<'_>) -> Result<Self> {
375        let lookahead = input.lookahead1();
376        if lookahead.peek(Brace) {
377            input.parse().map(Self::Block)
378        } else if lookahead.peek(Token![;]) {
379            input.parse().map(Self::Empty)
380        } else {
381            Err(lookahead.error())
382        }
383    }
384}
385
386impl FunctionBody {
387    /// Returns `true` if the function body is empty.
388    #[inline]
389    pub fn is_empty(&self) -> bool {
390        matches!(self, Self::Empty(_))
391    }
392
393    /// Returns a string representation of the function body.
394    #[inline]
395    pub fn as_str(&self) -> &'static str {
396        match self {
397            Self::Empty(_) => ";",
398            // TODO: fmt::Display for Stmt
399            Self::Block(_) => "{ <stmts> }",
400        }
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407    use pretty_assertions::assert_eq;
408    use std::{
409        error::Error,
410        io::Write,
411        process::{Command, Stdio},
412    };
413    use syn::parse_quote;
414
415    #[test]
416    fn modifiers() {
417        let none: ItemFunction = parse_quote! {
418            modifier noParens {
419                _;
420            }
421        };
422        let some: ItemFunction = parse_quote! {
423            modifier withParens() {
424                _;
425            }
426        };
427        assert_eq!(none.kind, FunctionKind::new_modifier(Span::call_site()));
428        assert_eq!(none.kind, some.kind);
429        assert_eq!(none.paren_token, None);
430        assert_eq!(some.paren_token, Some(Default::default()));
431    }
432
433    #[test]
434    #[cfg_attr(miri, ignore = "takes too long")]
435    fn getters() {
436        let run_solc = run_solc();
437
438        macro_rules! test_getters {
439            ($($var:literal => $f:literal),* $(,)?) => {
440                let vars: &[&str] = &[$($var),*];
441                let fns: &[&str] = &[$($f),*];
442                for (var, f) in std::iter::zip(vars, fns) {
443                    test_getter(var, f, run_solc);
444                }
445            };
446        }
447
448        test_getters! {
449            "bool public simple;"
450                => "function simple() public view returns (bool simple);",
451            "bool public constant simpleConstant = false;"
452                => "function simpleConstant() public view returns (bool simpleConstant);",
453
454            "mapping(address => bool) public map;"
455                => "function map(address) public view returns (bool);",
456            "mapping(address a => bool b) public mapWithNames;"
457                => "function mapWithNames(address a) public view returns (bool b);",
458            "mapping(uint256 k1 => mapping(uint256 k2 => bool v) ignored) public nested2;"
459                => "function nested2(uint256 k1, uint256 k2) public view returns (bool v);",
460            "mapping(uint256 k1 => mapping(uint256 k2 => mapping(uint256 k3 => bool v) ignored1) ignored2) public nested3;"
461                => "function nested3(uint256 k1, uint256 k2, uint256 k3) public view returns (bool v);",
462
463            "bool[] public boolArray;"
464                => "function boolArray(uint256) public view returns(bool);",
465            "mapping(bool => bytes2)[] public mapArray;"
466                => "function mapArray(uint256, bool) public view returns(bytes2);",
467            "mapping(bool => mapping(address => int[])[])[][] public nestedMapArray;"
468                => "function nestedMapArray(uint256, uint256, bool, uint256, address, uint256) public view returns(int);",
469        }
470    }
471
472    fn test_getter(var_s: &str, fn_s: &str, run_solc: bool) {
473        let var = syn::parse_str::<VariableDefinition>(var_s).unwrap();
474        let getter = ItemFunction::from_variable_definition(var);
475        let f = syn::parse_str::<ItemFunction>(fn_s).unwrap();
476        assert_eq!(format!("{getter:#?}"), format!("{f:#?}"), "{var_s}");
477
478        // Test that the ABIs are the same.
479        // Skip `simple` getters since the return type will have a different ABI because Solc
480        // doesn't populate the field.
481        if run_solc && !var_s.contains("simple") {
482            match (wrap_and_compile(var_s, true), wrap_and_compile(fn_s, false)) {
483                (Ok(a), Ok(b)) => {
484                    assert_eq!(a.trim(), b.trim(), "\nleft:  {var_s:?}\nright: {fn_s:?}")
485                }
486                (Err(e), _) | (_, Err(e)) => panic!("{e}"),
487            }
488        }
489    }
490
491    fn run_solc() -> bool {
492        let Some(v) = get_solc_version() else { return false };
493        // Named keys in mappings: https://soliditylang.org/blog/2023/02/01/solidity-0.8.18-release-announcement/
494        v >= (0, 8, 18)
495    }
496
497    fn get_solc_version() -> Option<(u16, u16, u16)> {
498        let output = Command::new("solc").arg("--version").output().ok()?;
499        if !output.status.success() {
500            return None;
501        }
502        let stdout = String::from_utf8(output.stdout).ok()?;
503
504        let start = stdout.find(": 0.")?;
505        let version = &stdout[start + 2..];
506        let end = version.find('+')?;
507        let version = &version[..end];
508
509        let mut iter = version.split('.').map(|s| s.parse::<u16>().expect("bad solc version"));
510        let major = iter.next().unwrap();
511        let minor = iter.next().unwrap();
512        let patch = iter.next().unwrap();
513        Some((major, minor, patch))
514    }
515
516    fn wrap_and_compile(s: &str, var: bool) -> std::result::Result<String, Box<dyn Error>> {
517        let contract = if var {
518            format!("contract C {{ {s} }}")
519        } else {
520            format!("abstract contract C {{ {} }}", s.replace("returns", "virtual returns"))
521        };
522        let mut cmd = Command::new("solc")
523            .args(["--abi", "--pretty-json", "-"])
524            .stdin(Stdio::piped())
525            .stdout(Stdio::piped())
526            .stderr(Stdio::piped())
527            .spawn()?;
528        cmd.stdin.as_mut().unwrap().write_all(contract.as_bytes())?;
529        let output = cmd.wait_with_output()?;
530        if output.status.success() {
531            String::from_utf8(output.stdout).map_err(Into::into)
532        } else {
533            Err(String::from_utf8(output.stderr)?.into())
534        }
535    }
536}