1use crate::gqlgen::{
4 codegen::{
5 decorate_type,
6 shared::{field_rename_annotation, keyword_replace},
7 },
8 deprecation::DeprecationStrategy,
9 query::{
10 fragment_is_recursive, full_path_prefix, BoundQuery, InlineFragment, OperationId,
11 ResolvedFragment, ResolvedFragmentId, SelectedField, Selection, SelectionId,
12 },
13 schema::{Schema, TypeId},
14 type_qualifiers::GraphqlTypeQualifier,
15 GraphQLClientCodegenOptions,
16};
17use heck::*;
18use proc_macro2::{Ident, Span, TokenStream};
19use quote::quote;
20use std::borrow::Cow;
21
22pub(crate) fn render_response_data_fields<'a>(
23 operation_id: OperationId,
24 options: &'a GraphQLClientCodegenOptions,
25 query: &'a BoundQuery<'a>,
26) -> ExpandedSelection<'a> {
27 let operation = query.query.get_operation(operation_id);
28 let mut expanded_selection = ExpandedSelection {
29 query,
30 types: Vec::with_capacity(8),
31 aliases: Vec::new(),
32 variants: Vec::new(),
33 fields: Vec::with_capacity(operation.selection_set.len()),
34 options,
35 };
36
37 let response_data_type_id = expanded_selection.push_type(ExpandedType {
38 name: Cow::Borrowed("InputData"),
39 });
40
41 calculate_selection(
42 &mut expanded_selection,
43 &operation.selection_set,
44 response_data_type_id,
45 TypeId::Object(operation.object_id),
46 options,
47 );
48
49 expanded_selection
50}
51
52pub(super) fn render_fragment<'a>(
53 fragment_id: ResolvedFragmentId,
54 options: &'a GraphQLClientCodegenOptions,
55 query: &'a BoundQuery<'a>,
56) -> ExpandedSelection<'a> {
57 let fragment = query.query.get_fragment(fragment_id);
58 let mut expanded_selection = ExpandedSelection {
59 query,
60 aliases: Vec::new(),
61 types: Vec::with_capacity(8),
62 variants: Vec::new(),
63 fields: Vec::with_capacity(fragment.selection_set.len()),
64 options,
65 };
66
67 let response_type_id = expanded_selection.push_type(ExpandedType {
68 name: fragment.name.as_str().into(),
69 });
70
71 calculate_selection(
72 &mut expanded_selection,
73 &fragment.selection_set,
74 response_type_id,
75 fragment.on,
76 options,
77 );
78
79 expanded_selection
80}
81
82enum VariantSelection<'a> {
84 InlineFragment(&'a InlineFragment),
85 FragmentSpread((ResolvedFragmentId, &'a ResolvedFragment)),
86}
87
88impl<'a> VariantSelection<'a> {
89 fn from_selection(
91 selection: &'a Selection,
92 type_id: TypeId,
93 query: &BoundQuery<'a>,
94 ) -> Option<VariantSelection<'a>> {
95 match selection {
96 Selection::InlineFragment(inline_fragment) => {
97 Some(VariantSelection::InlineFragment(inline_fragment))
98 }
99 Selection::FragmentSpread(fragment_id) => {
100 let fragment = query.query.get_fragment(*fragment_id);
101
102 if fragment.on == type_id {
103 None
105 } else {
106 Some(VariantSelection::FragmentSpread((*fragment_id, fragment)))
108 }
109 }
110 Selection::Field(_) | Selection::Typename => None,
111 }
112 }
113
114 fn variant_type_id(&self) -> TypeId {
115 match self {
116 VariantSelection::InlineFragment(f) => f.type_id,
117 VariantSelection::FragmentSpread((_id, f)) => f.on,
118 }
119 }
120}
121
122fn calculate_selection<'a>(
123 context: &mut ExpandedSelection<'a>,
124 selection_set: &[SelectionId],
125 struct_id: ResponseTypeId,
126 type_id: TypeId,
127 options: &'a GraphQLClientCodegenOptions,
128) {
129 if selection_set.len() == 1 {
132 if let Selection::FragmentSpread(fragment_id) =
133 context.query.query.get_selection(selection_set[0])
134 {
135 let fragment = context.query.query.get_fragment(*fragment_id);
136 context.push_type_alias(TypeAlias {
137 name: &fragment.name,
138 struct_id,
139 boxed: fragment_is_recursive(*fragment_id, context.query.query),
140 });
141 return;
142 }
143 }
144
145 {
147 let variants: Option<Cow<'_, [TypeId]>> = match type_id {
148 TypeId::Interface(interface_id) => {
149 let variants = context
150 .query
151 .schema
152 .objects()
153 .filter(|(_, obj)| obj.implements_interfaces.contains(&interface_id))
154 .map(|(id, _)| TypeId::Object(id));
155
156 Some(variants.collect::<Vec<TypeId>>().into())
157 }
158 TypeId::Union(union_id) => {
159 let union = context.schema().get_union(union_id);
160 Some(union.variants.as_slice().into())
161 }
162 _ => None,
163 };
164
165 if let Some(variants) = variants {
166 let variant_selections: Vec<(SelectionId, &Selection, VariantSelection<'_>)> =
167 selection_set
168 .iter()
169 .map(|id| (id, context.query.query.get_selection(*id)))
170 .filter_map(|(id, selection)| {
171 VariantSelection::from_selection(selection, type_id, context.query)
172 .map(|variant_selection| (*id, selection, variant_selection))
173 })
174 .collect();
175
176 for variant_type_id in variants.as_ref() {
180 let variant_name_str = variant_type_id.name(context.schema());
181
182 let variant_selections: Vec<_> = variant_selections
183 .iter()
184 .filter(|(_id, _selection_ref, variant)| {
185 variant.variant_type_id() == *variant_type_id
186 })
187 .collect();
188
189 if let Some((selection_id, selection, _variant)) = variant_selections.get(0) {
190 let mut variant_struct_name_str =
191 full_path_prefix(*selection_id, context.query);
192 variant_struct_name_str.reserve(2 + variant_name_str.len());
193 variant_struct_name_str.push_str("On");
194 variant_struct_name_str.push_str(variant_name_str);
195
196 context.push_variant(ExpandedVariant {
197 name: variant_name_str.into(),
198 variant_type: Some(variant_struct_name_str.clone().into()),
199 on: struct_id,
200 is_default_variant: false,
201 });
202
203 let expanded_type = ExpandedType {
204 name: variant_struct_name_str.into(),
205 };
206
207 let struct_id = context.push_type(expanded_type);
208
209 if variant_selections.len() == 1 {
210 if let VariantSelection::FragmentSpread((fragment_id, fragment)) =
211 variant_selections[0].2
212 {
213 context.push_type_alias(TypeAlias {
214 boxed: fragment_is_recursive(fragment_id, context.query.query),
215 name: &fragment.name,
216 struct_id,
217 });
218 continue;
219 }
220 }
221
222 for (_selection_id, _selection, variant_selection) in variant_selections {
223 match variant_selection {
224 VariantSelection::InlineFragment(_) => {
225 calculate_selection(
226 context,
227 selection.subselection(),
228 struct_id,
229 *variant_type_id,
230 options,
231 );
232 }
233 VariantSelection::FragmentSpread((fragment_id, fragment)) => context
234 .push_field(ExpandedField {
235 field_type: fragment.name.as_str().into(),
236 field_type_qualifiers: &[GraphqlTypeQualifier::Required],
237 flatten: true,
238 graphql_name: None,
239 rust_name: fragment.name.to_snake_case().into(),
240 struct_id,
241 deprecation: None,
242 boxed: fragment_is_recursive(*fragment_id, context.query.query),
243 }),
244 }
245 }
246 } else {
247 context.push_variant(ExpandedVariant {
248 name: variant_name_str.into(),
249 on: struct_id,
250 variant_type: None,
251 is_default_variant: false,
252 });
253 }
254 }
255
256 if *options.fragments_other_variant() {
257 context.push_variant(ExpandedVariant {
258 name: "Unknown".into(),
259 on: struct_id,
260 variant_type: None,
261 is_default_variant: true,
262 });
263 }
264 }
265 }
266
267 for id in selection_set {
268 let selection = context.query.query.get_selection(*id);
269
270 match selection {
271 Selection::Field(field) => {
272 let (graphql_name, rust_name) = context.field_name(field);
273 let schema_field = field.schema_field(context.schema());
274 let field_type_id = schema_field.r#type.id;
275
276 match field_type_id {
277 TypeId::Enum(enm) => {
278 context.push_field(ExpandedField {
279 graphql_name: Some(graphql_name),
280 rust_name,
281 struct_id,
282 field_type: options
283 .normalization()
284 .field_type(&context.schema().get_enum(enm).name),
285 field_type_qualifiers: &schema_field.r#type.qualifiers,
286 flatten: false,
287 deprecation: schema_field.deprecation(),
288 boxed: false,
289 });
290 }
291 TypeId::Scalar(scalar) => {
292 context.push_field(ExpandedField {
293 field_type: options
294 .normalization()
295 .field_type(context.schema().get_scalar(scalar).name.as_str()),
296 field_type_qualifiers: &field
297 .schema_field(context.schema())
298 .r#type
299 .qualifiers,
300 graphql_name: Some(graphql_name),
301 struct_id,
302 rust_name,
303 flatten: false,
304 deprecation: schema_field.deprecation(),
305 boxed: false,
306 });
307 }
308 TypeId::Object(_) | TypeId::Interface(_) | TypeId::Union(_) => {
309 let struct_name_string = full_path_prefix(*id, context.query);
310
311 context.push_field(ExpandedField {
312 struct_id,
313 graphql_name: Some(graphql_name),
314 rust_name,
315 field_type_qualifiers: &schema_field.r#type.qualifiers,
316 field_type: Cow::Owned(struct_name_string.clone()),
317 flatten: false,
318 boxed: false,
319 deprecation: schema_field.deprecation(),
320 });
321
322 let type_id = context.push_type(ExpandedType {
323 name: Cow::Owned(struct_name_string),
324 });
325
326 calculate_selection(
327 context,
328 selection.subselection(),
329 type_id,
330 field_type_id,
331 options,
332 );
333 }
334 TypeId::Input(_) => unreachable!("field selection on input type"),
335 };
336 }
337 Selection::Typename => (),
338 Selection::InlineFragment(_inline) => (),
339 Selection::FragmentSpread(fragment_id) => {
340 let fragment = context.query.query.get_fragment(*fragment_id);
344
345 if fragment.on != type_id {
350 continue;
351 }
352
353 let original_field_name = fragment.name.to_snake_case();
354 let final_field_name = keyword_replace(original_field_name);
355
356 context.push_field(ExpandedField {
357 field_type: fragment.name.as_str().into(),
358 field_type_qualifiers: &[GraphqlTypeQualifier::Required],
359 graphql_name: None,
360 rust_name: final_field_name,
361 struct_id,
362 flatten: true,
363 deprecation: None,
364 boxed: fragment_is_recursive(*fragment_id, context.query.query),
365 });
366
367 }
370 }
371 }
372}
373
374#[derive(Clone, Copy, PartialEq)]
375struct ResponseTypeId(u32);
376
377struct TypeAlias<'a> {
378 name: &'a str,
379 struct_id: ResponseTypeId,
380 boxed: bool,
381}
382
383struct ExpandedField<'a> {
384 graphql_name: Option<&'a str>,
385 rust_name: Cow<'a, str>,
386 field_type: Cow<'a, str>,
387 field_type_qualifiers: &'a [GraphqlTypeQualifier],
388 struct_id: ResponseTypeId,
389 flatten: bool,
390 deprecation: Option<Option<&'a str>>,
391 boxed: bool,
392}
393
394impl<'a> ExpandedField<'a> {
395 fn render(&self, options: &GraphQLClientCodegenOptions) -> Option<TokenStream> {
396 let ident = Ident::new(&self.rust_name, Span::call_site());
397 let qualified_type = decorate_type(
398 &Ident::new(&self.field_type, Span::call_site()),
399 self.field_type_qualifiers,
400 );
401
402 let qualified_type = if self.boxed {
403 quote!(Box<#qualified_type>)
404 } else {
405 qualified_type
406 };
407
408 let optional_skip_serializing_none = if *options.skip_serializing_none()
409 && self
410 .field_type_qualifiers
411 .get(0)
412 .map(|qualifier| !qualifier.is_required())
413 .unwrap_or(false)
414 {
415 Some(quote!(#[serde(skip_serializing_if = "Option::is_none")]))
416 } else {
417 None
418 };
419
420 let optional_rename = self
421 .graphql_name
422 .as_ref()
423 .map(|graphql_name| field_rename_annotation(graphql_name, &self.rust_name));
424 let optional_flatten = if self.flatten {
425 Some(quote!(#[serde(flatten)]))
426 } else {
427 None
428 };
429
430 let optional_deprecation_annotation =
431 match (self.deprecation, options.deprecation_strategy()) {
432 (None, _) | (Some(_), DeprecationStrategy::Allow) => None,
433 (Some(msg), DeprecationStrategy::Warn) => {
434 let optional_msg = msg.map(|msg| quote!((note = #msg)));
435
436 Some(quote!(#[deprecated#optional_msg]))
437 }
438 (Some(_), DeprecationStrategy::Deny) => return None,
439 };
440
441 let tokens = quote! {
442 #optional_skip_serializing_none
443 #optional_flatten
444 #optional_rename
445 #optional_deprecation_annotation
446 pub #ident: #qualified_type
447 };
448
449 Some(tokens)
450 }
451}
452
453struct ExpandedVariant<'a> {
454 name: Cow<'a, str>,
455 variant_type: Option<Cow<'a, str>>,
456 on: ResponseTypeId,
457 is_default_variant: bool,
458}
459
460impl<'a> ExpandedVariant<'a> {
461 fn render(&self) -> TokenStream {
462 let name_ident = Ident::new(&self.name, Span::call_site());
463 let optional_type_ident = self.variant_type.as_ref().map(|variant_type| {
464 let ident = Ident::new(variant_type, Span::call_site());
465 quote!((#ident))
466 });
467
468 if self.is_default_variant {
469 quote! {
470 #[serde(other)]
471 #name_ident #optional_type_ident
472 }
473 } else {
474 quote!(#name_ident #optional_type_ident)
475 }
476 }
477}
478
479pub(crate) struct ExpandedType<'a> {
480 name: Cow<'a, str>,
481}
482
483pub(crate) struct ExpandedSelection<'a> {
484 query: &'a BoundQuery<'a>,
485 types: Vec<ExpandedType<'a>>,
486 fields: Vec<ExpandedField<'a>>,
487 variants: Vec<ExpandedVariant<'a>>,
488 aliases: Vec<TypeAlias<'a>>,
489 options: &'a GraphQLClientCodegenOptions,
490}
491
492impl<'a> ExpandedSelection<'a> {
493 pub(crate) fn schema(&self) -> &'a Schema {
494 self.query.schema
495 }
496
497 fn push_type(&mut self, tpe: ExpandedType<'a>) -> ResponseTypeId {
498 let id = self.types.len();
499 self.types.push(tpe);
500
501 ResponseTypeId(id as u32)
502 }
503
504 fn push_field(&mut self, field: ExpandedField<'a>) {
505 self.fields.push(field);
506 }
507
508 fn push_type_alias(&mut self, alias: TypeAlias<'a>) {
509 self.aliases.push(alias)
510 }
511
512 fn push_variant(&mut self, variant: ExpandedVariant<'a>) {
513 self.variants.push(variant);
514 }
515
516 pub(crate) fn field_name(&self, field: &'a SelectedField) -> (&'a str, Cow<'a, str>) {
518 let name = field
519 .alias()
520 .unwrap_or_else(|| &field.schema_field(self.query.schema).name);
521 let snake_case_name = name.to_snake_case();
522 let final_name = keyword_replace(snake_case_name);
523
524 (name, final_name)
525 }
526
527 fn types(&self) -> impl Iterator<Item = (ResponseTypeId, &ExpandedType<'_>)> {
528 self.types
529 .iter()
530 .enumerate()
531 .map(|(idx, ty)| (ResponseTypeId(idx as u32), ty))
532 }
533
534 pub fn render(&self, response_derives: &impl quote::ToTokens) -> TokenStream {
535 let mut items = Vec::with_capacity(self.types.len());
536
537 for (type_id, ty) in self.types() {
538 let struct_name = Ident::new(&ty.name, Span::call_site());
539
540 if let Some(alias) = self.aliases.iter().find(|alias| alias.struct_id == type_id) {
542 let fragment_name = Ident::new(alias.name, Span::call_site());
543 let fragment_name = if alias.boxed {
544 quote!(Box<#fragment_name>)
545 } else {
546 quote!(#fragment_name)
547 };
548 let item = quote! {
549 pub type #struct_name = #fragment_name;
550 };
551 items.push(item);
552 continue;
553 }
554
555 let mut fields = self
556 .fields
557 .iter()
558 .filter(|field| field.struct_id == type_id)
559 .filter_map(|field| field.render(self.options))
560 .peekable();
561
562 let on_variants: Vec<TokenStream> = self
563 .variants
564 .iter()
565 .filter(|variant| variant.on == type_id)
566 .map(|variant| variant.render())
567 .collect();
568
569 if fields.peek().is_none() {
572 let item = quote! {
573 #response_derives
574 #[serde(tag = "__typename")]
575 pub enum #struct_name {
576 #(#on_variants),*
577 }
578 };
579 items.push(item);
580 continue;
581 }
582
583 let (on_field, on_enum) = if !on_variants.is_empty() {
584 let enum_name = Ident::new(&format!("{}On", ty.name), Span::call_site());
585
586 let on_field = quote!(#[serde(flatten)] pub on: #enum_name);
587
588 let on_enum = quote!(
589 #response_derives
590 #[serde(tag = "__typename")]
591 pub enum #enum_name {
592 #(#on_variants,)*
593 }
594 );
595
596 (Some(on_field), Some(on_enum))
597 } else {
598 (None, None)
599 };
600
601 let tokens = quote! {
602 #response_derives
603 pub struct #struct_name {
604 #(#fields,)*
605 #on_field
606 }
607
608 #on_enum
609 };
610
611 items.push(tokens);
612 }
613
614 quote!(#(#items)*)
615 }
616}