rusty_gql/types/
schema.rs

1use std::{collections::HashMap, ops::Deref, sync::Arc};
2
3use graphql_parser::schema::TypeDefinition as ParserTypeDefinition;
4
5use crate::{
6    error::GqlError, CustomDirective, EnumType, GqlDirective, InputObjectType, InterfaceType,
7    ObjectType, UnionType,
8};
9
10use super::{
11    argument::InputValueType, directive::DirectiveDefinition, field::FieldType,
12    introspection::introspection_sdl, scalar::ScalarType, type_definition::TypeDefinition,
13    EnumTypeValue,
14};
15
16pub struct SchemaInner {
17    pub queries: HashMap<String, FieldType>,
18    pub mutations: HashMap<String, FieldType>,
19    pub subscriptions: HashMap<String, FieldType>,
20    pub directives: HashMap<String, DirectiveDefinition>,
21    pub type_definitions: HashMap<String, TypeDefinition>,
22    pub interfaces: HashMap<String, InterfaceType>,
23    pub query_type_name: String,
24    pub mutation_type_name: String,
25    pub subscription_type_name: String,
26    pub custom_directives: HashMap<&'static str, Box<dyn CustomDirective>>,
27}
28
29pub struct Schema(Arc<SchemaInner>);
30
31impl Schema {
32    pub fn new(schema: SchemaInner) -> Self {
33        Schema(Arc::new(schema))
34    }
35}
36
37impl Deref for Schema {
38    type Target = SchemaInner;
39
40    fn deref(&self) -> &Self::Target {
41        &self.0
42    }
43}
44
45pub fn build_schema(
46    schema_documents: &[&str],
47    custom_directives: HashMap<&'static str, Box<dyn CustomDirective>>,
48) -> Result<Schema, GqlError> {
49    let mut queries = HashMap::new();
50    let mut mutations = HashMap::new();
51    let mut subscriptions = HashMap::new();
52    let mut type_definitions = HashMap::new();
53    let mut directives = HashMap::new();
54    let mut extensions = Vec::new();
55    let mut schema_definition = None;
56    let mut interfaces = HashMap::new();
57
58    type_definitions.insert(
59        "String".to_string(),
60        TypeDefinition::Scalar(ScalarType::string_scalar()),
61    );
62    type_definitions.insert(
63        "Int".to_string(),
64        TypeDefinition::Scalar(ScalarType::int_scalar()),
65    );
66    type_definitions.insert(
67        "Float".to_string(),
68        TypeDefinition::Scalar(ScalarType::float_scalar()),
69    );
70    type_definitions.insert(
71        "Boolean".to_string(),
72        TypeDefinition::Scalar(ScalarType::boolean_scalar()),
73    );
74    type_definitions.insert(
75        "ID".to_string(),
76        TypeDefinition::Scalar(ScalarType::id_scalar()),
77    );
78
79    directives.insert("skip".to_string(), DirectiveDefinition::skip_directive());
80    directives.insert(
81        "include".to_string(),
82        DirectiveDefinition::include_directive(),
83    );
84    directives.insert(
85        "deprecated".to_string(),
86        DirectiveDefinition::deprecated_directive(),
87    );
88
89    let mut definitions = schema_documents.to_vec();
90    definitions.push(introspection_sdl());
91
92    for doc in definitions {
93        let parsed_schema =
94            graphql_parser::parse_schema::<String>(doc).expect("failed to parse graphql schema");
95        for node in parsed_schema.definitions {
96            match node {
97                graphql_parser::schema::Definition::SchemaDefinition(schema_def) => {
98                    schema_definition = Some(schema_def);
99                }
100                graphql_parser::schema::Definition::TypeDefinition(ty_def) => {
101                    let gql_def = TypeDefinition::from_schema_type_def(&ty_def);
102                    type_definitions.insert(gql_def.name().to_string(), gql_def);
103
104                    if let ParserTypeDefinition::Interface(interface) = &ty_def {
105                        interfaces.insert(
106                            interface.name.to_string(),
107                            InterfaceType::from(interface.clone()),
108                        );
109                    }
110                }
111                graphql_parser::schema::Definition::TypeExtension(ext) => {
112                    extensions.push(ext);
113                }
114                graphql_parser::schema::Definition::DirectiveDefinition(directive) => {
115                    let arguments = InputValueType::from_vec_input_value(directive.arguments);
116                    let result = DirectiveDefinition {
117                        position: directive.position,
118                        name: directive.name,
119                        description: directive.description,
120                        arguments,
121                        locations: directive.locations,
122                    };
123                    directives.insert(result.name.to_string(), result);
124                }
125            }
126        }
127    }
128
129    for ext in extensions {
130        match ext {
131            graphql_parser::schema::TypeExtension::Scalar(scalar_ext) => {
132                let original_name = scalar_ext.name.clone();
133                match type_definitions.get(&original_name) {
134                    Some(original_scalar) => {
135                        if let TypeDefinition::Scalar(original) = original_scalar {
136                            let mut extended_directives = original.directives.clone();
137                            let directives =
138                                GqlDirective::from_vec_directive(scalar_ext.directives);
139                            extended_directives.extend(directives);
140
141                            let extended_scalar = ScalarType {
142                                position: original.position,
143                                description: original.description.clone(),
144                                name: original_name.clone(),
145                                directives: extended_directives,
146                            };
147                            type_definitions
148                                .insert(original_name, TypeDefinition::Scalar(extended_scalar));
149                        }
150                    }
151                    None => {
152                        return Err(GqlError::new(
153                            format!("The {} scalar to extend is not found", original_name),
154                            None,
155                        ))
156                    }
157                }
158            }
159            graphql_parser::schema::TypeExtension::Object(obj_ext) => {
160                let original_name = obj_ext.name.clone();
161                match type_definitions.get(&original_name) {
162                    Some(original_obj) => {
163                        if let TypeDefinition::Object(original) = original_obj {
164                            let mut extended_directives = original.directives.clone();
165                            let directives = GqlDirective::from_vec_directive(obj_ext.directives);
166                            extended_directives.extend(directives);
167
168                            let mut extended_fields = original.fields.clone();
169                            let fields = FieldType::from_vec_field(obj_ext.fields);
170                            extended_fields.extend(fields);
171
172                            let mut extended_impl_interfaces =
173                                original.implements_interfaces.clone();
174                            extended_impl_interfaces.extend(obj_ext.implements_interfaces.clone());
175
176                            let extended_obj = ObjectType {
177                                position: original.position,
178                                description: original.description.clone(),
179                                name: original_name.clone(),
180                                directives: extended_directives,
181                                fields: extended_fields,
182                                implements_interfaces: extended_impl_interfaces,
183                            };
184                            type_definitions.insert(
185                                original_name.to_string(),
186                                TypeDefinition::Object(extended_obj),
187                            );
188                        }
189                    }
190                    None => {
191                        return Err(GqlError::new(
192                            format!("The {} object to extend is not found", original_name),
193                            None,
194                        ))
195                    }
196                }
197            }
198            graphql_parser::schema::TypeExtension::Interface(inter_ext) => {
199                let original_name = inter_ext.name.clone();
200                match type_definitions.get(&original_name) {
201                    Some(original_interface) => {
202                        if let TypeDefinition::Interface(original) = original_interface {
203                            let mut extended_directives = original.directives.clone();
204                            let directives = GqlDirective::from_vec_directive(inter_ext.directives);
205                            extended_directives.extend(directives);
206
207                            let mut extended_fields = original.fields.clone();
208                            let fields = FieldType::from_vec_field(inter_ext.fields);
209                            extended_fields.extend(fields);
210
211                            let extended_interface = InterfaceType {
212                                position: original.position,
213                                description: original.description.clone(),
214                                name: original_name.clone(),
215                                directives: extended_directives,
216                                fields: extended_fields,
217                            };
218                            type_definitions.insert(
219                                original_name.to_string(),
220                                TypeDefinition::Interface(extended_interface.clone()),
221                            );
222                            interfaces
223                                .insert(original_name.to_string(), extended_interface.clone());
224                        }
225                    }
226                    None => {
227                        return Err(GqlError::new(
228                            format!("The {} interface to extend is not found", original_name),
229                            None,
230                        ))
231                    }
232                }
233            }
234            graphql_parser::schema::TypeExtension::Union(union_ext) => {
235                let original_name = union_ext.name.clone();
236                match type_definitions.get(&original_name) {
237                    Some(original_union) => {
238                        if let TypeDefinition::Union(original) = original_union {
239                            let mut extended_directives = original.directives.clone();
240                            let directives =
241                                GqlDirective::from_vec_directive(union_ext.directives.clone());
242                            extended_directives.extend(directives);
243
244                            let mut extended_types = original.types.clone();
245                            extended_types.extend(union_ext.types.clone());
246
247                            let extended_union = UnionType {
248                                position: original.position,
249                                description: original.description.clone(),
250                                name: original_name.clone(),
251                                directives: extended_directives,
252                                types: extended_types,
253                            };
254                            type_definitions.insert(
255                                original_name.to_string(),
256                                TypeDefinition::Union(extended_union),
257                            );
258                        }
259                    }
260                    None => {
261                        return Err(GqlError::new(
262                            format!("The {} union to extend is not found", original_name),
263                            None,
264                        ))
265                    }
266                }
267            }
268            graphql_parser::schema::TypeExtension::Enum(enum_ext) => {
269                let original_name = enum_ext.name.clone();
270                match type_definitions.get(&original_name) {
271                    Some(original_enum) => {
272                        if let TypeDefinition::Enum(original) = original_enum {
273                            let mut extended_directives = original.directives.clone();
274                            let directives =
275                                GqlDirective::from_vec_directive(enum_ext.directives.clone());
276                            extended_directives.extend(directives);
277
278                            let mut extended_values = original.values.clone();
279                            let values: Vec<EnumTypeValue> = enum_ext
280                                .values
281                                .into_iter()
282                                .map(EnumTypeValue::from)
283                                .collect();
284                            extended_values.extend(values);
285
286                            let extended_enum = EnumType {
287                                position: original.position,
288                                description: original.description.clone(),
289                                name: original_name.clone(),
290                                directives: extended_directives,
291                                values: extended_values,
292                            };
293                            type_definitions.insert(
294                                original_name.to_string(),
295                                TypeDefinition::Enum(extended_enum),
296                            );
297                        }
298                    }
299                    None => {
300                        return Err(GqlError::new(
301                            format!("The {} enum to extend is not found", original_name),
302                            None,
303                        ))
304                    }
305                }
306            }
307            graphql_parser::schema::TypeExtension::InputObject(input_ext) => {
308                let original_name = input_ext.name.clone();
309                match type_definitions.get(&original_name) {
310                    Some(original_input) => {
311                        if let TypeDefinition::InputObject(original) = original_input {
312                            let mut extended_directives = original.directives.clone();
313                            let directives =
314                                GqlDirective::from_vec_directive(input_ext.directives.clone());
315                            extended_directives.extend(directives);
316
317                            let mut extended_fields = original.fields.clone();
318                            let fields = InputValueType::from_vec_input_value(input_ext.fields);
319                            extended_fields.extend(fields);
320
321                            let extended_input = InputObjectType {
322                                position: original.position,
323                                description: original.description.clone(),
324                                name: original_name.clone(),
325                                directives: extended_directives,
326                                fields: extended_fields,
327                            };
328                            type_definitions.insert(
329                                original_name.to_string(),
330                                TypeDefinition::InputObject(extended_input),
331                            );
332                        }
333                    }
334                    None => {
335                        return Err(GqlError::new(
336                            format!("The {} input object to extend is not found", original_name),
337                            None,
338                        ))
339                    }
340                }
341            }
342        }
343    }
344
345    let mut query_type_name = "Query".to_string();
346    let mut mutation_type_name = "Mutation".to_string();
347    let mut subscription_type_name = "Subscription".to_string();
348
349    if let Some(def) = schema_definition {
350        if let Some(query) = def.query {
351            query_type_name = query;
352        }
353        if let Some(mutation) = def.mutation {
354            mutation_type_name = mutation;
355        }
356        if let Some(subscription) = def.subscription {
357            subscription_type_name = subscription;
358        }
359    }
360
361    match type_definitions.get(&query_type_name) {
362        Some(query_def) => {
363            if let TypeDefinition::Object(def) = query_def {
364                for f in &def.fields {
365                    queries.insert(f.name.to_string(), f.clone());
366                }
367            }
368        }
369        None => {
370            return Err(GqlError::new("Query type must be defined", None));
371        }
372    }
373
374    if let Some(TypeDefinition::Object(mutation_def)) = type_definitions.get(&mutation_type_name) {
375        for f in &mutation_def.fields {
376            mutations.insert(f.name.to_string(), f.clone());
377        }
378    }
379
380    if let Some(TypeDefinition::Object(subscription_def)) =
381        type_definitions.get(&subscription_type_name)
382    {
383        for f in &subscription_def.fields {
384            subscriptions.insert(f.name.to_string(), f.clone());
385        }
386    }
387
388    Ok(Schema(Arc::new(SchemaInner {
389        queries,
390        mutations,
391        subscriptions,
392        directives,
393        type_definitions,
394        query_type_name,
395        mutation_type_name,
396        subscription_type_name,
397        interfaces,
398        custom_directives,
399    })))
400}
401
402#[cfg(test)]
403mod tests {
404    use std::fs;
405
406    use super::build_schema;
407
408    #[test]
409    fn it_works() {
410        let contents = fs::read_to_string("tests/schemas/github.graphql");
411        let schema = build_schema(&vec![contents.unwrap().as_str()], Default::default()).unwrap();
412
413        assert!(schema.queries.get("repository").is_some());
414        assert!(schema.type_definitions.get("AddCommentInput").is_some());
415
416        let base = fs::read_to_string("tests/schemas/test_schema.graphql").unwrap();
417        let extend = fs::read_to_string("tests/schemas/extend.graphql").unwrap();
418        let schema =
419            build_schema(&vec![base.as_str(), extend.as_str()], Default::default()).unwrap();
420
421        assert!(schema.queries.get("pets").is_some());
422        assert!(schema.queries.get("authors").is_some());
423    }
424}