Skip to main content

synaptic_core/
lib.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::Stream;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use thiserror::Error;
11
12#[cfg(feature = "schemars")]
13pub use schemars;
14
15// ---------------------------------------------------------------------------
16// ContentBlock — multimodal message content
17// ---------------------------------------------------------------------------
18
19/// A block of content within a message, supporting multimodal inputs.
20#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
21#[serde(tag = "type", rename_all = "snake_case")]
22pub enum ContentBlock {
23    Text {
24        text: String,
25    },
26    Image {
27        url: String,
28        #[serde(default, skip_serializing_if = "Option::is_none")]
29        detail: Option<String>,
30    },
31    Audio {
32        url: String,
33    },
34    Video {
35        url: String,
36    },
37    File {
38        url: String,
39        #[serde(default, skip_serializing_if = "Option::is_none")]
40        mime_type: Option<String>,
41    },
42    Data {
43        data: Value,
44    },
45    Reasoning {
46        content: String,
47    },
48}
49
50// ---------------------------------------------------------------------------
51// Message
52// ---------------------------------------------------------------------------
53
54/// Represents a chat message. Tagged enum with System, Human, AI, and Tool variants.
55#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
56#[serde(tag = "role")]
57pub enum Message {
58    #[serde(rename = "system")]
59    System {
60        content: String,
61        #[serde(default, skip_serializing_if = "Option::is_none")]
62        id: Option<String>,
63        #[serde(default, skip_serializing_if = "Option::is_none")]
64        name: Option<String>,
65        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
66        additional_kwargs: HashMap<String, Value>,
67        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
68        response_metadata: HashMap<String, Value>,
69        #[serde(default, skip_serializing_if = "Vec::is_empty")]
70        content_blocks: Vec<ContentBlock>,
71    },
72    #[serde(rename = "human")]
73    Human {
74        content: String,
75        #[serde(default, skip_serializing_if = "Option::is_none")]
76        id: Option<String>,
77        #[serde(default, skip_serializing_if = "Option::is_none")]
78        name: Option<String>,
79        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
80        additional_kwargs: HashMap<String, Value>,
81        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
82        response_metadata: HashMap<String, Value>,
83        #[serde(default, skip_serializing_if = "Vec::is_empty")]
84        content_blocks: Vec<ContentBlock>,
85    },
86    #[serde(rename = "assistant")]
87    AI {
88        content: String,
89        #[serde(default, skip_serializing_if = "Vec::is_empty")]
90        tool_calls: Vec<ToolCall>,
91        #[serde(default, skip_serializing_if = "Option::is_none")]
92        id: Option<String>,
93        #[serde(default, skip_serializing_if = "Option::is_none")]
94        name: Option<String>,
95        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
96        additional_kwargs: HashMap<String, Value>,
97        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
98        response_metadata: HashMap<String, Value>,
99        #[serde(default, skip_serializing_if = "Vec::is_empty")]
100        content_blocks: Vec<ContentBlock>,
101        #[serde(default, skip_serializing_if = "Option::is_none")]
102        usage_metadata: Option<TokenUsage>,
103        #[serde(default, skip_serializing_if = "Vec::is_empty")]
104        invalid_tool_calls: Vec<InvalidToolCall>,
105    },
106    #[serde(rename = "tool")]
107    Tool {
108        content: String,
109        tool_call_id: String,
110        #[serde(default, skip_serializing_if = "Option::is_none")]
111        id: Option<String>,
112        #[serde(default, skip_serializing_if = "Option::is_none")]
113        name: Option<String>,
114        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
115        additional_kwargs: HashMap<String, Value>,
116        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
117        response_metadata: HashMap<String, Value>,
118        #[serde(default, skip_serializing_if = "Vec::is_empty")]
119        content_blocks: Vec<ContentBlock>,
120    },
121    #[serde(rename = "chat")]
122    Chat {
123        custom_role: String,
124        content: String,
125        #[serde(default, skip_serializing_if = "Option::is_none")]
126        id: Option<String>,
127        #[serde(default, skip_serializing_if = "Option::is_none")]
128        name: Option<String>,
129        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
130        additional_kwargs: HashMap<String, Value>,
131        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
132        response_metadata: HashMap<String, Value>,
133        #[serde(default, skip_serializing_if = "Vec::is_empty")]
134        content_blocks: Vec<ContentBlock>,
135    },
136    /// A special message that signals removal of a message by its ID.
137    /// Used in message history management.
138    #[serde(rename = "remove")]
139    Remove {
140        /// ID of the message to remove.
141        id: String,
142    },
143}
144
145/// Helper macro to set a shared field across all Message variants.
146/// Note: Remove variant has no common fields, so it is a no-op.
147macro_rules! set_message_field {
148    ($self:expr, $field:ident, $value:expr) => {
149        match $self {
150            Message::System { $field, .. } => *$field = $value,
151            Message::Human { $field, .. } => *$field = $value,
152            Message::AI { $field, .. } => *$field = $value,
153            Message::Tool { $field, .. } => *$field = $value,
154            Message::Chat { $field, .. } => *$field = $value,
155            Message::Remove { .. } => { /* Remove has no common fields */ }
156        }
157    };
158}
159
160/// Helper macro to get a shared field from all Message variants.
161/// Note: Remove variant panics — callers handle Remove before using this macro.
162macro_rules! get_message_field {
163    ($self:expr, $field:ident) => {
164        match $self {
165            Message::System { $field, .. } => $field,
166            Message::Human { $field, .. } => $field,
167            Message::AI { $field, .. } => $field,
168            Message::Tool { $field, .. } => $field,
169            Message::Chat { $field, .. } => $field,
170            Message::Remove { .. } => unreachable!("get_message_field called on Remove variant"),
171        }
172    };
173}
174
175impl Message {
176    // -- Factory methods -----------------------------------------------------
177
178    pub fn system(content: impl Into<String>) -> Self {
179        Message::System {
180            content: content.into(),
181            id: None,
182            name: None,
183            additional_kwargs: HashMap::new(),
184            response_metadata: HashMap::new(),
185            content_blocks: Vec::new(),
186        }
187    }
188
189    pub fn human(content: impl Into<String>) -> Self {
190        Message::Human {
191            content: content.into(),
192            id: None,
193            name: None,
194            additional_kwargs: HashMap::new(),
195            response_metadata: HashMap::new(),
196            content_blocks: Vec::new(),
197        }
198    }
199
200    pub fn ai(content: impl Into<String>) -> Self {
201        Message::AI {
202            content: content.into(),
203            tool_calls: vec![],
204            id: None,
205            name: None,
206            additional_kwargs: HashMap::new(),
207            response_metadata: HashMap::new(),
208            content_blocks: Vec::new(),
209            usage_metadata: None,
210            invalid_tool_calls: Vec::new(),
211        }
212    }
213
214    pub fn ai_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
215        Message::AI {
216            content: content.into(),
217            tool_calls,
218            id: None,
219            name: None,
220            additional_kwargs: HashMap::new(),
221            response_metadata: HashMap::new(),
222            content_blocks: Vec::new(),
223            usage_metadata: None,
224            invalid_tool_calls: Vec::new(),
225        }
226    }
227
228    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
229        Message::Tool {
230            content: content.into(),
231            tool_call_id: tool_call_id.into(),
232            id: None,
233            name: None,
234            additional_kwargs: HashMap::new(),
235            response_metadata: HashMap::new(),
236            content_blocks: Vec::new(),
237        }
238    }
239
240    pub fn chat(role: impl Into<String>, content: impl Into<String>) -> Self {
241        Message::Chat {
242            custom_role: role.into(),
243            content: content.into(),
244            id: None,
245            name: None,
246            additional_kwargs: HashMap::new(),
247            response_metadata: HashMap::new(),
248            content_blocks: Vec::new(),
249        }
250    }
251
252    /// Create a Remove message that signals removal of a message by its ID.
253    pub fn remove(id: impl Into<String>) -> Self {
254        Message::Remove { id: id.into() }
255    }
256
257    // -- Builder methods -----------------------------------------------------
258
259    pub fn with_id(mut self, value: impl Into<String>) -> Self {
260        set_message_field!(&mut self, id, Some(value.into()));
261        self
262    }
263
264    pub fn with_name(mut self, value: impl Into<String>) -> Self {
265        set_message_field!(&mut self, name, Some(value.into()));
266        self
267    }
268
269    pub fn with_additional_kwarg(mut self, key: impl Into<String>, value: Value) -> Self {
270        match &mut self {
271            Message::System {
272                additional_kwargs, ..
273            }
274            | Message::Human {
275                additional_kwargs, ..
276            }
277            | Message::AI {
278                additional_kwargs, ..
279            }
280            | Message::Tool {
281                additional_kwargs, ..
282            }
283            | Message::Chat {
284                additional_kwargs, ..
285            } => {
286                additional_kwargs.insert(key.into(), value);
287            }
288            Message::Remove { .. } => { /* Remove has no additional_kwargs */ }
289        }
290        self
291    }
292
293    pub fn with_response_metadata_entry(mut self, key: impl Into<String>, value: Value) -> Self {
294        match &mut self {
295            Message::System {
296                response_metadata, ..
297            }
298            | Message::Human {
299                response_metadata, ..
300            }
301            | Message::AI {
302                response_metadata, ..
303            }
304            | Message::Tool {
305                response_metadata, ..
306            }
307            | Message::Chat {
308                response_metadata, ..
309            } => {
310                response_metadata.insert(key.into(), value);
311            }
312            Message::Remove { .. } => { /* Remove has no response_metadata */ }
313        }
314        self
315    }
316
317    pub fn with_content_blocks(mut self, blocks: Vec<ContentBlock>) -> Self {
318        set_message_field!(&mut self, content_blocks, blocks);
319        self
320    }
321
322    pub fn with_usage_metadata(mut self, usage: TokenUsage) -> Self {
323        if let Message::AI { usage_metadata, .. } = &mut self {
324            *usage_metadata = Some(usage);
325        }
326        self
327    }
328
329    // -- Accessor methods ----------------------------------------------------
330
331    pub fn content(&self) -> &str {
332        match self {
333            Message::Remove { .. } => "",
334            other => get_message_field!(other, content),
335        }
336    }
337
338    pub fn role(&self) -> &str {
339        match self {
340            Message::System { .. } => "system",
341            Message::Human { .. } => "human",
342            Message::AI { .. } => "assistant",
343            Message::Tool { .. } => "tool",
344            Message::Chat { custom_role, .. } => custom_role,
345            Message::Remove { .. } => "remove",
346        }
347    }
348
349    pub fn is_system(&self) -> bool {
350        matches!(self, Message::System { .. })
351    }
352
353    pub fn is_human(&self) -> bool {
354        matches!(self, Message::Human { .. })
355    }
356
357    pub fn is_ai(&self) -> bool {
358        matches!(self, Message::AI { .. })
359    }
360
361    pub fn is_tool(&self) -> bool {
362        matches!(self, Message::Tool { .. })
363    }
364
365    pub fn is_chat(&self) -> bool {
366        matches!(self, Message::Chat { .. })
367    }
368
369    pub fn is_remove(&self) -> bool {
370        matches!(self, Message::Remove { .. })
371    }
372
373    pub fn tool_calls(&self) -> &[ToolCall] {
374        match self {
375            Message::AI { tool_calls, .. } => tool_calls,
376            _ => &[],
377        }
378    }
379
380    pub fn tool_call_id(&self) -> Option<&str> {
381        match self {
382            Message::Tool { tool_call_id, .. } => Some(tool_call_id),
383            _ => None,
384        }
385    }
386
387    pub fn id(&self) -> Option<&str> {
388        match self {
389            Message::Remove { id } => Some(id),
390            other => get_message_field!(other, id).as_deref(),
391        }
392    }
393
394    pub fn name(&self) -> Option<&str> {
395        match self {
396            Message::Remove { .. } => None,
397            other => get_message_field!(other, name).as_deref(),
398        }
399    }
400
401    pub fn additional_kwargs(&self) -> &HashMap<String, Value> {
402        match self {
403            Message::System {
404                additional_kwargs, ..
405            }
406            | Message::Human {
407                additional_kwargs, ..
408            }
409            | Message::AI {
410                additional_kwargs, ..
411            }
412            | Message::Tool {
413                additional_kwargs, ..
414            }
415            | Message::Chat {
416                additional_kwargs, ..
417            } => additional_kwargs,
418            Message::Remove { .. } => {
419                static EMPTY: std::sync::OnceLock<HashMap<String, Value>> =
420                    std::sync::OnceLock::new();
421                EMPTY.get_or_init(HashMap::new)
422            }
423        }
424    }
425
426    pub fn response_metadata(&self) -> &HashMap<String, Value> {
427        match self {
428            Message::System {
429                response_metadata, ..
430            }
431            | Message::Human {
432                response_metadata, ..
433            }
434            | Message::AI {
435                response_metadata, ..
436            }
437            | Message::Tool {
438                response_metadata, ..
439            }
440            | Message::Chat {
441                response_metadata, ..
442            } => response_metadata,
443            Message::Remove { .. } => {
444                static EMPTY: std::sync::OnceLock<HashMap<String, Value>> =
445                    std::sync::OnceLock::new();
446                EMPTY.get_or_init(HashMap::new)
447            }
448        }
449    }
450
451    pub fn content_blocks(&self) -> &[ContentBlock] {
452        match self {
453            Message::Remove { .. } => &[],
454            other => get_message_field!(other, content_blocks),
455        }
456    }
457
458    /// Return the remove ID if this is a Remove message.
459    pub fn remove_id(&self) -> Option<&str> {
460        match self {
461            Message::Remove { id } => Some(id),
462            _ => None,
463        }
464    }
465
466    pub fn usage_metadata(&self) -> Option<&TokenUsage> {
467        match self {
468            Message::AI { usage_metadata, .. } => usage_metadata.as_ref(),
469            _ => None,
470        }
471    }
472
473    pub fn invalid_tool_calls(&self) -> &[InvalidToolCall] {
474        match self {
475            Message::AI {
476                invalid_tool_calls, ..
477            } => invalid_tool_calls,
478            _ => &[],
479        }
480    }
481}
482
483// ---------------------------------------------------------------------------
484// Message utility functions
485// ---------------------------------------------------------------------------
486
487/// Filter messages by type, name, or id.
488pub fn filter_messages(
489    messages: &[Message],
490    include_types: Option<&[&str]>,
491    exclude_types: Option<&[&str]>,
492    include_names: Option<&[&str]>,
493    exclude_names: Option<&[&str]>,
494    include_ids: Option<&[&str]>,
495    exclude_ids: Option<&[&str]>,
496) -> Vec<Message> {
497    messages
498        .iter()
499        .filter(|msg| {
500            if let Some(include) = include_types {
501                if !include.contains(&msg.role()) {
502                    return false;
503                }
504            }
505            if let Some(exclude) = exclude_types {
506                if exclude.contains(&msg.role()) {
507                    return false;
508                }
509            }
510            if let Some(include) = include_names {
511                match msg.name() {
512                    Some(name) => {
513                        if !include.contains(&name) {
514                            return false;
515                        }
516                    }
517                    None => return false,
518                }
519            }
520            if let Some(exclude) = exclude_names {
521                if let Some(name) = msg.name() {
522                    if exclude.contains(&name) {
523                        return false;
524                    }
525                }
526            }
527            if let Some(include) = include_ids {
528                match msg.id() {
529                    Some(id) => {
530                        if !include.contains(&id) {
531                            return false;
532                        }
533                    }
534                    None => return false,
535                }
536            }
537            if let Some(exclude) = exclude_ids {
538                if let Some(id) = msg.id() {
539                    if exclude.contains(&id) {
540                        return false;
541                    }
542                }
543            }
544            true
545        })
546        .cloned()
547        .collect()
548}
549
550/// Strategy for trimming messages.
551#[derive(Debug, Clone, Copy, PartialEq, Eq)]
552pub enum TrimStrategy {
553    /// Keep the first messages that fit within the token budget.
554    First,
555    /// Keep the last messages that fit within the token budget.
556    Last,
557}
558
559/// Trim messages to fit within a token budget.
560///
561/// `token_counter` receives a single message and returns its token count.
562/// When `include_system` is true and `strategy` is `Last`, the leading system
563/// message is always preserved.
564pub fn trim_messages(
565    messages: Vec<Message>,
566    max_tokens: usize,
567    token_counter: impl Fn(&Message) -> usize,
568    strategy: TrimStrategy,
569    include_system: bool,
570) -> Vec<Message> {
571    if messages.is_empty() {
572        return messages;
573    }
574
575    match strategy {
576        TrimStrategy::First => {
577            let mut result = Vec::new();
578            let mut total = 0;
579            for msg in messages {
580                let count = token_counter(&msg);
581                if total + count > max_tokens {
582                    break;
583                }
584                total += count;
585                result.push(msg);
586            }
587            result
588        }
589        TrimStrategy::Last => {
590            let (system_msg, rest) = if include_system && messages[0].is_system() {
591                (Some(messages[0].clone()), &messages[1..])
592            } else {
593                (None, messages.as_slice())
594            };
595
596            let system_tokens = system_msg.as_ref().map(&token_counter).unwrap_or(0);
597            let budget = max_tokens.saturating_sub(system_tokens);
598
599            let mut selected = Vec::new();
600            let mut total = 0;
601            for msg in rest.iter().rev() {
602                let count = token_counter(msg);
603                if total + count > budget {
604                    break;
605                }
606                total += count;
607                selected.push(msg.clone());
608            }
609            selected.reverse();
610
611            let mut result = Vec::new();
612            if let Some(sys) = system_msg {
613                result.push(sys);
614            }
615            result.extend(selected);
616            result
617        }
618    }
619}
620
621/// Merge consecutive messages of the same role into a single message.
622pub fn merge_message_runs(messages: Vec<Message>) -> Vec<Message> {
623    if messages.is_empty() {
624        return messages;
625    }
626
627    let mut result: Vec<Message> = Vec::new();
628
629    for msg in messages {
630        let should_merge = result
631            .last()
632            .map(|last| last.role() == msg.role())
633            .unwrap_or(false);
634
635        if should_merge {
636            let last = result.last_mut().unwrap();
637            // Merge content
638            let merged_content = format!("{}\n{}", last.content(), msg.content());
639            match last {
640                Message::System { content, .. } => *content = merged_content,
641                Message::Human { content, .. } => *content = merged_content,
642                Message::AI {
643                    content,
644                    tool_calls,
645                    invalid_tool_calls,
646                    ..
647                } => {
648                    *content = merged_content;
649                    tool_calls.extend(msg.tool_calls().to_vec());
650                    invalid_tool_calls.extend(msg.invalid_tool_calls().to_vec());
651                }
652                Message::Tool { content, .. } => *content = merged_content,
653                Message::Chat { content, .. } => *content = merged_content,
654                Message::Remove { .. } => { /* Remove messages are not merged */ }
655            }
656        } else {
657            result.push(msg);
658        }
659    }
660
661    result
662}
663
664/// Convert messages to a human-readable buffer string.
665pub fn get_buffer_string(messages: &[Message], human_prefix: &str, ai_prefix: &str) -> String {
666    messages
667        .iter()
668        .map(|msg| {
669            let prefix = match msg {
670                Message::System { .. } => "System",
671                Message::Human { .. } => human_prefix,
672                Message::AI { .. } => ai_prefix,
673                Message::Tool { .. } => "Tool",
674                Message::Chat { custom_role, .. } => custom_role.as_str(),
675                Message::Remove { .. } => "Remove",
676            };
677            format!("{prefix}: {}", msg.content())
678        })
679        .collect::<Vec<_>>()
680        .join("\n")
681}
682
683// ---------------------------------------------------------------------------
684// AIMessageChunk
685// ---------------------------------------------------------------------------
686
687/// A streaming chunk from an AI model response. Supports merge via `+`/`+=` operators and conversion to `Message` via `into_message()`.
688#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
689pub struct AIMessageChunk {
690    pub content: String,
691    #[serde(default, skip_serializing_if = "Vec::is_empty")]
692    pub tool_calls: Vec<ToolCall>,
693    #[serde(default, skip_serializing_if = "Option::is_none")]
694    pub usage: Option<TokenUsage>,
695    #[serde(default, skip_serializing_if = "Option::is_none")]
696    pub id: Option<String>,
697    #[serde(default, skip_serializing_if = "Vec::is_empty")]
698    pub tool_call_chunks: Vec<ToolCallChunk>,
699    #[serde(default, skip_serializing_if = "Vec::is_empty")]
700    pub invalid_tool_calls: Vec<InvalidToolCall>,
701}
702
703impl AIMessageChunk {
704    pub fn into_message(self) -> Message {
705        Message::ai_with_tool_calls(self.content, self.tool_calls)
706    }
707}
708
709impl std::ops::Add for AIMessageChunk {
710    type Output = Self;
711
712    fn add(mut self, rhs: Self) -> Self {
713        self += rhs;
714        self
715    }
716}
717
718impl std::ops::AddAssign for AIMessageChunk {
719    fn add_assign(&mut self, rhs: Self) {
720        self.content.push_str(&rhs.content);
721        self.tool_calls.extend(rhs.tool_calls);
722        self.tool_call_chunks.extend(rhs.tool_call_chunks);
723        self.invalid_tool_calls.extend(rhs.invalid_tool_calls);
724        if self.id.is_none() {
725            self.id = rhs.id;
726        }
727        match (&mut self.usage, rhs.usage) {
728            (Some(u), Some(rhs_u)) => {
729                u.input_tokens += rhs_u.input_tokens;
730                u.output_tokens += rhs_u.output_tokens;
731                u.total_tokens += rhs_u.total_tokens;
732            }
733            (None, Some(rhs_u)) => {
734                self.usage = Some(rhs_u);
735            }
736            _ => {}
737        }
738    }
739}
740
741// ---------------------------------------------------------------------------
742// Tool-related types
743// ---------------------------------------------------------------------------
744
745/// Represents a tool invocation requested by an AI model, with an ID, function name, and JSON arguments.
746#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
747pub struct ToolCall {
748    pub id: String,
749    pub name: String,
750    pub arguments: Value,
751}
752
753/// A tool call that failed to parse correctly.
754#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
755pub struct InvalidToolCall {
756    #[serde(default, skip_serializing_if = "Option::is_none")]
757    pub id: Option<String>,
758    #[serde(default, skip_serializing_if = "Option::is_none")]
759    pub name: Option<String>,
760    #[serde(default, skip_serializing_if = "Option::is_none")]
761    pub arguments: Option<String>,
762    pub error: String,
763}
764
765/// A partial tool call chunk received during streaming.
766#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
767pub struct ToolCallChunk {
768    #[serde(default, skip_serializing_if = "Option::is_none")]
769    pub id: Option<String>,
770    #[serde(default, skip_serializing_if = "Option::is_none")]
771    pub name: Option<String>,
772    #[serde(default, skip_serializing_if = "Option::is_none")]
773    pub arguments: Option<String>,
774    #[serde(default, skip_serializing_if = "Option::is_none")]
775    pub index: Option<usize>,
776}
777
778/// Schema definition for a tool, including its name, description, and JSON Schema for parameters.
779#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
780pub struct ToolDefinition {
781    pub name: String,
782    pub description: String,
783    pub parameters: Value,
784    /// Provider-specific parameters (e.g., Anthropic's `cache_control`).
785    #[serde(default, skip_serializing_if = "Option::is_none")]
786    pub extras: Option<HashMap<String, Value>>,
787}
788
789/// Controls how the model selects tools: Auto, Required, None, or a Specific named tool.
790#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
791#[serde(rename_all = "lowercase")]
792pub enum ToolChoice {
793    Auto,
794    Required,
795    None,
796    Specific(String),
797}
798
799// ---------------------------------------------------------------------------
800// Chat request / response
801// ---------------------------------------------------------------------------
802
803/// A request to a chat model containing messages, optional tool definitions, and tool choice configuration.
804#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
805pub struct ChatRequest {
806    pub messages: Vec<Message>,
807    #[serde(default, skip_serializing_if = "Vec::is_empty")]
808    pub tools: Vec<ToolDefinition>,
809    #[serde(default, skip_serializing_if = "Option::is_none")]
810    pub tool_choice: Option<ToolChoice>,
811}
812
813impl ChatRequest {
814    pub fn new(messages: Vec<Message>) -> Self {
815        Self {
816            messages,
817            tools: vec![],
818            tool_choice: None,
819        }
820    }
821
822    pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
823        self.tools = tools;
824        self
825    }
826
827    pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
828        self.tool_choice = Some(choice);
829        self
830    }
831}
832
833/// A response from a chat model containing the AI message and optional token usage statistics.
834#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
835pub struct ChatResponse {
836    pub message: Message,
837    pub usage: Option<TokenUsage>,
838}
839
840// ---------------------------------------------------------------------------
841// Token usage
842// ---------------------------------------------------------------------------
843
844#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
845pub struct TokenUsage {
846    pub input_tokens: u32,
847    pub output_tokens: u32,
848    pub total_tokens: u32,
849    #[serde(default, skip_serializing_if = "Option::is_none")]
850    pub input_details: Option<InputTokenDetails>,
851    #[serde(default, skip_serializing_if = "Option::is_none")]
852    pub output_details: Option<OutputTokenDetails>,
853}
854
855/// Detailed breakdown of input token usage.
856#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
857pub struct InputTokenDetails {
858    #[serde(default)]
859    pub cached: u32,
860    #[serde(default)]
861    pub audio: u32,
862}
863
864/// Detailed breakdown of output token usage.
865#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
866pub struct OutputTokenDetails {
867    #[serde(default)]
868    pub reasoning: u32,
869    #[serde(default)]
870    pub audio: u32,
871}
872
873// ---------------------------------------------------------------------------
874// Events
875// ---------------------------------------------------------------------------
876
877/// Lifecycle events emitted during agent execution, used by `CallbackHandler` implementations.
878#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
879pub enum RunEvent {
880    RunStarted {
881        run_id: String,
882        session_id: String,
883    },
884    RunStep {
885        run_id: String,
886        step: usize,
887    },
888    LlmCalled {
889        run_id: String,
890        message_count: usize,
891    },
892    ToolCalled {
893        run_id: String,
894        tool_name: String,
895    },
896    RunFinished {
897        run_id: String,
898        output: String,
899    },
900    RunFailed {
901        run_id: String,
902        error: String,
903    },
904}
905
906// ---------------------------------------------------------------------------
907// Errors
908// ---------------------------------------------------------------------------
909
910/// Unified error type for the Synaptic framework with variants covering all subsystems.
911#[derive(Debug, Error)]
912pub enum SynapticError {
913    #[error("prompt error: {0}")]
914    Prompt(String),
915    #[error("model error: {0}")]
916    Model(String),
917    #[error("tool error: {0}")]
918    Tool(String),
919    #[error("tool not found: {0}")]
920    ToolNotFound(String),
921    #[error("memory error: {0}")]
922    Memory(String),
923    #[error("rate limit: {0}")]
924    RateLimit(String),
925    #[error("timeout: {0}")]
926    Timeout(String),
927    #[error("validation error: {0}")]
928    Validation(String),
929    #[error("parsing error: {0}")]
930    Parsing(String),
931    #[error("callback error: {0}")]
932    Callback(String),
933    #[error("max steps exceeded: {max_steps}")]
934    MaxStepsExceeded { max_steps: usize },
935    #[error("embedding error: {0}")]
936    Embedding(String),
937    #[error("vector store error: {0}")]
938    VectorStore(String),
939    #[error("retriever error: {0}")]
940    Retriever(String),
941    #[error("loader error: {0}")]
942    Loader(String),
943    #[error("splitter error: {0}")]
944    Splitter(String),
945    #[error("graph error: {0}")]
946    Graph(String),
947    #[error("cache error: {0}")]
948    Cache(String),
949    #[error("store error: {0}")]
950    Store(String),
951    #[error("config error: {0}")]
952    Config(String),
953    #[error("mcp error: {0}")]
954    Mcp(String),
955}
956
957// ---------------------------------------------------------------------------
958// Core traits
959// ---------------------------------------------------------------------------
960
961/// Type alias for a pinned, boxed async stream of `AIMessageChunk` results.
962pub type ChatStream<'a> =
963    Pin<Box<dyn Stream<Item = Result<AIMessageChunk, SynapticError>> + Send + 'a>>;
964
965/// Describes a model's capabilities and limits.
966#[derive(Debug, Clone, Serialize, Deserialize)]
967pub struct ModelProfile {
968    pub name: String,
969    pub provider: String,
970    pub supports_tool_calling: bool,
971    pub supports_structured_output: bool,
972    pub supports_streaming: bool,
973    pub max_input_tokens: Option<usize>,
974    pub max_output_tokens: Option<usize>,
975}
976
977/// The core trait for language model providers. Implementations provide `chat()` for single responses and optionally `stream_chat()` for streaming.
978#[async_trait]
979pub trait ChatModel: Send + Sync {
980    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError>;
981
982    /// Return the model's capability profile, if known.
983    fn profile(&self) -> Option<ModelProfile> {
984        None
985    }
986
987    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
988        Box::pin(async_stream::stream! {
989            match self.chat(request).await {
990                Ok(response) => {
991                    yield Ok(AIMessageChunk {
992                        content: response.message.content().to_string(),
993                        tool_calls: response.message.tool_calls().to_vec(),
994                        usage: response.usage,
995                        ..Default::default()
996                    });
997                }
998                Err(e) => yield Err(e),
999            }
1000        })
1001    }
1002}
1003
1004/// Defines an executable tool that can be called by an AI model. Each tool has a name, description, JSON schema for parameters, and an async `call()` method.
1005#[async_trait]
1006pub trait Tool: Send + Sync {
1007    fn name(&self) -> &'static str;
1008    fn description(&self) -> &'static str;
1009
1010    fn parameters(&self) -> Option<Value> {
1011        None
1012    }
1013
1014    async fn call(&self, args: Value) -> Result<Value, SynapticError>;
1015
1016    fn as_tool_definition(&self) -> ToolDefinition {
1017        ToolDefinition {
1018            name: self.name().to_string(),
1019            description: self.description().to_string(),
1020            parameters: self
1021                .parameters()
1022                .unwrap_or(serde_json::json!({"type": "object", "properties": {}})),
1023            extras: None,
1024        }
1025    }
1026}
1027
1028// ---------------------------------------------------------------------------
1029// ToolContext — context-aware tool execution
1030// ---------------------------------------------------------------------------
1031
1032/// Context passed to tools during graph execution.
1033///
1034/// Provides access to the current graph state (serialized as JSON),
1035/// the tool call ID, and an optional key-value store reference.
1036#[derive(Debug, Clone, Default)]
1037pub struct ToolContext {
1038    /// The current graph state, serialized as JSON.
1039    pub state: Option<Value>,
1040    /// The ID of the tool call being executed.
1041    pub tool_call_id: String,
1042}
1043
1044/// A tool that receives execution context from the graph.
1045///
1046/// This extends the basic `Tool` trait with graph-level context
1047/// (current state, store, tool call ID). Implement this for tools
1048/// that need to read or modify graph state.
1049#[async_trait]
1050pub trait ContextAwareTool: Send + Sync {
1051    fn name(&self) -> &'static str;
1052    fn description(&self) -> &'static str;
1053    async fn call_with_context(
1054        &self,
1055        args: Value,
1056        ctx: ToolContext,
1057    ) -> Result<Value, SynapticError>;
1058}
1059
1060/// Wrapper that adapts a `ContextAwareTool` into a standard `Tool`.
1061///
1062/// When used outside a graph context, the tool receives a default
1063/// (empty) `ToolContext`.
1064pub struct ContextAwareToolAdapter {
1065    inner: Arc<dyn ContextAwareTool>,
1066}
1067
1068impl ContextAwareToolAdapter {
1069    pub fn new(inner: Arc<dyn ContextAwareTool>) -> Self {
1070        Self { inner }
1071    }
1072}
1073
1074#[async_trait]
1075impl Tool for ContextAwareToolAdapter {
1076    fn name(&self) -> &'static str {
1077        self.inner.name()
1078    }
1079
1080    fn description(&self) -> &'static str {
1081        self.inner.description()
1082    }
1083
1084    async fn call(&self, args: Value) -> Result<Value, SynapticError> {
1085        self.inner
1086            .call_with_context(args, ToolContext::default())
1087            .await
1088    }
1089}
1090
1091// ---------------------------------------------------------------------------
1092// MemoryStore
1093// ---------------------------------------------------------------------------
1094
1095/// Persistent storage for conversation message history, keyed by session ID.
1096#[async_trait]
1097pub trait MemoryStore: Send + Sync {
1098    async fn append(&self, session_id: &str, message: Message) -> Result<(), SynapticError>;
1099    async fn load(&self, session_id: &str) -> Result<Vec<Message>, SynapticError>;
1100    async fn clear(&self, session_id: &str) -> Result<(), SynapticError>;
1101}
1102
1103/// Handler for lifecycle events during agent execution. Receives `RunEvent` notifications at each stage.
1104#[async_trait]
1105pub trait CallbackHandler: Send + Sync {
1106    async fn on_event(&self, event: RunEvent) -> Result<(), SynapticError>;
1107}
1108
1109// ---------------------------------------------------------------------------
1110// RunnableConfig
1111// ---------------------------------------------------------------------------
1112
1113/// Runtime configuration passed through runnable chains, including tags, metadata, concurrency limits, and run identification.
1114#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1115pub struct RunnableConfig {
1116    #[serde(default)]
1117    pub tags: Vec<String>,
1118    #[serde(default)]
1119    pub metadata: HashMap<String, Value>,
1120    #[serde(default)]
1121    pub max_concurrency: Option<usize>,
1122    #[serde(default)]
1123    pub recursion_limit: Option<usize>,
1124    #[serde(default)]
1125    pub run_id: Option<String>,
1126    #[serde(default)]
1127    pub run_name: Option<String>,
1128}
1129
1130impl RunnableConfig {
1131    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
1132        self.tags = tags;
1133        self
1134    }
1135
1136    pub fn with_run_name(mut self, name: impl Into<String>) -> Self {
1137        self.run_name = Some(name.into());
1138        self
1139    }
1140
1141    pub fn with_run_id(mut self, id: impl Into<String>) -> Self {
1142        self.run_id = Some(id.into());
1143        self
1144    }
1145
1146    pub fn with_max_concurrency(mut self, max: usize) -> Self {
1147        self.max_concurrency = Some(max);
1148        self
1149    }
1150
1151    pub fn with_recursion_limit(mut self, limit: usize) -> Self {
1152        self.recursion_limit = Some(limit);
1153        self
1154    }
1155
1156    pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
1157        self.metadata.insert(key.into(), value);
1158        self
1159    }
1160}
1161
1162// ---------------------------------------------------------------------------
1163// Store trait (forward-declared in core, implemented in synaptic-store)
1164// ---------------------------------------------------------------------------
1165
1166/// A stored item in the key-value store.
1167#[derive(Debug, Clone, Serialize, Deserialize)]
1168pub struct Item {
1169    pub namespace: Vec<String>,
1170    pub key: String,
1171    pub value: Value,
1172    pub created_at: String,
1173    pub updated_at: String,
1174    /// Relevance score from a search operation (e.g., similarity score).
1175    #[serde(default, skip_serializing_if = "Option::is_none")]
1176    pub score: Option<f64>,
1177}
1178
1179/// Persistent key-value store trait for cross-invocation state.
1180///
1181/// Namespaces are hierarchical (represented as slices of strings) and
1182/// keys are strings within a namespace. Values are arbitrary JSON.
1183#[async_trait]
1184pub trait Store: Send + Sync {
1185    /// Get an item by namespace and key.
1186    async fn get(&self, namespace: &[&str], key: &str) -> Result<Option<Item>, SynapticError>;
1187
1188    /// Search items within a namespace.
1189    async fn search(
1190        &self,
1191        namespace: &[&str],
1192        query: Option<&str>,
1193        limit: usize,
1194    ) -> Result<Vec<Item>, SynapticError>;
1195
1196    /// Put (upsert) an item.
1197    async fn put(&self, namespace: &[&str], key: &str, value: Value) -> Result<(), SynapticError>;
1198
1199    /// Delete an item.
1200    async fn delete(&self, namespace: &[&str], key: &str) -> Result<(), SynapticError>;
1201
1202    /// List all namespaces, optionally filtered by prefix.
1203    async fn list_namespaces(&self, prefix: &[&str]) -> Result<Vec<Vec<String>>, SynapticError>;
1204}
1205
1206// ---------------------------------------------------------------------------
1207// Embeddings trait (forward-declared here, implemented in synaptic-embeddings)
1208// ---------------------------------------------------------------------------
1209
1210/// Trait for embedding text into vectors.
1211#[async_trait]
1212pub trait Embeddings: Send + Sync {
1213    /// Embed multiple texts (for batch document embedding).
1214    async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError>;
1215
1216    /// Embed a single query text.
1217    async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError>;
1218}
1219
1220// ---------------------------------------------------------------------------
1221// StreamWriter
1222// ---------------------------------------------------------------------------
1223
1224/// Custom stream writer that nodes can use to emit custom events.
1225pub type StreamWriter = Arc<dyn Fn(Value) + Send + Sync>;
1226
1227// ---------------------------------------------------------------------------
1228// Runtime types
1229// ---------------------------------------------------------------------------
1230
1231/// Graph execution runtime context passed to nodes and middleware.
1232#[derive(Clone)]
1233pub struct Runtime {
1234    pub store: Option<Arc<dyn Store>>,
1235    pub stream_writer: Option<StreamWriter>,
1236}
1237
1238/// Tool execution runtime context.
1239#[derive(Clone)]
1240pub struct ToolRuntime {
1241    pub store: Option<Arc<dyn Store>>,
1242    pub stream_writer: Option<StreamWriter>,
1243    pub state: Option<Value>,
1244    pub tool_call_id: String,
1245    pub config: Option<RunnableConfig>,
1246}
1247
1248// ---------------------------------------------------------------------------
1249// RuntimeAwareTool
1250// ---------------------------------------------------------------------------
1251
1252/// Context-aware tool that receives runtime information.
1253///
1254/// This extends the basic `Tool` trait with runtime context
1255/// (current state, store, stream writer, tool call ID). Implement this
1256/// for tools that need to read or modify graph state.
1257#[async_trait]
1258pub trait RuntimeAwareTool: Send + Sync {
1259    fn name(&self) -> &'static str;
1260    fn description(&self) -> &'static str;
1261
1262    fn parameters(&self) -> Option<Value> {
1263        None
1264    }
1265
1266    async fn call_with_runtime(
1267        &self,
1268        args: Value,
1269        runtime: ToolRuntime,
1270    ) -> Result<Value, SynapticError>;
1271
1272    fn as_tool_definition(&self) -> ToolDefinition {
1273        ToolDefinition {
1274            name: self.name().to_string(),
1275            description: self.description().to_string(),
1276            parameters: self
1277                .parameters()
1278                .unwrap_or(serde_json::json!({"type": "object", "properties": {}})),
1279            extras: None,
1280        }
1281    }
1282}
1283
1284/// Adapter that wraps a `RuntimeAwareTool` into a standard `Tool`.
1285///
1286/// When used outside a graph context, the tool receives a default
1287/// (empty) `ToolRuntime`.
1288pub struct RuntimeAwareToolAdapter {
1289    inner: Arc<dyn RuntimeAwareTool>,
1290    runtime: Arc<tokio::sync::RwLock<Option<ToolRuntime>>>,
1291}
1292
1293impl RuntimeAwareToolAdapter {
1294    pub fn new(tool: Arc<dyn RuntimeAwareTool>) -> Self {
1295        Self {
1296            inner: tool,
1297            runtime: Arc::new(tokio::sync::RwLock::new(None)),
1298        }
1299    }
1300
1301    pub async fn set_runtime(&self, runtime: ToolRuntime) {
1302        *self.runtime.write().await = Some(runtime);
1303    }
1304}
1305
1306#[async_trait]
1307impl Tool for RuntimeAwareToolAdapter {
1308    fn name(&self) -> &'static str {
1309        self.inner.name()
1310    }
1311
1312    fn description(&self) -> &'static str {
1313        self.inner.description()
1314    }
1315
1316    fn parameters(&self) -> Option<Value> {
1317        self.inner.parameters()
1318    }
1319
1320    async fn call(&self, args: Value) -> Result<Value, SynapticError> {
1321        let runtime = self.runtime.read().await.clone().unwrap_or(ToolRuntime {
1322            store: None,
1323            stream_writer: None,
1324            state: None,
1325            tool_call_id: String::new(),
1326            config: None,
1327        });
1328        self.inner.call_with_runtime(args, runtime).await
1329    }
1330}
1331
1332// ---------------------------------------------------------------------------
1333// Document
1334// ---------------------------------------------------------------------------
1335
1336/// A document with content and metadata, used throughout the retrieval pipeline.
1337#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1338pub struct Document {
1339    pub id: String,
1340    pub content: String,
1341    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
1342    pub metadata: HashMap<String, Value>,
1343}
1344
1345impl Document {
1346    pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
1347        Self {
1348            id: id.into(),
1349            content: content.into(),
1350            metadata: HashMap::new(),
1351        }
1352    }
1353
1354    pub fn with_metadata(
1355        id: impl Into<String>,
1356        content: impl Into<String>,
1357        metadata: HashMap<String, Value>,
1358    ) -> Self {
1359        Self {
1360            id: id.into(),
1361            content: content.into(),
1362            metadata,
1363        }
1364    }
1365}
1366
1367// ---------------------------------------------------------------------------
1368// Retriever trait (forward-declared here, implementations in synaptic-retrieval)
1369// ---------------------------------------------------------------------------
1370
1371/// Trait for retrieving relevant documents given a query string.
1372#[async_trait]
1373pub trait Retriever: Send + Sync {
1374    async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapticError>;
1375}
1376
1377// ---------------------------------------------------------------------------
1378// VectorStore trait (forward-declared here, implementations in synaptic-vectorstores)
1379// ---------------------------------------------------------------------------
1380
1381/// Trait for vector storage backends.
1382#[async_trait]
1383pub trait VectorStore: Send + Sync {
1384    /// Add documents to the store, computing their embeddings.
1385    async fn add_documents(
1386        &self,
1387        docs: Vec<Document>,
1388        embeddings: &dyn Embeddings,
1389    ) -> Result<Vec<String>, SynapticError>;
1390
1391    /// Search for similar documents by query string.
1392    async fn similarity_search(
1393        &self,
1394        query: &str,
1395        k: usize,
1396        embeddings: &dyn Embeddings,
1397    ) -> Result<Vec<Document>, SynapticError>;
1398
1399    /// Search with similarity scores (higher = more similar).
1400    async fn similarity_search_with_score(
1401        &self,
1402        query: &str,
1403        k: usize,
1404        embeddings: &dyn Embeddings,
1405    ) -> Result<Vec<(Document, f32)>, SynapticError>;
1406
1407    /// Search by pre-computed embedding vector instead of text query.
1408    async fn similarity_search_by_vector(
1409        &self,
1410        embedding: &[f32],
1411        k: usize,
1412    ) -> Result<Vec<Document>, SynapticError>;
1413
1414    /// Delete documents by ID.
1415    async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError>;
1416}
1417
1418// ---------------------------------------------------------------------------
1419// Loader trait (forward-declared here, implementations in synaptic-loaders)
1420// ---------------------------------------------------------------------------
1421
1422/// Trait for loading documents from various sources.
1423#[async_trait]
1424pub trait Loader: Send + Sync {
1425    /// Load all documents from this source.
1426    async fn load(&self) -> Result<Vec<Document>, SynapticError>;
1427
1428    /// Stream documents lazily. Default implementation wraps load().
1429    fn lazy_load(
1430        &self,
1431    ) -> Pin<Box<dyn Stream<Item = Result<Document, SynapticError>> + Send + '_>> {
1432        Box::pin(async_stream::stream! {
1433            match self.load().await {
1434                Ok(docs) => {
1435                    for doc in docs {
1436                        yield Ok(doc);
1437                    }
1438                }
1439                Err(e) => yield Err(e),
1440            }
1441        })
1442    }
1443}
1444
1445// ---------------------------------------------------------------------------
1446// LlmCache trait (forward-declared here, implementations in synaptic-cache)
1447// ---------------------------------------------------------------------------
1448
1449/// Trait for caching LLM responses.
1450#[async_trait]
1451pub trait LlmCache: Send + Sync {
1452    /// Look up a cached response by cache key.
1453    async fn get(&self, key: &str) -> Result<Option<ChatResponse>, SynapticError>;
1454    /// Store a response in the cache.
1455    async fn put(&self, key: &str, response: &ChatResponse) -> Result<(), SynapticError>;
1456    /// Clear all entries from the cache.
1457    async fn clear(&self) -> Result<(), SynapticError>;
1458}
1459
1460// ---------------------------------------------------------------------------
1461// Entrypoint / Task metadata (used by proc macros)
1462// ---------------------------------------------------------------------------
1463
1464/// Configuration for an `#[entrypoint]`-decorated function.
1465#[derive(Debug, Clone)]
1466pub struct EntrypointConfig {
1467    pub name: &'static str,
1468    pub checkpointer: Option<&'static str>,
1469}
1470
1471/// An entrypoint wrapping an async function as a runnable workflow.
1472///
1473/// The `invoke_fn` field is a type-erased async function (`Value -> Result<Value, SynapticError>`).
1474/// Type alias for the async entrypoint function signature.
1475pub type EntrypointFn = dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<Value, SynapticError>> + Send>>
1476    + Send
1477    + Sync;
1478
1479pub struct Entrypoint {
1480    pub config: EntrypointConfig,
1481    pub invoke_fn: Box<EntrypointFn>,
1482}
1483
1484impl Entrypoint {
1485    pub async fn invoke(&self, input: Value) -> Result<Value, SynapticError> {
1486        (self.invoke_fn)(input).await
1487    }
1488}