1use crate::{
2 Block, FunctionAttribute, FunctionAttributes, Mutability, ParameterList, Parameters, SolIdent,
3 Spanned, Stmt, Type, VariableDeclaration, VariableDefinition, Visibility, kw,
4};
5use proc_macro2::Span;
6use std::{
7 fmt,
8 hash::{Hash, Hasher},
9 num::NonZeroU16,
10};
11use syn::{
12 Attribute, Error, Result, Token, parenthesized,
13 parse::{Parse, ParseStream},
14 token::{Brace, Paren},
15};
16
17#[derive(Clone)]
23pub struct ItemFunction {
24 pub attrs: Vec<Attribute>,
26 pub kind: FunctionKind,
27 pub name: Option<SolIdent>,
28 pub paren_token: Option<Paren>,
31 pub parameters: ParameterList,
32 pub attributes: FunctionAttributes,
34 pub returns: Option<Returns>,
36 pub body: FunctionBody,
37}
38
39impl fmt::Display for ItemFunction {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 f.write_str(self.kind.as_str())?;
42 if let Some(name) = &self.name {
43 f.write_str(" ")?;
44 name.fmt(f)?;
45 }
46 write!(f, "({})", self.parameters)?;
47
48 if !self.attributes.is_empty() {
49 write!(f, " {}", self.attributes)?;
50 }
51
52 if let Some(returns) = &self.returns {
53 write!(f, " {returns}")?;
54 }
55
56 if !self.body.is_empty() {
57 f.write_str(" ")?;
58 }
59 f.write_str(self.body.as_str())
60 }
61}
62
63impl fmt::Debug for ItemFunction {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 f.debug_struct("ItemFunction")
66 .field("attrs", &self.attrs)
67 .field("kind", &self.kind)
68 .field("name", &self.name)
69 .field("arguments", &self.parameters)
70 .field("attributes", &self.attributes)
71 .field("returns", &self.returns)
72 .field("body", &self.body)
73 .finish()
74 }
75}
76
77impl Parse for ItemFunction {
78 fn parse(input: ParseStream<'_>) -> Result<Self> {
79 let attrs = input.call(Attribute::parse_outer)?;
80 let kind: FunctionKind = input.parse()?;
81 let name = input.call(SolIdent::parse_opt)?;
82
83 let (paren_token, parameters) = if kind.is_modifier() && !input.peek(Paren) {
84 (None, ParameterList::new())
85 } else {
86 let content;
87 (Some(parenthesized!(content in input)), content.parse()?)
88 };
89
90 let attributes = input.parse()?;
91 let returns = input.call(Returns::parse_opt)?;
92 let body = input.parse()?;
93
94 Ok(Self { attrs, kind, name, paren_token, parameters, attributes, returns, body })
95 }
96}
97
98impl Spanned for ItemFunction {
99 fn span(&self) -> Span {
100 if let Some(name) = &self.name { name.span() } else { self.kind.span() }
101 }
102
103 fn set_span(&mut self, span: Span) {
104 self.kind.set_span(span);
105 if let Some(name) = &mut self.name {
106 name.set_span(span);
107 }
108 }
109}
110
111impl ItemFunction {
112 pub fn new(kind: FunctionKind, name: Option<SolIdent>) -> Self {
114 let span = name.as_ref().map_or_else(|| kind.span(), |name| name.span());
115 Self {
116 attrs: Vec::new(),
117 kind,
118 name,
119 paren_token: Some(Paren(span)),
120 parameters: Parameters::new(),
121 attributes: FunctionAttributes::new(),
122 returns: None,
123 body: FunctionBody::Empty(Token),
124 }
125 }
126
127 pub fn new_getter(name: SolIdent, ty: Type) -> Self {
141 let span = name.span();
142 let kind = FunctionKind::new_function(span);
143 let mut function = Self::new(kind, Some(name.clone()));
144
145 function.attributes.0 = vec![
147 FunctionAttribute::Visibility(Visibility::new_public(span)),
148 FunctionAttribute::Mutability(Mutability::new_view(span)),
149 ];
150
151 let mut ty = ty;
154 let mut return_name = None;
155 let mut first = true;
156 loop {
157 match ty {
158 Type::Mapping(map) => {
160 let key = VariableDeclaration::new_with(*map.key, None, map.key_name);
161 function.parameters.push(key);
162 return_name = map.value_name;
163 ty = *map.value;
164 }
165 Type::Array(array) => {
167 let uint256 = Type::Uint(span, NonZeroU16::new(256));
168 function.parameters.push(VariableDeclaration::new(uint256));
169 ty = *array.ty;
170 }
171 _ => {
172 if first {
173 return_name = Some(name);
174 }
175 break;
176 }
177 }
178 first = false;
179 }
180 let mut returns = ParameterList::new();
181 returns.push(VariableDeclaration::new_with(ty, None, return_name));
182 function.returns = Some(Returns::new(span, returns));
183
184 function
185 }
186
187 pub fn from_variable_definition(var: VariableDefinition) -> Self {
195 let mut function = Self::new_getter(var.name, var.ty);
196 function.attrs = var.attrs;
197 function
198 }
199
200 #[track_caller]
207 pub fn name(&self) -> &SolIdent {
208 match &self.name {
209 Some(name) => name,
210 None => panic!("function has no name: {self:?}"),
211 }
212 }
213
214 pub fn is_void(&self) -> bool {
216 match &self.returns {
217 None => true,
218 Some(returns) => returns.returns.is_empty(),
219 }
220 }
221
222 pub fn has_implementation(&self) -> bool {
224 matches!(self.body, FunctionBody::Block(_))
225 }
226
227 pub fn call_type(&self) -> Type {
229 Type::Tuple(self.parameters.types().cloned().collect())
230 }
231
232 pub fn return_type(&self) -> Option<Type> {
234 self.returns.as_ref().map(|returns| Type::Tuple(returns.returns.types().cloned().collect()))
235 }
236
237 pub fn body(&self) -> Option<&[Stmt]> {
239 match &self.body {
240 FunctionBody::Block(block) => Some(&block.stmts),
241 _ => None,
242 }
243 }
244
245 pub fn body_mut(&mut self) -> Option<&mut Vec<Stmt>> {
247 match &mut self.body {
248 FunctionBody::Block(block) => Some(&mut block.stmts),
249 _ => None,
250 }
251 }
252
253 #[allow(clippy::result_large_err)]
254 pub fn into_body(self) -> std::result::Result<Vec<Stmt>, Self> {
255 match self.body {
256 FunctionBody::Block(block) => Ok(block.stmts),
257 _ => Err(self),
258 }
259 }
260}
261
262kw_enum! {
263 pub enum FunctionKind {
265 Constructor(kw::constructor),
266 Function(kw::function),
267 Fallback(kw::fallback),
268 Receive(kw::receive),
269 Modifier(kw::modifier),
270 }
271}
272
273#[derive(Clone)]
275pub struct Returns {
276 pub returns_token: kw::returns,
277 pub paren_token: Paren,
278 pub returns: ParameterList,
280}
281
282impl PartialEq for Returns {
283 fn eq(&self, other: &Self) -> bool {
284 self.returns == other.returns
285 }
286}
287
288impl Eq for Returns {}
289
290impl Hash for Returns {
291 fn hash<H: Hasher>(&self, state: &mut H) {
292 self.returns.hash(state);
293 }
294}
295
296impl fmt::Display for Returns {
297 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298 f.write_str("returns (")?;
299 self.returns.fmt(f)?;
300 f.write_str(")")
301 }
302}
303
304impl fmt::Debug for Returns {
305 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
306 f.debug_tuple("Returns").field(&self.returns).finish()
307 }
308}
309
310impl Parse for Returns {
311 fn parse(input: ParseStream<'_>) -> Result<Self> {
312 let content;
313 let this = Self {
314 returns_token: input.parse()?,
315 paren_token: parenthesized!(content in input),
316 returns: content.parse()?,
317 };
318 if this.returns.is_empty() {
319 Err(Error::new(this.paren_token.span.join(), "expected at least one return type"))
320 } else {
321 Ok(this)
322 }
323 }
324}
325
326impl Spanned for Returns {
327 fn span(&self) -> Span {
328 let span = self.returns_token.span;
329 span.join(self.paren_token.span.join()).unwrap_or(span)
330 }
331
332 fn set_span(&mut self, span: Span) {
333 self.returns_token.span = span;
334 self.paren_token = Paren(span);
335 }
336}
337
338impl Returns {
339 pub fn new(span: Span, returns: ParameterList) -> Self {
340 Self { returns_token: kw::returns(span), paren_token: Paren(span), returns }
341 }
342
343 pub fn parse_opt(input: ParseStream<'_>) -> Result<Option<Self>> {
344 if input.peek(kw::returns) { input.parse().map(Some) } else { Ok(None) }
345 }
346}
347
348#[derive(Clone)]
350pub enum FunctionBody {
351 Empty(Token![;]),
353 Block(Block),
355}
356
357impl fmt::Display for FunctionBody {
358 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
359 f.write_str(self.as_str())
360 }
361}
362
363impl fmt::Debug for FunctionBody {
364 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365 f.write_str("FunctionBody::")?;
366 match self {
367 Self::Empty(_) => f.write_str("Empty"),
368 Self::Block(block) => block.fmt(f),
369 }
370 }
371}
372
373impl Parse for FunctionBody {
374 fn parse(input: ParseStream<'_>) -> Result<Self> {
375 let lookahead = input.lookahead1();
376 if lookahead.peek(Brace) {
377 input.parse().map(Self::Block)
378 } else if lookahead.peek(Token![;]) {
379 input.parse().map(Self::Empty)
380 } else {
381 Err(lookahead.error())
382 }
383 }
384}
385
386impl FunctionBody {
387 #[inline]
389 pub fn is_empty(&self) -> bool {
390 matches!(self, Self::Empty(_))
391 }
392
393 #[inline]
395 pub fn as_str(&self) -> &'static str {
396 match self {
397 Self::Empty(_) => ";",
398 Self::Block(_) => "{ <stmts> }",
400 }
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407 use pretty_assertions::assert_eq;
408 use std::{
409 error::Error,
410 io::Write,
411 process::{Command, Stdio},
412 };
413 use syn::parse_quote;
414
415 #[test]
416 fn modifiers() {
417 let none: ItemFunction = parse_quote! {
418 modifier noParens {
419 _;
420 }
421 };
422 let some: ItemFunction = parse_quote! {
423 modifier withParens() {
424 _;
425 }
426 };
427 assert_eq!(none.kind, FunctionKind::new_modifier(Span::call_site()));
428 assert_eq!(none.kind, some.kind);
429 assert_eq!(none.paren_token, None);
430 assert_eq!(some.paren_token, Some(Default::default()));
431 }
432
433 #[test]
434 #[cfg_attr(miri, ignore = "takes too long")]
435 fn getters() {
436 let run_solc = run_solc();
437
438 macro_rules! test_getters {
439 ($($var:literal => $f:literal),* $(,)?) => {
440 let vars: &[&str] = &[$($var),*];
441 let fns: &[&str] = &[$($f),*];
442 for (var, f) in std::iter::zip(vars, fns) {
443 test_getter(var, f, run_solc);
444 }
445 };
446 }
447
448 test_getters! {
449 "bool public simple;"
450 => "function simple() public view returns (bool simple);",
451 "bool public constant simpleConstant = false;"
452 => "function simpleConstant() public view returns (bool simpleConstant);",
453
454 "mapping(address => bool) public map;"
455 => "function map(address) public view returns (bool);",
456 "mapping(address a => bool b) public mapWithNames;"
457 => "function mapWithNames(address a) public view returns (bool b);",
458 "mapping(uint256 k1 => mapping(uint256 k2 => bool v) ignored) public nested2;"
459 => "function nested2(uint256 k1, uint256 k2) public view returns (bool v);",
460 "mapping(uint256 k1 => mapping(uint256 k2 => mapping(uint256 k3 => bool v) ignored1) ignored2) public nested3;"
461 => "function nested3(uint256 k1, uint256 k2, uint256 k3) public view returns (bool v);",
462
463 "bool[] public boolArray;"
464 => "function boolArray(uint256) public view returns(bool);",
465 "mapping(bool => bytes2)[] public mapArray;"
466 => "function mapArray(uint256, bool) public view returns(bytes2);",
467 "mapping(bool => mapping(address => int[])[])[][] public nestedMapArray;"
468 => "function nestedMapArray(uint256, uint256, bool, uint256, address, uint256) public view returns(int);",
469 }
470 }
471
472 fn test_getter(var_s: &str, fn_s: &str, run_solc: bool) {
473 let var = syn::parse_str::<VariableDefinition>(var_s).unwrap();
474 let getter = ItemFunction::from_variable_definition(var);
475 let f = syn::parse_str::<ItemFunction>(fn_s).unwrap();
476 assert_eq!(format!("{getter:#?}"), format!("{f:#?}"), "{var_s}");
477
478 if run_solc && !var_s.contains("simple") {
482 match (wrap_and_compile(var_s, true), wrap_and_compile(fn_s, false)) {
483 (Ok(a), Ok(b)) => {
484 assert_eq!(a.trim(), b.trim(), "\nleft: {var_s:?}\nright: {fn_s:?}")
485 }
486 (Err(e), _) | (_, Err(e)) => panic!("{e}"),
487 }
488 }
489 }
490
491 fn run_solc() -> bool {
492 let Some(v) = get_solc_version() else { return false };
493 v >= (0, 8, 18)
495 }
496
497 fn get_solc_version() -> Option<(u16, u16, u16)> {
498 let output = Command::new("solc").arg("--version").output().ok()?;
499 if !output.status.success() {
500 return None;
501 }
502 let stdout = String::from_utf8(output.stdout).ok()?;
503
504 let start = stdout.find(": 0.")?;
505 let version = &stdout[start + 2..];
506 let end = version.find('+')?;
507 let version = &version[..end];
508
509 let mut iter = version.split('.').map(|s| s.parse::<u16>().expect("bad solc version"));
510 let major = iter.next().unwrap();
511 let minor = iter.next().unwrap();
512 let patch = iter.next().unwrap();
513 Some((major, minor, patch))
514 }
515
516 fn wrap_and_compile(s: &str, var: bool) -> std::result::Result<String, Box<dyn Error>> {
517 let contract = if var {
518 format!("contract C {{ {s} }}")
519 } else {
520 format!("abstract contract C {{ {} }}", s.replace("returns", "virtual returns"))
521 };
522 let mut cmd = Command::new("solc")
523 .args(["--abi", "--pretty-json", "-"])
524 .stdin(Stdio::piped())
525 .stdout(Stdio::piped())
526 .stderr(Stdio::piped())
527 .spawn()?;
528 cmd.stdin.as_mut().unwrap().write_all(contract.as_bytes())?;
529 let output = cmd.wait_with_output()?;
530 if output.status.success() {
531 String::from_utf8(output.stdout).map_err(Into::into)
532 } else {
533 Err(String::from_utf8(output.stderr)?.into())
534 }
535 }
536}