traverse_solidity/
builder.rs

1//! Builder for Solidity AST
2
3use crate::ast::*;
4use std::error::Error;
5use std::fmt;
6
7#[derive(Debug)]
8pub enum BuilderError {
9    BuildError(String),
10    Other(String),
11}
12
13impl fmt::Display for BuilderError {
14    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15        match self {
16            BuilderError::BuildError(msg) => write!(f, "Build error: {}", msg),
17            BuilderError::Other(msg) => write!(f, "Builder error: {}", msg),
18        }
19    }
20}
21
22impl Error for BuilderError {}
23
24#[derive(Debug, Default, Clone)]
25pub struct SolidityBuilder {
26    items: Vec<SourceUnitItem>,
27}
28
29impl SolidityBuilder {
30    pub fn new() -> Self {
31        Default::default()
32    }
33
34    fn add_item(&mut self, item: SourceUnitItem) -> &mut Self {
35        self.items.push(item);
36        self
37    }
38
39    pub fn build(self) -> SourceUnit {
40        SourceUnit { items: self.items }
41    }
42
43    pub fn item_count(&self) -> usize {
44        self.items.len()
45    }
46
47    pub fn pragma(&mut self, name: impl Into<String>, value: impl Into<String>) -> &mut Self {
48        let tokens = vec![name.into(), value.into()];
49        self.add_item(SourceUnitItem::Pragma(PragmaDirective { tokens }))
50    }
51
52    pub fn import(&mut self, path: impl Into<String>) -> &mut Self {
53        self.add_item(SourceUnitItem::Import(ImportDirective {
54            path: path.into(),
55            alias: None,
56            symbols: None,
57        }))
58    }
59
60    pub fn import_as(&mut self, path: impl Into<String>, alias: impl Into<String>) -> &mut Self {
61        self.add_item(SourceUnitItem::Import(ImportDirective {
62            path: path.into(),
63            alias: Some(alias.into()),
64            symbols: None,
65        }))
66    }
67
68    pub fn import_symbols(
69        &mut self,
70        path: impl Into<String>,
71        symbols: Vec<(String, Option<String>)>,
72    ) -> &mut Self {
73        let import_symbols = symbols
74            .into_iter()
75            .map(|(name, alias)| ImportSymbol { name, alias })
76            .collect();
77
78        self.add_item(SourceUnitItem::Import(ImportDirective {
79            path: path.into(),
80            alias: None,
81            symbols: Some(import_symbols),
82        }))
83    }
84
85    pub fn contract<F>(&mut self, name: impl Into<String>, build_contract: F) -> &mut Self
86    where
87        F: FnOnce(&mut ContractBuilder),
88    {
89        let mut contract_builder = ContractBuilder::new(name.into(), false);
90        build_contract(&mut contract_builder);
91
92        self.add_item(SourceUnitItem::Contract(contract_builder.build()))
93    }
94
95    pub fn abstract_contract<F>(&mut self, name: impl Into<String>, build_contract: F) -> &mut Self
96    where
97        F: FnOnce(&mut ContractBuilder),
98    {
99        let mut contract_builder = ContractBuilder::new(name.into(), true);
100        build_contract(&mut contract_builder);
101
102        self.add_item(SourceUnitItem::Contract(contract_builder.build()))
103    }
104
105    pub fn interface<F>(&mut self, name: impl Into<String>, build_interface: F) -> &mut Self
106    where
107        F: FnOnce(&mut ContractBuilder),
108    {
109        let mut interface_builder = ContractBuilder::new(name.into(), false);
110        build_interface(&mut interface_builder);
111
112        let contract = interface_builder.build();
113        self.add_item(SourceUnitItem::Interface(InterfaceDefinition {
114            name: contract.name,
115            inheritance: contract.inheritance,
116            body: contract.body,
117        }))
118    }
119
120    pub fn library<F>(&mut self, name: impl Into<String>, build_library: F) -> &mut Self
121    where
122        F: FnOnce(&mut ContractBuilder),
123    {
124        let mut library_builder = ContractBuilder::new(name.into(), false);
125        build_library(&mut library_builder);
126
127        let contract = library_builder.build();
128        self.add_item(SourceUnitItem::Library(LibraryDefinition {
129            name: contract.name,
130            body: contract.body,
131        }))
132    }
133
134    pub fn struct_def<F>(&mut self, name: impl Into<String>, build_struct: F) -> &mut Self
135    where
136        F: FnOnce(&mut StructBuilder),
137    {
138        let mut struct_builder = StructBuilder::new(name.into());
139        build_struct(&mut struct_builder);
140
141        self.add_item(SourceUnitItem::Struct(struct_builder.build()))
142    }
143
144    pub fn enum_def(
145        &mut self,
146        name: impl Into<String>,
147        values: Vec<impl Into<String>>,
148    ) -> &mut Self {
149        let enum_values = values.into_iter().map(Into::into).collect();
150
151        self.add_item(SourceUnitItem::Enum(EnumDefinition {
152            name: name.into(),
153            values: enum_values,
154        }))
155    }
156
157    pub fn error_def<F>(&mut self, name: impl Into<String>, build_error: F) -> &mut Self
158    where
159        F: FnOnce(&mut ErrorBuilder),
160    {
161        let mut error_builder = ErrorBuilder::new(name.into());
162        build_error(&mut error_builder);
163
164        self.add_item(SourceUnitItem::Error(error_builder.build()))
165    }
166
167    pub fn event_def<F>(&mut self, name: impl Into<String>, build_event: F) -> &mut Self
168    where
169        F: FnOnce(&mut EventBuilder),
170    {
171        let mut event_builder = EventBuilder::new(name.into());
172        build_event(&mut event_builder);
173
174        self.add_item(SourceUnitItem::Event(event_builder.build()))
175    }
176}
177
178#[derive(Debug, Clone)]
179pub struct ContractBuilder {
180    name: String,
181    is_abstract: bool,
182    inheritance: Vec<InheritanceSpecifier>,
183    body: Vec<ContractBodyElement>,
184}
185
186impl ContractBuilder {
187    pub fn new(name: String, is_abstract: bool) -> Self {
188        Self {
189            name,
190            is_abstract,
191            inheritance: Vec::new(),
192            body: Vec::new(),
193        }
194    }
195
196    pub fn build(self) -> ContractDefinition {
197        ContractDefinition {
198            is_abstract: self.is_abstract,
199            name: self.name,
200            inheritance: self.inheritance,
201            body: self.body,
202        }
203    }
204
205    pub fn inherits(&mut self, name: impl Into<String>) -> &mut Self {
206        self.inheritance.push(InheritanceSpecifier {
207            name: IdentifierPath::single(name.into()),
208            arguments: None,
209        });
210        self
211    }
212
213    pub fn inherits_with_args(
214        &mut self,
215        name: impl Into<String>,
216        args: Vec<Expression>,
217    ) -> &mut Self {
218        self.inheritance.push(InheritanceSpecifier {
219            name: IdentifierPath::single(name.into()),
220            arguments: Some(args),
221        });
222        self
223    }
224
225    pub fn state_variable(
226        &mut self,
227        type_name: TypeName,
228        name: impl Into<String>,
229        visibility: Option<Visibility>,
230        initial_value: Option<Expression>,
231    ) -> &mut Self {
232        self.body.push(ContractBodyElement::StateVariable(
233            StateVariableDeclaration {
234                type_name,
235                visibility,
236                is_constant: false,
237                is_immutable: false,
238                is_transient: false,
239                override_specifier: None,
240                name: name.into(),
241                initial_value,
242            },
243        ));
244        self
245    }
246
247    pub fn constant_variable(
248        &mut self,
249        type_name: TypeName,
250        name: impl Into<String>,
251        initial_value: Expression,
252    ) -> &mut Self {
253        self.body.push(ContractBodyElement::StateVariable(
254            StateVariableDeclaration {
255                type_name,
256                visibility: None,
257                is_constant: true,
258                is_immutable: false,
259                is_transient: false,
260                override_specifier: None,
261                name: name.into(),
262                initial_value: Some(initial_value),
263            },
264        ));
265        self
266    }
267
268    pub fn function<F>(&mut self, name: impl Into<String>, build_function: F) -> &mut Self
269    where
270        F: FnOnce(&mut FunctionBuilder),
271    {
272        let mut function_builder = FunctionBuilder::new(Some(name.into()));
273        build_function(&mut function_builder);
274
275        self.body
276            .push(ContractBodyElement::Function(function_builder.build()));
277        self
278    }
279
280    pub fn constructor<F>(&mut self, build_constructor: F) -> &mut Self
281    where
282        F: FnOnce(&mut ConstructorBuilder),
283    {
284        let mut constructor_builder = ConstructorBuilder::new();
285        build_constructor(&mut constructor_builder);
286
287        self.body.push(ContractBodyElement::Constructor(
288            constructor_builder.build(),
289        ));
290        self
291    }
292
293    pub fn modifier<F>(&mut self, name: impl Into<String>, build_modifier: F) -> &mut Self
294    where
295        F: FnOnce(&mut ModifierBuilder),
296    {
297        let mut modifier_builder = ModifierBuilder::new(name.into());
298        build_modifier(&mut modifier_builder);
299
300        self.body
301            .push(ContractBodyElement::Modifier(modifier_builder.build()));
302        self
303    }
304
305    pub fn event<F>(&mut self, name: impl Into<String>, build_event: F) -> &mut Self
306    where
307        F: FnOnce(&mut EventBuilder),
308    {
309        let mut event_builder = EventBuilder::new(name.into());
310        build_event(&mut event_builder);
311
312        self.body
313            .push(ContractBodyElement::Event(event_builder.build()));
314        self
315    }
316
317    pub fn error<F>(&mut self, name: impl Into<String>, build_error: F) -> &mut Self
318    where
319        F: FnOnce(&mut ErrorBuilder),
320    {
321        let mut error_builder = ErrorBuilder::new(name.into());
322        build_error(&mut error_builder);
323
324        self.body
325            .push(ContractBodyElement::Error(error_builder.build()));
326        self
327    }
328
329    pub fn struct_def<F>(&mut self, name: impl Into<String>, build_struct: F) -> &mut Self
330    where
331        F: FnOnce(&mut StructBuilder),
332    {
333        let mut struct_builder = StructBuilder::new(name.into());
334        build_struct(&mut struct_builder);
335
336        self.body
337            .push(ContractBodyElement::Struct(struct_builder.build()));
338        self
339    }
340
341    pub fn enum_def(
342        &mut self,
343        name: impl Into<String>,
344        values: Vec<impl Into<String>>,
345    ) -> &mut Self {
346        let enum_values = values.into_iter().map(Into::into).collect();
347
348        self.body.push(ContractBodyElement::Enum(EnumDefinition {
349            name: name.into(),
350            values: enum_values,
351        }));
352        self
353    }
354}
355
356#[derive(Debug, Clone)]
357pub struct FunctionBuilder {
358    name: Option<String>,
359    parameters: Vec<Parameter>,
360    visibility: Option<Visibility>,
361    state_mutability: Option<StateMutability>,
362    modifiers: Vec<ModifierInvocation>,
363    is_virtual: bool,
364    override_specifier: Option<OverrideSpecifier>,
365    returns: Option<Vec<Parameter>>,
366    body: Option<Block>,
367}
368
369impl FunctionBuilder {
370    pub fn new(name: Option<String>) -> Self {
371        Self {
372            name,
373            parameters: Vec::new(),
374            visibility: None,
375            state_mutability: None,
376            modifiers: Vec::new(),
377            is_virtual: false,
378            override_specifier: None,
379            returns: None,
380            body: None,
381        }
382    }
383
384    pub fn build(self) -> FunctionDefinition {
385        FunctionDefinition {
386            name: self.name,
387            parameters: self.parameters,
388            visibility: self.visibility,
389            state_mutability: self.state_mutability,
390            modifiers: self.modifiers,
391            is_virtual: self.is_virtual,
392            override_specifier: self.override_specifier,
393            returns: self.returns,
394            body: self.body,
395        }
396    }
397
398    pub fn parameter(&mut self, type_name: TypeName, name: impl Into<String>) -> &mut Self {
399        self.parameters.push(Parameter {
400            type_name,
401            data_location: None,
402            name: Some(name.into()),
403        });
404        self
405    }
406
407    pub fn parameter_with_location(
408        &mut self,
409        type_name: TypeName,
410        name: impl Into<String>,
411        location: DataLocation,
412    ) -> &mut Self {
413        self.parameters.push(Parameter {
414            type_name,
415            data_location: Some(location),
416            name: Some(name.into()),
417        });
418        self
419    }
420
421    pub fn visibility(&mut self, visibility: Visibility) -> &mut Self {
422        self.visibility = Some(visibility);
423        self
424    }
425
426    pub fn state_mutability(&mut self, mutability: StateMutability) -> &mut Self {
427        self.state_mutability = Some(mutability);
428        self
429    }
430
431    pub fn modifier(&mut self, name: impl Into<String>) -> &mut Self {
432        self.modifiers.push(ModifierInvocation {
433            name: IdentifierPath::single(name.into()),
434            arguments: None,
435        });
436        self
437    }
438
439    pub fn modifier_with_args(
440        &mut self,
441        name: impl Into<String>,
442        args: Vec<Expression>,
443    ) -> &mut Self {
444        self.modifiers.push(ModifierInvocation {
445            name: IdentifierPath::single(name.into()),
446            arguments: Some(args),
447        });
448        self
449    }
450
451    pub fn virtual_fn(&mut self) -> &mut Self {
452        self.is_virtual = true;
453        self
454    }
455
456    pub fn override_fn(&mut self, overrides: Vec<String>) -> &mut Self {
457        self.override_specifier = Some(OverrideSpecifier {
458            overrides: overrides.into_iter().map(IdentifierPath::single).collect(),
459        });
460        self
461    }
462
463    pub fn returns(&mut self, parameters: Vec<Parameter>) -> &mut Self {
464        self.returns = Some(parameters);
465        self
466    }
467
468    pub fn body<F>(&mut self, build_body: F) -> &mut Self
469    where
470        F: FnOnce(&mut BlockBuilder),
471    {
472        let mut block_builder = BlockBuilder::new();
473        build_body(&mut block_builder);
474
475        self.body = Some(block_builder.build());
476        self
477    }
478}
479
480#[derive(Debug, Clone)]
481pub struct ConstructorBuilder {
482    parameters: Vec<Parameter>,
483    modifiers: Vec<ModifierInvocation>,
484    is_payable: bool,
485    visibility: Option<Visibility>,
486    body: Option<Block>,
487}
488
489impl ConstructorBuilder {
490    pub fn new() -> Self {
491        Self {
492            parameters: Vec::new(),
493            modifiers: Vec::new(),
494            is_payable: false,
495            visibility: None,
496            body: None,
497        }
498    }
499
500    pub fn build(self) -> ConstructorDefinition {
501        ConstructorDefinition {
502            parameters: self.parameters,
503            modifiers: self.modifiers,
504            is_payable: self.is_payable,
505            visibility: self.visibility,
506            body: self.body.unwrap_or(Block {
507                statements: Vec::new(),
508            }),
509        }
510    }
511
512    pub fn parameter(&mut self, type_name: TypeName, name: impl Into<String>) -> &mut Self {
513        self.parameters.push(Parameter {
514            type_name,
515            data_location: None,
516            name: Some(name.into()),
517        });
518        self
519    }
520
521    /// Adds a modifier.
522    pub fn modifier(&mut self, name: impl Into<String>) -> &mut Self {
523        self.modifiers.push(ModifierInvocation {
524            name: IdentifierPath::single(name.into()),
525            arguments: None,
526        });
527        self
528    }
529
530    pub fn payable(&mut self) -> &mut Self {
531        self.is_payable = true;
532        self
533    }
534
535    pub fn visibility(&mut self, visibility: Visibility) -> &mut Self {
536        self.visibility = Some(visibility);
537        self
538    }
539
540    pub fn body<F>(&mut self, build_body: F) -> &mut Self
541    where
542        F: FnOnce(&mut BlockBuilder),
543    {
544        let mut block_builder = BlockBuilder::new();
545        build_body(&mut block_builder);
546
547        self.body = Some(block_builder.build());
548        self
549    }
550}
551
552impl Default for ConstructorBuilder {
553    fn default() -> Self {
554        Self::new()
555    }
556}
557
558#[derive(Debug, Clone)]
559pub struct ModifierBuilder {
560    name: String,
561    parameters: Vec<Parameter>,
562    is_virtual: bool,
563    override_specifier: Option<OverrideSpecifier>,
564    body: Option<Block>,
565}
566
567impl ModifierBuilder {
568    pub fn new(name: String) -> Self {
569        Self {
570            name,
571            parameters: Vec::new(),
572            is_virtual: false,
573            override_specifier: None,
574            body: None,
575        }
576    }
577
578    pub fn build(self) -> ModifierDefinition {
579        ModifierDefinition {
580            name: self.name,
581            parameters: self.parameters,
582            is_virtual: self.is_virtual,
583            override_specifier: self.override_specifier,
584            body: self.body,
585        }
586    }
587
588    pub fn parameter(&mut self, type_name: TypeName, name: impl Into<String>) -> &mut Self {
589        self.parameters.push(Parameter {
590            type_name,
591            data_location: None,
592            name: Some(name.into()),
593        });
594        self
595    }
596
597    pub fn virtual_modifier(&mut self) -> &mut Self {
598        self.is_virtual = true;
599        self
600    }
601
602    pub fn override_modifier(&mut self, overrides: Vec<String>) -> &mut Self {
603        self.override_specifier = Some(OverrideSpecifier {
604            overrides: overrides.into_iter().map(IdentifierPath::single).collect(),
605        });
606        self
607    }
608
609    pub fn body<F>(&mut self, build_body: F) -> &mut Self
610    where
611        F: FnOnce(&mut BlockBuilder),
612    {
613        let mut block_builder = BlockBuilder::new();
614        build_body(&mut block_builder);
615
616        self.body = Some(block_builder.build());
617        self
618    }
619}
620
621#[derive(Debug, Clone)]
622pub struct StructBuilder {
623    name: String,
624    members: Vec<StructMember>,
625}
626
627impl StructBuilder {
628    pub fn new(name: String) -> Self {
629        Self {
630            name,
631            members: Vec::new(),
632        }
633    }
634
635    pub fn build(self) -> StructDefinition {
636        StructDefinition {
637            name: self.name,
638            members: self.members,
639        }
640    }
641
642    /// Adds a member.
643    pub fn member(&mut self, type_name: TypeName, name: impl Into<String>) -> &mut Self {
644        self.members.push(StructMember {
645            type_name,
646            name: name.into(),
647        });
648        self
649    }
650}
651
652#[derive(Debug, Clone)]
653pub struct EventBuilder {
654    name: String,
655    parameters: Vec<EventParameter>,
656    is_anonymous: bool,
657}
658
659impl EventBuilder {
660    pub fn new(name: String) -> Self {
661        Self {
662            name,
663            parameters: Vec::new(),
664            is_anonymous: false,
665        }
666    }
667
668    pub fn build(self) -> EventDefinition {
669        EventDefinition {
670            name: self.name,
671            parameters: self.parameters,
672            is_anonymous: self.is_anonymous,
673        }
674    }
675
676    pub fn parameter(&mut self, type_name: TypeName, name: Option<String>) -> &mut Self {
677        self.parameters.push(EventParameter {
678            type_name,
679            is_indexed: false,
680            name,
681        });
682        self
683    }
684
685    pub fn indexed_parameter(&mut self, type_name: TypeName, name: Option<String>) -> &mut Self {
686        self.parameters.push(EventParameter {
687            type_name,
688            is_indexed: true,
689            name,
690        });
691        self
692    }
693
694    pub fn anonymous(&mut self) -> &mut Self {
695        self.is_anonymous = true;
696        self
697    }
698}
699
700#[derive(Debug, Clone)]
701pub struct ErrorBuilder {
702    name: String,
703    parameters: Vec<ErrorParameter>,
704}
705
706impl ErrorBuilder {
707    pub fn new(name: String) -> Self {
708        Self {
709            name,
710            parameters: Vec::new(),
711        }
712    }
713
714    pub fn build(self) -> ErrorDefinition {
715        ErrorDefinition {
716            name: self.name,
717            parameters: self.parameters,
718        }
719    }
720
721    pub fn parameter(&mut self, type_name: TypeName, name: Option<String>) -> &mut Self {
722        self.parameters.push(ErrorParameter { type_name, name });
723        self
724    }
725}
726
727#[derive(Debug, Clone)]
728pub struct BlockBuilder {
729    statements: Vec<Statement>,
730}
731
732impl BlockBuilder {
733    pub fn new() -> Self {
734        Self {
735            statements: Vec::new(),
736        }
737    }
738
739    pub fn build(self) -> Block {
740        Block {
741            statements: self.statements,
742        }
743    }
744
745    pub fn statement(&mut self, statement: Statement) -> &mut Self {
746        self.statements.push(statement);
747        self
748    }
749
750    /// Adds an expression statement.
751    pub fn expression(&mut self, expression: Expression) -> &mut Self {
752        self.statements
753            .push(Statement::Expression(ExpressionStatement { expression }));
754        self
755    }
756
757    pub fn variable_declaration(
758        &mut self,
759        type_name: TypeName,
760        name: impl Into<String>,
761        initial_value: Option<Expression>,
762    ) -> &mut Self {
763        self.variable_declaration_with_location(type_name, name, None, initial_value)
764    }
765
766    pub fn variable_declaration_with_location(
767        &mut self,
768        type_name: TypeName,
769        name: impl Into<String>,
770        data_location: Option<DataLocation>,
771        initial_value: Option<Expression>,
772    ) -> &mut Self {
773        let declaration = VariableDeclaration {
774            type_name,
775            data_location,
776            name: name.into(),
777        };
778
779        self.statements
780            .push(Statement::Variable(VariableDeclarationStatement {
781                declaration,
782                initial_value,
783            }));
784        self
785    }
786
787    pub fn assignment(&mut self, left: impl Into<String>, right: impl Into<String>) -> &mut Self {
788        let assignment = Expression::Assignment(AssignmentExpression {
789            left: Box::new(Expression::Identifier(left.into())),
790            operator: AssignmentOperator::Assign,
791            right: Box::new(Expression::Identifier(right.into())),
792        });
793
794        self.expression(assignment)
795    }
796
797    pub fn return_statement(&mut self, expression: Option<Expression>) -> &mut Self {
798        self.statements
799            .push(Statement::Return(ReturnStatement { expression }));
800        self
801    }
802
803    pub fn if_statement<F>(&mut self, condition: Expression, build_then: F) -> &mut Self
804    where
805        F: FnOnce(&mut BlockBuilder),
806    {
807        let mut then_builder = BlockBuilder::new();
808        build_then(&mut then_builder);
809
810        self.statements.push(Statement::If(IfStatement {
811            condition,
812            then_statement: Box::new(Statement::Block(then_builder.build())),
813            else_statement: None,
814        }));
815        self
816    }
817
818    pub fn if_else_statement<F, G>(
819        &mut self,
820        condition: Expression,
821        build_then: F,
822        build_else: G,
823    ) -> &mut Self
824    where
825        F: FnOnce(&mut BlockBuilder),
826        G: FnOnce(&mut BlockBuilder),
827    {
828        let mut then_builder = BlockBuilder::new();
829        build_then(&mut then_builder);
830
831        let mut else_builder = BlockBuilder::new();
832        build_else(&mut else_builder);
833
834        self.statements.push(Statement::If(IfStatement {
835            condition,
836            then_statement: Box::new(Statement::Block(then_builder.build())),
837            else_statement: Some(Box::new(Statement::Block(else_builder.build()))),
838        }));
839        self
840    }
841}
842
843impl Default for BlockBuilder {
844    fn default() -> Self {
845        Self::new()
846    }
847}
848
849pub fn uint256() -> TypeName {
850    TypeName::Elementary(ElementaryTypeName::UnsignedInteger(Some(256)))
851}
852
853pub fn uint(bits: u16) -> TypeName {
854    TypeName::Elementary(ElementaryTypeName::UnsignedInteger(Some(bits)))
855}
856
857pub fn int256() -> TypeName {
858    TypeName::Elementary(ElementaryTypeName::SignedInteger(Some(256)))
859}
860
861pub fn int(bits: u16) -> TypeName {
862    TypeName::Elementary(ElementaryTypeName::SignedInteger(Some(bits)))
863}
864
865pub fn address() -> TypeName {
866    TypeName::Elementary(ElementaryTypeName::Address)
867}
868
869pub fn bool() -> TypeName {
870    TypeName::Elementary(ElementaryTypeName::Bool)
871}
872
873pub fn string() -> TypeName {
874    TypeName::Elementary(ElementaryTypeName::String)
875}
876
877pub fn bytes() -> TypeName {
878    TypeName::Elementary(ElementaryTypeName::Bytes)
879}
880
881/// Helper function to determine if a type requires a data location specifier
882pub fn requires_data_location(type_name: &TypeName) -> bool {
883    match type_name {
884        TypeName::Elementary(elem) => match elem {
885            ElementaryTypeName::String | ElementaryTypeName::Bytes => true,
886            _ => false,
887        },
888        TypeName::Array(_, _) => true,
889        TypeName::UserDefined(_) => true, // Structs require data location
890        TypeName::Mapping(_) => false, // Mappings are always storage
891        TypeName::Function(_) => false,
892    }
893}
894
895/// Helper function to get the appropriate data location for a type in local variable context
896pub fn get_default_data_location(type_name: &TypeName) -> Option<DataLocation> {
897    if requires_data_location(type_name) {
898        Some(DataLocation::Memory)
899    } else {
900        None
901    }
902}
903
904pub fn bytes_fixed(size: u8) -> TypeName {
905    TypeName::Elementary(ElementaryTypeName::FixedBytes(Some(size)))
906}
907
908pub fn array(element_type: TypeName, size: Option<Expression>) -> TypeName {
909    TypeName::Array(Box::new(element_type), size.map(Box::new))
910}
911
912pub fn mapping(key_type: TypeName, value_type: TypeName) -> TypeName {
913    TypeName::Mapping(MappingType {
914        key_type: Box::new(key_type),
915        key_name: None,
916        value_type: Box::new(value_type),
917        value_name: None,
918    })
919}
920
921pub fn user_type(name: impl Into<String>) -> TypeName {
922    TypeName::UserDefined(IdentifierPath::single(name.into()))
923}
924
925pub fn number(value: impl Into<String>) -> Expression {
926    Expression::Literal(Literal::Number(NumberLiteral {
927        value: value.into(),
928        sub_denomination: None,
929    }))
930}
931
932pub fn string_literal(value: impl Into<String>) -> Expression {
933    Expression::Literal(Literal::String(StringLiteral {
934        value: value.into(),
935    }))
936}
937
938pub fn boolean(value: bool) -> Expression {
939    Expression::Literal(Literal::Boolean(value))
940}
941
942pub fn identifier(name: impl Into<String>) -> Expression {
943    Expression::Identifier(name.into())
944}
945
946pub fn binary(left: Expression, operator: BinaryOperator, right: Expression) -> Expression {
947    Expression::Binary(BinaryExpression {
948        left: Box::new(left),
949        operator,
950        right: Box::new(right),
951    })
952}
953
954pub fn add(left: Expression, right: Expression) -> Expression {
955    binary(left, BinaryOperator::Add, right)
956}
957
958pub fn sub(left: Expression, right: Expression) -> Expression {
959    binary(left, BinaryOperator::Sub, right)
960}
961
962pub fn mul(left: Expression, right: Expression) -> Expression {
963    binary(left, BinaryOperator::Mul, right)
964}
965
966pub fn div(left: Expression, right: Expression) -> Expression {
967    binary(left, BinaryOperator::Div, right)
968}
969
970pub fn eq(left: Expression, right: Expression) -> Expression {
971    binary(left, BinaryOperator::Equal, right)
972}
973
974pub fn gt(left: Expression, right: Expression) -> Expression {
975    binary(left, BinaryOperator::GreaterThan, right)
976}
977
978pub fn lt(left: Expression, right: Expression) -> Expression {
979    binary(left, BinaryOperator::LessThan, right)
980}
981
982pub fn and(left: Expression, right: Expression) -> Expression {
983    binary(left, BinaryOperator::And, right)
984}
985
986pub fn or(left: Expression, right: Expression) -> Expression {
987    binary(left, BinaryOperator::Or, right)
988}
989
990#[cfg(test)]
991mod tests {
992    use super::*;
993
994    #[test]
995    fn test_basic_contract_builder() {
996        let mut builder = SolidityBuilder::new();
997        builder
998            .pragma("solidity", "^0.8.0")
999            .contract("MyContract", |contract| {
1000                contract
1001                    .state_variable(uint256(), "value", Some(Visibility::Public), None)
1002                    .function("setValue", |func| {
1003                        func.parameter(uint256(), "_value")
1004                            .visibility(Visibility::Public)
1005                            .body(|body| {
1006                                body.assignment("value", "_value");
1007                            });
1008                    });
1009            });
1010
1011        let source_unit = builder.build();
1012        assert_eq!(source_unit.items.len(), 2);
1013
1014        if let SourceUnitItem::Pragma(pragma) = &source_unit.items[0] {
1015            assert_eq!(pragma.tokens, vec!["solidity", "^0.8.0"]);
1016        } else {
1017            panic!("Expected pragma directive");
1018        }
1019
1020        if let SourceUnitItem::Contract(contract) = &source_unit.items[1] {
1021            assert_eq!(contract.name, "MyContract");
1022            assert!(!contract.is_abstract);
1023            assert_eq!(contract.body.len(), 2);
1024        } else {
1025            panic!("Expected contract definition");
1026        }
1027    }
1028
1029    #[test]
1030    fn test_struct_builder() {
1031        let mut builder = SolidityBuilder::new();
1032        builder.struct_def("Person", |s| {
1033            s.member(string(), "name")
1034                .member(uint256(), "age")
1035                .member(address(), "wallet");
1036        });
1037
1038        let source_unit = builder.build();
1039        assert_eq!(source_unit.items.len(), 1);
1040
1041        if let SourceUnitItem::Struct(struct_def) = &source_unit.items[0] {
1042            assert_eq!(struct_def.name, "Person");
1043            assert_eq!(struct_def.members.len(), 3);
1044            assert_eq!(struct_def.members[0].name, "name");
1045            assert_eq!(struct_def.members[1].name, "age");
1046            assert_eq!(struct_def.members[2].name, "wallet");
1047        } else {
1048            panic!("Expected struct definition");
1049        }
1050    }
1051
1052    #[test]
1053    fn test_event_builder() {
1054        let mut builder = SolidityBuilder::new();
1055        builder.event_def("Transfer", |event| {
1056            event
1057                .indexed_parameter(address(), Some("from".to_string()))
1058                .indexed_parameter(address(), Some("to".to_string()))
1059                .parameter(uint256(), Some("value".to_string()));
1060        });
1061
1062        let source_unit = builder.build();
1063        assert_eq!(source_unit.items.len(), 1);
1064
1065        if let SourceUnitItem::Event(event_def) = &source_unit.items[0] {
1066            assert_eq!(event_def.name, "Transfer");
1067            assert_eq!(event_def.parameters.len(), 3);
1068            assert!(event_def.parameters[0].is_indexed);
1069            assert!(event_def.parameters[1].is_indexed);
1070            assert!(!event_def.parameters[2].is_indexed);
1071        } else {
1072            panic!("Expected event definition");
1073        }
1074    }
1075
1076    #[test]
1077    fn test_expression_helpers() {
1078        let expr = add(identifier("a"), mul(identifier("b"), number("10")));
1079
1080        if let Expression::Binary(binary) = expr {
1081            assert!(matches!(binary.operator, BinaryOperator::Add));
1082            if let Expression::Binary(right_binary) = *binary.right {
1083                assert!(matches!(right_binary.operator, BinaryOperator::Mul));
1084            } else {
1085                panic!("Expected multiplication on the right side");
1086            }
1087        } else {
1088            panic!("Expected binary expression");
1089        }
1090    }
1091
1092    #[test]
1093    fn test_type_helpers() {
1094        assert!(matches!(
1095            uint256(),
1096            TypeName::Elementary(ElementaryTypeName::UnsignedInteger(Some(256)))
1097        ));
1098        assert!(matches!(
1099            address(),
1100            TypeName::Elementary(ElementaryTypeName::Address)
1101        ));
1102        assert!(matches!(
1103            bool(),
1104            TypeName::Elementary(ElementaryTypeName::Bool)
1105        ));
1106
1107        let arr_type = array(uint256(), Some(number("10")));
1108        if let TypeName::Array(element_type, size) = arr_type {
1109            assert!(matches!(
1110                *element_type,
1111                TypeName::Elementary(ElementaryTypeName::UnsignedInteger(Some(256)))
1112            ));
1113            assert!(size.is_some());
1114        } else {
1115            panic!("Expected array type");
1116        }
1117
1118        let map_type = mapping(address(), uint256());
1119        if let TypeName::Mapping(mapping_type) = map_type {
1120            assert!(matches!(
1121                *mapping_type.key_type,
1122                TypeName::Elementary(ElementaryTypeName::Address)
1123            ));
1124            assert!(matches!(
1125                *mapping_type.value_type,
1126                TypeName::Elementary(ElementaryTypeName::UnsignedInteger(Some(256)))
1127            ));
1128        } else {
1129            panic!("Expected mapping type");
1130        }
1131    }
1132
1133    #[test]
1134    fn test_complex_contract() {
1135        let mut builder = SolidityBuilder::new();
1136        builder
1137            .pragma("solidity", "^0.8.0")
1138            .import("./IERC20.sol")
1139            .contract("Token", |contract| {
1140                contract
1141                    .inherits("IERC20")
1142                    .state_variable(
1143                        mapping(address(), uint256()),
1144                        "balances",
1145                        Some(Visibility::Private),
1146                        None,
1147                    )
1148                    .state_variable(uint256(), "totalSupply", Some(Visibility::Public), None)
1149                    .constructor(|constructor| {
1150                        constructor
1151                            .parameter(uint256(), "initialSupply")
1152                            .body(|body| {
1153                                body.assignment("totalSupply", "initialSupply");
1154                            });
1155                    })
1156                    .function("transfer", |func| {
1157                        func.parameter(address(), "to")
1158                            .parameter(uint256(), "amount")
1159                            .visibility(Visibility::Public)
1160                            .returns(vec![Parameter {
1161                                type_name: bool(),
1162                                data_location: None,
1163                                name: None,
1164                            }])
1165                            .body(|body| {
1166                                body.if_statement(
1167                                    gt(identifier("balances[msg.sender]"), identifier("amount")),
1168                                    |then_block| {
1169                                        then_block
1170                                            .assignment(
1171                                                "balances[msg.sender]",
1172                                                "balances[msg.sender] - amount",
1173                                            )
1174                                            .assignment("balances[to]", "balances[to] + amount")
1175                                            .return_statement(Some(boolean(true)));
1176                                    },
1177                                );
1178                            });
1179                    })
1180                    .event("Transfer", |event| {
1181                        event
1182                            .indexed_parameter(address(), Some("from".to_string()))
1183                            .indexed_parameter(address(), Some("to".to_string()))
1184                            .parameter(uint256(), Some("value".to_string()));
1185                    });
1186            });
1187
1188        let source_unit = builder.build();
1189        assert_eq!(source_unit.items.len(), 3); // pragma, import, contract
1190
1191        if let SourceUnitItem::Contract(contract) = &source_unit.items[2] {
1192            assert_eq!(contract.name, "Token");
1193            assert_eq!(contract.inheritance.len(), 1);
1194            assert!(!contract.body.is_empty());
1195        } else {
1196            panic!("Expected contract definition");
1197        }
1198    }
1199}