Skip to main content

sui_graphql_macros/
lib.rs

1//! Compile-time validated derive macro for deserializing Sui GraphQL responses.
2//!
3//! GraphQL responses are deeply nested JSON. Rather than manually writing
4//! `serde_json::Value` traversal code (with `.get()`, `.as_array()`, null checks, etc.),
5//! the [`Response`] derive macro generates all of that from a declarative field path.
6//! Paths are validated against the Sui GraphQL schema at compile time, so typos and
7//! type mismatches are caught before your code ever runs.
8//!
9//! For a complete client that uses this macro, see
10//! [`sui-graphql`](https://docs.rs/sui-graphql).
11//!
12//! # Quick Start
13//!
14//! ```no_run
15//! use sui_graphql_macros::Response;
16//!
17//! #[derive(Response)]
18//! struct ObjectData {
19//!     #[field(path = "object.address")]
20//!     address: String,
21//!     #[field(path = "object.version")]
22//!     version: u64,
23//! }
24//! fn main() {}
25//! ```
26//!
27//! The macro validates that `object.address` and `object.version` exist in the schema
28//! and that their types match at compile time. It then generates a
29//! `from_value(serde_json::Value) -> Result<Self, String>` method and a `Deserialize`
30//! implementation, so the struct can be used directly with
31//! `serde_json::from_value` or as a response type in GraphQL client calls.
32//!
33//! # Path Syntax
34//!
35//! Paths use dot-separated segments with optional suffixes:
36//!
37//! | Syntax | Meaning | Rust Type |
38//! |--------|---------|-----------|
39//! | `field` | Required field | `T` |
40//! | `field?` | Nullable field | `Option<T>` |
41//! | `field[]` | Required list | `Vec<T>` |
42//! | `field?[]` | Nullable list | `Option<Vec<T>>` |
43//! | `field[]?` | List with nullable elements | `Vec<Option<T>>` |
44//! | `field?[]?` | Nullable list, nullable elements | `Option<Vec<Option<T>>>` |
45//!
46//! Multiple `?` markers between `[]` boundaries share one `Option` wrapper.
47//! Each `?` controls null tolerance at that specific segment.
48//!
49//! The macro enforces that path suffixes match the Rust type at compile time.
50//! For example, `field?` requires `Option<T>`, and `field[]` requires `Vec<T>`.
51//! A mismatch (e.g., `field?` with `String` or `field` with `Option<String>`)
52//! produces a compile error.
53//!
54//! ## Null Handling
55//!
56//! ```no_run
57//! use sui_graphql_macros::Response;
58//!
59//! #[derive(Response)]
60//! struct Example {
61//!     // null at `object` → error, null at `address` → error
62//!     #[field(path = "object.address")]
63//!     strict: String,
64//!
65//!     // null at `object` → Ok(None), null at `address` → Ok(None)
66//!     #[field(path = "object?.address?")]
67//!     flexible: Option<String>,
68//!
69//!     // null at `object` → Ok(None), null at `address` → error
70//!     #[field(path = "object?.address")]
71//!     partial: Option<String>,
72//! }
73//! fn main() {}
74//! ```
75//!
76//! ## Lists
77//!
78//! Use `[]` to mark list fields. The macro validates this matches the schema.
79//!
80//! ```no_run
81//! use sui_graphql_macros::Response;
82//!
83//! #[derive(Response)]
84//! struct CheckpointDigests {
85//!     #[field(path = "checkpoints.nodes[].digest")]
86//!     digests: Vec<String>,
87//!
88//!     // Nullable list with nullable elements
89//!     #[field(path = "checkpoints?.nodes?[]?.digest?")]
90//!     maybe_digests: Option<Vec<Option<String>>>,
91//! }
92//! fn main() {}
93//! ```
94//!
95//! ## Aliases
96//!
97//! Use `alias:field` when your GraphQL query uses aliases. The alias (before `:`) is the
98//! JSON key used for extraction, while the field name (after `:`) is validated against
99//! the schema. The alias itself is not schema-validated since it is user-defined in the
100//! query.
101//!
102//! ```no_run
103//! use sui_graphql_macros::Response;
104//!
105//! #[derive(Response)]
106//! struct EpochCheckpoints {
107//!     // GraphQL alias "firstCp" maps to schema field "checkpoints"
108//!     #[field(path = "epoch.firstCp:checkpoints.nodes[].sequenceNumber")]
109//!     first_checkpoints: Vec<u64>,
110//! }
111//! fn main() {}
112//! ```
113//!
114//! ## Enums (GraphQL Unions)
115//!
116//! Use `#[response(root_type = "UnionType")]` on enums with newtype variants:
117//!
118//! ```ignore
119//! #[derive(Response)]
120//! #[response(root_type = "DynamicFieldValue")]
121//! enum FieldValue {
122//!     #[response(on = "MoveValue")]
123//!     Value(MoveValueData),
124//!     MoveObject(MoveObjectData), // `on` defaults to variant name
125//! }
126//! ```
127//!
128//! The macro dispatches on `__typename` in the JSON response.
129//!
130//! ## Attributes
131//!
132//! | Attribute | Level | Description |
133//! |-----------|-------|-------------|
134//! | `#[response(root_type = "Type")]` | struct/enum | Schema type to validate against (default: `"Query"`) |
135//! | `#[response(schema = "path")]` | struct/enum | Custom schema file (relative to `CARGO_MANIFEST_DIR`) |
136//! | `#[field(path = "...")]` | field | Dot-separated path with optional `?`/`[]`/alias |
137//! | `#[field(skip_schema_validation)]` | field | Skip compile-time schema checks for this field |
138//! | `#[response(on = "TypeName")]` | variant | GraphQL `__typename` to match (default: variant name) |
139
140extern crate proc_macro;
141
142mod path;
143mod schema;
144mod validation;
145
146use darling::FromDeriveInput;
147use darling::FromField;
148use darling::FromVariant;
149use darling::util::SpannedValue;
150use proc_macro::TokenStream;
151use proc_macro2::TokenStream as TokenStream2;
152use quote::quote;
153use syn::DeriveInput;
154use syn::parse_macro_input;
155
156// ---------------------------------------------------------------------------
157// Darling input structures — define the "schema" for macro input.
158// Darling generates parsing code automatically, including error messages.
159// ---------------------------------------------------------------------------
160
161#[derive(Debug, FromDeriveInput)]
162#[darling(attributes(response), supports(struct_named, enum_newtype))]
163struct ResponseInput {
164    ident: syn::Ident,
165    generics: syn::Generics,
166    data: darling::ast::Data<ResponseVariant, ResponseField>,
167    #[darling(default)]
168    schema: Option<String>,
169    #[darling(default)]
170    root_type: Option<SpannedValue<String>>,
171}
172
173/// A struct field (requires `#[field(path = "...")]`).
174#[derive(Debug, FromField)]
175#[darling(attributes(field))]
176struct ResponseField {
177    ident: Option<syn::Ident>,
178    ty: syn::Type,
179    path: SpannedValue<String>,
180    #[darling(default)]
181    skip_schema_validation: bool,
182}
183
184/// The inner type of a newtype enum variant.
185#[derive(Debug, FromField)]
186struct VariantInner {
187    ty: syn::Type,
188}
189
190/// An enum variant mapping to a GraphQL union member.
191#[derive(Debug, FromVariant)]
192#[darling(attributes(response))]
193struct ResponseVariant {
194    ident: syn::Ident,
195    fields: darling::ast::Fields<VariantInner>,
196    /// The GraphQL type name this variant maps to (e.g., `#[response(on = "MoveValue")]`).
197    /// Defaults to the variant ident if not specified.
198    #[darling(default)]
199    on: Option<SpannedValue<String>>,
200}
201
202/// Derive macro for GraphQL response types with nested field extraction.
203///
204/// Use `#[field(path = "...")]` to specify the JSON path to extract each field.
205/// Paths are dot-separated (e.g., `"object.address"` extracts `json["object"]["address"]`).
206///
207/// # Root Type
208///
209/// By default, field paths are validated against the `Query` type. Use
210/// `#[response(root_type = "...")]` to validate against a different type instead.
211///
212/// # Generated Code
213///
214/// The macro generates:
215/// - `from_value(serde_json::Value) -> Result<Self, String>` method
216/// - `Deserialize` implementation that uses `from_value`
217///
218/// # Example
219///
220/// ```ignore
221/// // Query response (default)
222/// #[derive(Response)]
223/// struct ChainInfo {
224///     #[field(path = "chainIdentifier")]
225///     chain_id: String,
226///
227///     #[field(path = "epoch.epochId")]
228///     epoch_id: Option<u64>,
229/// }
230///
231/// // Mutation response
232/// #[derive(Response)]
233/// #[response(root_type = "Mutation")]
234/// struct ExecuteResult {
235///     #[field(path = "executeTransaction.effects.effectsBcs")]
236///     effects_bcs: Option<String>,
237/// }
238/// ```
239#[proc_macro_derive(Response, attributes(response, field))]
240pub fn derive_query_response(input: TokenStream) -> TokenStream {
241    let input = parse_macro_input!(input as DeriveInput);
242
243    match derive_query_response_impl(input) {
244        Ok(tokens) => tokens.into(),
245        Err(err) => err.to_compile_error().into(),
246    }
247}
248
249fn derive_query_response_impl(input: DeriveInput) -> Result<TokenStream2, syn::Error> {
250    let parsed = ResponseInput::from_derive_input(&input)?;
251
252    // Load the GraphQL schema for validation.
253    // If a custom schema path is provided, load it; otherwise use the embedded Sui schema.
254    let loaded_schema = if let Some(path) = &parsed.schema {
255        // Resolve path relative to the crate's directory.
256        // SUI_GRAPHQL_SCHEMA_DIR is used by trybuild tests (which run from a temp directory).
257        let base_dir = std::env::var("SUI_GRAPHQL_SCHEMA_DIR")
258            .or_else(|_| std::env::var("CARGO_MANIFEST_DIR"))
259            .unwrap();
260        let full_path = std::path::Path::new(&base_dir).join(path);
261        let sdl = std::fs::read_to_string(&full_path).map_err(|e| {
262            syn::Error::new(
263                proc_macro2::Span::call_site(),
264                format!(
265                    "Failed to read schema from '{}': {}",
266                    full_path.display(),
267                    e
268                ),
269            )
270        })?;
271        Some(schema::Schema::from_sdl(&sdl)?)
272    } else {
273        None
274    };
275    let schema = if let Some(schema) = &loaded_schema {
276        schema
277    } else {
278        schema::Schema::load()?
279    };
280
281    // Determine root type: use specified root_type or default to "Query"
282    let root_type = parsed
283        .root_type
284        .as_ref()
285        .map(|s| s.as_str())
286        .unwrap_or("Query");
287
288    // Validate that the root type exists in the schema
289    if !schema.has_type(root_type) {
290        use std::fmt::Write;
291
292        let type_names = schema.type_names();
293        let suggestion = validation::find_similar(&type_names, root_type);
294
295        let mut msg = format!("Type '{}' not found in GraphQL schema", root_type);
296        if let Some(suggested) = suggestion {
297            write!(msg, ". Did you mean '{}'?", suggested).unwrap();
298        }
299
300        // We only enter this block if root_type was explicitly specified (and invalid),
301        // since "Query" (the default) always exists in a valid schema.
302        let span = parsed.root_type.as_ref().unwrap().span();
303
304        return Err(syn::Error::new(span, msg));
305    }
306
307    match parsed.data {
308        darling::ast::Data::Struct(ref fields) => {
309            generate_struct_impl(&parsed, &fields.fields, schema, root_type)
310        }
311        darling::ast::Data::Enum(ref variants) => {
312            generate_enum_impl(&parsed, variants, schema, root_type)
313        }
314    }
315}
316
317/// Generate `from_value` and `Deserialize` for a struct.
318fn generate_struct_impl(
319    input: &ResponseInput,
320    fields: &[ResponseField],
321    schema: &schema::Schema,
322    root_type: &str,
323) -> Result<TokenStream2, syn::Error> {
324    let ident = &input.ident;
325    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
326
327    // Generate extraction code for each field
328    let mut field_extractions = Vec::new();
329    let mut field_names = Vec::new();
330
331    for field in fields {
332        let field_ident = field
333            .ident
334            .as_ref()
335            .expect("darling ensures named fields only");
336
337        let spanned_path = &field.path;
338        let parsed_path = path::ParsedPath::parse(spanned_path.as_str())
339            .map_err(|e| syn::Error::new(spanned_path.span(), e.to_string()))?;
340
341        let terminal_type = if !field.skip_schema_validation {
342            Some(validation::validate_path_against_schema(
343                schema,
344                root_type,
345                &parsed_path,
346                spanned_path.span(),
347            )?)
348        } else {
349            None
350        };
351
352        // Skip Vec excess check when schema validation is skipped (user takes full
353        // responsibility) or when the terminal type is an object-like scalar (e.g., JSON)
354        // whose value can be an array.
355        let skip_vec_excess_check = field.skip_schema_validation
356            || terminal_type.is_some_and(validation::is_object_like_scalar);
357        validation::validate_type_matches_path(&parsed_path, &field.ty, skip_vec_excess_check)?;
358
359        // Generate extraction code using the same parsed path
360        let type_structure = validation::analyze_type(&field.ty);
361        let extraction = generate_field_extraction(&parsed_path, &type_structure, field_ident);
362        field_extractions.push(extraction);
363        field_names.push(field_ident);
364    }
365
366    // Generate both `from_value` and `Deserialize` impl:
367    //
368    // - `from_value`: Core extraction logic, parses from serde_json::Value
369    // - `Deserialize`: Allows direct use with serde (e.g., `serde_json::from_str::<MyStruct>(...)`)
370    //   and with the GraphQL client's `query::<T>()` which requires `T: DeserializeOwned`
371    Ok(quote! {
372        impl #impl_generics #ident #ty_generics #where_clause {
373            pub fn from_value(value: serde_json::Value) -> Result<Self, String> {
374                #(#field_extractions)*
375
376                Ok(Self {
377                    #(#field_names),*
378                })
379            }
380        }
381
382        // TODO: Implement efficient deserialization that only extracts the fields we need.
383        impl<'de> serde::Deserialize<'de> for #ident #ty_generics #where_clause {
384            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
385            where
386                D: serde::Deserializer<'de>,
387            {
388                let value = serde_json::Value::deserialize(deserializer)?;
389                Self::from_value(value).map_err(serde::de::Error::custom)
390            }
391        }
392    })
393}
394
395/// Generate `from_value` and `Deserialize` for an enum (GraphQL union).
396///
397/// Each variant wraps a type that implements `from_value`. Dispatches on `__typename`.
398fn generate_enum_impl(
399    input: &ResponseInput,
400    variants: &[ResponseVariant],
401    schema: &schema::Schema,
402    root_type: &str,
403) -> Result<TokenStream2, syn::Error> {
404    let ident = &input.ident;
405    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
406
407    let root_type_span = input
408        .root_type
409        .as_ref()
410        .map(|s| s.span())
411        .unwrap_or_else(|| ident.span());
412
413    if !schema.is_union(root_type) {
414        return Err(syn::Error::new(
415            root_type_span,
416            format!(
417                "'{}' is not a union type. \
418                 Enum Response requires root_type to be a GraphQL union",
419                root_type
420            ),
421        ));
422    }
423
424    let mut match_arms = Vec::new();
425
426    for variant in variants {
427        let variant_ident = &variant.ident;
428
429        // Resolve the GraphQL typename: explicit `on` or variant ident
430        let graphql_typename = variant
431            .on
432            .as_ref()
433            .map(|s| s.as_str().to_string())
434            .unwrap_or_else(|| variant_ident.to_string());
435
436        let span = variant
437            .on
438            .as_ref()
439            .map(|s| s.span())
440            .unwrap_or_else(|| variant_ident.span());
441
442        if let Err(mut err) =
443            validation::validate_union_member(schema, root_type, &graphql_typename, span)
444        {
445            if variant.on.is_none() {
446                err.combine(syn::Error::new(
447                    span,
448                    "hint: use #[response(on = \"...\")] to specify a GraphQL type name different from the variant name",
449                ));
450            }
451            return Err(err);
452        }
453
454        // Newtype variant: delegate to inner type's from_value
455        let inner_ty = &variant.fields.fields[0].ty;
456        match_arms.push(quote! {
457            #graphql_typename => {
458                Ok(Self::#variant_ident(
459                    <#inner_ty>::from_value(value)?
460                ))
461            }
462        });
463    }
464
465    let root_type_str = root_type;
466    let enum_name_str = ident.to_string();
467
468    Ok(quote! {
469        impl #impl_generics #ident #ty_generics #where_clause {
470            pub fn from_value(value: serde_json::Value) -> Result<Self, String> {
471                let typename = value.get("__typename")
472                    .and_then(|v| v.as_str())
473                    .ok_or_else(|| format!(
474                        "union '{}' requires '__typename' in the response to distinguish variants. \
475                         Make sure your query requests '__typename' on this field ({})",
476                        #root_type_str, #enum_name_str
477                    ))?;
478
479                match typename {
480                    #(#match_arms)*
481                    other => Err(format!(
482                        "unknown __typename '{}' for union '{}' ({})",
483                        other, #root_type_str, #enum_name_str
484                    )),
485                }
486            }
487        }
488
489        impl<'de> serde::Deserialize<'de> for #ident #ty_generics #where_clause {
490            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
491            where
492                D: serde::Deserializer<'de>,
493            {
494                let value = serde_json::Value::deserialize(deserializer)?;
495                Self::from_value(value).map_err(serde::de::Error::custom)
496            }
497        }
498    })
499}
500
501/// Generate code to extract a single field from JSON using its path.
502///
503/// Supports multiple path formats:
504/// - Simple: `"object.address"` - navigates to nested field
505/// - Array: `"nodes[].name"` - iterates over array, extracts field from each element
506/// - Nested arrays: `"nodes[].edges[].id"` - nested iteration, returns `Vec<Vec<T>>`
507/// - Aliased: `"alias:field"` - uses alias for JSON extraction, field for validation
508fn generate_field_extraction(
509    path: &path::ParsedPath,
510    type_structure: &validation::TypeStructure,
511    field_ident: &syn::Ident,
512) -> TokenStream2 {
513    let full_path = &path.raw;
514    let inner = generate_from_segments(full_path, &path.segments, type_structure);
515    // The inner expression returns Result<T, String>, so we use ? to unwrap
516    quote! {
517        let #field_ident = {
518            let current = &value;
519            #inner?
520        };
521    }
522}
523
524/// Recursively generate extraction code by traversing path segments.
525///
526/// For JSON extraction, uses the alias if present, otherwise uses the field name.
527/// Returns code that evaluates to `Result<T, String>` (caller adds `?` to unwrap).
528///
529/// ## Example: `"data.nodes[].edges[].id"` with `Option<Vec<Vec<String>>>`
530///
531/// Each `[]` in the path corresponds to one `Vec<_>` wrapper in the type.
532///
533/// For `Option<_>` types, null at the outer level returns `Ok(None)`. This is achieved
534/// by wrapping the extraction in a closure to capture early returns. However, once
535/// inside an array iteration, the element type (`Vec<String>`) is not Optional, so
536/// null values there return errors instead.
537///
538/// ```ignore
539/// (|| {
540///     // "data" (non-list) - missing/null returns None (outer Optional)
541///     let current = current.get("data").unwrap_or(&serde_json::Value::Null);
542///     if current.is_null() { return Ok(None); }
543///
544///     // "nodes[]" (list) - missing/null returns None (outer Optional)
545///     let field_value = current.get("nodes").unwrap_or(&serde_json::Value::Null);
546///     if field_value.is_null() { return Ok(None); }
547///     let array = field_value.as_array().ok_or_else(|| "expected array")?;
548///     array.iter().map(|current| {
549///         // Element type: Vec<String> (not Optional, so null = error)
550///
551///         // "edges[]" (list) - missing/null returns Err
552///         let field_value = current.get("edges").unwrap_or(&serde_json::Value::Null);
553///         if field_value.is_null() { return Err("null at 'edges'"); }
554///         let array = field_value.as_array().ok_or_else(|| "expected array")?;
555///         array.iter().map(|current| {
556///             // Element type: String (not Optional, so null = error)
557///
558///             // "id" (scalar) - missing/null returns Err
559///             let current = current.get("id").unwrap_or(&serde_json::Value::Null);
560///             if current.is_null() { return Err("null at 'id'"); }
561///             serde_json::from_value(current.clone())
562///         }).collect::<Result<Vec<_>, _>>()
563///     }).collect::<Result<Vec<_>, _>>()
564///     .map(Some)  // Wrap in Some for Option
565/// })()
566/// ```
567fn generate_from_segments(
568    full_path: &str,
569    segments: &[path::PathSegment],
570    type_structure: &validation::TypeStructure,
571) -> TokenStream2 {
572    // Step 1: Check if outer type is Optional and unwrap it
573    let (is_optional, inner_type) = match type_structure {
574        validation::TypeStructure::Optional(inner) => (true, inner.as_ref()),
575        other => (false, other),
576    };
577
578    // Step 2: Generate core extraction code
579    let core = generate_from_segments_core(full_path, segments, inner_type);
580
581    // Step 3: Wrap Optional types in a closure so `return Ok(None)` stays local to this field.
582    if is_optional {
583        quote! {
584            (|| {
585                // Handle null elements (from `[]?`) and null top-level values
586                if current.is_null() { return Ok(None) }
587                #core.map(Some)
588            })()
589        }
590    } else {
591        core
592    }
593}
594
595/// Core extraction logic that handles both list and non-list segments.
596///
597/// Each segment determines its own null behavior via `is_nullable`:
598/// - `is_nullable = true` (`?` marker): null → `return Ok(None)`
599/// - `is_nullable = false` (no `?`): null → `return Err(...)`
600fn generate_from_segments_core(
601    full_path: &str,
602    segments: &[path::PathSegment],
603    type_structure: &validation::TypeStructure,
604) -> TokenStream2 {
605    // Base case: no more segments, deserialize the current value
606    let Some((segment, rest)) = segments.split_first() else {
607        return quote! {
608            serde_json::from_value(current.clone())
609                .map_err(|e| format!("failed to deserialize '{}': {}", #full_path, e))
610        };
611    };
612
613    let name = segment.field;
614    // Use alias for JSON extraction if present, otherwise use field name
615    let json_key = segment.json_key();
616
617    // Generate null handling based on this segment's `?` marker
618    let on_null = if segment.is_nullable {
619        quote! { return Ok(None) }
620    } else {
621        quote! {
622            return Err(format!("null value at '{}' in path '{}'", #name, #full_path))
623        }
624    };
625
626    if segment.is_list() {
627        // For list segments, unwrap Vector to get element type
628        let element_type = match type_structure {
629            validation::TypeStructure::Vector(inner) => inner.as_ref(),
630            _ => unreachable!("validated: list segment requires Vec type"),
631        };
632
633        // Each array element is processed independently with its own type structure.
634        // Use generate_from_segments (not _core) to handle element-level Optional.
635        let rest_code = generate_from_segments(full_path, rest, element_type);
636
637        quote! {
638            // Treat missing fields as null (allows Option<T> to deserialize as None)
639            let field_value = current.get(#json_key).unwrap_or(&serde_json::Value::Null);
640            if field_value.is_null() {
641                #on_null
642            }
643            let array = field_value.as_array()
644                .ok_or_else(|| format!("expected array at '{}' in path '{}'", #json_key, #full_path))?;
645            array.iter()
646                .map(|current| { #rest_code })
647                .collect::<Result<Vec<_>, String>>()
648        }
649    } else {
650        // For non-list segments, pass type unchanged to handle nested structures
651        let rest_code = generate_from_segments_core(full_path, rest, type_structure);
652
653        quote! {
654            // Treat missing fields as null (allows Option<T> to deserialize as None)
655            let current = current.get(#json_key).unwrap_or(&serde_json::Value::Null);
656            if current.is_null() {
657                #on_null
658            }
659            #rest_code
660        }
661    }
662}