1use std::{collections::HashMap, ops::Deref, sync::Arc};
2
3use graphql_parser::schema::TypeDefinition as ParserTypeDefinition;
4
5use crate::{
6 error::GqlError, CustomDirective, EnumType, GqlDirective, InputObjectType, InterfaceType,
7 ObjectType, UnionType,
8};
9
10use super::{
11 argument::InputValueType, directive::DirectiveDefinition, field::FieldType,
12 introspection::introspection_sdl, scalar::ScalarType, type_definition::TypeDefinition,
13 EnumTypeValue,
14};
15
16pub struct SchemaInner {
17 pub queries: HashMap<String, FieldType>,
18 pub mutations: HashMap<String, FieldType>,
19 pub subscriptions: HashMap<String, FieldType>,
20 pub directives: HashMap<String, DirectiveDefinition>,
21 pub type_definitions: HashMap<String, TypeDefinition>,
22 pub interfaces: HashMap<String, InterfaceType>,
23 pub query_type_name: String,
24 pub mutation_type_name: String,
25 pub subscription_type_name: String,
26 pub custom_directives: HashMap<&'static str, Box<dyn CustomDirective>>,
27}
28
29pub struct Schema(Arc<SchemaInner>);
30
31impl Schema {
32 pub fn new(schema: SchemaInner) -> Self {
33 Schema(Arc::new(schema))
34 }
35}
36
37impl Deref for Schema {
38 type Target = SchemaInner;
39
40 fn deref(&self) -> &Self::Target {
41 &self.0
42 }
43}
44
45pub fn build_schema(
46 schema_documents: &[&str],
47 custom_directives: HashMap<&'static str, Box<dyn CustomDirective>>,
48) -> Result<Schema, GqlError> {
49 let mut queries = HashMap::new();
50 let mut mutations = HashMap::new();
51 let mut subscriptions = HashMap::new();
52 let mut type_definitions = HashMap::new();
53 let mut directives = HashMap::new();
54 let mut extensions = Vec::new();
55 let mut schema_definition = None;
56 let mut interfaces = HashMap::new();
57
58 type_definitions.insert(
59 "String".to_string(),
60 TypeDefinition::Scalar(ScalarType::string_scalar()),
61 );
62 type_definitions.insert(
63 "Int".to_string(),
64 TypeDefinition::Scalar(ScalarType::int_scalar()),
65 );
66 type_definitions.insert(
67 "Float".to_string(),
68 TypeDefinition::Scalar(ScalarType::float_scalar()),
69 );
70 type_definitions.insert(
71 "Boolean".to_string(),
72 TypeDefinition::Scalar(ScalarType::boolean_scalar()),
73 );
74 type_definitions.insert(
75 "ID".to_string(),
76 TypeDefinition::Scalar(ScalarType::id_scalar()),
77 );
78
79 directives.insert("skip".to_string(), DirectiveDefinition::skip_directive());
80 directives.insert(
81 "include".to_string(),
82 DirectiveDefinition::include_directive(),
83 );
84 directives.insert(
85 "deprecated".to_string(),
86 DirectiveDefinition::deprecated_directive(),
87 );
88
89 let mut definitions = schema_documents.to_vec();
90 definitions.push(introspection_sdl());
91
92 for doc in definitions {
93 let parsed_schema =
94 graphql_parser::parse_schema::<String>(doc).expect("failed to parse graphql schema");
95 for node in parsed_schema.definitions {
96 match node {
97 graphql_parser::schema::Definition::SchemaDefinition(schema_def) => {
98 schema_definition = Some(schema_def);
99 }
100 graphql_parser::schema::Definition::TypeDefinition(ty_def) => {
101 let gql_def = TypeDefinition::from_schema_type_def(&ty_def);
102 type_definitions.insert(gql_def.name().to_string(), gql_def);
103
104 if let ParserTypeDefinition::Interface(interface) = &ty_def {
105 interfaces.insert(
106 interface.name.to_string(),
107 InterfaceType::from(interface.clone()),
108 );
109 }
110 }
111 graphql_parser::schema::Definition::TypeExtension(ext) => {
112 extensions.push(ext);
113 }
114 graphql_parser::schema::Definition::DirectiveDefinition(directive) => {
115 let arguments = InputValueType::from_vec_input_value(directive.arguments);
116 let result = DirectiveDefinition {
117 position: directive.position,
118 name: directive.name,
119 description: directive.description,
120 arguments,
121 locations: directive.locations,
122 };
123 directives.insert(result.name.to_string(), result);
124 }
125 }
126 }
127 }
128
129 for ext in extensions {
130 match ext {
131 graphql_parser::schema::TypeExtension::Scalar(scalar_ext) => {
132 let original_name = scalar_ext.name.clone();
133 match type_definitions.get(&original_name) {
134 Some(original_scalar) => {
135 if let TypeDefinition::Scalar(original) = original_scalar {
136 let mut extended_directives = original.directives.clone();
137 let directives =
138 GqlDirective::from_vec_directive(scalar_ext.directives);
139 extended_directives.extend(directives);
140
141 let extended_scalar = ScalarType {
142 position: original.position,
143 description: original.description.clone(),
144 name: original_name.clone(),
145 directives: extended_directives,
146 };
147 type_definitions
148 .insert(original_name, TypeDefinition::Scalar(extended_scalar));
149 }
150 }
151 None => {
152 return Err(GqlError::new(
153 format!("The {} scalar to extend is not found", original_name),
154 None,
155 ))
156 }
157 }
158 }
159 graphql_parser::schema::TypeExtension::Object(obj_ext) => {
160 let original_name = obj_ext.name.clone();
161 match type_definitions.get(&original_name) {
162 Some(original_obj) => {
163 if let TypeDefinition::Object(original) = original_obj {
164 let mut extended_directives = original.directives.clone();
165 let directives = GqlDirective::from_vec_directive(obj_ext.directives);
166 extended_directives.extend(directives);
167
168 let mut extended_fields = original.fields.clone();
169 let fields = FieldType::from_vec_field(obj_ext.fields);
170 extended_fields.extend(fields);
171
172 let mut extended_impl_interfaces =
173 original.implements_interfaces.clone();
174 extended_impl_interfaces.extend(obj_ext.implements_interfaces.clone());
175
176 let extended_obj = ObjectType {
177 position: original.position,
178 description: original.description.clone(),
179 name: original_name.clone(),
180 directives: extended_directives,
181 fields: extended_fields,
182 implements_interfaces: extended_impl_interfaces,
183 };
184 type_definitions.insert(
185 original_name.to_string(),
186 TypeDefinition::Object(extended_obj),
187 );
188 }
189 }
190 None => {
191 return Err(GqlError::new(
192 format!("The {} object to extend is not found", original_name),
193 None,
194 ))
195 }
196 }
197 }
198 graphql_parser::schema::TypeExtension::Interface(inter_ext) => {
199 let original_name = inter_ext.name.clone();
200 match type_definitions.get(&original_name) {
201 Some(original_interface) => {
202 if let TypeDefinition::Interface(original) = original_interface {
203 let mut extended_directives = original.directives.clone();
204 let directives = GqlDirective::from_vec_directive(inter_ext.directives);
205 extended_directives.extend(directives);
206
207 let mut extended_fields = original.fields.clone();
208 let fields = FieldType::from_vec_field(inter_ext.fields);
209 extended_fields.extend(fields);
210
211 let extended_interface = InterfaceType {
212 position: original.position,
213 description: original.description.clone(),
214 name: original_name.clone(),
215 directives: extended_directives,
216 fields: extended_fields,
217 };
218 type_definitions.insert(
219 original_name.to_string(),
220 TypeDefinition::Interface(extended_interface.clone()),
221 );
222 interfaces
223 .insert(original_name.to_string(), extended_interface.clone());
224 }
225 }
226 None => {
227 return Err(GqlError::new(
228 format!("The {} interface to extend is not found", original_name),
229 None,
230 ))
231 }
232 }
233 }
234 graphql_parser::schema::TypeExtension::Union(union_ext) => {
235 let original_name = union_ext.name.clone();
236 match type_definitions.get(&original_name) {
237 Some(original_union) => {
238 if let TypeDefinition::Union(original) = original_union {
239 let mut extended_directives = original.directives.clone();
240 let directives =
241 GqlDirective::from_vec_directive(union_ext.directives.clone());
242 extended_directives.extend(directives);
243
244 let mut extended_types = original.types.clone();
245 extended_types.extend(union_ext.types.clone());
246
247 let extended_union = UnionType {
248 position: original.position,
249 description: original.description.clone(),
250 name: original_name.clone(),
251 directives: extended_directives,
252 types: extended_types,
253 };
254 type_definitions.insert(
255 original_name.to_string(),
256 TypeDefinition::Union(extended_union),
257 );
258 }
259 }
260 None => {
261 return Err(GqlError::new(
262 format!("The {} union to extend is not found", original_name),
263 None,
264 ))
265 }
266 }
267 }
268 graphql_parser::schema::TypeExtension::Enum(enum_ext) => {
269 let original_name = enum_ext.name.clone();
270 match type_definitions.get(&original_name) {
271 Some(original_enum) => {
272 if let TypeDefinition::Enum(original) = original_enum {
273 let mut extended_directives = original.directives.clone();
274 let directives =
275 GqlDirective::from_vec_directive(enum_ext.directives.clone());
276 extended_directives.extend(directives);
277
278 let mut extended_values = original.values.clone();
279 let values: Vec<EnumTypeValue> = enum_ext
280 .values
281 .into_iter()
282 .map(EnumTypeValue::from)
283 .collect();
284 extended_values.extend(values);
285
286 let extended_enum = EnumType {
287 position: original.position,
288 description: original.description.clone(),
289 name: original_name.clone(),
290 directives: extended_directives,
291 values: extended_values,
292 };
293 type_definitions.insert(
294 original_name.to_string(),
295 TypeDefinition::Enum(extended_enum),
296 );
297 }
298 }
299 None => {
300 return Err(GqlError::new(
301 format!("The {} enum to extend is not found", original_name),
302 None,
303 ))
304 }
305 }
306 }
307 graphql_parser::schema::TypeExtension::InputObject(input_ext) => {
308 let original_name = input_ext.name.clone();
309 match type_definitions.get(&original_name) {
310 Some(original_input) => {
311 if let TypeDefinition::InputObject(original) = original_input {
312 let mut extended_directives = original.directives.clone();
313 let directives =
314 GqlDirective::from_vec_directive(input_ext.directives.clone());
315 extended_directives.extend(directives);
316
317 let mut extended_fields = original.fields.clone();
318 let fields = InputValueType::from_vec_input_value(input_ext.fields);
319 extended_fields.extend(fields);
320
321 let extended_input = InputObjectType {
322 position: original.position,
323 description: original.description.clone(),
324 name: original_name.clone(),
325 directives: extended_directives,
326 fields: extended_fields,
327 };
328 type_definitions.insert(
329 original_name.to_string(),
330 TypeDefinition::InputObject(extended_input),
331 );
332 }
333 }
334 None => {
335 return Err(GqlError::new(
336 format!("The {} input object to extend is not found", original_name),
337 None,
338 ))
339 }
340 }
341 }
342 }
343 }
344
345 let mut query_type_name = "Query".to_string();
346 let mut mutation_type_name = "Mutation".to_string();
347 let mut subscription_type_name = "Subscription".to_string();
348
349 if let Some(def) = schema_definition {
350 if let Some(query) = def.query {
351 query_type_name = query;
352 }
353 if let Some(mutation) = def.mutation {
354 mutation_type_name = mutation;
355 }
356 if let Some(subscription) = def.subscription {
357 subscription_type_name = subscription;
358 }
359 }
360
361 match type_definitions.get(&query_type_name) {
362 Some(query_def) => {
363 if let TypeDefinition::Object(def) = query_def {
364 for f in &def.fields {
365 queries.insert(f.name.to_string(), f.clone());
366 }
367 }
368 }
369 None => {
370 return Err(GqlError::new("Query type must be defined", None));
371 }
372 }
373
374 if let Some(TypeDefinition::Object(mutation_def)) = type_definitions.get(&mutation_type_name) {
375 for f in &mutation_def.fields {
376 mutations.insert(f.name.to_string(), f.clone());
377 }
378 }
379
380 if let Some(TypeDefinition::Object(subscription_def)) =
381 type_definitions.get(&subscription_type_name)
382 {
383 for f in &subscription_def.fields {
384 subscriptions.insert(f.name.to_string(), f.clone());
385 }
386 }
387
388 Ok(Schema(Arc::new(SchemaInner {
389 queries,
390 mutations,
391 subscriptions,
392 directives,
393 type_definitions,
394 query_type_name,
395 mutation_type_name,
396 subscription_type_name,
397 interfaces,
398 custom_directives,
399 })))
400}
401
402#[cfg(test)]
403mod tests {
404 use std::fs;
405
406 use super::build_schema;
407
408 #[test]
409 fn it_works() {
410 let contents = fs::read_to_string("tests/schemas/github.graphql");
411 let schema = build_schema(&vec![contents.unwrap().as_str()], Default::default()).unwrap();
412
413 assert!(schema.queries.get("repository").is_some());
414 assert!(schema.type_definitions.get("AddCommentInput").is_some());
415
416 let base = fs::read_to_string("tests/schemas/test_schema.graphql").unwrap();
417 let extend = fs::read_to_string("tests/schemas/extend.graphql").unwrap();
418 let schema =
419 build_schema(&vec![base.as_str(), extend.as_str()], Default::default()).unwrap();
420
421 assert!(schema.queries.get("pets").is_some());
422 assert!(schema.queries.get("authors").is_some());
423 }
424}