torn_api_codegen/model/
mod.rs

1use std::{cell::RefCell, rc::Rc};
2
3use indexmap::IndexMap;
4use newtype::Newtype;
5use object::Object;
6use parameter::Parameter;
7use path::{Path, PrettySegments};
8use proc_macro2::TokenStream;
9use r#enum::Enum;
10use scope::Scope;
11
12use crate::openapi::{r#type::OpenApiType, schema::OpenApiSchema};
13
14pub mod r#enum;
15pub mod newtype;
16pub mod object;
17pub mod parameter;
18pub mod path;
19pub mod scope;
20pub mod union;
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum Model {
24    Newtype(Newtype),
25    Enum(Enum),
26    Object(Object),
27    Unresolved,
28}
29
30impl Model {
31    pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
32        match self {
33            Self::Enum(r#enum) => r#enum.is_display(resolved),
34            Self::Newtype(_) => true,
35            _ => false,
36        }
37    }
38}
39
40#[derive(Default)]
41pub struct ResolvedSchema {
42    pub models: IndexMap<String, Model>,
43    pub paths: IndexMap<String, Path>,
44    pub parameters: Vec<Parameter>,
45
46    pub warnings: WarningReporter,
47}
48
49#[derive(Clone)]
50pub struct Warning {
51    pub message: String,
52    pub path: Vec<String>,
53}
54
55impl std::fmt::Display for Warning {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        write!(f, "in {}: {}", self.path.join("."), self.message)
58    }
59}
60
61#[derive(Default)]
62struct WarningReporterState {
63    warnings: Vec<Warning>,
64    path: Vec<String>,
65}
66
67#[derive(Clone, Default)]
68pub struct WarningReporter {
69    state: Rc<RefCell<WarningReporterState>>,
70}
71
72impl WarningReporter {
73    pub fn new() -> Self {
74        Self::default()
75    }
76
77    fn push(&self, message: impl ToString) {
78        let mut state = self.state.borrow_mut();
79        let path = state.path.iter().map(ToString::to_string).collect();
80        state.warnings.push(Warning {
81            message: message.to_string(),
82            path,
83        });
84    }
85
86    fn child(&self, name: impl ToString) -> WarningReporter {
87        self.state.borrow_mut().path.push(name.to_string());
88
89        Self {
90            state: self.state.clone(),
91        }
92    }
93
94    pub fn is_empty(&self) -> bool {
95        self.state.borrow().warnings.is_empty()
96    }
97
98    pub fn get_warnings(&self) -> Vec<Warning> {
99        self.state.borrow().warnings.clone()
100    }
101}
102
103impl Drop for WarningReporter {
104    fn drop(&mut self) {
105        self.state.borrow_mut().path.pop();
106    }
107}
108
109impl ResolvedSchema {
110    pub fn from_open_api(schema: &OpenApiSchema) -> Self {
111        let mut result = Self::default();
112
113        for (name, r#type) in &schema.components.schemas {
114            result.models.insert(
115                name.to_string(),
116                resolve(r#type, name, &schema.components.schemas, &result.warnings),
117            );
118        }
119
120        for (path, body) in &schema.paths {
121            result.paths.insert(
122                path.to_string(),
123                Path::from_schema(
124                    path,
125                    body,
126                    &schema.components.parameters,
127                    result.warnings.child(path),
128                )
129                .unwrap(),
130            );
131        }
132
133        for (name, param) in &schema.components.parameters {
134            result.parameters.push(
135                Parameter::from_schema(name, param, result.warnings.child(name.to_owned()))
136                    .unwrap(),
137            );
138        }
139
140        result
141    }
142
143    pub fn codegen_models(&self) -> TokenStream {
144        let mut output = TokenStream::default();
145
146        for model in self.models.values() {
147            output.extend(model.codegen(self));
148        }
149
150        output
151    }
152
153    pub fn codegen_requests(&self) -> TokenStream {
154        let mut output = TokenStream::default();
155
156        for path in self.paths.values() {
157            output.extend(
158                path.codegen_request(self, self.warnings.child(PrettySegments(&path.segments))),
159            );
160        }
161
162        output
163    }
164
165    pub fn codegen_parameters(&self) -> TokenStream {
166        let mut output = TokenStream::default();
167
168        for param in &self.parameters {
169            output.extend(param.codegen(self));
170        }
171
172        output
173    }
174
175    pub fn codegen_scopes(&self) -> TokenStream {
176        let mut output = TokenStream::default();
177
178        let scopes = Scope::from_paths(self.paths.values().cloned().collect());
179
180        for scope in scopes {
181            output.extend(scope.codegen());
182        }
183
184        output
185    }
186}
187
188pub fn resolve(
189    r#type: &OpenApiType,
190    name: &str,
191    schemas: &IndexMap<&str, OpenApiType>,
192    warnings: &WarningReporter,
193) -> Model {
194    match r#type {
195        OpenApiType {
196            r#enum: Some(_), ..
197        } => Enum::from_schema(name, r#type, warnings.child(name))
198            .map_or(Model::Unresolved, Model::Enum),
199        OpenApiType {
200            r#type: Some("object"),
201            ..
202        } => Model::Object(Object::from_schema_object(
203            name,
204            r#type,
205            schemas,
206            warnings.child(name),
207        )),
208        OpenApiType {
209            r#type: Some(_), ..
210        } => Newtype::from_schema(name, r#type).map_or(Model::Unresolved, Model::Newtype),
211        OpenApiType {
212            one_of: Some(types),
213            ..
214        } => Enum::from_one_of(name, types, warnings.child(name))
215            .map_or(Model::Unresolved, Model::Enum),
216        OpenApiType {
217            all_of: Some(types),
218            ..
219        } => Model::Object(Object::from_all_of(
220            name,
221            types,
222            schemas,
223            warnings.child(name),
224        )),
225        _ => Model::Unresolved,
226    }
227}
228
229impl Model {
230    pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
231        match self {
232            Self::Newtype(newtype) => newtype.codegen(),
233            Self::Enum(r#enum) => r#enum.codegen(resolved),
234            Self::Object(object) => object.codegen(resolved),
235            Self::Unresolved => None,
236        }
237    }
238}
239
240#[cfg(test)]
241mod test {
242    use super::*;
243    use crate::openapi::schema::test::get_schema;
244
245    #[test]
246    fn resolve_newtypes() {
247        let schema = get_schema();
248
249        let user_id_schema = schema.components.schemas.get("UserId").unwrap();
250
251        let reporter = WarningReporter::new();
252        let user_id = resolve(
253            user_id_schema,
254            "UserId",
255            &schema.components.schemas,
256            &reporter,
257        );
258        assert!(reporter.is_empty());
259
260        assert_eq!(
261            user_id,
262            Model::Newtype(Newtype {
263                name: "UserId".to_owned(),
264                description: None,
265                inner: newtype::NewtypeInner::I32,
266                copy: true,
267                ord: true
268            })
269        );
270
271        let attack_code_schema = schema.components.schemas.get("AttackCode").unwrap();
272
273        let attack_code = resolve(
274            attack_code_schema,
275            "AttackCode",
276            &schema.components.schemas,
277            &reporter,
278        );
279        assert!(reporter.is_empty());
280
281        assert_eq!(
282            attack_code,
283            Model::Newtype(Newtype {
284                name: "AttackCode".to_owned(),
285                description: None,
286                inner: newtype::NewtypeInner::Str,
287                copy: false,
288                ord: false
289            })
290        );
291    }
292
293    #[test]
294    fn resolve_all() {
295        let schema = get_schema();
296
297        let mut unresolved = vec![];
298        let total = schema.components.schemas.len();
299
300        for (name, desc) in &schema.components.schemas {
301            let reporter = WarningReporter::new();
302            if resolve(desc, name, &schema.components.schemas, &reporter) == Model::Unresolved
303                || !reporter.is_empty()
304            {
305                unresolved.push(name);
306            }
307        }
308
309        if !unresolved.is_empty() {
310            panic!(
311                "Failed to resolve {}/{} types. Could not resolve [{}]",
312                unresolved.len(),
313                total,
314                unresolved
315                    .into_iter()
316                    .map(|u| format!("`{u}`"))
317                    .collect::<Vec<_>>()
318                    .join(", ")
319            )
320        }
321    }
322}