Skip to main content

ryo_executor/engine/impls/
derive.rs

1//! ASTRegApply implementation for derive mutations
2
3use ryo_mutations::basic::{AddDeriveMutation, RemoveDeriveMutation};
4use ryo_mutations::MutationResult;
5use ryo_source::pure::{PureAttrMeta, PureAttribute, PureItem};
6
7use crate::engine::{ASTMutationContext, ASTRegApply, ModificationType, MutationEvent};
8
9impl ASTRegApply for AddDeriveMutation {
10    fn apply_to_registry(&self, ctx: &mut ASTMutationContext) -> MutationResult {
11        let target_id = self.symbol_id;
12
13        // Get the AST for this symbol (O(1) lookup)
14        let item = match ctx.ast_registry.get(target_id) {
15            Some(item) => item.clone(),
16            None => {
17                return MutationResult {
18                    mutation_type: "AddDerive".to_string(),
19                    changes: 0,
20                    description: format!("AST not found for SymbolId({:?})", target_id),
21                };
22            }
23        };
24
25        // Modify the item
26        let (new_item, changes) = match item {
27            PureItem::Struct(mut s) => {
28                let changes = add_derive_to_attrs(&mut s.attrs, &self.derives);
29                (PureItem::Struct(s), changes)
30            }
31            PureItem::Enum(mut e) => {
32                let mut changes = add_derive_to_attrs(&mut e.attrs, &self.derives);
33                // For enum with Default derive, add #[default] to first variant
34                if self.derives.contains(&"Default".to_string()) {
35                    if let Some(first_variant) = e.variants.first_mut() {
36                        let has_default_attr =
37                            first_variant.attrs.iter().any(|a| a.path == "default");
38                        if !has_default_attr {
39                            first_variant.attrs.push(PureAttribute {
40                                path: "default".to_string(),
41                                meta: PureAttrMeta::Path,
42                                is_inner: false,
43                            });
44                            changes += 1;
45                        }
46                    }
47                }
48                (PureItem::Enum(e), changes)
49            }
50            _ => {
51                return MutationResult {
52                    mutation_type: "AddDerive".to_string(),
53                    changes: 0,
54                    description: format!("SymbolId({:?}) is not a struct or enum", target_id),
55                };
56            }
57        };
58
59        if changes > 0 {
60            // Update the registry
61            ctx.set_ast(target_id, new_item);
62
63            // Emit event for each derive added
64            for derive in &self.derives {
65                ctx.emit(MutationEvent::SymbolModified {
66                    id: target_id,
67                    modification: ModificationType::DeriveAdded(derive.clone()),
68                });
69            }
70        }
71
72        MutationResult {
73            mutation_type: "AddDerive".to_string(),
74            changes,
75            description: if changes > 0 {
76                format!(
77                    "Added derive({}) to SymbolId({:?})",
78                    self.derives.join(", "),
79                    target_id
80                )
81            } else {
82                "Derives already present".to_string()
83            },
84        }
85    }
86}
87
88impl ASTRegApply for RemoveDeriveMutation {
89    fn apply_to_registry(&self, ctx: &mut ASTMutationContext) -> MutationResult {
90        let target_id = self.symbol_id;
91
92        // Get the AST for this symbol (O(1) lookup)
93        let item = match ctx.ast_registry.get(target_id) {
94            Some(item) => item.clone(),
95            None => {
96                return MutationResult {
97                    mutation_type: "RemoveDerive".to_string(),
98                    changes: 0,
99                    description: format!("AST not found for SymbolId({:?})", target_id),
100                };
101            }
102        };
103
104        // Modify the item
105        let (new_item, changes) = match item {
106            PureItem::Struct(mut s) => {
107                let changes = remove_derive_from_attrs(&mut s.attrs, &self.derives);
108                (PureItem::Struct(s), changes)
109            }
110            PureItem::Enum(mut e) => {
111                let changes = remove_derive_from_attrs(&mut e.attrs, &self.derives);
112                (PureItem::Enum(e), changes)
113            }
114            _ => {
115                return MutationResult {
116                    mutation_type: "RemoveDerive".to_string(),
117                    changes: 0,
118                    description: format!("SymbolId({:?}) is not a struct or enum", target_id),
119                };
120            }
121        };
122
123        if changes > 0 {
124            // Update the registry
125            ctx.set_ast(target_id, new_item);
126
127            // Emit event for each derive removed
128            for derive in &self.derives {
129                ctx.emit(MutationEvent::SymbolModified {
130                    id: target_id,
131                    modification: ModificationType::DeriveRemoved(derive.clone()),
132                });
133            }
134        }
135
136        MutationResult {
137            mutation_type: "RemoveDerive".to_string(),
138            changes,
139            description: if changes > 0 {
140                format!(
141                    "Removed derive({}) from SymbolId({:?})",
142                    self.derives.join(", "),
143                    target_id
144                )
145            } else {
146                "Derive not found".to_string()
147            },
148        }
149    }
150}
151
152/// Helper to add derive to attrs (shared logic from AddDeriveMutation)
153fn add_derive_to_attrs(attrs: &mut Vec<PureAttribute>, derives: &[String]) -> usize {
154    let existing_derive_idx = attrs.iter().position(|a| a.path == "derive");
155
156    if let Some(idx) = existing_derive_idx {
157        // Extract existing derives from meta
158        let existing_args = match &attrs[idx].meta {
159            PureAttrMeta::List(args) => args.clone(),
160            _ => String::new(),
161        };
162        let mut all_derives: Vec<String> = existing_args
163            .split(',')
164            .map(|s| s.trim().to_string())
165            .filter(|s| !s.is_empty())
166            .collect();
167
168        let mut added = 0;
169        for d in derives {
170            if !all_derives.contains(d) {
171                all_derives.push(d.clone());
172                added += 1;
173            }
174        }
175
176        if added > 0 {
177            attrs[idx].meta = PureAttrMeta::List(all_derives.join(", "));
178        }
179        added
180    } else {
181        // Create new derive attribute
182        attrs.insert(
183            0,
184            PureAttribute {
185                path: "derive".to_string(),
186                meta: PureAttrMeta::List(derives.join(", ")),
187                is_inner: false,
188            },
189        );
190        derives.len()
191    }
192}
193
194/// Helper to remove derive from attrs (shared logic from RemoveDeriveMutation)
195fn remove_derive_from_attrs(attrs: &mut Vec<PureAttribute>, derives: &[String]) -> usize {
196    let existing_derive_idx = attrs.iter().position(|a| a.path == "derive");
197
198    if let Some(idx) = existing_derive_idx {
199        // Extract existing derives from meta
200        let existing_args = match &attrs[idx].meta {
201            PureAttrMeta::List(args) => args.clone(),
202            _ => String::new(),
203        };
204        let remaining: Vec<String> = existing_args
205            .split(',')
206            .map(|s| s.trim().to_string())
207            .filter(|s| !s.is_empty() && !derives.contains(s))
208            .collect();
209
210        let original_count = existing_args
211            .split(',')
212            .filter(|s| !s.trim().is_empty())
213            .count();
214        let removed = original_count - remaining.len();
215
216        if remaining.is_empty() {
217            attrs.remove(idx);
218        } else {
219            attrs[idx].meta = PureAttrMeta::List(remaining.join(", "));
220        }
221        removed
222    } else {
223        0
224    }
225}