Skip to main content

synaptic_core/
lib.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::Stream;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use thiserror::Error;
11
12#[cfg(feature = "schemars")]
13pub use schemars;
14
15// ---------------------------------------------------------------------------
16// ContentBlock — multimodal message content
17// ---------------------------------------------------------------------------
18
19/// A block of content within a message, supporting multimodal inputs.
20#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
21#[serde(tag = "type", rename_all = "snake_case")]
22pub enum ContentBlock {
23    Text {
24        text: String,
25    },
26    Image {
27        url: String,
28        #[serde(default, skip_serializing_if = "Option::is_none")]
29        detail: Option<String>,
30    },
31    Audio {
32        url: String,
33    },
34    Video {
35        url: String,
36    },
37    File {
38        url: String,
39        #[serde(default, skip_serializing_if = "Option::is_none")]
40        mime_type: Option<String>,
41    },
42    Data {
43        data: Value,
44    },
45    Reasoning {
46        content: String,
47    },
48}
49
50// ---------------------------------------------------------------------------
51// Message
52// ---------------------------------------------------------------------------
53
54/// Represents a chat message. Tagged enum with System, Human, AI, and Tool variants.
55#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
56#[serde(tag = "role")]
57pub enum Message {
58    #[serde(rename = "system")]
59    System {
60        content: String,
61        #[serde(default, skip_serializing_if = "Option::is_none")]
62        id: Option<String>,
63        #[serde(default, skip_serializing_if = "Option::is_none")]
64        name: Option<String>,
65        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
66        additional_kwargs: HashMap<String, Value>,
67        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
68        response_metadata: HashMap<String, Value>,
69        #[serde(default, skip_serializing_if = "Vec::is_empty")]
70        content_blocks: Vec<ContentBlock>,
71    },
72    #[serde(rename = "human")]
73    Human {
74        content: String,
75        #[serde(default, skip_serializing_if = "Option::is_none")]
76        id: Option<String>,
77        #[serde(default, skip_serializing_if = "Option::is_none")]
78        name: Option<String>,
79        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
80        additional_kwargs: HashMap<String, Value>,
81        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
82        response_metadata: HashMap<String, Value>,
83        #[serde(default, skip_serializing_if = "Vec::is_empty")]
84        content_blocks: Vec<ContentBlock>,
85    },
86    #[serde(rename = "assistant")]
87    AI {
88        content: String,
89        #[serde(default, skip_serializing_if = "Vec::is_empty")]
90        tool_calls: Vec<ToolCall>,
91        #[serde(default, skip_serializing_if = "Option::is_none")]
92        id: Option<String>,
93        #[serde(default, skip_serializing_if = "Option::is_none")]
94        name: Option<String>,
95        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
96        additional_kwargs: HashMap<String, Value>,
97        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
98        response_metadata: HashMap<String, Value>,
99        #[serde(default, skip_serializing_if = "Vec::is_empty")]
100        content_blocks: Vec<ContentBlock>,
101        #[serde(default, skip_serializing_if = "Option::is_none")]
102        usage_metadata: Option<TokenUsage>,
103        #[serde(default, skip_serializing_if = "Vec::is_empty")]
104        invalid_tool_calls: Vec<InvalidToolCall>,
105    },
106    #[serde(rename = "tool")]
107    Tool {
108        content: String,
109        tool_call_id: String,
110        #[serde(default, skip_serializing_if = "Option::is_none")]
111        id: Option<String>,
112        #[serde(default, skip_serializing_if = "Option::is_none")]
113        name: Option<String>,
114        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
115        additional_kwargs: HashMap<String, Value>,
116        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
117        response_metadata: HashMap<String, Value>,
118        #[serde(default, skip_serializing_if = "Vec::is_empty")]
119        content_blocks: Vec<ContentBlock>,
120    },
121    #[serde(rename = "chat")]
122    Chat {
123        custom_role: String,
124        content: String,
125        #[serde(default, skip_serializing_if = "Option::is_none")]
126        id: Option<String>,
127        #[serde(default, skip_serializing_if = "Option::is_none")]
128        name: Option<String>,
129        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
130        additional_kwargs: HashMap<String, Value>,
131        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
132        response_metadata: HashMap<String, Value>,
133        #[serde(default, skip_serializing_if = "Vec::is_empty")]
134        content_blocks: Vec<ContentBlock>,
135    },
136    /// A special message that signals removal of a message by its ID.
137    /// Used in message history management.
138    #[serde(rename = "remove")]
139    Remove {
140        /// ID of the message to remove.
141        id: String,
142    },
143}
144
145/// Helper macro to set a shared field across all Message variants.
146/// Note: Remove variant has no common fields, so it is a no-op.
147macro_rules! set_message_field {
148    ($self:expr, $field:ident, $value:expr) => {
149        match $self {
150            Message::System { $field, .. } => *$field = $value,
151            Message::Human { $field, .. } => *$field = $value,
152            Message::AI { $field, .. } => *$field = $value,
153            Message::Tool { $field, .. } => *$field = $value,
154            Message::Chat { $field, .. } => *$field = $value,
155            Message::Remove { .. } => { /* Remove has no common fields */ }
156        }
157    };
158}
159
160/// Helper macro to get a shared field from all Message variants.
161/// Note: Remove variant panics — callers handle Remove before using this macro.
162macro_rules! get_message_field {
163    ($self:expr, $field:ident) => {
164        match $self {
165            Message::System { $field, .. } => $field,
166            Message::Human { $field, .. } => $field,
167            Message::AI { $field, .. } => $field,
168            Message::Tool { $field, .. } => $field,
169            Message::Chat { $field, .. } => $field,
170            Message::Remove { .. } => unreachable!("get_message_field called on Remove variant"),
171        }
172    };
173}
174
175impl Message {
176    // -- Factory methods -----------------------------------------------------
177
178    pub fn system(content: impl Into<String>) -> Self {
179        Message::System {
180            content: content.into(),
181            id: None,
182            name: None,
183            additional_kwargs: HashMap::new(),
184            response_metadata: HashMap::new(),
185            content_blocks: Vec::new(),
186        }
187    }
188
189    pub fn human(content: impl Into<String>) -> Self {
190        Message::Human {
191            content: content.into(),
192            id: None,
193            name: None,
194            additional_kwargs: HashMap::new(),
195            response_metadata: HashMap::new(),
196            content_blocks: Vec::new(),
197        }
198    }
199
200    pub fn ai(content: impl Into<String>) -> Self {
201        Message::AI {
202            content: content.into(),
203            tool_calls: vec![],
204            id: None,
205            name: None,
206            additional_kwargs: HashMap::new(),
207            response_metadata: HashMap::new(),
208            content_blocks: Vec::new(),
209            usage_metadata: None,
210            invalid_tool_calls: Vec::new(),
211        }
212    }
213
214    pub fn ai_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
215        Message::AI {
216            content: content.into(),
217            tool_calls,
218            id: None,
219            name: None,
220            additional_kwargs: HashMap::new(),
221            response_metadata: HashMap::new(),
222            content_blocks: Vec::new(),
223            usage_metadata: None,
224            invalid_tool_calls: Vec::new(),
225        }
226    }
227
228    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
229        Message::Tool {
230            content: content.into(),
231            tool_call_id: tool_call_id.into(),
232            id: None,
233            name: None,
234            additional_kwargs: HashMap::new(),
235            response_metadata: HashMap::new(),
236            content_blocks: Vec::new(),
237        }
238    }
239
240    pub fn chat(role: impl Into<String>, content: impl Into<String>) -> Self {
241        Message::Chat {
242            custom_role: role.into(),
243            content: content.into(),
244            id: None,
245            name: None,
246            additional_kwargs: HashMap::new(),
247            response_metadata: HashMap::new(),
248            content_blocks: Vec::new(),
249        }
250    }
251
252    /// Create a Remove message that signals removal of a message by its ID.
253    pub fn remove(id: impl Into<String>) -> Self {
254        Message::Remove { id: id.into() }
255    }
256
257    // -- Builder methods -----------------------------------------------------
258
259    pub fn with_id(mut self, value: impl Into<String>) -> Self {
260        set_message_field!(&mut self, id, Some(value.into()));
261        self
262    }
263
264    pub fn with_name(mut self, value: impl Into<String>) -> Self {
265        set_message_field!(&mut self, name, Some(value.into()));
266        self
267    }
268
269    pub fn with_additional_kwarg(mut self, key: impl Into<String>, value: Value) -> Self {
270        match &mut self {
271            Message::System {
272                additional_kwargs, ..
273            }
274            | Message::Human {
275                additional_kwargs, ..
276            }
277            | Message::AI {
278                additional_kwargs, ..
279            }
280            | Message::Tool {
281                additional_kwargs, ..
282            }
283            | Message::Chat {
284                additional_kwargs, ..
285            } => {
286                additional_kwargs.insert(key.into(), value);
287            }
288            Message::Remove { .. } => { /* Remove has no additional_kwargs */ }
289        }
290        self
291    }
292
293    pub fn with_response_metadata_entry(mut self, key: impl Into<String>, value: Value) -> Self {
294        match &mut self {
295            Message::System {
296                response_metadata, ..
297            }
298            | Message::Human {
299                response_metadata, ..
300            }
301            | Message::AI {
302                response_metadata, ..
303            }
304            | Message::Tool {
305                response_metadata, ..
306            }
307            | Message::Chat {
308                response_metadata, ..
309            } => {
310                response_metadata.insert(key.into(), value);
311            }
312            Message::Remove { .. } => { /* Remove has no response_metadata */ }
313        }
314        self
315    }
316
317    pub fn with_content_blocks(mut self, blocks: Vec<ContentBlock>) -> Self {
318        set_message_field!(&mut self, content_blocks, blocks);
319        self
320    }
321
322    pub fn with_usage_metadata(mut self, usage: TokenUsage) -> Self {
323        if let Message::AI { usage_metadata, .. } = &mut self {
324            *usage_metadata = Some(usage);
325        }
326        self
327    }
328
329    // -- Accessor methods ----------------------------------------------------
330
331    pub fn content(&self) -> &str {
332        match self {
333            Message::Remove { .. } => "",
334            other => get_message_field!(other, content),
335        }
336    }
337
338    pub fn role(&self) -> &str {
339        match self {
340            Message::System { .. } => "system",
341            Message::Human { .. } => "human",
342            Message::AI { .. } => "assistant",
343            Message::Tool { .. } => "tool",
344            Message::Chat { custom_role, .. } => custom_role,
345            Message::Remove { .. } => "remove",
346        }
347    }
348
349    pub fn is_system(&self) -> bool {
350        matches!(self, Message::System { .. })
351    }
352
353    pub fn is_human(&self) -> bool {
354        matches!(self, Message::Human { .. })
355    }
356
357    pub fn is_ai(&self) -> bool {
358        matches!(self, Message::AI { .. })
359    }
360
361    pub fn is_tool(&self) -> bool {
362        matches!(self, Message::Tool { .. })
363    }
364
365    pub fn is_chat(&self) -> bool {
366        matches!(self, Message::Chat { .. })
367    }
368
369    pub fn is_remove(&self) -> bool {
370        matches!(self, Message::Remove { .. })
371    }
372
373    pub fn tool_calls(&self) -> &[ToolCall] {
374        match self {
375            Message::AI { tool_calls, .. } => tool_calls,
376            _ => &[],
377        }
378    }
379
380    pub fn tool_call_id(&self) -> Option<&str> {
381        match self {
382            Message::Tool { tool_call_id, .. } => Some(tool_call_id),
383            _ => None,
384        }
385    }
386
387    pub fn id(&self) -> Option<&str> {
388        match self {
389            Message::Remove { id } => Some(id),
390            other => get_message_field!(other, id).as_deref(),
391        }
392    }
393
394    pub fn name(&self) -> Option<&str> {
395        match self {
396            Message::Remove { .. } => None,
397            other => get_message_field!(other, name).as_deref(),
398        }
399    }
400
401    pub fn additional_kwargs(&self) -> &HashMap<String, Value> {
402        match self {
403            Message::System {
404                additional_kwargs, ..
405            }
406            | Message::Human {
407                additional_kwargs, ..
408            }
409            | Message::AI {
410                additional_kwargs, ..
411            }
412            | Message::Tool {
413                additional_kwargs, ..
414            }
415            | Message::Chat {
416                additional_kwargs, ..
417            } => additional_kwargs,
418            Message::Remove { .. } => {
419                static EMPTY: std::sync::OnceLock<HashMap<String, Value>> =
420                    std::sync::OnceLock::new();
421                EMPTY.get_or_init(HashMap::new)
422            }
423        }
424    }
425
426    pub fn response_metadata(&self) -> &HashMap<String, Value> {
427        match self {
428            Message::System {
429                response_metadata, ..
430            }
431            | Message::Human {
432                response_metadata, ..
433            }
434            | Message::AI {
435                response_metadata, ..
436            }
437            | Message::Tool {
438                response_metadata, ..
439            }
440            | Message::Chat {
441                response_metadata, ..
442            } => response_metadata,
443            Message::Remove { .. } => {
444                static EMPTY: std::sync::OnceLock<HashMap<String, Value>> =
445                    std::sync::OnceLock::new();
446                EMPTY.get_or_init(HashMap::new)
447            }
448        }
449    }
450
451    pub fn content_blocks(&self) -> &[ContentBlock] {
452        match self {
453            Message::Remove { .. } => &[],
454            other => get_message_field!(other, content_blocks),
455        }
456    }
457
458    /// Return the remove ID if this is a Remove message.
459    pub fn remove_id(&self) -> Option<&str> {
460        match self {
461            Message::Remove { id } => Some(id),
462            _ => None,
463        }
464    }
465
466    pub fn usage_metadata(&self) -> Option<&TokenUsage> {
467        match self {
468            Message::AI { usage_metadata, .. } => usage_metadata.as_ref(),
469            _ => None,
470        }
471    }
472
473    pub fn invalid_tool_calls(&self) -> &[InvalidToolCall] {
474        match self {
475            Message::AI {
476                invalid_tool_calls, ..
477            } => invalid_tool_calls,
478            _ => &[],
479        }
480    }
481}
482
483// ---------------------------------------------------------------------------
484// Message utility functions
485// ---------------------------------------------------------------------------
486
487/// Filter messages by type, name, or id.
488pub fn filter_messages(
489    messages: &[Message],
490    include_types: Option<&[&str]>,
491    exclude_types: Option<&[&str]>,
492    include_names: Option<&[&str]>,
493    exclude_names: Option<&[&str]>,
494    include_ids: Option<&[&str]>,
495    exclude_ids: Option<&[&str]>,
496) -> Vec<Message> {
497    messages
498        .iter()
499        .filter(|msg| {
500            if let Some(include) = include_types {
501                if !include.contains(&msg.role()) {
502                    return false;
503                }
504            }
505            if let Some(exclude) = exclude_types {
506                if exclude.contains(&msg.role()) {
507                    return false;
508                }
509            }
510            if let Some(include) = include_names {
511                match msg.name() {
512                    Some(name) => {
513                        if !include.contains(&name) {
514                            return false;
515                        }
516                    }
517                    None => return false,
518                }
519            }
520            if let Some(exclude) = exclude_names {
521                if let Some(name) = msg.name() {
522                    if exclude.contains(&name) {
523                        return false;
524                    }
525                }
526            }
527            if let Some(include) = include_ids {
528                match msg.id() {
529                    Some(id) => {
530                        if !include.contains(&id) {
531                            return false;
532                        }
533                    }
534                    None => return false,
535                }
536            }
537            if let Some(exclude) = exclude_ids {
538                if let Some(id) = msg.id() {
539                    if exclude.contains(&id) {
540                        return false;
541                    }
542                }
543            }
544            true
545        })
546        .cloned()
547        .collect()
548}
549
550/// Strategy for trimming messages.
551#[derive(Debug, Clone, Copy, PartialEq, Eq)]
552pub enum TrimStrategy {
553    /// Keep the first messages that fit within the token budget.
554    First,
555    /// Keep the last messages that fit within the token budget.
556    Last,
557}
558
559/// Trim messages to fit within a token budget.
560///
561/// `token_counter` receives a single message and returns its token count.
562/// When `include_system` is true and `strategy` is `Last`, the leading system
563/// message is always preserved.
564pub fn trim_messages(
565    messages: Vec<Message>,
566    max_tokens: usize,
567    token_counter: impl Fn(&Message) -> usize,
568    strategy: TrimStrategy,
569    include_system: bool,
570) -> Vec<Message> {
571    if messages.is_empty() {
572        return messages;
573    }
574
575    match strategy {
576        TrimStrategy::First => {
577            let mut result = Vec::new();
578            let mut total = 0;
579            for msg in messages {
580                let count = token_counter(&msg);
581                if total + count > max_tokens {
582                    break;
583                }
584                total += count;
585                result.push(msg);
586            }
587            result
588        }
589        TrimStrategy::Last => {
590            let (system_msg, rest) = if include_system && messages[0].is_system() {
591                (Some(messages[0].clone()), &messages[1..])
592            } else {
593                (None, messages.as_slice())
594            };
595
596            let system_tokens = system_msg.as_ref().map(&token_counter).unwrap_or(0);
597            let budget = max_tokens.saturating_sub(system_tokens);
598
599            let mut selected = Vec::new();
600            let mut total = 0;
601            for msg in rest.iter().rev() {
602                let count = token_counter(msg);
603                if total + count > budget {
604                    break;
605                }
606                total += count;
607                selected.push(msg.clone());
608            }
609            selected.reverse();
610
611            let mut result = Vec::new();
612            if let Some(sys) = system_msg {
613                result.push(sys);
614            }
615            result.extend(selected);
616            result
617        }
618    }
619}
620
621/// Merge consecutive messages of the same role into a single message.
622pub fn merge_message_runs(messages: Vec<Message>) -> Vec<Message> {
623    if messages.is_empty() {
624        return messages;
625    }
626
627    let mut result: Vec<Message> = Vec::new();
628
629    for msg in messages {
630        let should_merge = result
631            .last()
632            .map(|last| last.role() == msg.role())
633            .unwrap_or(false);
634
635        if should_merge {
636            let last = result.last_mut().unwrap();
637            // Merge content
638            let merged_content = format!("{}\n{}", last.content(), msg.content());
639            match last {
640                Message::System { content, .. } => *content = merged_content,
641                Message::Human { content, .. } => *content = merged_content,
642                Message::AI {
643                    content,
644                    tool_calls,
645                    invalid_tool_calls,
646                    ..
647                } => {
648                    *content = merged_content;
649                    tool_calls.extend(msg.tool_calls().to_vec());
650                    invalid_tool_calls.extend(msg.invalid_tool_calls().to_vec());
651                }
652                Message::Tool { content, .. } => *content = merged_content,
653                Message::Chat { content, .. } => *content = merged_content,
654                Message::Remove { .. } => { /* Remove messages are not merged */ }
655            }
656        } else {
657            result.push(msg);
658        }
659    }
660
661    result
662}
663
664/// Convert messages to a human-readable buffer string.
665pub fn get_buffer_string(messages: &[Message], human_prefix: &str, ai_prefix: &str) -> String {
666    messages
667        .iter()
668        .map(|msg| {
669            let prefix = match msg {
670                Message::System { .. } => "System",
671                Message::Human { .. } => human_prefix,
672                Message::AI { .. } => ai_prefix,
673                Message::Tool { .. } => "Tool",
674                Message::Chat { custom_role, .. } => custom_role.as_str(),
675                Message::Remove { .. } => "Remove",
676            };
677            format!("{prefix}: {}", msg.content())
678        })
679        .collect::<Vec<_>>()
680        .join("\n")
681}
682
683// ---------------------------------------------------------------------------
684// AIMessageChunk
685// ---------------------------------------------------------------------------
686
687/// A streaming chunk from an AI model response. Supports merge via `+`/`+=` operators and conversion to `Message` via `into_message()`.
688#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
689pub struct AIMessageChunk {
690    pub content: String,
691    #[serde(default, skip_serializing_if = "Vec::is_empty")]
692    pub tool_calls: Vec<ToolCall>,
693    #[serde(default, skip_serializing_if = "Option::is_none")]
694    pub usage: Option<TokenUsage>,
695    #[serde(default, skip_serializing_if = "Option::is_none")]
696    pub id: Option<String>,
697    #[serde(default, skip_serializing_if = "Vec::is_empty")]
698    pub tool_call_chunks: Vec<ToolCallChunk>,
699    #[serde(default, skip_serializing_if = "Vec::is_empty")]
700    pub invalid_tool_calls: Vec<InvalidToolCall>,
701}
702
703impl AIMessageChunk {
704    pub fn into_message(self) -> Message {
705        Message::ai_with_tool_calls(self.content, self.tool_calls)
706    }
707}
708
709impl std::ops::Add for AIMessageChunk {
710    type Output = Self;
711
712    fn add(mut self, rhs: Self) -> Self {
713        self += rhs;
714        self
715    }
716}
717
718impl std::ops::AddAssign for AIMessageChunk {
719    fn add_assign(&mut self, rhs: Self) {
720        self.content.push_str(&rhs.content);
721        self.tool_calls.extend(rhs.tool_calls);
722        self.tool_call_chunks.extend(rhs.tool_call_chunks);
723        self.invalid_tool_calls.extend(rhs.invalid_tool_calls);
724        if self.id.is_none() {
725            self.id = rhs.id;
726        }
727        match (&mut self.usage, rhs.usage) {
728            (Some(u), Some(rhs_u)) => {
729                u.input_tokens += rhs_u.input_tokens;
730                u.output_tokens += rhs_u.output_tokens;
731                u.total_tokens += rhs_u.total_tokens;
732            }
733            (None, Some(rhs_u)) => {
734                self.usage = Some(rhs_u);
735            }
736            _ => {}
737        }
738    }
739}
740
741// ---------------------------------------------------------------------------
742// Tool-related types
743// ---------------------------------------------------------------------------
744
745/// Represents a tool invocation requested by an AI model, with an ID, function name, and JSON arguments.
746#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
747pub struct ToolCall {
748    pub id: String,
749    pub name: String,
750    pub arguments: Value,
751}
752
753/// A tool call that failed to parse correctly.
754#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
755pub struct InvalidToolCall {
756    #[serde(default, skip_serializing_if = "Option::is_none")]
757    pub id: Option<String>,
758    #[serde(default, skip_serializing_if = "Option::is_none")]
759    pub name: Option<String>,
760    #[serde(default, skip_serializing_if = "Option::is_none")]
761    pub arguments: Option<String>,
762    pub error: String,
763}
764
765/// A partial tool call chunk received during streaming.
766#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
767pub struct ToolCallChunk {
768    #[serde(default, skip_serializing_if = "Option::is_none")]
769    pub id: Option<String>,
770    #[serde(default, skip_serializing_if = "Option::is_none")]
771    pub name: Option<String>,
772    #[serde(default, skip_serializing_if = "Option::is_none")]
773    pub arguments: Option<String>,
774    #[serde(default, skip_serializing_if = "Option::is_none")]
775    pub index: Option<usize>,
776}
777
778/// Schema definition for a tool, including its name, description, and JSON Schema for parameters.
779#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
780pub struct ToolDefinition {
781    pub name: String,
782    pub description: String,
783    pub parameters: Value,
784    /// Provider-specific parameters (e.g., Anthropic's `cache_control`).
785    #[serde(default, skip_serializing_if = "Option::is_none")]
786    pub extras: Option<HashMap<String, Value>>,
787}
788
789/// Controls how the model selects tools: Auto, Required, None, or a Specific named tool.
790#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
791#[serde(rename_all = "lowercase")]
792pub enum ToolChoice {
793    Auto,
794    Required,
795    None,
796    Specific(String),
797}
798
799// ---------------------------------------------------------------------------
800// Chat request / response
801// ---------------------------------------------------------------------------
802
803/// A request to a chat model containing messages, optional tool definitions, and tool choice configuration.
804#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
805pub struct ChatRequest {
806    pub messages: Vec<Message>,
807    #[serde(default, skip_serializing_if = "Vec::is_empty")]
808    pub tools: Vec<ToolDefinition>,
809    #[serde(default, skip_serializing_if = "Option::is_none")]
810    pub tool_choice: Option<ToolChoice>,
811}
812
813impl ChatRequest {
814    pub fn new(messages: Vec<Message>) -> Self {
815        Self {
816            messages,
817            tools: vec![],
818            tool_choice: None,
819        }
820    }
821
822    pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
823        self.tools = tools;
824        self
825    }
826
827    pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
828        self.tool_choice = Some(choice);
829        self
830    }
831}
832
833/// A response from a chat model containing the AI message and optional token usage statistics.
834#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
835pub struct ChatResponse {
836    pub message: Message,
837    pub usage: Option<TokenUsage>,
838}
839
840// ---------------------------------------------------------------------------
841// Token usage
842// ---------------------------------------------------------------------------
843
844#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
845pub struct TokenUsage {
846    pub input_tokens: u32,
847    pub output_tokens: u32,
848    pub total_tokens: u32,
849    #[serde(default, skip_serializing_if = "Option::is_none")]
850    pub input_details: Option<InputTokenDetails>,
851    #[serde(default, skip_serializing_if = "Option::is_none")]
852    pub output_details: Option<OutputTokenDetails>,
853}
854
855/// Detailed breakdown of input token usage.
856#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
857pub struct InputTokenDetails {
858    #[serde(default)]
859    pub cached: u32,
860    #[serde(default)]
861    pub audio: u32,
862}
863
864/// Detailed breakdown of output token usage.
865#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
866pub struct OutputTokenDetails {
867    #[serde(default)]
868    pub reasoning: u32,
869    #[serde(default)]
870    pub audio: u32,
871}
872
873// ---------------------------------------------------------------------------
874// Events
875// ---------------------------------------------------------------------------
876
877/// Lifecycle events emitted during agent execution, used by `CallbackHandler` implementations.
878#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
879pub enum RunEvent {
880    RunStarted {
881        run_id: String,
882        session_id: String,
883    },
884    RunStep {
885        run_id: String,
886        step: usize,
887    },
888    LlmCalled {
889        run_id: String,
890        message_count: usize,
891    },
892    ToolCalled {
893        run_id: String,
894        tool_name: String,
895    },
896    RunFinished {
897        run_id: String,
898        output: String,
899    },
900    RunFailed {
901        run_id: String,
902        error: String,
903    },
904}
905
906// ---------------------------------------------------------------------------
907// Errors
908// ---------------------------------------------------------------------------
909
910/// Unified error type for the Synaptic framework with variants covering all subsystems.
911#[derive(Debug, Error)]
912pub enum SynapticError {
913    #[error("prompt error: {0}")]
914    Prompt(String),
915    #[error("model error: {0}")]
916    Model(String),
917    #[error("tool error: {0}")]
918    Tool(String),
919    #[error("tool not found: {0}")]
920    ToolNotFound(String),
921    #[error("memory error: {0}")]
922    Memory(String),
923    #[error("rate limit: {0}")]
924    RateLimit(String),
925    #[error("timeout: {0}")]
926    Timeout(String),
927    #[error("validation error: {0}")]
928    Validation(String),
929    #[error("parsing error: {0}")]
930    Parsing(String),
931    #[error("callback error: {0}")]
932    Callback(String),
933    #[error("max steps exceeded: {max_steps}")]
934    MaxStepsExceeded { max_steps: usize },
935    #[error("embedding error: {0}")]
936    Embedding(String),
937    #[error("vector store error: {0}")]
938    VectorStore(String),
939    #[error("retriever error: {0}")]
940    Retriever(String),
941    #[error("loader error: {0}")]
942    Loader(String),
943    #[error("splitter error: {0}")]
944    Splitter(String),
945    #[error("graph error: {0}")]
946    Graph(String),
947    #[error("cache error: {0}")]
948    Cache(String),
949    #[error("config error: {0}")]
950    Config(String),
951    #[error("mcp error: {0}")]
952    Mcp(String),
953}
954
955// ---------------------------------------------------------------------------
956// Core traits
957// ---------------------------------------------------------------------------
958
959/// Type alias for a pinned, boxed async stream of `AIMessageChunk` results.
960pub type ChatStream<'a> =
961    Pin<Box<dyn Stream<Item = Result<AIMessageChunk, SynapticError>> + Send + 'a>>;
962
963/// Describes a model's capabilities and limits.
964#[derive(Debug, Clone, Serialize, Deserialize)]
965pub struct ModelProfile {
966    pub name: String,
967    pub provider: String,
968    pub supports_tool_calling: bool,
969    pub supports_structured_output: bool,
970    pub supports_streaming: bool,
971    pub max_input_tokens: Option<usize>,
972    pub max_output_tokens: Option<usize>,
973}
974
975/// The core trait for language model providers. Implementations provide `chat()` for single responses and optionally `stream_chat()` for streaming.
976#[async_trait]
977pub trait ChatModel: Send + Sync {
978    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError>;
979
980    /// Return the model's capability profile, if known.
981    fn profile(&self) -> Option<ModelProfile> {
982        None
983    }
984
985    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
986        Box::pin(async_stream::stream! {
987            match self.chat(request).await {
988                Ok(response) => {
989                    yield Ok(AIMessageChunk {
990                        content: response.message.content().to_string(),
991                        tool_calls: response.message.tool_calls().to_vec(),
992                        usage: response.usage,
993                        ..Default::default()
994                    });
995                }
996                Err(e) => yield Err(e),
997            }
998        })
999    }
1000}
1001
1002/// Defines an executable tool that can be called by an AI model. Each tool has a name, description, JSON schema for parameters, and an async `call()` method.
1003#[async_trait]
1004pub trait Tool: Send + Sync {
1005    fn name(&self) -> &'static str;
1006    fn description(&self) -> &'static str;
1007
1008    fn parameters(&self) -> Option<Value> {
1009        None
1010    }
1011
1012    async fn call(&self, args: Value) -> Result<Value, SynapticError>;
1013
1014    fn as_tool_definition(&self) -> ToolDefinition {
1015        ToolDefinition {
1016            name: self.name().to_string(),
1017            description: self.description().to_string(),
1018            parameters: self
1019                .parameters()
1020                .unwrap_or(serde_json::json!({"type": "object", "properties": {}})),
1021            extras: None,
1022        }
1023    }
1024}
1025
1026// ---------------------------------------------------------------------------
1027// ToolContext — context-aware tool execution
1028// ---------------------------------------------------------------------------
1029
1030/// Context passed to tools during graph execution.
1031///
1032/// Provides access to the current graph state (serialized as JSON),
1033/// the tool call ID, and an optional key-value store reference.
1034#[derive(Debug, Clone, Default)]
1035pub struct ToolContext {
1036    /// The current graph state, serialized as JSON.
1037    pub state: Option<Value>,
1038    /// The ID of the tool call being executed.
1039    pub tool_call_id: String,
1040}
1041
1042/// A tool that receives execution context from the graph.
1043///
1044/// This extends the basic `Tool` trait with graph-level context
1045/// (current state, store, tool call ID). Implement this for tools
1046/// that need to read or modify graph state.
1047#[async_trait]
1048pub trait ContextAwareTool: Send + Sync {
1049    fn name(&self) -> &'static str;
1050    fn description(&self) -> &'static str;
1051    async fn call_with_context(
1052        &self,
1053        args: Value,
1054        ctx: ToolContext,
1055    ) -> Result<Value, SynapticError>;
1056}
1057
1058/// Wrapper that adapts a `ContextAwareTool` into a standard `Tool`.
1059///
1060/// When used outside a graph context, the tool receives a default
1061/// (empty) `ToolContext`.
1062pub struct ContextAwareToolAdapter {
1063    inner: Arc<dyn ContextAwareTool>,
1064}
1065
1066impl ContextAwareToolAdapter {
1067    pub fn new(inner: Arc<dyn ContextAwareTool>) -> Self {
1068        Self { inner }
1069    }
1070}
1071
1072#[async_trait]
1073impl Tool for ContextAwareToolAdapter {
1074    fn name(&self) -> &'static str {
1075        self.inner.name()
1076    }
1077
1078    fn description(&self) -> &'static str {
1079        self.inner.description()
1080    }
1081
1082    async fn call(&self, args: Value) -> Result<Value, SynapticError> {
1083        self.inner
1084            .call_with_context(args, ToolContext::default())
1085            .await
1086    }
1087}
1088
1089// ---------------------------------------------------------------------------
1090// MemoryStore
1091// ---------------------------------------------------------------------------
1092
1093/// Persistent storage for conversation message history, keyed by session ID.
1094#[async_trait]
1095pub trait MemoryStore: Send + Sync {
1096    async fn append(&self, session_id: &str, message: Message) -> Result<(), SynapticError>;
1097    async fn load(&self, session_id: &str) -> Result<Vec<Message>, SynapticError>;
1098    async fn clear(&self, session_id: &str) -> Result<(), SynapticError>;
1099}
1100
1101/// Handler for lifecycle events during agent execution. Receives `RunEvent` notifications at each stage.
1102#[async_trait]
1103pub trait CallbackHandler: Send + Sync {
1104    async fn on_event(&self, event: RunEvent) -> Result<(), SynapticError>;
1105}
1106
1107// ---------------------------------------------------------------------------
1108// RunnableConfig
1109// ---------------------------------------------------------------------------
1110
1111/// Runtime configuration passed through runnable chains, including tags, metadata, concurrency limits, and run identification.
1112#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1113pub struct RunnableConfig {
1114    #[serde(default)]
1115    pub tags: Vec<String>,
1116    #[serde(default)]
1117    pub metadata: HashMap<String, Value>,
1118    #[serde(default)]
1119    pub max_concurrency: Option<usize>,
1120    #[serde(default)]
1121    pub recursion_limit: Option<usize>,
1122    #[serde(default)]
1123    pub run_id: Option<String>,
1124    #[serde(default)]
1125    pub run_name: Option<String>,
1126}
1127
1128impl RunnableConfig {
1129    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
1130        self.tags = tags;
1131        self
1132    }
1133
1134    pub fn with_run_name(mut self, name: impl Into<String>) -> Self {
1135        self.run_name = Some(name.into());
1136        self
1137    }
1138
1139    pub fn with_run_id(mut self, id: impl Into<String>) -> Self {
1140        self.run_id = Some(id.into());
1141        self
1142    }
1143
1144    pub fn with_max_concurrency(mut self, max: usize) -> Self {
1145        self.max_concurrency = Some(max);
1146        self
1147    }
1148
1149    pub fn with_recursion_limit(mut self, limit: usize) -> Self {
1150        self.recursion_limit = Some(limit);
1151        self
1152    }
1153
1154    pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
1155        self.metadata.insert(key.into(), value);
1156        self
1157    }
1158}
1159
1160// ---------------------------------------------------------------------------
1161// Store trait (forward-declared in core, implemented in synaptic-store)
1162// ---------------------------------------------------------------------------
1163
1164/// A stored item in the key-value store.
1165#[derive(Debug, Clone, Serialize, Deserialize)]
1166pub struct Item {
1167    pub namespace: Vec<String>,
1168    pub key: String,
1169    pub value: Value,
1170    pub created_at: String,
1171    pub updated_at: String,
1172    /// Relevance score from a search operation (e.g., similarity score).
1173    #[serde(default, skip_serializing_if = "Option::is_none")]
1174    pub score: Option<f64>,
1175}
1176
1177/// Persistent key-value store trait for cross-invocation state.
1178///
1179/// Namespaces are hierarchical (represented as slices of strings) and
1180/// keys are strings within a namespace. Values are arbitrary JSON.
1181#[async_trait]
1182pub trait Store: Send + Sync {
1183    /// Get an item by namespace and key.
1184    async fn get(&self, namespace: &[&str], key: &str) -> Result<Option<Item>, SynapticError>;
1185
1186    /// Search items within a namespace.
1187    async fn search(
1188        &self,
1189        namespace: &[&str],
1190        query: Option<&str>,
1191        limit: usize,
1192    ) -> Result<Vec<Item>, SynapticError>;
1193
1194    /// Put (upsert) an item.
1195    async fn put(&self, namespace: &[&str], key: &str, value: Value) -> Result<(), SynapticError>;
1196
1197    /// Delete an item.
1198    async fn delete(&self, namespace: &[&str], key: &str) -> Result<(), SynapticError>;
1199
1200    /// List all namespaces, optionally filtered by prefix.
1201    async fn list_namespaces(&self, prefix: &[&str]) -> Result<Vec<Vec<String>>, SynapticError>;
1202}
1203
1204// ---------------------------------------------------------------------------
1205// Embeddings trait (forward-declared here, implemented in synaptic-embeddings)
1206// ---------------------------------------------------------------------------
1207
1208/// Trait for embedding text into vectors.
1209#[async_trait]
1210pub trait Embeddings: Send + Sync {
1211    /// Embed multiple texts (for batch document embedding).
1212    async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError>;
1213
1214    /// Embed a single query text.
1215    async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError>;
1216}
1217
1218// ---------------------------------------------------------------------------
1219// StreamWriter
1220// ---------------------------------------------------------------------------
1221
1222/// Custom stream writer that nodes can use to emit custom events.
1223pub type StreamWriter = Arc<dyn Fn(Value) + Send + Sync>;
1224
1225// ---------------------------------------------------------------------------
1226// Runtime types
1227// ---------------------------------------------------------------------------
1228
1229/// Graph execution runtime context passed to nodes and middleware.
1230#[derive(Clone)]
1231pub struct Runtime {
1232    pub store: Option<Arc<dyn Store>>,
1233    pub stream_writer: Option<StreamWriter>,
1234}
1235
1236/// Tool execution runtime context.
1237#[derive(Clone)]
1238pub struct ToolRuntime {
1239    pub store: Option<Arc<dyn Store>>,
1240    pub stream_writer: Option<StreamWriter>,
1241    pub state: Option<Value>,
1242    pub tool_call_id: String,
1243    pub config: Option<RunnableConfig>,
1244}
1245
1246// ---------------------------------------------------------------------------
1247// RuntimeAwareTool
1248// ---------------------------------------------------------------------------
1249
1250/// Context-aware tool that receives runtime information.
1251///
1252/// This extends the basic `Tool` trait with runtime context
1253/// (current state, store, stream writer, tool call ID). Implement this
1254/// for tools that need to read or modify graph state.
1255#[async_trait]
1256pub trait RuntimeAwareTool: Send + Sync {
1257    fn name(&self) -> &'static str;
1258    fn description(&self) -> &'static str;
1259
1260    fn parameters(&self) -> Option<Value> {
1261        None
1262    }
1263
1264    async fn call_with_runtime(
1265        &self,
1266        args: Value,
1267        runtime: ToolRuntime,
1268    ) -> Result<Value, SynapticError>;
1269
1270    fn as_tool_definition(&self) -> ToolDefinition {
1271        ToolDefinition {
1272            name: self.name().to_string(),
1273            description: self.description().to_string(),
1274            parameters: self
1275                .parameters()
1276                .unwrap_or(serde_json::json!({"type": "object", "properties": {}})),
1277            extras: None,
1278        }
1279    }
1280}
1281
1282/// Adapter that wraps a `RuntimeAwareTool` into a standard `Tool`.
1283///
1284/// When used outside a graph context, the tool receives a default
1285/// (empty) `ToolRuntime`.
1286pub struct RuntimeAwareToolAdapter {
1287    inner: Arc<dyn RuntimeAwareTool>,
1288    runtime: Arc<tokio::sync::RwLock<Option<ToolRuntime>>>,
1289}
1290
1291impl RuntimeAwareToolAdapter {
1292    pub fn new(tool: Arc<dyn RuntimeAwareTool>) -> Self {
1293        Self {
1294            inner: tool,
1295            runtime: Arc::new(tokio::sync::RwLock::new(None)),
1296        }
1297    }
1298
1299    pub async fn set_runtime(&self, runtime: ToolRuntime) {
1300        *self.runtime.write().await = Some(runtime);
1301    }
1302}
1303
1304#[async_trait]
1305impl Tool for RuntimeAwareToolAdapter {
1306    fn name(&self) -> &'static str {
1307        self.inner.name()
1308    }
1309
1310    fn description(&self) -> &'static str {
1311        self.inner.description()
1312    }
1313
1314    fn parameters(&self) -> Option<Value> {
1315        self.inner.parameters()
1316    }
1317
1318    async fn call(&self, args: Value) -> Result<Value, SynapticError> {
1319        let runtime = self.runtime.read().await.clone().unwrap_or(ToolRuntime {
1320            store: None,
1321            stream_writer: None,
1322            state: None,
1323            tool_call_id: String::new(),
1324            config: None,
1325        });
1326        self.inner.call_with_runtime(args, runtime).await
1327    }
1328}
1329
1330// ---------------------------------------------------------------------------
1331// Entrypoint / Task metadata (used by proc macros)
1332// ---------------------------------------------------------------------------
1333
1334/// Configuration for an `#[entrypoint]`-decorated function.
1335#[derive(Debug, Clone)]
1336pub struct EntrypointConfig {
1337    pub name: &'static str,
1338    pub checkpointer: Option<&'static str>,
1339}
1340
1341/// An entrypoint wrapping an async function as a runnable workflow.
1342///
1343/// The `invoke_fn` field is a type-erased async function (`Value -> Result<Value, SynapticError>`).
1344pub struct Entrypoint {
1345    pub config: EntrypointConfig,
1346    pub invoke_fn: Box<
1347        dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<Value, SynapticError>> + Send>>
1348            + Send
1349            + Sync,
1350    >,
1351}
1352
1353impl Entrypoint {
1354    pub async fn invoke(&self, input: Value) -> Result<Value, SynapticError> {
1355        (self.invoke_fn)(input).await
1356    }
1357}