wick_component_codegen/
generate.rs1pub(crate) mod config;
2mod dependency;
3mod expand_type;
4mod f;
5mod ids;
6mod module;
7mod templates;
8
9use std::cell::RefCell;
10use std::rc::Rc;
11
12use anyhow::Result;
13pub use config::configure;
14use expand_type::expand_type;
15use ids::*;
16use itertools::Itertools;
17use module::Module;
18use proc_macro2::{Ident, TokenStream};
19use quote::quote;
20use templates::TypeOptions;
21use wick_config::{FetchOptions, WickConfiguration};
22use wick_interface_types::{OperationSignature, OperationSignatures, TypeDefinition};
23
24#[derive(Debug, Clone, Copy, PartialEq)]
25pub(crate) enum Direction {
26 In,
27 Out,
28}
29
30fn put_impl_in_path(module: &Rc<RefCell<Module>>, mut path_parts_reverse: Vec<&str>, implementation: TokenStream) {
31 if let Some(next) = path_parts_reverse.pop() {
32 let module = module.borrow_mut().get_or_add(next);
33 put_impl_in_path(&module, path_parts_reverse, implementation);
34 } else {
35 module.borrow_mut().add(implementation);
36 }
37}
38
39fn gen_types<'a>(
40 module_name: &str,
41 config: &mut config::Config,
42 ty: impl Iterator<Item = &'a TypeDefinition>,
43) -> TokenStream {
44 let types = ty
45 .map(|v| templates::type_def(config, v, TypeOptions::empty()))
46 .collect_vec();
47 let root = Module::new(module_name);
48 for (mod_parts, implementation) in types {
49 put_impl_in_path(&root, mod_parts, implementation);
50 }
51
52 let borrowed = root.borrow();
53 borrowed.codegen()
54}
55
56fn gen_wrapper_fns<'a>(
57 config: &mut config::Config,
58 component: &Ident,
59 op: impl Iterator<Item = &'a OperationSignature>,
60) -> Vec<TokenStream> {
61 op.map(|op| templates::gen_wrapper_fn(config, component, op))
62 .collect_vec()
63}
64
65fn gen_trait_fns<'a>(
66 config: &mut config::Config,
67 op: impl Iterator<Item = &'a OperationSignature>,
68) -> Vec<TokenStream> {
69 op.map(|op| {
70 let op_name = id(&snake(op.name()));
71 let op_config = templates::op_config(config, &generic_config_id(), op);
72 let op_output = templates::op_outgoing(config, "Outputs", op.outputs());
73 let op_input = templates::op_incoming(config, "Inputs", op.inputs());
74 let trait_sig = templates::trait_signature(config, op);
75 let desc = format!("Types associated with the `{}` operation", op.name());
76 quote! {
77 #[doc = #desc]
78 pub mod #op_name {
79 #[allow(unused)]
80 use super::*;
81 #op_config
82 #op_output
83 #op_input
84 #trait_sig
85 }
86
87 }
88 })
89 .collect_vec()
90}
91
92#[allow(clippy::needless_pass_by_value, clippy::too_many_lines)]
93fn codegen(wick_config: WickConfiguration, gen_config: &mut config::Config) -> Result<String> {
94 let (ops, types, required, imported, root_config) = match &wick_config {
95 wick_config::WickConfiguration::Component(comp) => {
96 let types = comp
97 .types()?
98 .into_iter()
99 .sorted_by(|a, b| a.name().cmp(b.name()))
100 .collect();
101 let root_config = comp.config().to_owned();
102 let requires = comp.requires().clone().to_vec();
103 let ops = comp.component().operation_signatures();
104 let imports = comp.import().to_vec();
105 (ops, types, requires, imports, Some(root_config))
106 }
107 wick_config::WickConfiguration::Types(config) => (
108 config.operation_signatures(),
109 config.types().to_vec(),
110 Default::default(),
111 Default::default(),
112 None,
113 ),
114 _ => panic!("Code generation only supports `wick/component` and `wick/types` configurations"),
115 };
116
117 let component_name = id("Component");
118 let wrapper_fns = gen_wrapper_fns(gen_config, &component_name, ops.iter());
119 let trait_defs = gen_trait_fns(gen_config, ops.iter());
120 let typedefs = gen_types("types", gen_config, types.iter());
121
122 let init = (!ops.is_empty())
123 .then(|| templates::gen_component_impls(gen_config, &component_name, ops.iter(), &required, &imported));
124
125 let root_config = templates::component_config(gen_config, root_config);
126
127 let imports = gen_config.deps.iter().map(|dep| quote! { #dep }).collect_vec();
128 let imports = quote! { #( #imports )* };
129
130 let components = gen_config.components.then(|| {
131 quote! {
132 #[derive(Default, Clone)]
133 #[doc = "The struct that the component implementation hinges around"]
134 pub struct #component_name;
135 impl #component_name {
136 #( #wrapper_fns )*
137 }
138 }
139 });
140
141 let expanded = quote! {
142 #imports
143
144 #[allow(unused)]
145 pub(crate) use wick_component::*;
146
147 #[allow(unused)]
148 pub(crate) use wick_component::WickStream;
149 pub use wick_component::flow_component::Context;
150
151 #init
152
153 #root_config
154
155 #[doc = "Additional generated types"]
156 #typedefs
157 #( #trait_defs )*
158 #components
159 };
160 let source = expanded.to_string();
161 match syn::parse_file(source.as_str()) {
162 Ok(reparsed) => {
163 let formatted = prettyplease::unparse(&reparsed);
164 Ok(formatted)
165 }
166 Err(e) => {
167 println!("Failed to parse generated code: {}", e);
168 Ok(source)
170 }
171 }
172}
173
174pub fn build(config: config::Config) -> Result<()> {
175 let rt = tokio::runtime::Runtime::new()?;
176 rt.block_on(async_build(config))
177}
178
179pub async fn async_build(mut config: config::Config) -> Result<()> {
180 let path = config.spec.as_path().to_str().unwrap();
181 let wick_config = wick_config::WickConfiguration::fetch_uninitialized_tree(path, FetchOptions::default())
182 .await?
183 .element
184 .into_inner();
185
186 let src = codegen(wick_config, &mut config)?;
187 tokio::fs::create_dir_all(&config.out_dir).await?;
188 let target = config.out_dir.join("mod.rs");
189 println!("Writing to {}", target.display());
190 tokio::fs::write(target, src).await?;
191 Ok(())
192}
193
194#[cfg(test)]
195mod test {
196 use anyhow::Result;
200
201 use super::*;
202 use crate::generate::config::ConfigBuilder;
203
204 #[tokio::test]
205 async fn test_build() -> Result<()> {
206 let mut config = ConfigBuilder::new().spec("./tests/testdata/component.yaml").build()?;
207 let wick_config = WickConfiguration::fetch(&config.spec, Default::default())
208 .await
209 .unwrap()
210 .finish()?;
211
212 let src = codegen(wick_config, &mut config)?;
213
214 assert!(src.contains("pub struct Component"));
215
216 Ok(())
217 }
218}