1use proc_macro::TokenStream;
10use proc_macro2::TokenStream as TokenStream2;
11use quote::{quote, quote_spanned};
12use syn::spanned::Spanned;
13use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, LitStr, Type};
14
15#[proc_macro_derive(IntoDoc, attributes(zvec))]
43pub fn derive_into_doc(input: TokenStream) -> TokenStream {
44 let input = parse_macro_input!(input as DeriveInput);
45 match expand(input) {
46 Ok(ts) => ts.into(),
47 Err(e) => e.to_compile_error().into(),
48 }
49}
50
51fn expand(input: DeriveInput) -> syn::Result<TokenStream2> {
52 let name = &input.ident;
53 let fields = match &input.data {
54 Data::Struct(DataStruct {
55 fields: Fields::Named(f),
56 ..
57 }) => &f.named,
58 _ => {
59 return Err(syn::Error::new_spanned(
60 &input,
61 "IntoDoc can only be derived for structs with named fields",
62 ));
63 }
64 };
65
66 let mut body = TokenStream2::new();
67 let mut pk_seen = false;
68
69 for field in fields {
70 let attrs = FieldAttrs::from(field)?;
71 if attrs.skip {
72 continue;
73 }
74 let rust_ident = field.ident.as_ref().unwrap();
75 let zvec_name = attrs.rename.unwrap_or_else(|| rust_ident.to_string());
76
77 if attrs.pk {
78 if pk_seen {
79 return Err(syn::Error::new_spanned(
80 field,
81 "duplicate #[zvec(pk)] — only one field may be the primary key",
82 ));
83 }
84 pk_seen = true;
85 body.extend(quote_spanned! { field.span() =>
86 __doc.set_pk(&self.#rust_ident)?;
87 });
88 }
89
90 let setter = emit_setter(field, &attrs.kind, &zvec_name)?;
91 body.extend(setter);
92 }
93
94 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
95
96 Ok(quote! {
97 #[allow(clippy::deref_addrof)]
98 impl #impl_generics ::zvec::IntoDoc for #name #ty_generics #where_clause {
99 fn into_doc(&self) -> ::zvec::Result<::zvec::Doc> {
100 let mut __doc = ::zvec::Doc::new()?;
101 #body
102 Ok(__doc)
103 }
104 }
105 })
106}
107
108#[derive(Default)]
109struct FieldAttrs {
110 pk: bool,
111 skip: bool,
112 rename: Option<String>,
113 kind: TypeHint,
114}
115
116#[derive(Default, Clone, Copy, PartialEq, Eq)]
117enum TypeHint {
118 #[default]
119 Auto,
120 Binary,
121 VectorFp32,
122 VectorFp64,
123 VectorInt8,
124 VectorInt16,
125}
126
127impl FieldAttrs {
128 fn from(field: &Field) -> syn::Result<Self> {
129 let mut out = FieldAttrs::default();
130 for attr in &field.attrs {
131 if !attr.path().is_ident("zvec") {
132 continue;
133 }
134 attr.parse_nested_meta(|meta| {
135 let p = &meta.path;
136 if p.is_ident("pk") {
137 out.pk = true;
138 } else if p.is_ident("skip") {
139 out.skip = true;
140 } else if p.is_ident("rename") {
141 let lit: LitStr = meta.value()?.parse()?;
142 out.rename = Some(lit.value());
143 } else if p.is_ident("binary") {
144 out.kind = TypeHint::Binary;
145 } else if p.is_ident("vector_fp32") {
146 out.kind = TypeHint::VectorFp32;
147 } else if p.is_ident("vector_fp64") {
148 out.kind = TypeHint::VectorFp64;
149 } else if p.is_ident("vector_int8") {
150 out.kind = TypeHint::VectorInt8;
151 } else if p.is_ident("vector_int16") {
152 out.kind = TypeHint::VectorInt16;
153 } else {
154 return Err(meta.error(
155 "unknown zvec attribute; expected one of: \
156 pk, skip, rename, binary, vector_fp32, vector_fp64, \
157 vector_int8, vector_int16",
158 ));
159 }
160 Ok(())
161 })?;
162 }
163 Ok(out)
164 }
165}
166
167fn emit_setter(field: &Field, hint: &TypeHint, name: &str) -> syn::Result<TokenStream2> {
168 let ident = field.ident.as_ref().unwrap();
169 let ty = &field.ty;
170 let name_lit = LitStr::new(name, field.span());
171
172 if let Some(inner) = option_inner(ty) {
175 let inner_ty = inner.clone();
176 let inner_call =
177 scalar_or_hinted_setter(&inner_ty, hint, &name_lit, quote!(__inner), field.span())?;
178 return Ok(quote_spanned! { field.span() =>
179 match &self.#ident {
180 ::core::option::Option::Some(__inner) => { #inner_call },
181 ::core::option::Option::None => { __doc.set_field_null(#name_lit)?; },
182 }
183 });
184 }
185
186 let access = quote_spanned! { field.span() => &self.#ident };
188 scalar_or_hinted_setter(ty, hint, &name_lit, access, field.span())
189}
190
191fn scalar_or_hinted_setter(
192 ty: &Type,
193 hint: &TypeHint,
194 name: &LitStr,
195 access: TokenStream2,
196 span: proc_macro2::Span,
197) -> syn::Result<TokenStream2> {
198 match hint {
199 TypeHint::Binary => {
200 return Ok(quote_spanned! { span =>
201 __doc.add_binary(#name, #access)?;
202 });
203 }
204 TypeHint::VectorFp32 => {
205 return Ok(quote_spanned! { span =>
206 __doc.add_vector_fp32(#name, #access)?;
207 });
208 }
209 TypeHint::VectorFp64 => {
210 return Ok(quote_spanned! { span =>
211 __doc.add_vector_fp64(#name, #access)?;
212 });
213 }
214 TypeHint::VectorInt8 => {
215 return Ok(quote_spanned! { span =>
216 __doc.add_vector_int8(#name, #access)?;
217 });
218 }
219 TypeHint::VectorInt16 => {
220 return Ok(quote_spanned! { span =>
221 __doc.add_vector_int16(#name, #access)?;
222 });
223 }
224 TypeHint::Auto => {}
225 }
226
227 let last_segment = match ty {
229 Type::Path(p) => p.path.segments.last(),
230 _ => None,
231 };
232 let Some(last) = last_segment else {
233 return Err(syn::Error::new(
234 span,
235 "unsupported field type for IntoDoc; add a #[zvec(...)] type hint \
236 (e.g. #[zvec(vector_fp32)] for Vec<f32>)",
237 ));
238 };
239 let name_s = last.ident.to_string();
240 let setter = match name_s.as_str() {
241 "String" => quote!(add_string),
242 "bool" => {
243 return Ok(quote_spanned! { span =>
245 __doc.add_bool(#name, *#access)?;
246 });
247 }
248 "i32" => {
249 return Ok(quote_spanned! { span =>
250 __doc.add_int32(#name, *#access)?;
251 });
252 }
253 "i64" => {
254 return Ok(quote_spanned! { span =>
255 __doc.add_int64(#name, *#access)?;
256 });
257 }
258 "u32" => {
259 return Ok(quote_spanned! { span =>
260 __doc.add_uint32(#name, *#access)?;
261 });
262 }
263 "u64" => {
264 return Ok(quote_spanned! { span =>
265 __doc.add_uint64(#name, *#access)?;
266 });
267 }
268 "f32" => {
269 return Ok(quote_spanned! { span =>
270 __doc.add_float(#name, *#access)?;
271 });
272 }
273 "f64" => {
274 return Ok(quote_spanned! { span =>
275 __doc.add_double(#name, *#access)?;
276 });
277 }
278 _ => {
279 return Err(syn::Error::new(
280 span,
281 format!(
282 "unsupported field type `{name_s}` for IntoDoc; \
283 add a #[zvec(...)] type hint or extend the derive \
284 to cover this type",
285 ),
286 ));
287 }
288 };
289 Ok(quote_spanned! { span =>
290 __doc.#setter(#name, #access)?;
291 })
292}
293
294#[proc_macro_derive(FromDoc, attributes(zvec))]
318pub fn derive_from_doc(input: TokenStream) -> TokenStream {
319 let input = parse_macro_input!(input as DeriveInput);
320 match expand_from_doc(input) {
321 Ok(ts) => ts.into(),
322 Err(e) => e.to_compile_error().into(),
323 }
324}
325
326fn expand_from_doc(input: DeriveInput) -> syn::Result<TokenStream2> {
327 let name = &input.ident;
328 let fields = match &input.data {
329 Data::Struct(DataStruct {
330 fields: Fields::Named(f),
331 ..
332 }) => &f.named,
333 _ => {
334 return Err(syn::Error::new_spanned(
335 &input,
336 "FromDoc can only be derived for structs with named fields",
337 ));
338 }
339 };
340
341 let mut inits = TokenStream2::new();
342
343 for field in fields {
344 let attrs = FieldAttrs::from(field)?;
345 let ident = field.ident.as_ref().unwrap();
346
347 if attrs.skip {
348 inits.extend(quote_spanned! { field.span() =>
349 #ident: ::core::default::Default::default(),
350 });
351 continue;
352 }
353
354 let zvec_name = attrs.rename.unwrap_or_else(|| ident.to_string());
355
356 let expr = if attrs.pk {
357 quote_spanned! { field.span() =>
359 __doc.pk_copy().ok_or_else(|| {
360 ::zvec::ZvecError::with_message(
361 ::zvec::ErrorCode::InvalidArgument,
362 "doc is missing a primary key",
363 )
364 })?
365 }
366 } else {
367 let name_lit = LitStr::new(&zvec_name, field.span());
368 field_reader(field, &attrs.kind, &name_lit)?
369 };
370
371 inits.extend(quote_spanned! { field.span() =>
372 #ident: #expr,
373 });
374 }
375
376 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
377
378 Ok(quote! {
379 impl #impl_generics ::zvec::FromDoc for #name #ty_generics #where_clause {
380 fn from_doc(__doc: ::zvec::DocRef<'_>) -> ::zvec::Result<Self> {
381 Ok(Self {
382 #inits
383 })
384 }
385 }
386 })
387}
388
389fn field_reader(field: &Field, hint: &TypeHint, name: &LitStr) -> syn::Result<TokenStream2> {
390 let ty = &field.ty;
391 let span = field.span();
392
393 if let Some(inner) = option_inner(ty) {
396 if matches_named(inner, "String") && matches!(hint, TypeHint::Auto) {
400 return Ok(quote_spanned! { span =>
401 {
402 if !__doc.has_field(#name) || __doc.is_field_null(#name) {
403 ::core::option::Option::None
404 } else {
405 __doc.get_string(#name)?
406 }
407 }
408 });
409 }
410 let inner_reader = scalar_or_hinted_reader(inner, hint, name, span)?;
411 return Ok(quote_spanned! { span =>
412 {
413 if !__doc.has_field(#name) || __doc.is_field_null(#name) {
414 ::core::option::Option::None
415 } else {
416 ::core::option::Option::Some(#inner_reader)
417 }
418 }
419 });
420 }
421
422 if matches_named(ty, "String") && matches!(hint, TypeHint::Auto) {
425 let err_msg = LitStr::new(&format!("doc is missing field `{}`", name.value()), span);
426 return Ok(quote_spanned! { span =>
427 __doc.get_string(#name)?.ok_or_else(|| {
428 ::zvec::ZvecError::with_message(
429 ::zvec::ErrorCode::InvalidArgument,
430 #err_msg,
431 )
432 })?
433 });
434 }
435
436 scalar_or_hinted_reader(ty, hint, name, span)
437}
438
439fn scalar_or_hinted_reader(
440 ty: &Type,
441 hint: &TypeHint,
442 name: &LitStr,
443 span: proc_macro2::Span,
444) -> syn::Result<TokenStream2> {
445 match hint {
446 TypeHint::Binary => {
447 return Ok(quote_spanned! { span => __doc.get_binary(#name)? });
448 }
449 TypeHint::VectorFp32 => {
450 return Ok(quote_spanned! { span => __doc.get_vector_fp32(#name)? });
451 }
452 TypeHint::VectorFp64 => {
453 return Ok(quote_spanned! { span => __doc.get_vector_fp64(#name)? });
454 }
455 TypeHint::VectorInt8 => {
456 return Ok(quote_spanned! { span => __doc.get_vector_int8(#name)? });
457 }
458 TypeHint::VectorInt16 => {
459 return Ok(quote_spanned! { span => __doc.get_vector_int16(#name)? });
460 }
461 TypeHint::Auto => {}
462 }
463
464 let last = match ty {
465 Type::Path(p) => p.path.segments.last(),
466 _ => None,
467 };
468 let Some(last) = last else {
469 return Err(syn::Error::new(
470 span,
471 "unsupported field type for FromDoc; add a #[zvec(...)] type hint",
472 ));
473 };
474 let tok = match last.ident.to_string().as_str() {
475 "String" => {
476 let err_msg = LitStr::new(&format!("doc is missing field `{}`", name.value()), span);
477 quote!(__doc.get_string(#name)?.ok_or_else(|| {
478 ::zvec::ZvecError::with_message(
479 ::zvec::ErrorCode::InvalidArgument,
480 #err_msg,
481 )
482 })?)
483 }
484 "bool" => quote!(__doc.get_bool(#name)?),
485 "i32" => quote!(__doc.get_int32(#name)?),
486 "i64" => quote!(__doc.get_int64(#name)?),
487 "u32" => quote!(__doc.get_uint32(#name)?),
488 "u64" => quote!(__doc.get_uint64(#name)?),
489 "f32" => quote!(__doc.get_float(#name)?),
490 "f64" => quote!(__doc.get_double(#name)?),
491 other => {
492 return Err(syn::Error::new(
493 span,
494 format!(
495 "unsupported field type `{other}` for FromDoc; \
496 add a #[zvec(...)] type hint or extend the derive",
497 ),
498 ));
499 }
500 };
501 Ok(quote_spanned! { span => #tok })
502}
503
504fn matches_named(ty: &Type, wanted: &str) -> bool {
505 let Type::Path(p) = ty else { return false };
506 let Some(seg) = p.path.segments.last() else {
507 return false;
508 };
509 seg.ident == wanted
510}
511
512fn option_inner(ty: &Type) -> Option<&Type> {
513 let Type::Path(p) = ty else { return None };
514 let seg = p.path.segments.last()?;
515 if seg.ident != "Option" {
516 return None;
517 }
518 let syn::PathArguments::AngleBracketed(args) = &seg.arguments else {
519 return None;
520 };
521 let syn::GenericArgument::Type(inner) = args.args.first()? else {
522 return None;
523 };
524 Some(inner)
525}