1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{FnArg, Ident, ImplItemFn, ItemFn, Pat, ReturnType, Signature, Type};
5
6fn returns_result(sig: &Signature) -> bool {
10 let ReturnType::Type(_, ty) = &sig.output else {
11 return false;
12 };
13 matches!(&**ty, Type::Path(tp)
14 if tp.path.segments.last().is_some_and(|seg| seg.ident == "Result"))
15}
16
17fn capture_args(sig: &mut Signature) -> Vec<Ident> {
22 let mut captured = Vec::new();
23 for arg in sig.inputs.iter_mut() {
24 let FnArg::Typed(pat_type) = arg else {
25 continue;
26 };
27 let skip = pat_type.attrs.iter().any(|a| a.path().is_ident("trace"));
28 pat_type.attrs.retain(|a| !a.path().is_ident("trace"));
29 if skip {
30 continue;
31 }
32 if let Pat::Ident(pat_ident) = &*pat_type.pat {
33 captured.push(pat_ident.ident.clone());
34 }
35 }
36 captured
37}
38
39fn expand(kind: TokenStream2, item: TokenStream) -> TokenStream {
46 let item2 = TokenStream2::from(item);
47
48 let parsed = syn::parse2::<ItemFn>(item2.clone())
52 .map(|f| (f.attrs, f.vis, quote!(), f.sig, *f.block))
53 .or_else(|_| {
54 syn::parse2::<ImplItemFn>(item2).map(|m| {
55 let defaultness = m.defaultness;
56 (m.attrs, m.vis, quote!(#defaultness), m.sig, m.block)
57 })
58 });
59 let (attrs, vis, defaultness, mut sig, block) = match parsed {
60 Ok(parts) => parts,
61 Err(err) => return err.to_compile_error().into(),
62 };
63 let captured = capture_args(&mut sig);
64 let returns_result = returns_result(&sig);
65 let name = &sig.ident;
66
67 let input_capture = if captured.is_empty() {
71 quote!()
72 } else {
73 let inserts = captured.iter().map(|id| {
74 let key = id.to_string();
75 quote! {
76 __input.insert(
77 #key.to_string(),
78 trace_weft::serde_json::to_value(&#id)
79 .unwrap_or(trace_weft::serde_json::Value::Null),
80 );
81 }
82 });
83 quote! {
84 if trace_weft::capture_enabled() {
85 let mut __input = trace_weft::serde_json::Map::new();
86 #(#inserts)*
87 _span.input_ref = trace_weft::capture_json(
88 "application/json",
89 trace_weft::serde_json::Value::Object(__input),
90 ).await;
91 }
92 }
93 };
94
95 let output_capture = if returns_result {
97 quote! {
98 if trace_weft::capture_enabled() {
99 if let Ok(__ok) = &result {
100 _span.output_ref = trace_weft::capture_json(
101 "application/json",
102 trace_weft::serde_json::to_value(__ok)
103 .unwrap_or(trace_weft::serde_json::Value::Null),
104 ).await;
105 }
106 }
107 }
108 } else {
109 quote! {
110 if trace_weft::capture_enabled() {
111 _span.output_ref = trace_weft::capture_json(
112 "application/json",
113 trace_weft::serde_json::to_value(&result)
114 .unwrap_or(trace_weft::serde_json::Value::Null),
115 ).await;
116 }
117 }
118 };
119
120 let status_update = if returns_result {
124 quote! {
125 match &result {
126 Ok(_) => { _span.status = trace_weft::SpanStatus::Ok; }
127 Err(__e) => {
128 _span.status = trace_weft::SpanStatus::Error;
129 _span.error_type = Some(std::any::type_name_of_val(__e).to_string());
130 _span.error_message_redacted = Some(trace_weft::redact_text(&format!("{}", __e)).redacted_text);
131 }
132 }
133 }
134 } else {
135 quote! { _span.status = trace_weft::SpanStatus::Ok; }
136 };
137
138 let expanded = quote! {
139 #(#attrs)*
140 #vis #defaultness #sig {
141 let mut _span = trace_weft::SpanRecord {
142 trace_id: trace_weft::TraceId(trace_weft::uuid::Uuid::now_v7()),
143 span_id: trace_weft::SpanId(trace_weft::uuid::Uuid::now_v7()),
144 parent_span_id: None,
145 run_id: trace_weft::RunId(trace_weft::uuid::Uuid::now_v7()),
146 session_id: None,
147 user_id_hash: None,
148 project_id: None,
149 span_kind: trace_weft::TraceWeftSpanKind::#kind,
150 name: stringify!(#name).to_string(),
151 start_time: std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64,
152 end_time: None,
153 status: trace_weft::SpanStatus::InProgress,
154 status_message: None,
155 error_type: None,
156 error_message_redacted: None,
157 attributes: std::collections::HashMap::new(),
158 otel_attributes: std::collections::HashMap::new(),
159 openinference_attributes: std::collections::HashMap::new(),
160 memory_state: None,
161 input_ref: None,
162 output_ref: None,
163 prompt_template_id: None,
164 prompt_version: None,
165 model_provider: None,
166 model_name: None,
167 tool_name: None,
168 tool_schema_hash: None,
169 retrieval_query_hash: None,
170 retrieved_document_refs: vec![],
171 token_usage: None,
172 cost_estimate: None,
173 latency_ms: None,
174 retry_count: None,
175 cache_hit: None,
176 redaction_policy: trace_weft::capture_policy(),
177 schema_version: "1.0".to_string(),
178 };
179
180 if let Some(__parent) = trace_weft::current_span_context() {
181 _span.trace_id = __parent.trace_id;
182 _span.run_id = __parent.run_id;
183 _span.parent_span_id = Some(__parent.span_id);
184 }
185
186 #input_capture
187
188 let __ctx = trace_weft::SpanContext {
189 trace_id: _span.trace_id,
190 run_id: _span.run_id,
191 span_id: _span.span_id,
192 };
193 let result = trace_weft::scope_current(__ctx, async move #block).await;
194
195 _span.end_time = Some(std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64);
196 _span.latency_ms = Some(_span.end_time.unwrap() - _span.start_time);
197 #status_update
198 #output_capture
199 trace_weft::record_span(_span).await;
200
201 result
202 }
203 };
204
205 TokenStream::from(expanded)
206}
207
208#[proc_macro_attribute]
209pub fn agent(_attr: TokenStream, item: TokenStream) -> TokenStream {
210 expand(quote!(Agent), item)
211}
212
213#[proc_macro_attribute]
214pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
215 expand(quote!(Tool), item)
216}
217
218#[proc_macro_attribute]
219pub fn llm_call(_attr: TokenStream, item: TokenStream) -> TokenStream {
220 expand(quote!(LlmCall), item)
221}