1use std::sync::Arc;
40use wasmtime::component::{Component, Linker};
41use wasmtime::{Config, Engine, Store};
42use wasmtime_wasi::{WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
43
44wasmtime::component::bindgen!({
46 world: "mutation-plugin",
47 path: "wit",
48});
49
50pub use ryo_plugin_api::{
52 Capture, MatchResult, MutationCategory, MutationManifest, NodeKind, TextEdit, TransformContext,
53 TransformDef, TransformError, TypeHint, CURRENT_API_VERSION,
54};
55
56#[derive(Debug, thiserror::Error)]
58pub enum LoaderError {
59 #[error("Failed to create WASM engine: {0}")]
61 EngineCreation(#[source] wasmtime::Error),
62
63 #[error("Failed to add WASI to linker: {0}")]
65 WasiSetup(#[source] wasmtime::Error),
66
67 #[error("Failed to parse WASM component: {0}")]
69 ComponentParse(#[source] wasmtime::Error),
70
71 #[error("Failed to set fuel limit: {0}")]
73 FuelSetup(#[source] wasmtime::Error),
74
75 #[error("Failed to instantiate WASM component: {0}")]
77 Instantiation(#[source] wasmtime::Error),
78
79 #[error("API version mismatch: expected {expected}, got {actual}")]
81 ApiVersionMismatch { expected: u32, actual: u32 },
82
83 #[error("Failed to call WASM function '{function}': {source}")]
85 FunctionCall {
86 function: &'static str,
87 #[source]
88 source: wasmtime::Error,
89 },
90
91 #[error("Transform error: {0}")]
93 TransformError(String),
94
95 #[error("IO error: {0}")]
97 Io(#[from] std::io::Error),
98}
99
100pub struct PluginLoader {
111 engine: Engine,
112 linker: Arc<Linker<PluginState>>,
113}
114
115struct PluginState {
117 wasi_ctx: WasiCtx,
118 resource_table: wasmtime::component::ResourceTable,
119}
120
121impl WasiView for PluginState {
122 fn ctx(&mut self) -> WasiCtxView<'_> {
123 WasiCtxView {
124 ctx: &mut self.wasi_ctx,
125 table: &mut self.resource_table,
126 }
127 }
128}
129
130impl PluginLoader {
131 pub fn new() -> Result<Self, LoaderError> {
139 let mut config = Config::new();
140
141 config.wasm_component_model(true);
143
144 config.consume_fuel(true);
146
147 config.max_wasm_stack(1024 * 1024);
149
150 let engine = Engine::new(&config).map_err(LoaderError::EngineCreation)?;
152
153 let mut linker = Linker::new(&engine);
155 wasmtime_wasi::p2::add_to_linker_sync(&mut linker).map_err(LoaderError::WasiSetup)?;
156
157 Ok(Self {
158 engine,
159 linker: Arc::new(linker),
160 })
161 }
162
163 pub fn load(&self, wasm_bytes: &[u8]) -> Result<LoadedPlugin, LoaderError> {
173 let component =
175 Component::new(&self.engine, wasm_bytes).map_err(LoaderError::ComponentParse)?;
176
177 let wasi_ctx = WasiCtxBuilder::new()
179 .inherit_stdout() .inherit_stderr()
181 .build();
182
183 let mut store = Store::new(
185 &self.engine,
186 PluginState {
187 wasi_ctx,
188 resource_table: wasmtime::component::ResourceTable::new(),
189 },
190 );
191
192 store.set_fuel(10_000_000).map_err(LoaderError::FuelSetup)?;
194
195 let bindings = MutationPlugin::instantiate(&mut store, &component, &self.linker)
197 .map_err(LoaderError::Instantiation)?;
198
199 let iface = bindings.ryo_transform_mutation();
201
202 let wasm_manifest =
204 iface
205 .call_get_manifest(&mut store)
206 .map_err(|e| LoaderError::FunctionCall {
207 function: "get-manifest",
208 source: e,
209 })?;
210
211 let expected_version = CURRENT_API_VERSION;
213 if wasm_manifest.api_version != expected_version {
214 return Err(LoaderError::ApiVersionMismatch {
215 expected: expected_version,
216 actual: wasm_manifest.api_version,
217 });
218 }
219
220 let manifest = convert_manifest(&wasm_manifest);
222
223 let additional_patterns =
225 iface
226 .call_get_pattern_source(&mut store)
227 .map_err(|e| LoaderError::FunctionCall {
228 function: "get-pattern-source",
229 source: e,
230 })?;
231
232 tracing::info!("Loaded mutation plugin: {}", manifest.name);
233
234 Ok(LoadedPlugin {
235 manifest,
236 additional_patterns,
237 bindings,
238 store,
239 })
240 }
241}
242
243pub struct LoadedPlugin {
248 pub manifest: MutationManifest,
250 pub additional_patterns: String,
252 bindings: MutationPlugin,
254 store: Store<PluginState>,
256}
257
258impl LoadedPlugin {
259 pub fn execute_transform(
268 &mut self,
269 matches: Vec<MatchResult>,
270 context: TransformContext,
271 ) -> Result<Vec<TextEdit>, LoaderError> {
272 self.store
274 .set_fuel(1_000_000)
275 .map_err(LoaderError::FuelSetup)?;
276
277 let wasm_matches = matches
279 .iter()
280 .map(convert_match_to_wasm)
281 .collect::<Vec<_>>();
282 let wasm_context = convert_context_to_wasm(&context);
283
284 let iface = self.bindings.ryo_transform_mutation();
286
287 let result = iface
289 .call_execute_transform(&mut self.store, &wasm_matches, &wasm_context)
290 .map_err(|e| LoaderError::FunctionCall {
291 function: "execute-transform",
292 source: e,
293 })?;
294
295 match result {
297 Ok(edits) => Ok(edits.into_iter().map(convert_text_edit).collect()),
298 Err(e) => Err(LoaderError::TransformError(format_transform_error(&e))),
299 }
300 }
301}
302
303fn convert_manifest(
308 wasm: &exports::ryo::transform::mutation::MutationManifest,
309) -> MutationManifest {
310 MutationManifest {
311 api_version: wasm.api_version,
312 name: wasm.name.clone(),
313 description: wasm.description.clone(),
314 category: convert_category(&wasm.category),
315 tier: wasm.tier,
316 pattern: wasm.pattern.clone(),
317 transform: convert_transform_def(&wasm.transform),
318 }
319}
320
321fn convert_category(
322 wasm: &exports::ryo::transform::mutation::MutationCategory,
323) -> MutationCategory {
324 match wasm {
325 exports::ryo::transform::mutation::MutationCategory::Idiom => MutationCategory::Idiom,
326 exports::ryo::transform::mutation::MutationCategory::Refactor => MutationCategory::Refactor,
327 exports::ryo::transform::mutation::MutationCategory::Generate => MutationCategory::Generate,
328 exports::ryo::transform::mutation::MutationCategory::Custom => MutationCategory::Custom,
329 }
330}
331
332fn convert_transform_def(wasm: &exports::ryo::transform::mutation::TransformDef) -> TransformDef {
333 match wasm {
334 exports::ryo::transform::mutation::TransformDef::Template(t) => {
335 TransformDef::Template(t.clone())
336 }
337 exports::ryo::transform::mutation::TransformDef::WasmExecute => TransformDef::WasmExecute,
338 }
339}
340
341fn convert_match_to_wasm(m: &MatchResult) -> exports::ryo::transform::mutation::MatchResult {
342 exports::ryo::transform::mutation::MatchResult {
343 kind: convert_node_kind_to_wasm(&m.kind),
344 start_byte: m.start_byte,
345 end_byte: m.end_byte,
346 captures: m.captures.iter().map(convert_capture_to_wasm).collect(),
347 }
348}
349
350fn convert_node_kind_to_wasm(k: &NodeKind) -> exports::ryo::transform::types::NodeKind {
351 match k {
352 NodeKind::FnCall => exports::ryo::transform::types::NodeKind::FnCall,
353 NodeKind::MethodCall => exports::ryo::transform::types::NodeKind::MethodCall,
354 NodeKind::MatchExpr => exports::ryo::transform::types::NodeKind::MatchExpr,
355 NodeKind::IfExpr => exports::ryo::transform::types::NodeKind::IfExpr,
356 NodeKind::IfLetExpr => exports::ryo::transform::types::NodeKind::IfLetExpr,
357 NodeKind::LoopExpr => exports::ryo::transform::types::NodeKind::LoopExpr,
358 NodeKind::ForExpr => exports::ryo::transform::types::NodeKind::ForExpr,
359 NodeKind::WhileExpr => exports::ryo::transform::types::NodeKind::WhileExpr,
360 NodeKind::Block => exports::ryo::transform::types::NodeKind::Block,
361 NodeKind::Ident => exports::ryo::transform::types::NodeKind::Ident,
362 NodeKind::Literal => exports::ryo::transform::types::NodeKind::Literal,
363 NodeKind::BinaryExpr => exports::ryo::transform::types::NodeKind::BinaryExpr,
364 NodeKind::UnaryExpr => exports::ryo::transform::types::NodeKind::UnaryExpr,
365 NodeKind::FieldAccess => exports::ryo::transform::types::NodeKind::FieldAccess,
366 NodeKind::IndexExpr => exports::ryo::transform::types::NodeKind::IndexExpr,
367 NodeKind::Closure => exports::ryo::transform::types::NodeKind::Closure,
368 NodeKind::StructExpr => exports::ryo::transform::types::NodeKind::StructExpr,
369 NodeKind::TupleExpr => exports::ryo::transform::types::NodeKind::TupleExpr,
370 NodeKind::ArrayExpr => exports::ryo::transform::types::NodeKind::ArrayExpr,
371 NodeKind::Path => exports::ryo::transform::types::NodeKind::Path,
372 NodeKind::TypePath => exports::ryo::transform::types::NodeKind::TypePath,
373 }
374}
375
376fn convert_capture_to_wasm(c: &Capture) -> exports::ryo::transform::types::Capture {
377 exports::ryo::transform::types::Capture {
378 name: c.name.clone(),
379 start_byte: c.start_byte,
380 end_byte: c.end_byte,
381 text: c.text.clone(),
382 }
383}
384
385fn convert_context_to_wasm(
386 ctx: &TransformContext,
387) -> exports::ryo::transform::mutation::TransformContext {
388 exports::ryo::transform::mutation::TransformContext {
389 file_path: ctx.file_path.clone(),
390 source_text: ctx.source_text.clone(),
391 type_hints: ctx
392 .type_hints
393 .iter()
394 .map(convert_type_hint_to_wasm)
395 .collect(),
396 fn_return_type: ctx.fn_return_type.clone(),
397 }
398}
399
400fn convert_type_hint_to_wasm(h: &TypeHint) -> exports::ryo::transform::types::TypeHint {
401 exports::ryo::transform::types::TypeHint {
402 node_id: h.node_id,
403 type_name: h.type_name.clone(),
404 is_result: h.is_result,
405 is_option: h.is_option,
406 is_copy: h.is_copy,
407 is_iterator: h.is_iterator,
408 }
409}
410
411fn convert_text_edit(e: exports::ryo::transform::types::TextEdit) -> TextEdit {
412 TextEdit {
413 start_byte: e.start_byte,
414 end_byte: e.end_byte,
415 replacement: e.replacement,
416 }
417}
418
419fn format_transform_error(e: &exports::ryo::transform::mutation::TransformError) -> String {
420 match e {
421 exports::ryo::transform::mutation::TransformError::MissingCapture(name) => {
422 format!("Missing capture: {}", name)
423 }
424 exports::ryo::transform::mutation::TransformError::InvalidContext(msg) => {
425 format!("Invalid context: {}", msg)
426 }
427 exports::ryo::transform::mutation::TransformError::TypeMismatch(msg) => {
428 format!("Type mismatch: {}", msg)
429 }
430 exports::ryo::transform::mutation::TransformError::PatternNotApplicable(msg) => {
431 format!("Pattern not applicable: {}", msg)
432 }
433 exports::ryo::transform::mutation::TransformError::Internal(msg) => {
434 format!("Internal error: {}", msg)
435 }
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 #[test]
444 fn test_loader_creation() {
445 let loader = PluginLoader::new();
446 assert!(loader.is_ok());
447 }
448}