rust_actions_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use serde::Deserialize;
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use syn::parse::{Parse, ParseStream};
7use syn::{parse_macro_input, DeriveInput, ItemFn, FnArg, Type, LitStr, Token};
8
9#[proc_macro_attribute]
10pub fn step(attr: TokenStream, item: TokenStream) -> TokenStream {
11 let step_name = parse_macro_input!(attr as LitStr);
12 let input = parse_macro_input!(item as ItemFn);
13
14 let fn_name = &input.sig.ident;
15
16 let mut params = input.sig.inputs.iter();
17
18 let world_type = match params.next() {
19 Some(FnArg::Typed(pat_type)) => {
20 extract_world_type(&pat_type.ty)
21 }
22 _ => {
23 return syn::Error::new_spanned(
24 &input.sig,
25 "Step function must have a world parameter as first argument"
26 ).to_compile_error().into();
27 }
28 };
29
30 let has_args = params.next().is_some();
31
32 let step_call = if has_args {
33 quote! {
34 let parsed_args = match ::rust_actions::args::FromArgs::from_args(&args) {
35 Ok(a) => a,
36 Err(e) => return Box::pin(async move { Err(e) }),
37 };
38 Box::pin(async move {
39 let result = #fn_name(world, parsed_args).await?;
40 Ok(::rust_actions::outputs::IntoOutputs::into_outputs(result))
41 })
42 }
43 } else {
44 quote! {
45 Box::pin(async move {
46 let result = #fn_name(world).await?;
47 Ok(::rust_actions::outputs::IntoOutputs::into_outputs(result))
48 })
49 }
50 };
51
52 let step_name_str = step_name.value();
53 let erased_fn_name = syn::Ident::new(
54 &format!("__erased_{}", fn_name),
55 fn_name.span()
56 );
57
58 let expanded = quote! {
59 #input
60
61 #[doc(hidden)]
62 #[allow(non_upper_case_globals)]
63 fn #erased_fn_name<'a>(
64 world_any: &'a mut dyn ::std::any::Any,
65 args: ::rust_actions::args::RawArgs,
66 ) -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = ::rust_actions::Result<::rust_actions::outputs::StepOutputs>> + Send + 'a>> {
67 let world = match world_any.downcast_mut::<#world_type>() {
68 Some(w) => w,
69 None => {
70 let msg = format!(
71 "World type mismatch: expected {}",
72 ::std::any::type_name::<#world_type>()
73 );
74 return Box::pin(async move {
75 Err(::rust_actions::Error::Custom(msg))
76 });
77 }
78 };
79
80 #step_call
81 }
82
83 ::rust_actions::inventory::submit! {
84 ::rust_actions::registry::ErasedStepDef::new(
85 #step_name_str,
86 {
87 use ::std::any::TypeId;
88 TypeId::of::<#world_type>()
89 },
90 #erased_fn_name,
91 )
92 }
93 };
94
95 TokenStream::from(expanded)
96}
97
98fn extract_world_type(ty: &Type) -> proc_macro2::TokenStream {
99 match ty {
100 Type::Reference(type_ref) => {
101 if let Type::Path(type_path) = &*type_ref.elem {
102 let path = &type_path.path;
103 quote! { #path }
104 } else {
105 quote! { compile_error!("Expected a type path for world parameter") }
106 }
107 }
108 _ => {
109 quote! { compile_error!("World parameter must be a mutable reference") }
110 }
111 }
112}
113
114#[proc_macro_derive(World, attributes(world))]
115pub fn derive_world(input: TokenStream) -> TokenStream {
116 let input = parse_macro_input!(input as DeriveInput);
117 let name = &input.ident;
118
119 let expanded = quote! {
120 impl ::rust_actions::world::World for #name {
121 fn new() -> impl ::std::future::Future<Output = ::rust_actions::Result<Self>> + Send {
122 Self::setup()
123 }
124 }
125 };
126
127 TokenStream::from(expanded)
128}
129
130#[proc_macro_derive(Args, attributes(arg))]
131pub fn derive_args(input: TokenStream) -> TokenStream {
132 let input = parse_macro_input!(input as DeriveInput);
133 let name = &input.ident;
134
135 let expanded = quote! {
136 impl ::rust_actions::args::FromArgs for #name {
137 fn from_args(args: &::rust_actions::args::RawArgs) -> ::rust_actions::Result<Self> {
138 let value = ::rust_actions::serde_json::Value::Object(
139 args.iter()
140 .map(|(k, v)| (k.clone(), v.clone()))
141 .collect()
142 );
143 ::rust_actions::serde_json::from_value(value)
144 .map_err(|e| ::rust_actions::Error::Args(e.to_string()))
145 }
146 }
147 };
148
149 TokenStream::from(expanded)
150}
151
152#[proc_macro_derive(Outputs)]
153pub fn derive_outputs(input: TokenStream) -> TokenStream {
154 let input = parse_macro_input!(input as DeriveInput);
155 let name = &input.ident;
156
157 let expanded = quote! {
158 impl ::rust_actions::outputs::IntoOutputs for #name {
159 fn into_outputs(self) -> ::rust_actions::outputs::StepOutputs {
160 ::rust_actions::serde_json::to_value(&self)
161 .map(|v| ::rust_actions::outputs::StepOutputs::from_value(v))
162 .unwrap_or_default()
163 }
164 }
165 };
166
167 TokenStream::from(expanded)
168}
169
170#[proc_macro_attribute]
171pub fn before_all(_attr: TokenStream, item: TokenStream) -> TokenStream {
172 let input = parse_macro_input!(item as ItemFn);
173 TokenStream::from(quote! { #input })
174}
175
176#[proc_macro_attribute]
177pub fn after_all(_attr: TokenStream, item: TokenStream) -> TokenStream {
178 let input = parse_macro_input!(item as ItemFn);
179 TokenStream::from(quote! { #input })
180}
181
182#[proc_macro_attribute]
183pub fn before_scenario(_attr: TokenStream, item: TokenStream) -> TokenStream {
184 let input = parse_macro_input!(item as ItemFn);
185 TokenStream::from(quote! { #input })
186}
187
188#[proc_macro_attribute]
189pub fn after_scenario(_attr: TokenStream, item: TokenStream) -> TokenStream {
190 let input = parse_macro_input!(item as ItemFn);
191 TokenStream::from(quote! { #input })
192}
193
194#[proc_macro_attribute]
195pub fn before_step(_attr: TokenStream, item: TokenStream) -> TokenStream {
196 let input = parse_macro_input!(item as ItemFn);
197 TokenStream::from(quote! { #input })
198}
199
200#[proc_macro_attribute]
201pub fn after_step(_attr: TokenStream, item: TokenStream) -> TokenStream {
202 let input = parse_macro_input!(item as ItemFn);
203 TokenStream::from(quote! { #input })
204}
205
206struct GenerateTestsArgs {
207 path: LitStr,
208 world_type: syn::Path,
209}
210
211impl Parse for GenerateTestsArgs {
212 fn parse(input: ParseStream) -> syn::Result<Self> {
213 let path: LitStr = input.parse()?;
214 input.parse::<Token![,]>()?;
215 let world_type: syn::Path = input.parse()?;
216 Ok(GenerateTestsArgs { path, world_type })
217 }
218}
219
220#[derive(Debug, Deserialize)]
221struct WorkflowHeader {
222 #[allow(dead_code)]
223 name: Option<String>,
224 #[serde(default)]
225 on: Option<WorkflowTrigger>,
226}
227
228#[derive(Debug, Deserialize)]
229struct WorkflowTrigger {
230 #[serde(default)]
231 workflow_call: Option<HashMap<String, serde_yaml::Value>>,
232}
233
234fn is_reusable_workflow(path: &Path) -> bool {
235 let content = match std::fs::read_to_string(path) {
236 Ok(c) => c,
237 Err(_) => return false,
238 };
239
240 let header: WorkflowHeader = match serde_yaml::from_str(&content) {
241 Ok(h) => h,
242 Err(_) => return false,
243 };
244
245 header
246 .on
247 .as_ref()
248 .map(|t| t.workflow_call.is_some())
249 .unwrap_or(false)
250}
251
252fn discover_yaml_files(dir: &Path) -> Vec<PathBuf> {
253 walkdir::WalkDir::new(dir)
254 .into_iter()
255 .filter_map(|e| e.ok())
256 .filter(|e| {
257 e.path().is_file()
258 && e.path()
259 .extension()
260 .map(|ext| ext == "yaml" || ext == "yml")
261 .unwrap_or(false)
262 })
263 .map(|e| e.path().to_path_buf())
264 .collect()
265}
266
267fn path_to_test_name(path: &Path, base: &Path) -> proc_macro2::Ident {
268 let rel_path = path.strip_prefix(base).unwrap_or(path);
269
270 let name = rel_path
271 .to_string_lossy()
272 .replace(std::path::MAIN_SEPARATOR, "_")
273 .replace(".yaml", "")
274 .replace(".yml", "")
275 .replace('-', "_")
276 .replace('.', "_");
277
278 let name = format!("test_{}", name);
279 proc_macro2::Ident::new(&name, proc_macro2::Span::call_site())
280}
281
282#[proc_macro]
283pub fn generate_tests(input: TokenStream) -> TokenStream {
284 let args = parse_macro_input!(input as GenerateTestsArgs);
285 let workflows_path = args.path.value();
286 let world_type = &args.world_type;
287
288 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")
289 .expect("CARGO_MANIFEST_DIR not set");
290 let full_path = Path::new(&manifest_dir).join(&workflows_path);
291
292 if !full_path.exists() {
293 let err = format!("Workflows path does not exist: {}", full_path.display());
294 return syn::Error::new_spanned(&args.path, err)
295 .to_compile_error()
296 .into();
297 }
298
299 let yaml_files = discover_yaml_files(&full_path);
300
301 let tests = yaml_files
302 .iter()
303 .filter(|f| !is_reusable_workflow(f))
304 .map(|file| {
305 let rel_path = file.strip_prefix(&manifest_dir).unwrap_or(file);
306 let test_name = path_to_test_name(file, &full_path);
307 let path_str = rel_path.to_string_lossy();
308
309 quote! {
310 #[::tokio::test(flavor = "current_thread", start_paused = true)]
311 async fn #test_name() {
312 ::rust_actions::prelude::RustActions::<#world_type>::new()
313 .workflow(#path_str)
314 .run()
315 .await;
316 }
317 }
318 });
319
320 let expanded = quote! {
321 #(#tests)*
322 };
323
324 TokenStream::from(expanded)
325}