Skip to main content

synaptic_core/
lib.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::Stream;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use thiserror::Error;
11
12#[cfg(feature = "schemars")]
13pub use schemars;
14
15pub mod context_budget;
16pub mod token_counter;
17
18pub use context_budget::{ContextBudget, ContextSlot, Priority};
19pub use token_counter::{HeuristicTokenCounter, TokenCounter};
20
21// ---------------------------------------------------------------------------
22// ContentBlock — multimodal message content
23// ---------------------------------------------------------------------------
24
25/// A block of content within a message, supporting multimodal inputs.
26#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
27#[serde(tag = "type", rename_all = "snake_case")]
28pub enum ContentBlock {
29    Text {
30        text: String,
31    },
32    Image {
33        url: String,
34        #[serde(default, skip_serializing_if = "Option::is_none")]
35        detail: Option<String>,
36    },
37    Audio {
38        url: String,
39    },
40    Video {
41        url: String,
42    },
43    File {
44        url: String,
45        #[serde(default, skip_serializing_if = "Option::is_none")]
46        mime_type: Option<String>,
47    },
48    Data {
49        data: Value,
50    },
51    Reasoning {
52        content: String,
53    },
54}
55
56// ---------------------------------------------------------------------------
57// Message
58// ---------------------------------------------------------------------------
59
60/// Represents a chat message. Tagged enum with System, Human, AI, and Tool variants.
61#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
62#[serde(tag = "role")]
63pub enum Message {
64    #[serde(rename = "system")]
65    System {
66        content: String,
67        #[serde(default, skip_serializing_if = "Option::is_none")]
68        id: Option<String>,
69        #[serde(default, skip_serializing_if = "Option::is_none")]
70        name: Option<String>,
71        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
72        additional_kwargs: HashMap<String, Value>,
73        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
74        response_metadata: HashMap<String, Value>,
75        #[serde(default, skip_serializing_if = "Vec::is_empty")]
76        content_blocks: Vec<ContentBlock>,
77    },
78    #[serde(rename = "human")]
79    Human {
80        content: String,
81        #[serde(default, skip_serializing_if = "Option::is_none")]
82        id: Option<String>,
83        #[serde(default, skip_serializing_if = "Option::is_none")]
84        name: Option<String>,
85        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
86        additional_kwargs: HashMap<String, Value>,
87        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
88        response_metadata: HashMap<String, Value>,
89        #[serde(default, skip_serializing_if = "Vec::is_empty")]
90        content_blocks: Vec<ContentBlock>,
91    },
92    #[serde(rename = "assistant")]
93    AI {
94        content: String,
95        #[serde(default, skip_serializing_if = "Vec::is_empty")]
96        tool_calls: Vec<ToolCall>,
97        #[serde(default, skip_serializing_if = "Option::is_none")]
98        id: Option<String>,
99        #[serde(default, skip_serializing_if = "Option::is_none")]
100        name: Option<String>,
101        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
102        additional_kwargs: HashMap<String, Value>,
103        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
104        response_metadata: HashMap<String, Value>,
105        #[serde(default, skip_serializing_if = "Vec::is_empty")]
106        content_blocks: Vec<ContentBlock>,
107        #[serde(default, skip_serializing_if = "Option::is_none")]
108        usage_metadata: Option<TokenUsage>,
109        #[serde(default, skip_serializing_if = "Vec::is_empty")]
110        invalid_tool_calls: Vec<InvalidToolCall>,
111    },
112    #[serde(rename = "tool")]
113    Tool {
114        content: String,
115        tool_call_id: String,
116        #[serde(default, skip_serializing_if = "Option::is_none")]
117        id: Option<String>,
118        #[serde(default, skip_serializing_if = "Option::is_none")]
119        name: Option<String>,
120        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
121        additional_kwargs: HashMap<String, Value>,
122        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
123        response_metadata: HashMap<String, Value>,
124        #[serde(default, skip_serializing_if = "Vec::is_empty")]
125        content_blocks: Vec<ContentBlock>,
126    },
127    #[serde(rename = "chat")]
128    Chat {
129        custom_role: String,
130        content: String,
131        #[serde(default, skip_serializing_if = "Option::is_none")]
132        id: Option<String>,
133        #[serde(default, skip_serializing_if = "Option::is_none")]
134        name: Option<String>,
135        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
136        additional_kwargs: HashMap<String, Value>,
137        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
138        response_metadata: HashMap<String, Value>,
139        #[serde(default, skip_serializing_if = "Vec::is_empty")]
140        content_blocks: Vec<ContentBlock>,
141    },
142    /// A special message that signals removal of a message by its ID.
143    /// Used in message history management.
144    #[serde(rename = "remove")]
145    Remove {
146        /// ID of the message to remove.
147        id: String,
148    },
149}
150
151/// Helper macro to set a shared field across all Message variants.
152/// Note: Remove variant has no common fields, so it is a no-op.
153macro_rules! set_message_field {
154    ($self:expr, $field:ident, $value:expr) => {
155        match $self {
156            Message::System { $field, .. } => *$field = $value,
157            Message::Human { $field, .. } => *$field = $value,
158            Message::AI { $field, .. } => *$field = $value,
159            Message::Tool { $field, .. } => *$field = $value,
160            Message::Chat { $field, .. } => *$field = $value,
161            Message::Remove { .. } => { /* Remove has no common fields */ }
162        }
163    };
164}
165
166/// Helper macro to get a shared field from all Message variants.
167/// Note: Remove variant panics — callers handle Remove before using this macro.
168macro_rules! get_message_field {
169    ($self:expr, $field:ident) => {
170        match $self {
171            Message::System { $field, .. } => $field,
172            Message::Human { $field, .. } => $field,
173            Message::AI { $field, .. } => $field,
174            Message::Tool { $field, .. } => $field,
175            Message::Chat { $field, .. } => $field,
176            Message::Remove { .. } => unreachable!("get_message_field called on Remove variant"),
177        }
178    };
179}
180
181impl Message {
182    // -- Factory methods -----------------------------------------------------
183
184    pub fn system(content: impl Into<String>) -> Self {
185        Message::System {
186            content: content.into(),
187            id: None,
188            name: None,
189            additional_kwargs: HashMap::new(),
190            response_metadata: HashMap::new(),
191            content_blocks: Vec::new(),
192        }
193    }
194
195    pub fn human(content: impl Into<String>) -> Self {
196        Message::Human {
197            content: content.into(),
198            id: None,
199            name: None,
200            additional_kwargs: HashMap::new(),
201            response_metadata: HashMap::new(),
202            content_blocks: Vec::new(),
203        }
204    }
205
206    pub fn ai(content: impl Into<String>) -> Self {
207        Message::AI {
208            content: content.into(),
209            tool_calls: vec![],
210            id: None,
211            name: None,
212            additional_kwargs: HashMap::new(),
213            response_metadata: HashMap::new(),
214            content_blocks: Vec::new(),
215            usage_metadata: None,
216            invalid_tool_calls: Vec::new(),
217        }
218    }
219
220    pub fn ai_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
221        Message::AI {
222            content: content.into(),
223            tool_calls,
224            id: None,
225            name: None,
226            additional_kwargs: HashMap::new(),
227            response_metadata: HashMap::new(),
228            content_blocks: Vec::new(),
229            usage_metadata: None,
230            invalid_tool_calls: Vec::new(),
231        }
232    }
233
234    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
235        Message::Tool {
236            content: content.into(),
237            tool_call_id: tool_call_id.into(),
238            id: None,
239            name: None,
240            additional_kwargs: HashMap::new(),
241            response_metadata: HashMap::new(),
242            content_blocks: Vec::new(),
243        }
244    }
245
246    pub fn chat(role: impl Into<String>, content: impl Into<String>) -> Self {
247        Message::Chat {
248            custom_role: role.into(),
249            content: content.into(),
250            id: None,
251            name: None,
252            additional_kwargs: HashMap::new(),
253            response_metadata: HashMap::new(),
254            content_blocks: Vec::new(),
255        }
256    }
257
258    /// Create a Remove message that signals removal of a message by its ID.
259    pub fn remove(id: impl Into<String>) -> Self {
260        Message::Remove { id: id.into() }
261    }
262
263    // -- Builder methods -----------------------------------------------------
264
265    pub fn with_id(mut self, value: impl Into<String>) -> Self {
266        set_message_field!(&mut self, id, Some(value.into()));
267        self
268    }
269
270    pub fn with_name(mut self, value: impl Into<String>) -> Self {
271        set_message_field!(&mut self, name, Some(value.into()));
272        self
273    }
274
275    pub fn with_additional_kwarg(mut self, key: impl Into<String>, value: Value) -> Self {
276        match &mut self {
277            Message::System {
278                additional_kwargs, ..
279            }
280            | Message::Human {
281                additional_kwargs, ..
282            }
283            | Message::AI {
284                additional_kwargs, ..
285            }
286            | Message::Tool {
287                additional_kwargs, ..
288            }
289            | Message::Chat {
290                additional_kwargs, ..
291            } => {
292                additional_kwargs.insert(key.into(), value);
293            }
294            Message::Remove { .. } => { /* Remove has no additional_kwargs */ }
295        }
296        self
297    }
298
299    pub fn with_response_metadata_entry(mut self, key: impl Into<String>, value: Value) -> Self {
300        match &mut self {
301            Message::System {
302                response_metadata, ..
303            }
304            | Message::Human {
305                response_metadata, ..
306            }
307            | Message::AI {
308                response_metadata, ..
309            }
310            | Message::Tool {
311                response_metadata, ..
312            }
313            | Message::Chat {
314                response_metadata, ..
315            } => {
316                response_metadata.insert(key.into(), value);
317            }
318            Message::Remove { .. } => { /* Remove has no response_metadata */ }
319        }
320        self
321    }
322
323    pub fn with_content_blocks(mut self, blocks: Vec<ContentBlock>) -> Self {
324        set_message_field!(&mut self, content_blocks, blocks);
325        self
326    }
327
328    pub fn with_usage_metadata(mut self, usage: TokenUsage) -> Self {
329        if let Message::AI { usage_metadata, .. } = &mut self {
330            *usage_metadata = Some(usage);
331        }
332        self
333    }
334
335    // -- Accessor methods ----------------------------------------------------
336
337    pub fn content(&self) -> &str {
338        match self {
339            Message::Remove { .. } => "",
340            other => get_message_field!(other, content),
341        }
342    }
343
344    pub fn role(&self) -> &str {
345        match self {
346            Message::System { .. } => "system",
347            Message::Human { .. } => "human",
348            Message::AI { .. } => "assistant",
349            Message::Tool { .. } => "tool",
350            Message::Chat { custom_role, .. } => custom_role,
351            Message::Remove { .. } => "remove",
352        }
353    }
354
355    pub fn is_system(&self) -> bool {
356        matches!(self, Message::System { .. })
357    }
358
359    pub fn is_human(&self) -> bool {
360        matches!(self, Message::Human { .. })
361    }
362
363    pub fn is_ai(&self) -> bool {
364        matches!(self, Message::AI { .. })
365    }
366
367    pub fn is_tool(&self) -> bool {
368        matches!(self, Message::Tool { .. })
369    }
370
371    pub fn is_chat(&self) -> bool {
372        matches!(self, Message::Chat { .. })
373    }
374
375    pub fn is_remove(&self) -> bool {
376        matches!(self, Message::Remove { .. })
377    }
378
379    pub fn tool_calls(&self) -> &[ToolCall] {
380        match self {
381            Message::AI { tool_calls, .. } => tool_calls,
382            _ => &[],
383        }
384    }
385
386    pub fn tool_call_id(&self) -> Option<&str> {
387        match self {
388            Message::Tool { tool_call_id, .. } => Some(tool_call_id),
389            _ => None,
390        }
391    }
392
393    pub fn id(&self) -> Option<&str> {
394        match self {
395            Message::Remove { id } => Some(id),
396            other => get_message_field!(other, id).as_deref(),
397        }
398    }
399
400    pub fn name(&self) -> Option<&str> {
401        match self {
402            Message::Remove { .. } => None,
403            other => get_message_field!(other, name).as_deref(),
404        }
405    }
406
407    pub fn additional_kwargs(&self) -> &HashMap<String, Value> {
408        match self {
409            Message::System {
410                additional_kwargs, ..
411            }
412            | Message::Human {
413                additional_kwargs, ..
414            }
415            | Message::AI {
416                additional_kwargs, ..
417            }
418            | Message::Tool {
419                additional_kwargs, ..
420            }
421            | Message::Chat {
422                additional_kwargs, ..
423            } => additional_kwargs,
424            Message::Remove { .. } => {
425                static EMPTY: std::sync::OnceLock<HashMap<String, Value>> =
426                    std::sync::OnceLock::new();
427                EMPTY.get_or_init(HashMap::new)
428            }
429        }
430    }
431
432    pub fn response_metadata(&self) -> &HashMap<String, Value> {
433        match self {
434            Message::System {
435                response_metadata, ..
436            }
437            | Message::Human {
438                response_metadata, ..
439            }
440            | Message::AI {
441                response_metadata, ..
442            }
443            | Message::Tool {
444                response_metadata, ..
445            }
446            | Message::Chat {
447                response_metadata, ..
448            } => response_metadata,
449            Message::Remove { .. } => {
450                static EMPTY: std::sync::OnceLock<HashMap<String, Value>> =
451                    std::sync::OnceLock::new();
452                EMPTY.get_or_init(HashMap::new)
453            }
454        }
455    }
456
457    pub fn content_blocks(&self) -> &[ContentBlock] {
458        match self {
459            Message::Remove { .. } => &[],
460            other => get_message_field!(other, content_blocks),
461        }
462    }
463
464    /// Return the remove ID if this is a Remove message.
465    pub fn remove_id(&self) -> Option<&str> {
466        match self {
467            Message::Remove { id } => Some(id),
468            _ => None,
469        }
470    }
471
472    pub fn usage_metadata(&self) -> Option<&TokenUsage> {
473        match self {
474            Message::AI { usage_metadata, .. } => usage_metadata.as_ref(),
475            _ => None,
476        }
477    }
478
479    pub fn invalid_tool_calls(&self) -> &[InvalidToolCall] {
480        match self {
481            Message::AI {
482                invalid_tool_calls, ..
483            } => invalid_tool_calls,
484            _ => &[],
485        }
486    }
487
488    /// Set the content of this message. No-op for Remove variant.
489    pub fn set_content(&mut self, new_content: impl Into<String>) {
490        let new_content = new_content.into();
491        set_message_field!(self, content, new_content);
492    }
493}
494
495// ---------------------------------------------------------------------------
496// Message utility functions
497// ---------------------------------------------------------------------------
498
499/// Filter messages by type, name, or id.
500pub fn filter_messages(
501    messages: &[Message],
502    include_types: Option<&[&str]>,
503    exclude_types: Option<&[&str]>,
504    include_names: Option<&[&str]>,
505    exclude_names: Option<&[&str]>,
506    include_ids: Option<&[&str]>,
507    exclude_ids: Option<&[&str]>,
508) -> Vec<Message> {
509    messages
510        .iter()
511        .filter(|msg| {
512            if let Some(include) = include_types {
513                if !include.contains(&msg.role()) {
514                    return false;
515                }
516            }
517            if let Some(exclude) = exclude_types {
518                if exclude.contains(&msg.role()) {
519                    return false;
520                }
521            }
522            if let Some(include) = include_names {
523                match msg.name() {
524                    Some(name) => {
525                        if !include.contains(&name) {
526                            return false;
527                        }
528                    }
529                    None => return false,
530                }
531            }
532            if let Some(exclude) = exclude_names {
533                if let Some(name) = msg.name() {
534                    if exclude.contains(&name) {
535                        return false;
536                    }
537                }
538            }
539            if let Some(include) = include_ids {
540                match msg.id() {
541                    Some(id) => {
542                        if !include.contains(&id) {
543                            return false;
544                        }
545                    }
546                    None => return false,
547                }
548            }
549            if let Some(exclude) = exclude_ids {
550                if let Some(id) = msg.id() {
551                    if exclude.contains(&id) {
552                        return false;
553                    }
554                }
555            }
556            true
557        })
558        .cloned()
559        .collect()
560}
561
562/// Strategy for trimming messages.
563#[derive(Debug, Clone, Copy, PartialEq, Eq)]
564pub enum TrimStrategy {
565    /// Keep the first messages that fit within the token budget.
566    First,
567    /// Keep the last messages that fit within the token budget.
568    Last,
569}
570
571/// Trim messages to fit within a token budget.
572///
573/// `token_counter` receives a single message and returns its token count.
574/// When `include_system` is true and `strategy` is `Last`, the leading system
575/// message is always preserved.
576pub fn trim_messages(
577    messages: Vec<Message>,
578    max_tokens: usize,
579    token_counter: impl Fn(&Message) -> usize,
580    strategy: TrimStrategy,
581    include_system: bool,
582) -> Vec<Message> {
583    if messages.is_empty() {
584        return messages;
585    }
586
587    match strategy {
588        TrimStrategy::First => {
589            let mut result = Vec::new();
590            let mut total = 0;
591            for msg in messages {
592                let count = token_counter(&msg);
593                if total + count > max_tokens {
594                    break;
595                }
596                total += count;
597                result.push(msg);
598            }
599            result
600        }
601        TrimStrategy::Last => {
602            let (system_msg, rest) = if include_system && messages[0].is_system() {
603                (Some(messages[0].clone()), &messages[1..])
604            } else {
605                (None, messages.as_slice())
606            };
607
608            let system_tokens = system_msg.as_ref().map(&token_counter).unwrap_or(0);
609            let budget = max_tokens.saturating_sub(system_tokens);
610
611            let mut selected = Vec::new();
612            let mut total = 0;
613            for msg in rest.iter().rev() {
614                let count = token_counter(msg);
615                if total + count > budget {
616                    break;
617                }
618                total += count;
619                selected.push(msg.clone());
620            }
621            selected.reverse();
622
623            let mut result = Vec::new();
624            if let Some(sys) = system_msg {
625                result.push(sys);
626            }
627            result.extend(selected);
628            result
629        }
630    }
631}
632
633/// Merge consecutive messages of the same role into a single message.
634pub fn merge_message_runs(messages: Vec<Message>) -> Vec<Message> {
635    if messages.is_empty() {
636        return messages;
637    }
638
639    let mut result: Vec<Message> = Vec::new();
640
641    for msg in messages {
642        let should_merge = result
643            .last()
644            .map(|last| last.role() == msg.role())
645            .unwrap_or(false);
646
647        if should_merge {
648            let last = result.last_mut().unwrap();
649            // Merge content
650            let merged_content = format!("{}\n{}", last.content(), msg.content());
651            match last {
652                Message::System { content, .. } => *content = merged_content,
653                Message::Human { content, .. } => *content = merged_content,
654                Message::AI {
655                    content,
656                    tool_calls,
657                    invalid_tool_calls,
658                    ..
659                } => {
660                    *content = merged_content;
661                    tool_calls.extend(msg.tool_calls().to_vec());
662                    invalid_tool_calls.extend(msg.invalid_tool_calls().to_vec());
663                }
664                Message::Tool { content, .. } => *content = merged_content,
665                Message::Chat { content, .. } => *content = merged_content,
666                Message::Remove { .. } => { /* Remove messages are not merged */ }
667            }
668        } else {
669            result.push(msg);
670        }
671    }
672
673    result
674}
675
676/// Convert messages to a human-readable buffer string.
677pub fn get_buffer_string(messages: &[Message], human_prefix: &str, ai_prefix: &str) -> String {
678    messages
679        .iter()
680        .map(|msg| {
681            let prefix = match msg {
682                Message::System { .. } => "System",
683                Message::Human { .. } => human_prefix,
684                Message::AI { .. } => ai_prefix,
685                Message::Tool { .. } => "Tool",
686                Message::Chat { custom_role, .. } => custom_role.as_str(),
687                Message::Remove { .. } => "Remove",
688            };
689            format!("{prefix}: {}", msg.content())
690        })
691        .collect::<Vec<_>>()
692        .join("\n")
693}
694
695// ---------------------------------------------------------------------------
696// AIMessageChunk
697// ---------------------------------------------------------------------------
698
699/// A streaming chunk from an AI model response. Supports merge via `+`/`+=` operators and conversion to `Message` via `into_message()`.
700#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
701pub struct AIMessageChunk {
702    pub content: String,
703    #[serde(default, skip_serializing_if = "Vec::is_empty")]
704    pub tool_calls: Vec<ToolCall>,
705    #[serde(default, skip_serializing_if = "Option::is_none")]
706    pub usage: Option<TokenUsage>,
707    #[serde(default, skip_serializing_if = "Option::is_none")]
708    pub id: Option<String>,
709    #[serde(default, skip_serializing_if = "Vec::is_empty")]
710    pub tool_call_chunks: Vec<ToolCallChunk>,
711    #[serde(default, skip_serializing_if = "Vec::is_empty")]
712    pub invalid_tool_calls: Vec<InvalidToolCall>,
713}
714
715impl AIMessageChunk {
716    pub fn into_message(self) -> Message {
717        Message::ai_with_tool_calls(self.content, self.tool_calls)
718    }
719}
720
721impl std::ops::Add for AIMessageChunk {
722    type Output = Self;
723
724    fn add(mut self, rhs: Self) -> Self {
725        self += rhs;
726        self
727    }
728}
729
730impl std::ops::AddAssign for AIMessageChunk {
731    fn add_assign(&mut self, rhs: Self) {
732        self.content.push_str(&rhs.content);
733        self.tool_calls.extend(rhs.tool_calls);
734        self.tool_call_chunks.extend(rhs.tool_call_chunks);
735        self.invalid_tool_calls.extend(rhs.invalid_tool_calls);
736        if self.id.is_none() {
737            self.id = rhs.id;
738        }
739        match (&mut self.usage, rhs.usage) {
740            (Some(u), Some(rhs_u)) => {
741                u.input_tokens += rhs_u.input_tokens;
742                u.output_tokens += rhs_u.output_tokens;
743                u.total_tokens += rhs_u.total_tokens;
744            }
745            (None, Some(rhs_u)) => {
746                self.usage = Some(rhs_u);
747            }
748            _ => {}
749        }
750    }
751}
752
753// ---------------------------------------------------------------------------
754// Tool-related types
755// ---------------------------------------------------------------------------
756
757/// Represents a tool invocation requested by an AI model, with an ID, function name, and JSON arguments.
758#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
759pub struct ToolCall {
760    pub id: String,
761    pub name: String,
762    pub arguments: Value,
763}
764
765/// A tool call that failed to parse correctly.
766#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
767pub struct InvalidToolCall {
768    #[serde(default, skip_serializing_if = "Option::is_none")]
769    pub id: Option<String>,
770    #[serde(default, skip_serializing_if = "Option::is_none")]
771    pub name: Option<String>,
772    #[serde(default, skip_serializing_if = "Option::is_none")]
773    pub arguments: Option<String>,
774    pub error: String,
775}
776
777/// A partial tool call chunk received during streaming.
778#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
779pub struct ToolCallChunk {
780    #[serde(default, skip_serializing_if = "Option::is_none")]
781    pub id: Option<String>,
782    #[serde(default, skip_serializing_if = "Option::is_none")]
783    pub name: Option<String>,
784    #[serde(default, skip_serializing_if = "Option::is_none")]
785    pub arguments: Option<String>,
786    #[serde(default, skip_serializing_if = "Option::is_none")]
787    pub index: Option<usize>,
788}
789
790/// Schema definition for a tool, including its name, description, and JSON Schema for parameters.
791#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
792pub struct ToolDefinition {
793    pub name: String,
794    pub description: String,
795    pub parameters: Value,
796    /// Provider-specific parameters (e.g., Anthropic's `cache_control`).
797    #[serde(default, skip_serializing_if = "Option::is_none")]
798    pub extras: Option<HashMap<String, Value>>,
799}
800
801/// Controls how the model selects tools: Auto, Required, None, or a Specific named tool.
802#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
803#[serde(rename_all = "lowercase")]
804pub enum ToolChoice {
805    Auto,
806    Required,
807    None,
808    Specific(String),
809}
810
811// ---------------------------------------------------------------------------
812// Chat request / response
813// ---------------------------------------------------------------------------
814
815/// A request to a chat model containing messages, optional tool definitions, and tool choice configuration.
816#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
817pub struct ChatRequest {
818    pub messages: Vec<Message>,
819    #[serde(default, skip_serializing_if = "Vec::is_empty")]
820    pub tools: Vec<ToolDefinition>,
821    #[serde(default, skip_serializing_if = "Option::is_none")]
822    pub tool_choice: Option<ToolChoice>,
823}
824
825impl ChatRequest {
826    pub fn new(messages: Vec<Message>) -> Self {
827        Self {
828            messages,
829            tools: vec![],
830            tool_choice: None,
831        }
832    }
833
834    pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
835        self.tools = tools;
836        self
837    }
838
839    pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
840        self.tool_choice = Some(choice);
841        self
842    }
843}
844
845/// A response from a chat model containing the AI message and optional token usage statistics.
846#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
847pub struct ChatResponse {
848    pub message: Message,
849    pub usage: Option<TokenUsage>,
850}
851
852// ---------------------------------------------------------------------------
853// Token usage
854// ---------------------------------------------------------------------------
855
856#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
857pub struct TokenUsage {
858    pub input_tokens: u32,
859    pub output_tokens: u32,
860    pub total_tokens: u32,
861    #[serde(default, skip_serializing_if = "Option::is_none")]
862    pub input_details: Option<InputTokenDetails>,
863    #[serde(default, skip_serializing_if = "Option::is_none")]
864    pub output_details: Option<OutputTokenDetails>,
865}
866
867/// Detailed breakdown of input token usage.
868#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
869pub struct InputTokenDetails {
870    #[serde(default)]
871    pub cached: u32,
872    #[serde(default)]
873    pub audio: u32,
874}
875
876/// Detailed breakdown of output token usage.
877#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
878pub struct OutputTokenDetails {
879    #[serde(default)]
880    pub reasoning: u32,
881    #[serde(default)]
882    pub audio: u32,
883}
884
885// ---------------------------------------------------------------------------
886// Events
887// ---------------------------------------------------------------------------
888
889/// Lifecycle events emitted during agent execution, used by `CallbackHandler` implementations.
890#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
891pub enum RunEvent {
892    RunStarted {
893        run_id: String,
894        session_id: String,
895    },
896    RunStep {
897        run_id: String,
898        step: usize,
899    },
900    LlmCalled {
901        run_id: String,
902        message_count: usize,
903    },
904    ToolCalled {
905        run_id: String,
906        tool_name: String,
907    },
908    RunFinished {
909        run_id: String,
910        output: String,
911    },
912    RunFailed {
913        run_id: String,
914        error: String,
915    },
916    /// Emitted before a tool call is executed.
917    BeforeToolCall {
918        run_id: String,
919        tool_name: String,
920        arguments: String,
921    },
922    /// Emitted after a tool call completes.
923    AfterToolCall {
924        run_id: String,
925        tool_name: String,
926        result: String,
927    },
928    /// Emitted before a message is sent to the model.
929    BeforeMessage {
930        run_id: String,
931        message_count: usize,
932    },
933    /// Emitted after a model response is received.
934    AfterMessage {
935        run_id: String,
936        response_length: usize,
937    },
938}
939
940// ---------------------------------------------------------------------------
941// Errors
942// ---------------------------------------------------------------------------
943
944/// Unified error type for the Synaptic framework with variants covering all subsystems.
945#[derive(Debug, Error)]
946pub enum SynapticError {
947    #[error("prompt error: {0}")]
948    Prompt(String),
949    #[error("model error: {0}")]
950    Model(String),
951    #[error("tool error: {0}")]
952    Tool(String),
953    #[error("tool not found: {0}")]
954    ToolNotFound(String),
955    #[error("memory error: {0}")]
956    Memory(String),
957    #[error("rate limit: {0}")]
958    RateLimit(String),
959    #[error("timeout: {0}")]
960    Timeout(String),
961    #[error("validation error: {0}")]
962    Validation(String),
963    #[error("parsing error: {0}")]
964    Parsing(String),
965    #[error("callback error: {0}")]
966    Callback(String),
967    #[error("max steps exceeded: {max_steps}")]
968    MaxStepsExceeded { max_steps: usize },
969    #[error("embedding error: {0}")]
970    Embedding(String),
971    #[error("vector store error: {0}")]
972    VectorStore(String),
973    #[error("retriever error: {0}")]
974    Retriever(String),
975    #[error("loader error: {0}")]
976    Loader(String),
977    #[error("splitter error: {0}")]
978    Splitter(String),
979    #[error("graph error: {0}")]
980    Graph(String),
981    #[error("cache error: {0}")]
982    Cache(String),
983    #[error("store error: {0}")]
984    Store(String),
985    #[error("config error: {0}")]
986    Config(String),
987    #[error("mcp error: {0}")]
988    Mcp(String),
989    #[error("security error: {0}")]
990    Security(String),
991}
992
993// ---------------------------------------------------------------------------
994// Core traits
995// ---------------------------------------------------------------------------
996
997/// Type alias for a pinned, boxed async stream of `AIMessageChunk` results.
998pub type ChatStream<'a> =
999    Pin<Box<dyn Stream<Item = Result<AIMessageChunk, SynapticError>> + Send + 'a>>;
1000
1001/// Describes a model's capabilities and limits.
1002#[derive(Debug, Clone, Serialize, Deserialize)]
1003pub struct ModelProfile {
1004    pub name: String,
1005    pub provider: String,
1006    pub supports_tool_calling: bool,
1007    pub supports_structured_output: bool,
1008    pub supports_streaming: bool,
1009    pub max_input_tokens: Option<usize>,
1010    pub max_output_tokens: Option<usize>,
1011}
1012
1013/// The core trait for language model providers. Implementations provide `chat()` for single responses and optionally `stream_chat()` for streaming.
1014#[async_trait]
1015pub trait ChatModel: Send + Sync {
1016    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError>;
1017
1018    /// Return the model's capability profile, if known.
1019    fn profile(&self) -> Option<ModelProfile> {
1020        None
1021    }
1022
1023    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
1024        Box::pin(async_stream::stream! {
1025            match self.chat(request).await {
1026                Ok(response) => {
1027                    yield Ok(AIMessageChunk {
1028                        content: response.message.content().to_string(),
1029                        tool_calls: response.message.tool_calls().to_vec(),
1030                        usage: response.usage,
1031                        ..Default::default()
1032                    });
1033                }
1034                Err(e) => yield Err(e),
1035            }
1036        })
1037    }
1038}
1039
1040/// Defines an executable tool that can be called by an AI model. Each tool has a name, description, JSON schema for parameters, and an async `call()` method.
1041#[async_trait]
1042pub trait Tool: Send + Sync {
1043    fn name(&self) -> &'static str;
1044    fn description(&self) -> &'static str;
1045
1046    fn parameters(&self) -> Option<Value> {
1047        None
1048    }
1049
1050    async fn call(&self, args: Value) -> Result<Value, SynapticError>;
1051
1052    fn as_tool_definition(&self) -> ToolDefinition {
1053        ToolDefinition {
1054            name: self.name().to_string(),
1055            description: self.description().to_string(),
1056            parameters: self
1057                .parameters()
1058                .unwrap_or(serde_json::json!({"type": "object", "properties": {}})),
1059            extras: None,
1060        }
1061    }
1062}
1063
1064// ---------------------------------------------------------------------------
1065// MemoryStore
1066// ---------------------------------------------------------------------------
1067
1068/// Persistent storage for conversation message history, keyed by session ID.
1069#[async_trait]
1070pub trait MemoryStore: Send + Sync {
1071    async fn append(&self, session_id: &str, message: Message) -> Result<(), SynapticError>;
1072    async fn load(&self, session_id: &str) -> Result<Vec<Message>, SynapticError>;
1073    async fn clear(&self, session_id: &str) -> Result<(), SynapticError>;
1074}
1075
1076/// Handler for lifecycle events during agent execution. Receives `RunEvent` notifications at each stage.
1077#[async_trait]
1078pub trait CallbackHandler: Send + Sync {
1079    async fn on_event(&self, event: RunEvent) -> Result<(), SynapticError>;
1080}
1081
1082// ---------------------------------------------------------------------------
1083// RunnableConfig
1084// ---------------------------------------------------------------------------
1085
1086/// Runtime configuration passed through runnable chains, including tags, metadata, concurrency limits, and run identification.
1087#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1088pub struct RunnableConfig {
1089    #[serde(default)]
1090    pub tags: Vec<String>,
1091    #[serde(default)]
1092    pub metadata: HashMap<String, Value>,
1093    #[serde(default)]
1094    pub max_concurrency: Option<usize>,
1095    #[serde(default)]
1096    pub recursion_limit: Option<usize>,
1097    #[serde(default)]
1098    pub run_id: Option<String>,
1099    #[serde(default)]
1100    pub run_name: Option<String>,
1101}
1102
1103impl RunnableConfig {
1104    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
1105        self.tags = tags;
1106        self
1107    }
1108
1109    pub fn with_run_name(mut self, name: impl Into<String>) -> Self {
1110        self.run_name = Some(name.into());
1111        self
1112    }
1113
1114    pub fn with_run_id(mut self, id: impl Into<String>) -> Self {
1115        self.run_id = Some(id.into());
1116        self
1117    }
1118
1119    pub fn with_max_concurrency(mut self, max: usize) -> Self {
1120        self.max_concurrency = Some(max);
1121        self
1122    }
1123
1124    pub fn with_recursion_limit(mut self, limit: usize) -> Self {
1125        self.recursion_limit = Some(limit);
1126        self
1127    }
1128
1129    pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
1130        self.metadata.insert(key.into(), value);
1131        self
1132    }
1133}
1134
1135// ---------------------------------------------------------------------------
1136// Store trait (forward-declared in core, implemented in synaptic-store)
1137// ---------------------------------------------------------------------------
1138
1139/// A stored item in the key-value store.
1140#[derive(Debug, Clone, Serialize, Deserialize)]
1141pub struct Item {
1142    pub namespace: Vec<String>,
1143    pub key: String,
1144    pub value: Value,
1145    pub created_at: String,
1146    pub updated_at: String,
1147    /// Relevance score from a search operation (e.g., similarity score).
1148    #[serde(default, skip_serializing_if = "Option::is_none")]
1149    pub score: Option<f64>,
1150}
1151
1152/// Persistent key-value store trait for cross-invocation state.
1153///
1154/// Namespaces are hierarchical (represented as slices of strings) and
1155/// keys are strings within a namespace. Values are arbitrary JSON.
1156#[async_trait]
1157pub trait Store: Send + Sync {
1158    /// Get an item by namespace and key.
1159    async fn get(&self, namespace: &[&str], key: &str) -> Result<Option<Item>, SynapticError>;
1160
1161    /// Search items within a namespace.
1162    async fn search(
1163        &self,
1164        namespace: &[&str],
1165        query: Option<&str>,
1166        limit: usize,
1167    ) -> Result<Vec<Item>, SynapticError>;
1168
1169    /// Put (upsert) an item.
1170    async fn put(&self, namespace: &[&str], key: &str, value: Value) -> Result<(), SynapticError>;
1171
1172    /// Delete an item.
1173    async fn delete(&self, namespace: &[&str], key: &str) -> Result<(), SynapticError>;
1174
1175    /// List all namespaces, optionally filtered by prefix.
1176    async fn list_namespaces(&self, prefix: &[&str]) -> Result<Vec<Vec<String>>, SynapticError>;
1177}
1178
1179// ---------------------------------------------------------------------------
1180// Embeddings trait (forward-declared here, implemented in synaptic-embeddings)
1181// ---------------------------------------------------------------------------
1182
1183/// Trait for embedding text into vectors.
1184#[async_trait]
1185pub trait Embeddings: Send + Sync {
1186    /// Embed multiple texts (for batch document embedding).
1187    async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError>;
1188
1189    /// Embed a single query text.
1190    async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError>;
1191}
1192
1193// ---------------------------------------------------------------------------
1194// StreamWriter
1195// ---------------------------------------------------------------------------
1196
1197/// Custom stream writer that nodes can use to emit custom events.
1198pub type StreamWriter = Arc<dyn Fn(Value) + Send + Sync>;
1199
1200// ---------------------------------------------------------------------------
1201// Runtime types
1202// ---------------------------------------------------------------------------
1203
1204/// Graph execution runtime context passed to nodes and middleware.
1205#[derive(Clone)]
1206pub struct Runtime {
1207    pub store: Option<Arc<dyn Store>>,
1208    pub stream_writer: Option<StreamWriter>,
1209}
1210
1211/// Tool execution runtime context.
1212#[derive(Clone)]
1213pub struct ToolRuntime {
1214    pub store: Option<Arc<dyn Store>>,
1215    pub stream_writer: Option<StreamWriter>,
1216    pub state: Option<Value>,
1217    pub tool_call_id: String,
1218    pub config: Option<RunnableConfig>,
1219}
1220
1221// ---------------------------------------------------------------------------
1222// RuntimeAwareTool
1223// ---------------------------------------------------------------------------
1224
1225/// Context-aware tool that receives runtime information.
1226///
1227/// This extends the basic `Tool` trait with runtime context
1228/// (current state, store, stream writer, tool call ID). Implement this
1229/// for tools that need to read or modify graph state.
1230#[async_trait]
1231pub trait RuntimeAwareTool: Send + Sync {
1232    fn name(&self) -> &'static str;
1233    fn description(&self) -> &'static str;
1234
1235    fn parameters(&self) -> Option<Value> {
1236        None
1237    }
1238
1239    async fn call_with_runtime(
1240        &self,
1241        args: Value,
1242        runtime: ToolRuntime,
1243    ) -> Result<Value, SynapticError>;
1244
1245    fn as_tool_definition(&self) -> ToolDefinition {
1246        ToolDefinition {
1247            name: self.name().to_string(),
1248            description: self.description().to_string(),
1249            parameters: self
1250                .parameters()
1251                .unwrap_or(serde_json::json!({"type": "object", "properties": {}})),
1252            extras: None,
1253        }
1254    }
1255}
1256
1257/// Adapter that wraps a `RuntimeAwareTool` into a standard `Tool`.
1258///
1259/// When used outside a graph context, the tool receives a default
1260/// (empty) `ToolRuntime`.
1261pub struct RuntimeAwareToolAdapter {
1262    inner: Arc<dyn RuntimeAwareTool>,
1263    runtime: Arc<tokio::sync::RwLock<Option<ToolRuntime>>>,
1264}
1265
1266impl RuntimeAwareToolAdapter {
1267    pub fn new(tool: Arc<dyn RuntimeAwareTool>) -> Self {
1268        Self {
1269            inner: tool,
1270            runtime: Arc::new(tokio::sync::RwLock::new(None)),
1271        }
1272    }
1273
1274    pub async fn set_runtime(&self, runtime: ToolRuntime) {
1275        *self.runtime.write().await = Some(runtime);
1276    }
1277}
1278
1279#[async_trait]
1280impl Tool for RuntimeAwareToolAdapter {
1281    fn name(&self) -> &'static str {
1282        self.inner.name()
1283    }
1284
1285    fn description(&self) -> &'static str {
1286        self.inner.description()
1287    }
1288
1289    fn parameters(&self) -> Option<Value> {
1290        self.inner.parameters()
1291    }
1292
1293    async fn call(&self, args: Value) -> Result<Value, SynapticError> {
1294        let runtime = self.runtime.read().await.clone().unwrap_or(ToolRuntime {
1295            store: None,
1296            stream_writer: None,
1297            state: None,
1298            tool_call_id: String::new(),
1299            config: None,
1300        });
1301        self.inner.call_with_runtime(args, runtime).await
1302    }
1303}
1304
1305// ---------------------------------------------------------------------------
1306// Document
1307// ---------------------------------------------------------------------------
1308
1309/// A document with content and metadata, used throughout the retrieval pipeline.
1310#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1311pub struct Document {
1312    pub id: String,
1313    pub content: String,
1314    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
1315    pub metadata: HashMap<String, Value>,
1316}
1317
1318impl Document {
1319    pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
1320        Self {
1321            id: id.into(),
1322            content: content.into(),
1323            metadata: HashMap::new(),
1324        }
1325    }
1326
1327    pub fn with_metadata(
1328        id: impl Into<String>,
1329        content: impl Into<String>,
1330        metadata: HashMap<String, Value>,
1331    ) -> Self {
1332        Self {
1333            id: id.into(),
1334            content: content.into(),
1335            metadata,
1336        }
1337    }
1338}
1339
1340// ---------------------------------------------------------------------------
1341// Retriever trait (forward-declared here, implementations in synaptic-retrieval)
1342// ---------------------------------------------------------------------------
1343
1344/// Trait for retrieving relevant documents given a query string.
1345#[async_trait]
1346pub trait Retriever: Send + Sync {
1347    async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapticError>;
1348}
1349
1350// ---------------------------------------------------------------------------
1351// VectorStore trait (forward-declared here, implementations in synaptic-vectorstores)
1352// ---------------------------------------------------------------------------
1353
1354/// Trait for vector storage backends.
1355#[async_trait]
1356pub trait VectorStore: Send + Sync {
1357    /// Add documents to the store, computing their embeddings.
1358    async fn add_documents(
1359        &self,
1360        docs: Vec<Document>,
1361        embeddings: &dyn Embeddings,
1362    ) -> Result<Vec<String>, SynapticError>;
1363
1364    /// Search for similar documents by query string.
1365    async fn similarity_search(
1366        &self,
1367        query: &str,
1368        k: usize,
1369        embeddings: &dyn Embeddings,
1370    ) -> Result<Vec<Document>, SynapticError>;
1371
1372    /// Search with similarity scores (higher = more similar).
1373    async fn similarity_search_with_score(
1374        &self,
1375        query: &str,
1376        k: usize,
1377        embeddings: &dyn Embeddings,
1378    ) -> Result<Vec<(Document, f32)>, SynapticError>;
1379
1380    /// Search by pre-computed embedding vector instead of text query.
1381    async fn similarity_search_by_vector(
1382        &self,
1383        embedding: &[f32],
1384        k: usize,
1385    ) -> Result<Vec<Document>, SynapticError>;
1386
1387    /// Delete documents by ID.
1388    async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError>;
1389}
1390
1391// ---------------------------------------------------------------------------
1392// Loader trait (forward-declared here, implementations in synaptic-loaders)
1393// ---------------------------------------------------------------------------
1394
1395/// Trait for loading documents from various sources.
1396#[async_trait]
1397pub trait Loader: Send + Sync {
1398    /// Load all documents from this source.
1399    async fn load(&self) -> Result<Vec<Document>, SynapticError>;
1400
1401    /// Stream documents lazily. Default implementation wraps load().
1402    fn lazy_load(
1403        &self,
1404    ) -> Pin<Box<dyn Stream<Item = Result<Document, SynapticError>> + Send + '_>> {
1405        Box::pin(async_stream::stream! {
1406            match self.load().await {
1407                Ok(docs) => {
1408                    for doc in docs {
1409                        yield Ok(doc);
1410                    }
1411                }
1412                Err(e) => yield Err(e),
1413            }
1414        })
1415    }
1416}
1417
1418// ---------------------------------------------------------------------------
1419// LlmCache trait (forward-declared here, implementations in synaptic-cache)
1420// ---------------------------------------------------------------------------
1421
1422/// Trait for caching LLM responses.
1423#[async_trait]
1424pub trait LlmCache: Send + Sync {
1425    /// Look up a cached response by cache key.
1426    async fn get(&self, key: &str) -> Result<Option<ChatResponse>, SynapticError>;
1427    /// Store a response in the cache.
1428    async fn put(&self, key: &str, response: &ChatResponse) -> Result<(), SynapticError>;
1429    /// Clear all entries from the cache.
1430    async fn clear(&self) -> Result<(), SynapticError>;
1431}
1432
1433// ---------------------------------------------------------------------------
1434// Entrypoint / Task metadata (used by proc macros)
1435// ---------------------------------------------------------------------------
1436
1437/// Configuration for an `#[entrypoint]`-decorated function.
1438#[derive(Debug, Clone)]
1439pub struct EntrypointConfig {
1440    pub name: &'static str,
1441    pub checkpointer: Option<&'static str>,
1442}
1443
1444/// An entrypoint wrapping an async function as a runnable workflow.
1445///
1446/// The `invoke_fn` field is a type-erased async function (`Value -> Result<Value, SynapticError>`).
1447/// Type alias for the async entrypoint function signature.
1448pub type EntrypointFn = dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<Value, SynapticError>> + Send>>
1449    + Send
1450    + Sync;
1451
1452pub struct Entrypoint {
1453    pub config: EntrypointConfig,
1454    pub invoke_fn: Box<EntrypointFn>,
1455}
1456
1457impl Entrypoint {
1458    pub async fn invoke(&self, input: Value) -> Result<Value, SynapticError> {
1459        (self.invoke_fn)(input).await
1460    }
1461}
1462
1463// ---------------------------------------------------------------------------
1464// Shared store utilities
1465// ---------------------------------------------------------------------------
1466
1467/// ISO 8601-ish timestamp for store metadata (no chrono dependency).
1468pub fn now_iso() -> String {
1469    format!("{:?}", std::time::SystemTime::now())
1470}
1471
1472/// Encode a namespace slice as a colon-separated string.
1473pub fn encode_namespace(namespace: &[&str]) -> String {
1474    namespace.join(":")
1475}
1476
1477/// Validate a SQL table name (alphanumeric, underscore, and dot only).
1478pub fn validate_table_name(name: &str) -> Result<(), SynapticError> {
1479    if name.is_empty() {
1480        return Err(SynapticError::Store(
1481            "table name must not be empty".to_string(),
1482        ));
1483    }
1484    if !name
1485        .chars()
1486        .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.')
1487    {
1488        return Err(SynapticError::Store(format!(
1489            "invalid table name '{name}': only alphanumeric, underscore, and dot characters are allowed",
1490        )));
1491    }
1492    Ok(())
1493}