Skip to main content

synaptic_core/
lib.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3
4use async_trait::async_trait;
5use futures::Stream;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use thiserror::Error;
9
10// ---------------------------------------------------------------------------
11// ContentBlock — multimodal message content
12// ---------------------------------------------------------------------------
13
14/// A block of content within a message, supporting multimodal inputs.
15#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
16#[serde(tag = "type", rename_all = "snake_case")]
17pub enum ContentBlock {
18    Text {
19        text: String,
20    },
21    Image {
22        url: String,
23        #[serde(default, skip_serializing_if = "Option::is_none")]
24        detail: Option<String>,
25    },
26    Audio {
27        url: String,
28    },
29    Video {
30        url: String,
31    },
32    File {
33        url: String,
34        #[serde(default, skip_serializing_if = "Option::is_none")]
35        mime_type: Option<String>,
36    },
37    Data {
38        data: Value,
39    },
40    Reasoning {
41        content: String,
42    },
43}
44
45// ---------------------------------------------------------------------------
46// Message
47// ---------------------------------------------------------------------------
48
49/// Represents a chat message. Tagged enum with System, Human, AI, and Tool variants.
50#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
51#[serde(tag = "role")]
52pub enum Message {
53    #[serde(rename = "system")]
54    System {
55        content: String,
56        #[serde(default, skip_serializing_if = "Option::is_none")]
57        id: Option<String>,
58        #[serde(default, skip_serializing_if = "Option::is_none")]
59        name: Option<String>,
60        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
61        additional_kwargs: HashMap<String, Value>,
62        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
63        response_metadata: HashMap<String, Value>,
64        #[serde(default, skip_serializing_if = "Vec::is_empty")]
65        content_blocks: Vec<ContentBlock>,
66    },
67    #[serde(rename = "human")]
68    Human {
69        content: String,
70        #[serde(default, skip_serializing_if = "Option::is_none")]
71        id: Option<String>,
72        #[serde(default, skip_serializing_if = "Option::is_none")]
73        name: Option<String>,
74        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
75        additional_kwargs: HashMap<String, Value>,
76        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
77        response_metadata: HashMap<String, Value>,
78        #[serde(default, skip_serializing_if = "Vec::is_empty")]
79        content_blocks: Vec<ContentBlock>,
80    },
81    #[serde(rename = "assistant")]
82    AI {
83        content: String,
84        #[serde(default, skip_serializing_if = "Vec::is_empty")]
85        tool_calls: Vec<ToolCall>,
86        #[serde(default, skip_serializing_if = "Option::is_none")]
87        id: Option<String>,
88        #[serde(default, skip_serializing_if = "Option::is_none")]
89        name: Option<String>,
90        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
91        additional_kwargs: HashMap<String, Value>,
92        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
93        response_metadata: HashMap<String, Value>,
94        #[serde(default, skip_serializing_if = "Vec::is_empty")]
95        content_blocks: Vec<ContentBlock>,
96        #[serde(default, skip_serializing_if = "Option::is_none")]
97        usage_metadata: Option<TokenUsage>,
98        #[serde(default, skip_serializing_if = "Vec::is_empty")]
99        invalid_tool_calls: Vec<InvalidToolCall>,
100    },
101    #[serde(rename = "tool")]
102    Tool {
103        content: String,
104        tool_call_id: String,
105        #[serde(default, skip_serializing_if = "Option::is_none")]
106        id: Option<String>,
107        #[serde(default, skip_serializing_if = "Option::is_none")]
108        name: Option<String>,
109        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
110        additional_kwargs: HashMap<String, Value>,
111        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
112        response_metadata: HashMap<String, Value>,
113        #[serde(default, skip_serializing_if = "Vec::is_empty")]
114        content_blocks: Vec<ContentBlock>,
115    },
116    #[serde(rename = "chat")]
117    Chat {
118        custom_role: String,
119        content: String,
120        #[serde(default, skip_serializing_if = "Option::is_none")]
121        id: Option<String>,
122        #[serde(default, skip_serializing_if = "Option::is_none")]
123        name: Option<String>,
124        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
125        additional_kwargs: HashMap<String, Value>,
126        #[serde(default, skip_serializing_if = "HashMap::is_empty")]
127        response_metadata: HashMap<String, Value>,
128        #[serde(default, skip_serializing_if = "Vec::is_empty")]
129        content_blocks: Vec<ContentBlock>,
130    },
131    /// A special message that signals removal of a message by its ID.
132    /// Used in message history management.
133    #[serde(rename = "remove")]
134    Remove {
135        /// ID of the message to remove.
136        id: String,
137    },
138}
139
140/// Helper macro to set a shared field across all Message variants.
141/// Note: Remove variant has no common fields, so it is a no-op.
142macro_rules! set_message_field {
143    ($self:expr, $field:ident, $value:expr) => {
144        match $self {
145            Message::System { $field, .. } => *$field = $value,
146            Message::Human { $field, .. } => *$field = $value,
147            Message::AI { $field, .. } => *$field = $value,
148            Message::Tool { $field, .. } => *$field = $value,
149            Message::Chat { $field, .. } => *$field = $value,
150            Message::Remove { .. } => { /* Remove has no common fields */ }
151        }
152    };
153}
154
155/// Helper macro to get a shared field from all Message variants.
156/// Note: Remove variant panics — callers handle Remove before using this macro.
157macro_rules! get_message_field {
158    ($self:expr, $field:ident) => {
159        match $self {
160            Message::System { $field, .. } => $field,
161            Message::Human { $field, .. } => $field,
162            Message::AI { $field, .. } => $field,
163            Message::Tool { $field, .. } => $field,
164            Message::Chat { $field, .. } => $field,
165            Message::Remove { .. } => unreachable!("get_message_field called on Remove variant"),
166        }
167    };
168}
169
170impl Message {
171    // -- Factory methods -----------------------------------------------------
172
173    pub fn system(content: impl Into<String>) -> Self {
174        Message::System {
175            content: content.into(),
176            id: None,
177            name: None,
178            additional_kwargs: HashMap::new(),
179            response_metadata: HashMap::new(),
180            content_blocks: Vec::new(),
181        }
182    }
183
184    pub fn human(content: impl Into<String>) -> Self {
185        Message::Human {
186            content: content.into(),
187            id: None,
188            name: None,
189            additional_kwargs: HashMap::new(),
190            response_metadata: HashMap::new(),
191            content_blocks: Vec::new(),
192        }
193    }
194
195    pub fn ai(content: impl Into<String>) -> Self {
196        Message::AI {
197            content: content.into(),
198            tool_calls: vec![],
199            id: None,
200            name: None,
201            additional_kwargs: HashMap::new(),
202            response_metadata: HashMap::new(),
203            content_blocks: Vec::new(),
204            usage_metadata: None,
205            invalid_tool_calls: Vec::new(),
206        }
207    }
208
209    pub fn ai_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
210        Message::AI {
211            content: content.into(),
212            tool_calls,
213            id: None,
214            name: None,
215            additional_kwargs: HashMap::new(),
216            response_metadata: HashMap::new(),
217            content_blocks: Vec::new(),
218            usage_metadata: None,
219            invalid_tool_calls: Vec::new(),
220        }
221    }
222
223    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
224        Message::Tool {
225            content: content.into(),
226            tool_call_id: tool_call_id.into(),
227            id: None,
228            name: None,
229            additional_kwargs: HashMap::new(),
230            response_metadata: HashMap::new(),
231            content_blocks: Vec::new(),
232        }
233    }
234
235    pub fn chat(role: impl Into<String>, content: impl Into<String>) -> Self {
236        Message::Chat {
237            custom_role: role.into(),
238            content: content.into(),
239            id: None,
240            name: None,
241            additional_kwargs: HashMap::new(),
242            response_metadata: HashMap::new(),
243            content_blocks: Vec::new(),
244        }
245    }
246
247    /// Create a Remove message that signals removal of a message by its ID.
248    pub fn remove(id: impl Into<String>) -> Self {
249        Message::Remove { id: id.into() }
250    }
251
252    // -- Builder methods -----------------------------------------------------
253
254    pub fn with_id(mut self, value: impl Into<String>) -> Self {
255        set_message_field!(&mut self, id, Some(value.into()));
256        self
257    }
258
259    pub fn with_name(mut self, value: impl Into<String>) -> Self {
260        set_message_field!(&mut self, name, Some(value.into()));
261        self
262    }
263
264    pub fn with_additional_kwarg(mut self, key: impl Into<String>, value: Value) -> Self {
265        match &mut self {
266            Message::System {
267                additional_kwargs, ..
268            }
269            | Message::Human {
270                additional_kwargs, ..
271            }
272            | Message::AI {
273                additional_kwargs, ..
274            }
275            | Message::Tool {
276                additional_kwargs, ..
277            }
278            | Message::Chat {
279                additional_kwargs, ..
280            } => {
281                additional_kwargs.insert(key.into(), value);
282            }
283            Message::Remove { .. } => { /* Remove has no additional_kwargs */ }
284        }
285        self
286    }
287
288    pub fn with_response_metadata_entry(mut self, key: impl Into<String>, value: Value) -> Self {
289        match &mut self {
290            Message::System {
291                response_metadata, ..
292            }
293            | Message::Human {
294                response_metadata, ..
295            }
296            | Message::AI {
297                response_metadata, ..
298            }
299            | Message::Tool {
300                response_metadata, ..
301            }
302            | Message::Chat {
303                response_metadata, ..
304            } => {
305                response_metadata.insert(key.into(), value);
306            }
307            Message::Remove { .. } => { /* Remove has no response_metadata */ }
308        }
309        self
310    }
311
312    pub fn with_content_blocks(mut self, blocks: Vec<ContentBlock>) -> Self {
313        set_message_field!(&mut self, content_blocks, blocks);
314        self
315    }
316
317    pub fn with_usage_metadata(mut self, usage: TokenUsage) -> Self {
318        if let Message::AI { usage_metadata, .. } = &mut self {
319            *usage_metadata = Some(usage);
320        }
321        self
322    }
323
324    // -- Accessor methods ----------------------------------------------------
325
326    pub fn content(&self) -> &str {
327        match self {
328            Message::Remove { .. } => "",
329            other => get_message_field!(other, content),
330        }
331    }
332
333    pub fn role(&self) -> &str {
334        match self {
335            Message::System { .. } => "system",
336            Message::Human { .. } => "human",
337            Message::AI { .. } => "assistant",
338            Message::Tool { .. } => "tool",
339            Message::Chat { custom_role, .. } => custom_role,
340            Message::Remove { .. } => "remove",
341        }
342    }
343
344    pub fn is_system(&self) -> bool {
345        matches!(self, Message::System { .. })
346    }
347
348    pub fn is_human(&self) -> bool {
349        matches!(self, Message::Human { .. })
350    }
351
352    pub fn is_ai(&self) -> bool {
353        matches!(self, Message::AI { .. })
354    }
355
356    pub fn is_tool(&self) -> bool {
357        matches!(self, Message::Tool { .. })
358    }
359
360    pub fn is_chat(&self) -> bool {
361        matches!(self, Message::Chat { .. })
362    }
363
364    pub fn is_remove(&self) -> bool {
365        matches!(self, Message::Remove { .. })
366    }
367
368    pub fn tool_calls(&self) -> &[ToolCall] {
369        match self {
370            Message::AI { tool_calls, .. } => tool_calls,
371            _ => &[],
372        }
373    }
374
375    pub fn tool_call_id(&self) -> Option<&str> {
376        match self {
377            Message::Tool { tool_call_id, .. } => Some(tool_call_id),
378            _ => None,
379        }
380    }
381
382    pub fn id(&self) -> Option<&str> {
383        match self {
384            Message::Remove { id } => Some(id),
385            other => get_message_field!(other, id).as_deref(),
386        }
387    }
388
389    pub fn name(&self) -> Option<&str> {
390        match self {
391            Message::Remove { .. } => None,
392            other => get_message_field!(other, name).as_deref(),
393        }
394    }
395
396    pub fn additional_kwargs(&self) -> &HashMap<String, Value> {
397        match self {
398            Message::System {
399                additional_kwargs, ..
400            }
401            | Message::Human {
402                additional_kwargs, ..
403            }
404            | Message::AI {
405                additional_kwargs, ..
406            }
407            | Message::Tool {
408                additional_kwargs, ..
409            }
410            | Message::Chat {
411                additional_kwargs, ..
412            } => additional_kwargs,
413            Message::Remove { .. } => {
414                static EMPTY: std::sync::OnceLock<HashMap<String, Value>> =
415                    std::sync::OnceLock::new();
416                EMPTY.get_or_init(HashMap::new)
417            }
418        }
419    }
420
421    pub fn response_metadata(&self) -> &HashMap<String, Value> {
422        match self {
423            Message::System {
424                response_metadata, ..
425            }
426            | Message::Human {
427                response_metadata, ..
428            }
429            | Message::AI {
430                response_metadata, ..
431            }
432            | Message::Tool {
433                response_metadata, ..
434            }
435            | Message::Chat {
436                response_metadata, ..
437            } => response_metadata,
438            Message::Remove { .. } => {
439                static EMPTY: std::sync::OnceLock<HashMap<String, Value>> =
440                    std::sync::OnceLock::new();
441                EMPTY.get_or_init(HashMap::new)
442            }
443        }
444    }
445
446    pub fn content_blocks(&self) -> &[ContentBlock] {
447        match self {
448            Message::Remove { .. } => &[],
449            other => get_message_field!(other, content_blocks),
450        }
451    }
452
453    /// Return the remove ID if this is a Remove message.
454    pub fn remove_id(&self) -> Option<&str> {
455        match self {
456            Message::Remove { id } => Some(id),
457            _ => None,
458        }
459    }
460
461    pub fn usage_metadata(&self) -> Option<&TokenUsage> {
462        match self {
463            Message::AI { usage_metadata, .. } => usage_metadata.as_ref(),
464            _ => None,
465        }
466    }
467
468    pub fn invalid_tool_calls(&self) -> &[InvalidToolCall] {
469        match self {
470            Message::AI {
471                invalid_tool_calls, ..
472            } => invalid_tool_calls,
473            _ => &[],
474        }
475    }
476}
477
478// ---------------------------------------------------------------------------
479// Message utility functions
480// ---------------------------------------------------------------------------
481
482/// Filter messages by type, name, or id.
483pub fn filter_messages(
484    messages: &[Message],
485    include_types: Option<&[&str]>,
486    exclude_types: Option<&[&str]>,
487    include_names: Option<&[&str]>,
488    exclude_names: Option<&[&str]>,
489    include_ids: Option<&[&str]>,
490    exclude_ids: Option<&[&str]>,
491) -> Vec<Message> {
492    messages
493        .iter()
494        .filter(|msg| {
495            if let Some(include) = include_types {
496                if !include.contains(&msg.role()) {
497                    return false;
498                }
499            }
500            if let Some(exclude) = exclude_types {
501                if exclude.contains(&msg.role()) {
502                    return false;
503                }
504            }
505            if let Some(include) = include_names {
506                match msg.name() {
507                    Some(name) => {
508                        if !include.contains(&name) {
509                            return false;
510                        }
511                    }
512                    None => return false,
513                }
514            }
515            if let Some(exclude) = exclude_names {
516                if let Some(name) = msg.name() {
517                    if exclude.contains(&name) {
518                        return false;
519                    }
520                }
521            }
522            if let Some(include) = include_ids {
523                match msg.id() {
524                    Some(id) => {
525                        if !include.contains(&id) {
526                            return false;
527                        }
528                    }
529                    None => return false,
530                }
531            }
532            if let Some(exclude) = exclude_ids {
533                if let Some(id) = msg.id() {
534                    if exclude.contains(&id) {
535                        return false;
536                    }
537                }
538            }
539            true
540        })
541        .cloned()
542        .collect()
543}
544
545/// Strategy for trimming messages.
546#[derive(Debug, Clone, Copy, PartialEq, Eq)]
547pub enum TrimStrategy {
548    /// Keep the first messages that fit within the token budget.
549    First,
550    /// Keep the last messages that fit within the token budget.
551    Last,
552}
553
554/// Trim messages to fit within a token budget.
555///
556/// `token_counter` receives a single message and returns its token count.
557/// When `include_system` is true and `strategy` is `Last`, the leading system
558/// message is always preserved.
559pub fn trim_messages(
560    messages: Vec<Message>,
561    max_tokens: usize,
562    token_counter: impl Fn(&Message) -> usize,
563    strategy: TrimStrategy,
564    include_system: bool,
565) -> Vec<Message> {
566    if messages.is_empty() {
567        return messages;
568    }
569
570    match strategy {
571        TrimStrategy::First => {
572            let mut result = Vec::new();
573            let mut total = 0;
574            for msg in messages {
575                let count = token_counter(&msg);
576                if total + count > max_tokens {
577                    break;
578                }
579                total += count;
580                result.push(msg);
581            }
582            result
583        }
584        TrimStrategy::Last => {
585            let (system_msg, rest) = if include_system && messages[0].is_system() {
586                (Some(messages[0].clone()), &messages[1..])
587            } else {
588                (None, messages.as_slice())
589            };
590
591            let system_tokens = system_msg.as_ref().map(&token_counter).unwrap_or(0);
592            let budget = max_tokens.saturating_sub(system_tokens);
593
594            let mut selected = Vec::new();
595            let mut total = 0;
596            for msg in rest.iter().rev() {
597                let count = token_counter(msg);
598                if total + count > budget {
599                    break;
600                }
601                total += count;
602                selected.push(msg.clone());
603            }
604            selected.reverse();
605
606            let mut result = Vec::new();
607            if let Some(sys) = system_msg {
608                result.push(sys);
609            }
610            result.extend(selected);
611            result
612        }
613    }
614}
615
616/// Merge consecutive messages of the same role into a single message.
617pub fn merge_message_runs(messages: Vec<Message>) -> Vec<Message> {
618    if messages.is_empty() {
619        return messages;
620    }
621
622    let mut result: Vec<Message> = Vec::new();
623
624    for msg in messages {
625        let should_merge = result
626            .last()
627            .map(|last| last.role() == msg.role())
628            .unwrap_or(false);
629
630        if should_merge {
631            let last = result.last_mut().unwrap();
632            // Merge content
633            let merged_content = format!("{}\n{}", last.content(), msg.content());
634            match last {
635                Message::System { content, .. } => *content = merged_content,
636                Message::Human { content, .. } => *content = merged_content,
637                Message::AI {
638                    content,
639                    tool_calls,
640                    invalid_tool_calls,
641                    ..
642                } => {
643                    *content = merged_content;
644                    tool_calls.extend(msg.tool_calls().to_vec());
645                    invalid_tool_calls.extend(msg.invalid_tool_calls().to_vec());
646                }
647                Message::Tool { content, .. } => *content = merged_content,
648                Message::Chat { content, .. } => *content = merged_content,
649                Message::Remove { .. } => { /* Remove messages are not merged */ }
650            }
651        } else {
652            result.push(msg);
653        }
654    }
655
656    result
657}
658
659/// Convert messages to a human-readable buffer string.
660pub fn get_buffer_string(messages: &[Message], human_prefix: &str, ai_prefix: &str) -> String {
661    messages
662        .iter()
663        .map(|msg| {
664            let prefix = match msg {
665                Message::System { .. } => "System",
666                Message::Human { .. } => human_prefix,
667                Message::AI { .. } => ai_prefix,
668                Message::Tool { .. } => "Tool",
669                Message::Chat { custom_role, .. } => custom_role.as_str(),
670                Message::Remove { .. } => "Remove",
671            };
672            format!("{prefix}: {}", msg.content())
673        })
674        .collect::<Vec<_>>()
675        .join("\n")
676}
677
678// ---------------------------------------------------------------------------
679// AIMessageChunk
680// ---------------------------------------------------------------------------
681
682/// A streaming chunk from an AI model response. Supports merge via `+`/`+=` operators and conversion to `Message` via `into_message()`.
683#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
684pub struct AIMessageChunk {
685    pub content: String,
686    #[serde(default, skip_serializing_if = "Vec::is_empty")]
687    pub tool_calls: Vec<ToolCall>,
688    #[serde(default, skip_serializing_if = "Option::is_none")]
689    pub usage: Option<TokenUsage>,
690    #[serde(default, skip_serializing_if = "Option::is_none")]
691    pub id: Option<String>,
692    #[serde(default, skip_serializing_if = "Vec::is_empty")]
693    pub tool_call_chunks: Vec<ToolCallChunk>,
694    #[serde(default, skip_serializing_if = "Vec::is_empty")]
695    pub invalid_tool_calls: Vec<InvalidToolCall>,
696}
697
698impl AIMessageChunk {
699    pub fn into_message(self) -> Message {
700        Message::ai_with_tool_calls(self.content, self.tool_calls)
701    }
702}
703
704impl std::ops::Add for AIMessageChunk {
705    type Output = Self;
706
707    fn add(mut self, rhs: Self) -> Self {
708        self += rhs;
709        self
710    }
711}
712
713impl std::ops::AddAssign for AIMessageChunk {
714    fn add_assign(&mut self, rhs: Self) {
715        self.content.push_str(&rhs.content);
716        self.tool_calls.extend(rhs.tool_calls);
717        self.tool_call_chunks.extend(rhs.tool_call_chunks);
718        self.invalid_tool_calls.extend(rhs.invalid_tool_calls);
719        if self.id.is_none() {
720            self.id = rhs.id;
721        }
722        match (&mut self.usage, rhs.usage) {
723            (Some(u), Some(rhs_u)) => {
724                u.input_tokens += rhs_u.input_tokens;
725                u.output_tokens += rhs_u.output_tokens;
726                u.total_tokens += rhs_u.total_tokens;
727            }
728            (None, Some(rhs_u)) => {
729                self.usage = Some(rhs_u);
730            }
731            _ => {}
732        }
733    }
734}
735
736// ---------------------------------------------------------------------------
737// Tool-related types
738// ---------------------------------------------------------------------------
739
740/// Represents a tool invocation requested by an AI model, with an ID, function name, and JSON arguments.
741#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
742pub struct ToolCall {
743    pub id: String,
744    pub name: String,
745    pub arguments: Value,
746}
747
748/// A tool call that failed to parse correctly.
749#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
750pub struct InvalidToolCall {
751    #[serde(default, skip_serializing_if = "Option::is_none")]
752    pub id: Option<String>,
753    #[serde(default, skip_serializing_if = "Option::is_none")]
754    pub name: Option<String>,
755    #[serde(default, skip_serializing_if = "Option::is_none")]
756    pub arguments: Option<String>,
757    pub error: String,
758}
759
760/// A partial tool call chunk received during streaming.
761#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
762pub struct ToolCallChunk {
763    #[serde(default, skip_serializing_if = "Option::is_none")]
764    pub id: Option<String>,
765    #[serde(default, skip_serializing_if = "Option::is_none")]
766    pub name: Option<String>,
767    #[serde(default, skip_serializing_if = "Option::is_none")]
768    pub arguments: Option<String>,
769    #[serde(default, skip_serializing_if = "Option::is_none")]
770    pub index: Option<usize>,
771}
772
773/// Schema definition for a tool, including its name, description, and JSON Schema for parameters.
774#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
775pub struct ToolDefinition {
776    pub name: String,
777    pub description: String,
778    pub parameters: Value,
779}
780
781/// Controls how the model selects tools: Auto, Required, None, or a Specific named tool.
782#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
783#[serde(rename_all = "lowercase")]
784pub enum ToolChoice {
785    Auto,
786    Required,
787    None,
788    Specific(String),
789}
790
791// ---------------------------------------------------------------------------
792// Chat request / response
793// ---------------------------------------------------------------------------
794
795/// A request to a chat model containing messages, optional tool definitions, and tool choice configuration.
796#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
797pub struct ChatRequest {
798    pub messages: Vec<Message>,
799    #[serde(default, skip_serializing_if = "Vec::is_empty")]
800    pub tools: Vec<ToolDefinition>,
801    #[serde(default, skip_serializing_if = "Option::is_none")]
802    pub tool_choice: Option<ToolChoice>,
803}
804
805impl ChatRequest {
806    pub fn new(messages: Vec<Message>) -> Self {
807        Self {
808            messages,
809            tools: vec![],
810            tool_choice: None,
811        }
812    }
813
814    pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
815        self.tools = tools;
816        self
817    }
818
819    pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
820        self.tool_choice = Some(choice);
821        self
822    }
823}
824
825/// A response from a chat model containing the AI message and optional token usage statistics.
826#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
827pub struct ChatResponse {
828    pub message: Message,
829    pub usage: Option<TokenUsage>,
830}
831
832// ---------------------------------------------------------------------------
833// Token usage
834// ---------------------------------------------------------------------------
835
836#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
837pub struct TokenUsage {
838    pub input_tokens: u32,
839    pub output_tokens: u32,
840    pub total_tokens: u32,
841    #[serde(default, skip_serializing_if = "Option::is_none")]
842    pub input_details: Option<InputTokenDetails>,
843    #[serde(default, skip_serializing_if = "Option::is_none")]
844    pub output_details: Option<OutputTokenDetails>,
845}
846
847/// Detailed breakdown of input token usage.
848#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
849pub struct InputTokenDetails {
850    #[serde(default)]
851    pub cached: u32,
852    #[serde(default)]
853    pub audio: u32,
854}
855
856/// Detailed breakdown of output token usage.
857#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
858pub struct OutputTokenDetails {
859    #[serde(default)]
860    pub reasoning: u32,
861    #[serde(default)]
862    pub audio: u32,
863}
864
865// ---------------------------------------------------------------------------
866// Events
867// ---------------------------------------------------------------------------
868
869/// Lifecycle events emitted during agent execution, used by `CallbackHandler` implementations.
870#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
871pub enum RunEvent {
872    RunStarted {
873        run_id: String,
874        session_id: String,
875    },
876    RunStep {
877        run_id: String,
878        step: usize,
879    },
880    LlmCalled {
881        run_id: String,
882        message_count: usize,
883    },
884    ToolCalled {
885        run_id: String,
886        tool_name: String,
887    },
888    RunFinished {
889        run_id: String,
890        output: String,
891    },
892    RunFailed {
893        run_id: String,
894        error: String,
895    },
896}
897
898// ---------------------------------------------------------------------------
899// Errors
900// ---------------------------------------------------------------------------
901
902/// Unified error type for the Synapse framework with variants covering all subsystems.
903#[derive(Debug, Error)]
904pub enum SynapseError {
905    #[error("prompt error: {0}")]
906    Prompt(String),
907    #[error("model error: {0}")]
908    Model(String),
909    #[error("tool error: {0}")]
910    Tool(String),
911    #[error("tool not found: {0}")]
912    ToolNotFound(String),
913    #[error("memory error: {0}")]
914    Memory(String),
915    #[error("rate limit: {0}")]
916    RateLimit(String),
917    #[error("timeout: {0}")]
918    Timeout(String),
919    #[error("validation error: {0}")]
920    Validation(String),
921    #[error("parsing error: {0}")]
922    Parsing(String),
923    #[error("callback error: {0}")]
924    Callback(String),
925    #[error("max steps exceeded: {max_steps}")]
926    MaxStepsExceeded { max_steps: usize },
927    #[error("embedding error: {0}")]
928    Embedding(String),
929    #[error("vector store error: {0}")]
930    VectorStore(String),
931    #[error("retriever error: {0}")]
932    Retriever(String),
933    #[error("loader error: {0}")]
934    Loader(String),
935    #[error("splitter error: {0}")]
936    Splitter(String),
937    #[error("graph error: {0}")]
938    Graph(String),
939    #[error("cache error: {0}")]
940    Cache(String),
941    #[error("config error: {0}")]
942    Config(String),
943}
944
945// ---------------------------------------------------------------------------
946// Core traits
947// ---------------------------------------------------------------------------
948
949/// Type alias for a pinned, boxed async stream of `AIMessageChunk` results.
950pub type ChatStream<'a> =
951    Pin<Box<dyn Stream<Item = Result<AIMessageChunk, SynapseError>> + Send + 'a>>;
952
953/// The core trait for language model providers. Implementations provide `chat()` for single responses and optionally `stream_chat()` for streaming.
954#[async_trait]
955pub trait ChatModel: Send + Sync {
956    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapseError>;
957
958    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
959        Box::pin(async_stream::stream! {
960            match self.chat(request).await {
961                Ok(response) => {
962                    yield Ok(AIMessageChunk {
963                        content: response.message.content().to_string(),
964                        tool_calls: response.message.tool_calls().to_vec(),
965                        usage: response.usage,
966                        ..Default::default()
967                    });
968                }
969                Err(e) => yield Err(e),
970            }
971        })
972    }
973}
974
975/// Defines an executable tool that can be called by an AI model. Each tool has a name, description, JSON schema for parameters, and an async `call()` method.
976#[async_trait]
977pub trait Tool: Send + Sync {
978    fn name(&self) -> &'static str;
979    fn description(&self) -> &'static str;
980    async fn call(&self, args: Value) -> Result<Value, SynapseError>;
981}
982
983/// Persistent storage for conversation message history, keyed by session ID.
984#[async_trait]
985pub trait MemoryStore: Send + Sync {
986    async fn append(&self, session_id: &str, message: Message) -> Result<(), SynapseError>;
987    async fn load(&self, session_id: &str) -> Result<Vec<Message>, SynapseError>;
988    async fn clear(&self, session_id: &str) -> Result<(), SynapseError>;
989}
990
991/// Handler for lifecycle events during agent execution. Receives `RunEvent` notifications at each stage.
992#[async_trait]
993pub trait CallbackHandler: Send + Sync {
994    async fn on_event(&self, event: RunEvent) -> Result<(), SynapseError>;
995}
996
997// ---------------------------------------------------------------------------
998// RunnableConfig
999// ---------------------------------------------------------------------------
1000
1001/// Runtime configuration passed through runnable chains, including tags, metadata, concurrency limits, and run identification.
1002#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1003pub struct RunnableConfig {
1004    #[serde(default)]
1005    pub tags: Vec<String>,
1006    #[serde(default)]
1007    pub metadata: HashMap<String, Value>,
1008    #[serde(default)]
1009    pub max_concurrency: Option<usize>,
1010    #[serde(default)]
1011    pub recursion_limit: Option<usize>,
1012    #[serde(default)]
1013    pub run_id: Option<String>,
1014    #[serde(default)]
1015    pub run_name: Option<String>,
1016}
1017
1018impl RunnableConfig {
1019    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
1020        self.tags = tags;
1021        self
1022    }
1023
1024    pub fn with_run_name(mut self, name: impl Into<String>) -> Self {
1025        self.run_name = Some(name.into());
1026        self
1027    }
1028
1029    pub fn with_run_id(mut self, id: impl Into<String>) -> Self {
1030        self.run_id = Some(id.into());
1031        self
1032    }
1033
1034    pub fn with_max_concurrency(mut self, max: usize) -> Self {
1035        self.max_concurrency = Some(max);
1036        self
1037    }
1038
1039    pub fn with_recursion_limit(mut self, limit: usize) -> Self {
1040        self.recursion_limit = Some(limit);
1041        self
1042    }
1043
1044    pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
1045        self.metadata.insert(key.into(), value);
1046        self
1047    }
1048}