torn_api_codegen/model/
mod.rs

1use std::{cell::RefCell, rc::Rc};
2
3use indexmap::IndexMap;
4use newtype::Newtype;
5use object::Object;
6use parameter::Parameter;
7use path::{Path, PrettySegments};
8use proc_macro2::TokenStream;
9use r#enum::Enum;
10use scope::Scope;
11
12use crate::openapi::{r#type::OpenApiType, schema::OpenApiSchema};
13
14pub mod r#enum;
15pub mod newtype;
16pub mod object;
17pub mod parameter;
18pub mod path;
19pub mod scope;
20pub mod union;
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum Model {
24    Newtype(Newtype),
25    Enum(Enum),
26    Object(Object),
27    Unresolved,
28}
29
30impl Model {
31    pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
32        match self {
33            Self::Enum(r#enum) => r#enum.is_display(resolved),
34            Self::Newtype(_) => true,
35            _ => false,
36        }
37    }
38}
39
40#[derive(Default)]
41pub struct ResolvedSchema {
42    pub models: IndexMap<String, Model>,
43    pub paths: IndexMap<String, Path>,
44    pub parameters: Vec<Parameter>,
45
46    pub warnings: WarningReporter,
47}
48
49#[derive(Clone)]
50pub struct Warning {
51    pub message: String,
52    pub path: Vec<String>,
53}
54
55impl std::fmt::Display for Warning {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        write!(f, "in {}: {}", self.path.join("."), self.message)
58    }
59}
60
61#[derive(Default)]
62struct WarningReporterState {
63    warnings: Vec<Warning>,
64    path: Vec<String>,
65}
66
67#[derive(Clone, Default)]
68pub struct WarningReporter {
69    state: Rc<RefCell<WarningReporterState>>,
70}
71
72impl WarningReporter {
73    pub fn new() -> Self {
74        Self::default()
75    }
76
77    fn push(&self, message: impl ToString) {
78        let mut state = self.state.borrow_mut();
79        let path = state.path.iter().map(ToString::to_string).collect();
80        state.warnings.push(Warning {
81            message: message.to_string(),
82            path,
83        });
84    }
85
86    fn child(&self, name: impl ToString) -> WarningReporter {
87        self.state.borrow_mut().path.push(name.to_string());
88
89        Self {
90            state: self.state.clone(),
91        }
92    }
93
94    pub fn is_empty(&self) -> bool {
95        self.state.borrow().warnings.is_empty()
96    }
97
98    pub fn get_warnings(&self) -> Vec<Warning> {
99        self.state.borrow().warnings.clone()
100    }
101}
102
103impl Drop for WarningReporter {
104    fn drop(&mut self) {
105        self.state.borrow_mut().path.pop();
106    }
107}
108
109impl ResolvedSchema {
110    pub fn from_open_api(schema: &OpenApiSchema) -> Self {
111        let mut result = Self::default();
112
113        for (name, r#type) in &schema.components.schemas {
114            result.models.insert(
115                name.to_string(),
116                resolve(r#type, name, &schema.components.schemas, &result.warnings),
117            );
118        }
119
120        for (path, body) in &schema.paths {
121            result.paths.insert(
122                path.to_string(),
123                Path::from_schema(
124                    path,
125                    body,
126                    &schema.components.parameters,
127                    result.warnings.child(path),
128                )
129                .unwrap(),
130            );
131        }
132
133        for (name, param) in &schema.components.parameters {
134            result
135                .parameters
136                .push(Parameter::from_schema(name, param).unwrap());
137        }
138
139        result
140    }
141
142    pub fn codegen_models(&self) -> TokenStream {
143        let mut output = TokenStream::default();
144
145        for model in self.models.values() {
146            output.extend(model.codegen(self));
147        }
148
149        output
150    }
151
152    pub fn codegen_requests(&self) -> TokenStream {
153        let mut output = TokenStream::default();
154
155        for path in self.paths.values() {
156            output.extend(
157                path.codegen_request(self, self.warnings.child(PrettySegments(&path.segments))),
158            );
159        }
160
161        output
162    }
163
164    pub fn codegen_parameters(&self) -> TokenStream {
165        let mut output = TokenStream::default();
166
167        for param in &self.parameters {
168            output.extend(param.codegen(self));
169        }
170
171        output
172    }
173
174    pub fn codegen_scopes(&self) -> TokenStream {
175        let mut output = TokenStream::default();
176
177        let scopes = Scope::from_paths(self.paths.values().cloned().collect());
178
179        for scope in scopes {
180            output.extend(scope.codegen());
181        }
182
183        output
184    }
185}
186
187pub fn resolve(
188    r#type: &OpenApiType,
189    name: &str,
190    schemas: &IndexMap<&str, OpenApiType>,
191    warnings: &WarningReporter,
192) -> Model {
193    match r#type {
194        OpenApiType {
195            r#enum: Some(_), ..
196        } => Enum::from_schema(name, r#type).map_or(Model::Unresolved, Model::Enum),
197        OpenApiType {
198            r#type: Some("object"),
199            ..
200        } => Model::Object(Object::from_schema_object(
201            name,
202            r#type,
203            schemas,
204            warnings.child(name),
205        )),
206        OpenApiType {
207            r#type: Some(_), ..
208        } => Newtype::from_schema(name, r#type).map_or(Model::Unresolved, Model::Newtype),
209        OpenApiType {
210            one_of: Some(types),
211            ..
212        } => Enum::from_one_of(name, types).map_or(Model::Unresolved, Model::Enum),
213        OpenApiType {
214            all_of: Some(types),
215            ..
216        } => Model::Object(Object::from_all_of(
217            name,
218            types,
219            schemas,
220            warnings.child(name),
221        )),
222        _ => Model::Unresolved,
223    }
224}
225
226impl Model {
227    pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
228        match self {
229            Self::Newtype(newtype) => newtype.codegen(),
230            Self::Enum(r#enum) => r#enum.codegen(resolved),
231            Self::Object(object) => object.codegen(resolved),
232            Self::Unresolved => None,
233        }
234    }
235}
236
237#[cfg(test)]
238mod test {
239    use super::*;
240    use crate::openapi::schema::test::get_schema;
241
242    #[test]
243    fn resolve_newtypes() {
244        let schema = get_schema();
245
246        let user_id_schema = schema.components.schemas.get("UserId").unwrap();
247
248        let reporter = WarningReporter::new();
249        let user_id = resolve(
250            user_id_schema,
251            "UserId",
252            &schema.components.schemas,
253            &reporter,
254        );
255        assert!(reporter.is_empty());
256
257        assert_eq!(
258            user_id,
259            Model::Newtype(Newtype {
260                name: "UserId".to_owned(),
261                description: None,
262                inner: newtype::NewtypeInner::I32,
263                copy: true,
264                ord: true
265            })
266        );
267
268        let attack_code_schema = schema.components.schemas.get("AttackCode").unwrap();
269
270        let attack_code = resolve(
271            attack_code_schema,
272            "AttackCode",
273            &schema.components.schemas,
274            &reporter,
275        );
276        assert!(reporter.is_empty());
277
278        assert_eq!(
279            attack_code,
280            Model::Newtype(Newtype {
281                name: "AttackCode".to_owned(),
282                description: None,
283                inner: newtype::NewtypeInner::Str,
284                copy: false,
285                ord: false
286            })
287        );
288    }
289
290    #[test]
291    fn resolve_all() {
292        let schema = get_schema();
293
294        let mut unresolved = vec![];
295        let total = schema.components.schemas.len();
296
297        for (name, desc) in &schema.components.schemas {
298            let reporter = WarningReporter::new();
299            if resolve(desc, name, &schema.components.schemas, &reporter) == Model::Unresolved
300                || !reporter.is_empty()
301            {
302                unresolved.push(name);
303            }
304        }
305
306        if !unresolved.is_empty() {
307            panic!(
308                "Failed to resolve {}/{} types. Could not resolve [{}]",
309                unresolved.len(),
310                total,
311                unresolved
312                    .into_iter()
313                    .map(|u| format!("`{u}`"))
314                    .collect::<Vec<_>>()
315                    .join(", ")
316            )
317        }
318    }
319}