test_shisho_policy_sdk/gqlgen/codegen/
selection.rs

1//! Code generation for the selection on an operation or a fragment.
2
3use crate::gqlgen::{
4    codegen::{
5        decorate_type,
6        shared::{field_rename_annotation, keyword_replace},
7    },
8    deprecation::DeprecationStrategy,
9    query::{
10        fragment_is_recursive, full_path_prefix, BoundQuery, InlineFragment, OperationId,
11        ResolvedFragment, ResolvedFragmentId, SelectedField, Selection, SelectionId,
12    },
13    schema::{Schema, TypeId},
14    type_qualifiers::GraphqlTypeQualifier,
15    GraphQLClientCodegenOptions,
16};
17use heck::*;
18use proc_macro2::{Ident, Span, TokenStream};
19use quote::quote;
20use std::borrow::Cow;
21
22pub(crate) fn render_response_data_fields<'a>(
23    operation_id: OperationId,
24    options: &'a GraphQLClientCodegenOptions,
25    query: &'a BoundQuery<'a>,
26) -> ExpandedSelection<'a> {
27    let operation = query.query.get_operation(operation_id);
28    let mut expanded_selection = ExpandedSelection {
29        query,
30        types: Vec::with_capacity(8),
31        aliases: Vec::new(),
32        variants: Vec::new(),
33        fields: Vec::with_capacity(operation.selection_set.len()),
34        options,
35    };
36
37    let response_data_type_id = expanded_selection.push_type(ExpandedType {
38        name: Cow::Borrowed("InputData"),
39    });
40
41    calculate_selection(
42        &mut expanded_selection,
43        &operation.selection_set,
44        response_data_type_id,
45        TypeId::Object(operation.object_id),
46        options,
47    );
48
49    expanded_selection
50}
51
52pub(super) fn render_fragment<'a>(
53    fragment_id: ResolvedFragmentId,
54    options: &'a GraphQLClientCodegenOptions,
55    query: &'a BoundQuery<'a>,
56) -> ExpandedSelection<'a> {
57    let fragment = query.query.get_fragment(fragment_id);
58    let mut expanded_selection = ExpandedSelection {
59        query,
60        aliases: Vec::new(),
61        types: Vec::with_capacity(8),
62        variants: Vec::new(),
63        fields: Vec::with_capacity(fragment.selection_set.len()),
64        options,
65    };
66
67    let response_type_id = expanded_selection.push_type(ExpandedType {
68        name: fragment.name.as_str().into(),
69    });
70
71    calculate_selection(
72        &mut expanded_selection,
73        &fragment.selection_set,
74        response_type_id,
75        fragment.on,
76        options,
77    );
78
79    expanded_selection
80}
81
82/// A sub-selection set (spread) on one of the variants of a union or interface.
83enum VariantSelection<'a> {
84    InlineFragment(&'a InlineFragment),
85    FragmentSpread((ResolvedFragmentId, &'a ResolvedFragment)),
86}
87
88impl<'a> VariantSelection<'a> {
89    /// The second argument is the parent type id, so it can be excluded.
90    fn from_selection(
91        selection: &'a Selection,
92        type_id: TypeId,
93        query: &BoundQuery<'a>,
94    ) -> Option<VariantSelection<'a>> {
95        match selection {
96            Selection::InlineFragment(inline_fragment) => {
97                Some(VariantSelection::InlineFragment(inline_fragment))
98            }
99            Selection::FragmentSpread(fragment_id) => {
100                let fragment = query.query.get_fragment(*fragment_id);
101
102                if fragment.on == type_id {
103                    // The selection is on the type itself.
104                    None
105                } else {
106                    // The selection is on one of the variants of the type.
107                    Some(VariantSelection::FragmentSpread((*fragment_id, fragment)))
108                }
109            }
110            Selection::Field(_) | Selection::Typename => None,
111        }
112    }
113
114    fn variant_type_id(&self) -> TypeId {
115        match self {
116            VariantSelection::InlineFragment(f) => f.type_id,
117            VariantSelection::FragmentSpread((_id, f)) => f.on,
118        }
119    }
120}
121
122fn calculate_selection<'a>(
123    context: &mut ExpandedSelection<'a>,
124    selection_set: &[SelectionId],
125    struct_id: ResponseTypeId,
126    type_id: TypeId,
127    options: &'a GraphQLClientCodegenOptions,
128) {
129    // If the selection only contains a fragment, replace the selection with
130    // that fragment.
131    if selection_set.len() == 1 {
132        if let Selection::FragmentSpread(fragment_id) =
133            context.query.query.get_selection(selection_set[0])
134        {
135            let fragment = context.query.query.get_fragment(*fragment_id);
136            context.push_type_alias(TypeAlias {
137                name: &fragment.name,
138                struct_id,
139                boxed: fragment_is_recursive(*fragment_id, context.query.query),
140            });
141            return;
142        }
143    }
144
145    // If we are on a union or an interface, we need to generate an enum that matches the variants _exhaustively_.
146    {
147        let variants: Option<Cow<'_, [TypeId]>> = match type_id {
148            TypeId::Interface(interface_id) => {
149                let variants = context
150                    .query
151                    .schema
152                    .objects()
153                    .filter(|(_, obj)| obj.implements_interfaces.contains(&interface_id))
154                    .map(|(id, _)| TypeId::Object(id));
155
156                Some(variants.collect::<Vec<TypeId>>().into())
157            }
158            TypeId::Union(union_id) => {
159                let union = context.schema().get_union(union_id);
160                Some(union.variants.as_slice().into())
161            }
162            _ => None,
163        };
164
165        if let Some(variants) = variants {
166            let variant_selections: Vec<(SelectionId, &Selection, VariantSelection<'_>)> =
167                selection_set
168                    .iter()
169                    .map(|id| (id, context.query.query.get_selection(*id)))
170                    .filter_map(|(id, selection)| {
171                        VariantSelection::from_selection(selection, type_id, context.query)
172                            .map(|variant_selection| (*id, selection, variant_selection))
173                    })
174                    .collect();
175
176            // For each variant, get the corresponding fragment spreads and
177            // inline fragments, or default to an empty variant (one with no
178            // associated data).
179            for variant_type_id in variants.as_ref() {
180                let variant_name_str = variant_type_id.name(context.schema());
181
182                let variant_selections: Vec<_> = variant_selections
183                    .iter()
184                    .filter(|(_id, _selection_ref, variant)| {
185                        variant.variant_type_id() == *variant_type_id
186                    })
187                    .collect();
188
189                if let Some((selection_id, selection, _variant)) = variant_selections.get(0) {
190                    let mut variant_struct_name_str =
191                        full_path_prefix(*selection_id, context.query);
192                    variant_struct_name_str.reserve(2 + variant_name_str.len());
193                    variant_struct_name_str.push_str("On");
194                    variant_struct_name_str.push_str(variant_name_str);
195
196                    context.push_variant(ExpandedVariant {
197                        name: variant_name_str.into(),
198                        variant_type: Some(variant_struct_name_str.clone().into()),
199                        on: struct_id,
200                        is_default_variant: false,
201                    });
202
203                    let expanded_type = ExpandedType {
204                        name: variant_struct_name_str.into(),
205                    };
206
207                    let struct_id = context.push_type(expanded_type);
208
209                    if variant_selections.len() == 1 {
210                        if let VariantSelection::FragmentSpread((fragment_id, fragment)) =
211                            variant_selections[0].2
212                        {
213                            context.push_type_alias(TypeAlias {
214                                boxed: fragment_is_recursive(fragment_id, context.query.query),
215                                name: &fragment.name,
216                                struct_id,
217                            });
218                            continue;
219                        }
220                    }
221
222                    for (_selection_id, _selection, variant_selection) in variant_selections {
223                        match variant_selection {
224                            VariantSelection::InlineFragment(_) => {
225                                calculate_selection(
226                                    context,
227                                    selection.subselection(),
228                                    struct_id,
229                                    *variant_type_id,
230                                    options,
231                                );
232                            }
233                            VariantSelection::FragmentSpread((fragment_id, fragment)) => context
234                                .push_field(ExpandedField {
235                                    field_type: fragment.name.as_str().into(),
236                                    field_type_qualifiers: &[GraphqlTypeQualifier::Required],
237                                    flatten: true,
238                                    graphql_name: None,
239                                    rust_name: fragment.name.to_snake_case().into(),
240                                    struct_id,
241                                    deprecation: None,
242                                    boxed: fragment_is_recursive(*fragment_id, context.query.query),
243                                }),
244                        }
245                    }
246                } else {
247                    context.push_variant(ExpandedVariant {
248                        name: variant_name_str.into(),
249                        on: struct_id,
250                        variant_type: None,
251                        is_default_variant: false,
252                    });
253                }
254            }
255
256            if *options.fragments_other_variant() {
257                context.push_variant(ExpandedVariant {
258                    name: "Unknown".into(),
259                    on: struct_id,
260                    variant_type: None,
261                    is_default_variant: true,
262                });
263            }
264        }
265    }
266
267    for id in selection_set {
268        let selection = context.query.query.get_selection(*id);
269
270        match selection {
271            Selection::Field(field) => {
272                let (graphql_name, rust_name) = context.field_name(field);
273                let schema_field = field.schema_field(context.schema());
274                let field_type_id = schema_field.r#type.id;
275
276                match field_type_id {
277                    TypeId::Enum(enm) => {
278                        context.push_field(ExpandedField {
279                            graphql_name: Some(graphql_name),
280                            rust_name,
281                            struct_id,
282                            field_type: options
283                                .normalization()
284                                .field_type(&context.schema().get_enum(enm).name),
285                            field_type_qualifiers: &schema_field.r#type.qualifiers,
286                            flatten: false,
287                            deprecation: schema_field.deprecation(),
288                            boxed: false,
289                        });
290                    }
291                    TypeId::Scalar(scalar) => {
292                        context.push_field(ExpandedField {
293                            field_type: options
294                                .normalization()
295                                .field_type(context.schema().get_scalar(scalar).name.as_str()),
296                            field_type_qualifiers: &field
297                                .schema_field(context.schema())
298                                .r#type
299                                .qualifiers,
300                            graphql_name: Some(graphql_name),
301                            struct_id,
302                            rust_name,
303                            flatten: false,
304                            deprecation: schema_field.deprecation(),
305                            boxed: false,
306                        });
307                    }
308                    TypeId::Object(_) | TypeId::Interface(_) | TypeId::Union(_) => {
309                        let struct_name_string = full_path_prefix(*id, context.query);
310
311                        context.push_field(ExpandedField {
312                            struct_id,
313                            graphql_name: Some(graphql_name),
314                            rust_name,
315                            field_type_qualifiers: &schema_field.r#type.qualifiers,
316                            field_type: Cow::Owned(struct_name_string.clone()),
317                            flatten: false,
318                            boxed: false,
319                            deprecation: schema_field.deprecation(),
320                        });
321
322                        let type_id = context.push_type(ExpandedType {
323                            name: Cow::Owned(struct_name_string),
324                        });
325
326                        calculate_selection(
327                            context,
328                            selection.subselection(),
329                            type_id,
330                            field_type_id,
331                            options,
332                        );
333                    }
334                    TypeId::Input(_) => unreachable!("field selection on input type"),
335                };
336            }
337            Selection::Typename => (),
338            Selection::InlineFragment(_inline) => (),
339            Selection::FragmentSpread(fragment_id) => {
340                // Here we only render fragments that are directly on the type
341                // itself, and not on one of its variants.
342
343                let fragment = context.query.query.get_fragment(*fragment_id);
344
345                // Assuming the query was validated properly, a fragment spread
346                // is either on the field's type itself, or on one of the
347                // variants (union or interfaces). If it's not directly a field
348                // on the struct, it will be handled in the `on` variants.
349                if fragment.on != type_id {
350                    continue;
351                }
352
353                let original_field_name = fragment.name.to_snake_case();
354                let final_field_name = keyword_replace(original_field_name);
355
356                context.push_field(ExpandedField {
357                    field_type: fragment.name.as_str().into(),
358                    field_type_qualifiers: &[GraphqlTypeQualifier::Required],
359                    graphql_name: None,
360                    rust_name: final_field_name,
361                    struct_id,
362                    flatten: true,
363                    deprecation: None,
364                    boxed: fragment_is_recursive(*fragment_id, context.query.query),
365                });
366
367                // We stop here, because the structs for the fragments are generated separately, to
368                // avoid duplication.
369            }
370        }
371    }
372}
373
374#[derive(Clone, Copy, PartialEq)]
375struct ResponseTypeId(u32);
376
377struct TypeAlias<'a> {
378    name: &'a str,
379    struct_id: ResponseTypeId,
380    boxed: bool,
381}
382
383struct ExpandedField<'a> {
384    graphql_name: Option<&'a str>,
385    rust_name: Cow<'a, str>,
386    field_type: Cow<'a, str>,
387    field_type_qualifiers: &'a [GraphqlTypeQualifier],
388    struct_id: ResponseTypeId,
389    flatten: bool,
390    deprecation: Option<Option<&'a str>>,
391    boxed: bool,
392}
393
394impl<'a> ExpandedField<'a> {
395    fn render(&self, options: &GraphQLClientCodegenOptions) -> Option<TokenStream> {
396        let ident = Ident::new(&self.rust_name, Span::call_site());
397        let qualified_type = decorate_type(
398            &Ident::new(&self.field_type, Span::call_site()),
399            self.field_type_qualifiers,
400        );
401
402        let qualified_type = if self.boxed {
403            quote!(Box<#qualified_type>)
404        } else {
405            qualified_type
406        };
407
408        let optional_skip_serializing_none = if *options.skip_serializing_none()
409            && self
410                .field_type_qualifiers
411                .get(0)
412                .map(|qualifier| !qualifier.is_required())
413                .unwrap_or(false)
414        {
415            Some(quote!(#[serde(skip_serializing_if = "Option::is_none")]))
416        } else {
417            None
418        };
419
420        let optional_rename = self
421            .graphql_name
422            .as_ref()
423            .map(|graphql_name| field_rename_annotation(graphql_name, &self.rust_name));
424        let optional_flatten = if self.flatten {
425            Some(quote!(#[serde(flatten)]))
426        } else {
427            None
428        };
429
430        let optional_deprecation_annotation =
431            match (self.deprecation, options.deprecation_strategy()) {
432                (None, _) | (Some(_), DeprecationStrategy::Allow) => None,
433                (Some(msg), DeprecationStrategy::Warn) => {
434                    let optional_msg = msg.map(|msg| quote!((note = #msg)));
435
436                    Some(quote!(#[deprecated#optional_msg]))
437                }
438                (Some(_), DeprecationStrategy::Deny) => return None,
439            };
440
441        let tokens = quote! {
442            #optional_skip_serializing_none
443            #optional_flatten
444            #optional_rename
445            #optional_deprecation_annotation
446            pub #ident: #qualified_type
447        };
448
449        Some(tokens)
450    }
451}
452
453struct ExpandedVariant<'a> {
454    name: Cow<'a, str>,
455    variant_type: Option<Cow<'a, str>>,
456    on: ResponseTypeId,
457    is_default_variant: bool,
458}
459
460impl<'a> ExpandedVariant<'a> {
461    fn render(&self) -> TokenStream {
462        let name_ident = Ident::new(&self.name, Span::call_site());
463        let optional_type_ident = self.variant_type.as_ref().map(|variant_type| {
464            let ident = Ident::new(variant_type, Span::call_site());
465            quote!((#ident))
466        });
467
468        if self.is_default_variant {
469            quote! {
470                    #[serde(other)]
471            #name_ident #optional_type_ident
472                }
473        } else {
474            quote!(#name_ident #optional_type_ident)
475        }
476    }
477}
478
479pub(crate) struct ExpandedType<'a> {
480    name: Cow<'a, str>,
481}
482
483pub(crate) struct ExpandedSelection<'a> {
484    query: &'a BoundQuery<'a>,
485    types: Vec<ExpandedType<'a>>,
486    fields: Vec<ExpandedField<'a>>,
487    variants: Vec<ExpandedVariant<'a>>,
488    aliases: Vec<TypeAlias<'a>>,
489    options: &'a GraphQLClientCodegenOptions,
490}
491
492impl<'a> ExpandedSelection<'a> {
493    pub(crate) fn schema(&self) -> &'a Schema {
494        self.query.schema
495    }
496
497    fn push_type(&mut self, tpe: ExpandedType<'a>) -> ResponseTypeId {
498        let id = self.types.len();
499        self.types.push(tpe);
500
501        ResponseTypeId(id as u32)
502    }
503
504    fn push_field(&mut self, field: ExpandedField<'a>) {
505        self.fields.push(field);
506    }
507
508    fn push_type_alias(&mut self, alias: TypeAlias<'a>) {
509        self.aliases.push(alias)
510    }
511
512    fn push_variant(&mut self, variant: ExpandedVariant<'a>) {
513        self.variants.push(variant);
514    }
515
516    /// Returns a tuple to be interpreted as (graphql_name, rust_name).
517    pub(crate) fn field_name(&self, field: &'a SelectedField) -> (&'a str, Cow<'a, str>) {
518        let name = field
519            .alias()
520            .unwrap_or_else(|| &field.schema_field(self.query.schema).name);
521        let snake_case_name = name.to_snake_case();
522        let final_name = keyword_replace(snake_case_name);
523
524        (name, final_name)
525    }
526
527    fn types(&self) -> impl Iterator<Item = (ResponseTypeId, &ExpandedType<'_>)> {
528        self.types
529            .iter()
530            .enumerate()
531            .map(|(idx, ty)| (ResponseTypeId(idx as u32), ty))
532    }
533
534    pub fn render(&self, response_derives: &impl quote::ToTokens) -> TokenStream {
535        let mut items = Vec::with_capacity(self.types.len());
536
537        for (type_id, ty) in self.types() {
538            let struct_name = Ident::new(&ty.name, Span::call_site());
539
540            // If the type is aliased, stop here.
541            if let Some(alias) = self.aliases.iter().find(|alias| alias.struct_id == type_id) {
542                let fragment_name = Ident::new(alias.name, Span::call_site());
543                let fragment_name = if alias.boxed {
544                    quote!(Box<#fragment_name>)
545                } else {
546                    quote!(#fragment_name)
547                };
548                let item = quote! {
549                    pub type #struct_name = #fragment_name;
550                };
551                items.push(item);
552                continue;
553            }
554
555            let mut fields = self
556                .fields
557                .iter()
558                .filter(|field| field.struct_id == type_id)
559                .filter_map(|field| field.render(self.options))
560                .peekable();
561
562            let on_variants: Vec<TokenStream> = self
563                .variants
564                .iter()
565                .filter(|variant| variant.on == type_id)
566                .map(|variant| variant.render())
567                .collect();
568
569            // If we only have an `on` field, turn the struct into the enum
570            // of the variants.
571            if fields.peek().is_none() {
572                let item = quote! {
573                    #response_derives
574                    #[serde(tag = "__typename")]
575                    pub enum #struct_name {
576                        #(#on_variants),*
577                    }
578                };
579                items.push(item);
580                continue;
581            }
582
583            let (on_field, on_enum) = if !on_variants.is_empty() {
584                let enum_name = Ident::new(&format!("{}On", ty.name), Span::call_site());
585
586                let on_field = quote!(#[serde(flatten)] pub on: #enum_name);
587
588                let on_enum = quote!(
589                    #response_derives
590                    #[serde(tag = "__typename")]
591                    pub enum #enum_name {
592                        #(#on_variants,)*
593                    }
594                );
595
596                (Some(on_field), Some(on_enum))
597            } else {
598                (None, None)
599            };
600
601            let tokens = quote! {
602                #response_derives
603                pub struct #struct_name {
604                    #(#fields,)*
605                    #on_field
606                }
607
608                #on_enum
609            };
610
611            items.push(tokens);
612        }
613
614        quote!(#(#items)*)
615    }
616}