1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::Stream;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use thiserror::Error;
11
12#[cfg(feature = "schemars")]
13pub use schemars;
14
15pub mod context_budget;
16pub mod token_counter;
17
18pub use context_budget::{ContextBudget, ContextSlot, Priority};
19pub use token_counter::{HeuristicTokenCounter, TokenCounter};
20
21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
27#[serde(tag = "type", rename_all = "snake_case")]
28pub enum ContentBlock {
29 Text {
30 text: String,
31 },
32 Image {
33 url: String,
34 #[serde(default, skip_serializing_if = "Option::is_none")]
35 detail: Option<String>,
36 },
37 Audio {
38 url: String,
39 },
40 Video {
41 url: String,
42 },
43 File {
44 url: String,
45 #[serde(default, skip_serializing_if = "Option::is_none")]
46 mime_type: Option<String>,
47 },
48 Data {
49 data: Value,
50 },
51 Reasoning {
52 content: String,
53 },
54}
55
56#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
62#[serde(tag = "role")]
63pub enum Message {
64 #[serde(rename = "system")]
65 System {
66 content: String,
67 #[serde(default, skip_serializing_if = "Option::is_none")]
68 id: Option<String>,
69 #[serde(default, skip_serializing_if = "Option::is_none")]
70 name: Option<String>,
71 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
72 additional_kwargs: HashMap<String, Value>,
73 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
74 response_metadata: HashMap<String, Value>,
75 #[serde(default, skip_serializing_if = "Vec::is_empty")]
76 content_blocks: Vec<ContentBlock>,
77 },
78 #[serde(rename = "human")]
79 Human {
80 content: String,
81 #[serde(default, skip_serializing_if = "Option::is_none")]
82 id: Option<String>,
83 #[serde(default, skip_serializing_if = "Option::is_none")]
84 name: Option<String>,
85 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
86 additional_kwargs: HashMap<String, Value>,
87 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
88 response_metadata: HashMap<String, Value>,
89 #[serde(default, skip_serializing_if = "Vec::is_empty")]
90 content_blocks: Vec<ContentBlock>,
91 },
92 #[serde(rename = "assistant")]
93 AI {
94 content: String,
95 #[serde(default, skip_serializing_if = "Vec::is_empty")]
96 tool_calls: Vec<ToolCall>,
97 #[serde(default, skip_serializing_if = "Option::is_none")]
98 id: Option<String>,
99 #[serde(default, skip_serializing_if = "Option::is_none")]
100 name: Option<String>,
101 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
102 additional_kwargs: HashMap<String, Value>,
103 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
104 response_metadata: HashMap<String, Value>,
105 #[serde(default, skip_serializing_if = "Vec::is_empty")]
106 content_blocks: Vec<ContentBlock>,
107 #[serde(default, skip_serializing_if = "Option::is_none")]
108 usage_metadata: Option<TokenUsage>,
109 #[serde(default, skip_serializing_if = "Vec::is_empty")]
110 invalid_tool_calls: Vec<InvalidToolCall>,
111 },
112 #[serde(rename = "tool")]
113 Tool {
114 content: String,
115 tool_call_id: String,
116 #[serde(default, skip_serializing_if = "Option::is_none")]
117 id: Option<String>,
118 #[serde(default, skip_serializing_if = "Option::is_none")]
119 name: Option<String>,
120 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
121 additional_kwargs: HashMap<String, Value>,
122 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
123 response_metadata: HashMap<String, Value>,
124 #[serde(default, skip_serializing_if = "Vec::is_empty")]
125 content_blocks: Vec<ContentBlock>,
126 },
127 #[serde(rename = "chat")]
128 Chat {
129 custom_role: String,
130 content: String,
131 #[serde(default, skip_serializing_if = "Option::is_none")]
132 id: Option<String>,
133 #[serde(default, skip_serializing_if = "Option::is_none")]
134 name: Option<String>,
135 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
136 additional_kwargs: HashMap<String, Value>,
137 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
138 response_metadata: HashMap<String, Value>,
139 #[serde(default, skip_serializing_if = "Vec::is_empty")]
140 content_blocks: Vec<ContentBlock>,
141 },
142 #[serde(rename = "remove")]
145 Remove {
146 id: String,
148 },
149}
150
151macro_rules! set_message_field {
154 ($self:expr, $field:ident, $value:expr) => {
155 match $self {
156 Message::System { $field, .. } => *$field = $value,
157 Message::Human { $field, .. } => *$field = $value,
158 Message::AI { $field, .. } => *$field = $value,
159 Message::Tool { $field, .. } => *$field = $value,
160 Message::Chat { $field, .. } => *$field = $value,
161 Message::Remove { .. } => { }
162 }
163 };
164}
165
166macro_rules! get_message_field {
169 ($self:expr, $field:ident) => {
170 match $self {
171 Message::System { $field, .. } => $field,
172 Message::Human { $field, .. } => $field,
173 Message::AI { $field, .. } => $field,
174 Message::Tool { $field, .. } => $field,
175 Message::Chat { $field, .. } => $field,
176 Message::Remove { .. } => unreachable!("get_message_field called on Remove variant"),
177 }
178 };
179}
180
181impl Message {
182 pub fn system(content: impl Into<String>) -> Self {
185 Message::System {
186 content: content.into(),
187 id: None,
188 name: None,
189 additional_kwargs: HashMap::new(),
190 response_metadata: HashMap::new(),
191 content_blocks: Vec::new(),
192 }
193 }
194
195 pub fn human(content: impl Into<String>) -> Self {
196 Message::Human {
197 content: content.into(),
198 id: None,
199 name: None,
200 additional_kwargs: HashMap::new(),
201 response_metadata: HashMap::new(),
202 content_blocks: Vec::new(),
203 }
204 }
205
206 pub fn ai(content: impl Into<String>) -> Self {
207 Message::AI {
208 content: content.into(),
209 tool_calls: vec![],
210 id: None,
211 name: None,
212 additional_kwargs: HashMap::new(),
213 response_metadata: HashMap::new(),
214 content_blocks: Vec::new(),
215 usage_metadata: None,
216 invalid_tool_calls: Vec::new(),
217 }
218 }
219
220 pub fn ai_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
221 Message::AI {
222 content: content.into(),
223 tool_calls,
224 id: None,
225 name: None,
226 additional_kwargs: HashMap::new(),
227 response_metadata: HashMap::new(),
228 content_blocks: Vec::new(),
229 usage_metadata: None,
230 invalid_tool_calls: Vec::new(),
231 }
232 }
233
234 pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
235 Message::Tool {
236 content: content.into(),
237 tool_call_id: tool_call_id.into(),
238 id: None,
239 name: None,
240 additional_kwargs: HashMap::new(),
241 response_metadata: HashMap::new(),
242 content_blocks: Vec::new(),
243 }
244 }
245
246 pub fn chat(role: impl Into<String>, content: impl Into<String>) -> Self {
247 Message::Chat {
248 custom_role: role.into(),
249 content: content.into(),
250 id: None,
251 name: None,
252 additional_kwargs: HashMap::new(),
253 response_metadata: HashMap::new(),
254 content_blocks: Vec::new(),
255 }
256 }
257
258 pub fn remove(id: impl Into<String>) -> Self {
260 Message::Remove { id: id.into() }
261 }
262
263 pub fn with_id(mut self, value: impl Into<String>) -> Self {
266 set_message_field!(&mut self, id, Some(value.into()));
267 self
268 }
269
270 pub fn with_name(mut self, value: impl Into<String>) -> Self {
271 set_message_field!(&mut self, name, Some(value.into()));
272 self
273 }
274
275 pub fn with_additional_kwarg(mut self, key: impl Into<String>, value: Value) -> Self {
276 match &mut self {
277 Message::System {
278 additional_kwargs, ..
279 }
280 | Message::Human {
281 additional_kwargs, ..
282 }
283 | Message::AI {
284 additional_kwargs, ..
285 }
286 | Message::Tool {
287 additional_kwargs, ..
288 }
289 | Message::Chat {
290 additional_kwargs, ..
291 } => {
292 additional_kwargs.insert(key.into(), value);
293 }
294 Message::Remove { .. } => { }
295 }
296 self
297 }
298
299 pub fn with_response_metadata_entry(mut self, key: impl Into<String>, value: Value) -> Self {
300 match &mut self {
301 Message::System {
302 response_metadata, ..
303 }
304 | Message::Human {
305 response_metadata, ..
306 }
307 | Message::AI {
308 response_metadata, ..
309 }
310 | Message::Tool {
311 response_metadata, ..
312 }
313 | Message::Chat {
314 response_metadata, ..
315 } => {
316 response_metadata.insert(key.into(), value);
317 }
318 Message::Remove { .. } => { }
319 }
320 self
321 }
322
323 pub fn with_content_blocks(mut self, blocks: Vec<ContentBlock>) -> Self {
324 set_message_field!(&mut self, content_blocks, blocks);
325 self
326 }
327
328 pub fn with_usage_metadata(mut self, usage: TokenUsage) -> Self {
329 if let Message::AI { usage_metadata, .. } = &mut self {
330 *usage_metadata = Some(usage);
331 }
332 self
333 }
334
335 pub fn content(&self) -> &str {
338 match self {
339 Message::Remove { .. } => "",
340 other => get_message_field!(other, content),
341 }
342 }
343
344 pub fn role(&self) -> &str {
345 match self {
346 Message::System { .. } => "system",
347 Message::Human { .. } => "human",
348 Message::AI { .. } => "assistant",
349 Message::Tool { .. } => "tool",
350 Message::Chat { custom_role, .. } => custom_role,
351 Message::Remove { .. } => "remove",
352 }
353 }
354
355 pub fn is_system(&self) -> bool {
356 matches!(self, Message::System { .. })
357 }
358
359 pub fn is_human(&self) -> bool {
360 matches!(self, Message::Human { .. })
361 }
362
363 pub fn is_ai(&self) -> bool {
364 matches!(self, Message::AI { .. })
365 }
366
367 pub fn is_tool(&self) -> bool {
368 matches!(self, Message::Tool { .. })
369 }
370
371 pub fn is_chat(&self) -> bool {
372 matches!(self, Message::Chat { .. })
373 }
374
375 pub fn is_remove(&self) -> bool {
376 matches!(self, Message::Remove { .. })
377 }
378
379 pub fn tool_calls(&self) -> &[ToolCall] {
380 match self {
381 Message::AI { tool_calls, .. } => tool_calls,
382 _ => &[],
383 }
384 }
385
386 pub fn tool_call_id(&self) -> Option<&str> {
387 match self {
388 Message::Tool { tool_call_id, .. } => Some(tool_call_id),
389 _ => None,
390 }
391 }
392
393 pub fn id(&self) -> Option<&str> {
394 match self {
395 Message::Remove { id } => Some(id),
396 other => get_message_field!(other, id).as_deref(),
397 }
398 }
399
400 pub fn name(&self) -> Option<&str> {
401 match self {
402 Message::Remove { .. } => None,
403 other => get_message_field!(other, name).as_deref(),
404 }
405 }
406
407 pub fn additional_kwargs(&self) -> &HashMap<String, Value> {
408 match self {
409 Message::System {
410 additional_kwargs, ..
411 }
412 | Message::Human {
413 additional_kwargs, ..
414 }
415 | Message::AI {
416 additional_kwargs, ..
417 }
418 | Message::Tool {
419 additional_kwargs, ..
420 }
421 | Message::Chat {
422 additional_kwargs, ..
423 } => additional_kwargs,
424 Message::Remove { .. } => {
425 static EMPTY: std::sync::OnceLock<HashMap<String, Value>> =
426 std::sync::OnceLock::new();
427 EMPTY.get_or_init(HashMap::new)
428 }
429 }
430 }
431
432 pub fn response_metadata(&self) -> &HashMap<String, Value> {
433 match self {
434 Message::System {
435 response_metadata, ..
436 }
437 | Message::Human {
438 response_metadata, ..
439 }
440 | Message::AI {
441 response_metadata, ..
442 }
443 | Message::Tool {
444 response_metadata, ..
445 }
446 | Message::Chat {
447 response_metadata, ..
448 } => response_metadata,
449 Message::Remove { .. } => {
450 static EMPTY: std::sync::OnceLock<HashMap<String, Value>> =
451 std::sync::OnceLock::new();
452 EMPTY.get_or_init(HashMap::new)
453 }
454 }
455 }
456
457 pub fn content_blocks(&self) -> &[ContentBlock] {
458 match self {
459 Message::Remove { .. } => &[],
460 other => get_message_field!(other, content_blocks),
461 }
462 }
463
464 pub fn remove_id(&self) -> Option<&str> {
466 match self {
467 Message::Remove { id } => Some(id),
468 _ => None,
469 }
470 }
471
472 pub fn usage_metadata(&self) -> Option<&TokenUsage> {
473 match self {
474 Message::AI { usage_metadata, .. } => usage_metadata.as_ref(),
475 _ => None,
476 }
477 }
478
479 pub fn invalid_tool_calls(&self) -> &[InvalidToolCall] {
480 match self {
481 Message::AI {
482 invalid_tool_calls, ..
483 } => invalid_tool_calls,
484 _ => &[],
485 }
486 }
487
488 pub fn set_content(&mut self, new_content: impl Into<String>) {
490 let new_content = new_content.into();
491 set_message_field!(self, content, new_content);
492 }
493}
494
495pub fn filter_messages(
501 messages: &[Message],
502 include_types: Option<&[&str]>,
503 exclude_types: Option<&[&str]>,
504 include_names: Option<&[&str]>,
505 exclude_names: Option<&[&str]>,
506 include_ids: Option<&[&str]>,
507 exclude_ids: Option<&[&str]>,
508) -> Vec<Message> {
509 messages
510 .iter()
511 .filter(|msg| {
512 if let Some(include) = include_types {
513 if !include.contains(&msg.role()) {
514 return false;
515 }
516 }
517 if let Some(exclude) = exclude_types {
518 if exclude.contains(&msg.role()) {
519 return false;
520 }
521 }
522 if let Some(include) = include_names {
523 match msg.name() {
524 Some(name) => {
525 if !include.contains(&name) {
526 return false;
527 }
528 }
529 None => return false,
530 }
531 }
532 if let Some(exclude) = exclude_names {
533 if let Some(name) = msg.name() {
534 if exclude.contains(&name) {
535 return false;
536 }
537 }
538 }
539 if let Some(include) = include_ids {
540 match msg.id() {
541 Some(id) => {
542 if !include.contains(&id) {
543 return false;
544 }
545 }
546 None => return false,
547 }
548 }
549 if let Some(exclude) = exclude_ids {
550 if let Some(id) = msg.id() {
551 if exclude.contains(&id) {
552 return false;
553 }
554 }
555 }
556 true
557 })
558 .cloned()
559 .collect()
560}
561
562#[derive(Debug, Clone, Copy, PartialEq, Eq)]
564pub enum TrimStrategy {
565 First,
567 Last,
569}
570
571pub fn trim_messages(
577 messages: Vec<Message>,
578 max_tokens: usize,
579 token_counter: impl Fn(&Message) -> usize,
580 strategy: TrimStrategy,
581 include_system: bool,
582) -> Vec<Message> {
583 if messages.is_empty() {
584 return messages;
585 }
586
587 match strategy {
588 TrimStrategy::First => {
589 let mut result = Vec::new();
590 let mut total = 0;
591 for msg in messages {
592 let count = token_counter(&msg);
593 if total + count > max_tokens {
594 break;
595 }
596 total += count;
597 result.push(msg);
598 }
599 result
600 }
601 TrimStrategy::Last => {
602 let (system_msg, rest) = if include_system && messages[0].is_system() {
603 (Some(messages[0].clone()), &messages[1..])
604 } else {
605 (None, messages.as_slice())
606 };
607
608 let system_tokens = system_msg.as_ref().map(&token_counter).unwrap_or(0);
609 let budget = max_tokens.saturating_sub(system_tokens);
610
611 let mut selected = Vec::new();
612 let mut total = 0;
613 for msg in rest.iter().rev() {
614 let count = token_counter(msg);
615 if total + count > budget {
616 break;
617 }
618 total += count;
619 selected.push(msg.clone());
620 }
621 selected.reverse();
622
623 let mut result = Vec::new();
624 if let Some(sys) = system_msg {
625 result.push(sys);
626 }
627 result.extend(selected);
628 result
629 }
630 }
631}
632
633pub fn merge_message_runs(messages: Vec<Message>) -> Vec<Message> {
635 if messages.is_empty() {
636 return messages;
637 }
638
639 let mut result: Vec<Message> = Vec::new();
640
641 for msg in messages {
642 let should_merge = result
643 .last()
644 .map(|last| last.role() == msg.role())
645 .unwrap_or(false);
646
647 if should_merge {
648 let last = result.last_mut().unwrap();
649 let merged_content = format!("{}\n{}", last.content(), msg.content());
651 match last {
652 Message::System { content, .. } => *content = merged_content,
653 Message::Human { content, .. } => *content = merged_content,
654 Message::AI {
655 content,
656 tool_calls,
657 invalid_tool_calls,
658 ..
659 } => {
660 *content = merged_content;
661 tool_calls.extend(msg.tool_calls().to_vec());
662 invalid_tool_calls.extend(msg.invalid_tool_calls().to_vec());
663 }
664 Message::Tool { content, .. } => *content = merged_content,
665 Message::Chat { content, .. } => *content = merged_content,
666 Message::Remove { .. } => { }
667 }
668 } else {
669 result.push(msg);
670 }
671 }
672
673 result
674}
675
676pub fn get_buffer_string(messages: &[Message], human_prefix: &str, ai_prefix: &str) -> String {
678 messages
679 .iter()
680 .map(|msg| {
681 let prefix = match msg {
682 Message::System { .. } => "System",
683 Message::Human { .. } => human_prefix,
684 Message::AI { .. } => ai_prefix,
685 Message::Tool { .. } => "Tool",
686 Message::Chat { custom_role, .. } => custom_role.as_str(),
687 Message::Remove { .. } => "Remove",
688 };
689 format!("{prefix}: {}", msg.content())
690 })
691 .collect::<Vec<_>>()
692 .join("\n")
693}
694
695#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
701pub struct AIMessageChunk {
702 pub content: String,
703 #[serde(default, skip_serializing_if = "Vec::is_empty")]
704 pub tool_calls: Vec<ToolCall>,
705 #[serde(default, skip_serializing_if = "Option::is_none")]
706 pub usage: Option<TokenUsage>,
707 #[serde(default, skip_serializing_if = "Option::is_none")]
708 pub id: Option<String>,
709 #[serde(default, skip_serializing_if = "Vec::is_empty")]
710 pub tool_call_chunks: Vec<ToolCallChunk>,
711 #[serde(default, skip_serializing_if = "Vec::is_empty")]
712 pub invalid_tool_calls: Vec<InvalidToolCall>,
713}
714
715impl AIMessageChunk {
716 pub fn into_message(self) -> Message {
717 Message::ai_with_tool_calls(self.content, self.tool_calls)
718 }
719}
720
721impl std::ops::Add for AIMessageChunk {
722 type Output = Self;
723
724 fn add(mut self, rhs: Self) -> Self {
725 self += rhs;
726 self
727 }
728}
729
730impl std::ops::AddAssign for AIMessageChunk {
731 fn add_assign(&mut self, rhs: Self) {
732 self.content.push_str(&rhs.content);
733 self.tool_calls.extend(rhs.tool_calls);
734 self.tool_call_chunks.extend(rhs.tool_call_chunks);
735 self.invalid_tool_calls.extend(rhs.invalid_tool_calls);
736 if self.id.is_none() {
737 self.id = rhs.id;
738 }
739 match (&mut self.usage, rhs.usage) {
740 (Some(u), Some(rhs_u)) => {
741 u.input_tokens += rhs_u.input_tokens;
742 u.output_tokens += rhs_u.output_tokens;
743 u.total_tokens += rhs_u.total_tokens;
744 }
745 (None, Some(rhs_u)) => {
746 self.usage = Some(rhs_u);
747 }
748 _ => {}
749 }
750 }
751}
752
753#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
759pub struct ToolCall {
760 pub id: String,
761 pub name: String,
762 pub arguments: Value,
763}
764
765#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
767pub struct InvalidToolCall {
768 #[serde(default, skip_serializing_if = "Option::is_none")]
769 pub id: Option<String>,
770 #[serde(default, skip_serializing_if = "Option::is_none")]
771 pub name: Option<String>,
772 #[serde(default, skip_serializing_if = "Option::is_none")]
773 pub arguments: Option<String>,
774 pub error: String,
775}
776
777#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
779pub struct ToolCallChunk {
780 #[serde(default, skip_serializing_if = "Option::is_none")]
781 pub id: Option<String>,
782 #[serde(default, skip_serializing_if = "Option::is_none")]
783 pub name: Option<String>,
784 #[serde(default, skip_serializing_if = "Option::is_none")]
785 pub arguments: Option<String>,
786 #[serde(default, skip_serializing_if = "Option::is_none")]
787 pub index: Option<usize>,
788}
789
790#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
792pub struct ToolDefinition {
793 pub name: String,
794 pub description: String,
795 pub parameters: Value,
796 #[serde(default, skip_serializing_if = "Option::is_none")]
798 pub extras: Option<HashMap<String, Value>>,
799}
800
801#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
803#[serde(rename_all = "lowercase")]
804pub enum ToolChoice {
805 Auto,
806 Required,
807 None,
808 Specific(String),
809}
810
811#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
817pub struct ChatRequest {
818 pub messages: Vec<Message>,
819 #[serde(default, skip_serializing_if = "Vec::is_empty")]
820 pub tools: Vec<ToolDefinition>,
821 #[serde(default, skip_serializing_if = "Option::is_none")]
822 pub tool_choice: Option<ToolChoice>,
823}
824
825impl ChatRequest {
826 pub fn new(messages: Vec<Message>) -> Self {
827 Self {
828 messages,
829 tools: vec![],
830 tool_choice: None,
831 }
832 }
833
834 pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
835 self.tools = tools;
836 self
837 }
838
839 pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
840 self.tool_choice = Some(choice);
841 self
842 }
843}
844
845#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
847pub struct ChatResponse {
848 pub message: Message,
849 pub usage: Option<TokenUsage>,
850}
851
852#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
857pub struct TokenUsage {
858 pub input_tokens: u32,
859 pub output_tokens: u32,
860 pub total_tokens: u32,
861 #[serde(default, skip_serializing_if = "Option::is_none")]
862 pub input_details: Option<InputTokenDetails>,
863 #[serde(default, skip_serializing_if = "Option::is_none")]
864 pub output_details: Option<OutputTokenDetails>,
865}
866
867#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
869pub struct InputTokenDetails {
870 #[serde(default)]
871 pub cached: u32,
872 #[serde(default)]
873 pub audio: u32,
874}
875
876#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
878pub struct OutputTokenDetails {
879 #[serde(default)]
880 pub reasoning: u32,
881 #[serde(default)]
882 pub audio: u32,
883}
884
885#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
891pub enum RunEvent {
892 RunStarted {
893 run_id: String,
894 session_id: String,
895 },
896 RunStep {
897 run_id: String,
898 step: usize,
899 },
900 LlmCalled {
901 run_id: String,
902 message_count: usize,
903 },
904 ToolCalled {
905 run_id: String,
906 tool_name: String,
907 },
908 RunFinished {
909 run_id: String,
910 output: String,
911 },
912 RunFailed {
913 run_id: String,
914 error: String,
915 },
916 BeforeToolCall {
918 run_id: String,
919 tool_name: String,
920 arguments: String,
921 },
922 AfterToolCall {
924 run_id: String,
925 tool_name: String,
926 result: String,
927 },
928 BeforeMessage {
930 run_id: String,
931 message_count: usize,
932 },
933 AfterMessage {
935 run_id: String,
936 response_length: usize,
937 },
938}
939
940#[derive(Debug, Error)]
946pub enum SynapticError {
947 #[error("prompt error: {0}")]
948 Prompt(String),
949 #[error("model error: {0}")]
950 Model(String),
951 #[error("tool error: {0}")]
952 Tool(String),
953 #[error("tool not found: {0}")]
954 ToolNotFound(String),
955 #[error("memory error: {0}")]
956 Memory(String),
957 #[error("rate limit: {0}")]
958 RateLimit(String),
959 #[error("timeout: {0}")]
960 Timeout(String),
961 #[error("validation error: {0}")]
962 Validation(String),
963 #[error("parsing error: {0}")]
964 Parsing(String),
965 #[error("callback error: {0}")]
966 Callback(String),
967 #[error("max steps exceeded: {max_steps}")]
968 MaxStepsExceeded { max_steps: usize },
969 #[error("embedding error: {0}")]
970 Embedding(String),
971 #[error("vector store error: {0}")]
972 VectorStore(String),
973 #[error("retriever error: {0}")]
974 Retriever(String),
975 #[error("loader error: {0}")]
976 Loader(String),
977 #[error("splitter error: {0}")]
978 Splitter(String),
979 #[error("graph error: {0}")]
980 Graph(String),
981 #[error("cache error: {0}")]
982 Cache(String),
983 #[error("store error: {0}")]
984 Store(String),
985 #[error("config error: {0}")]
986 Config(String),
987 #[error("mcp error: {0}")]
988 Mcp(String),
989 #[error("security error: {0}")]
990 Security(String),
991}
992
993pub type ChatStream<'a> =
999 Pin<Box<dyn Stream<Item = Result<AIMessageChunk, SynapticError>> + Send + 'a>>;
1000
1001#[derive(Debug, Clone, Serialize, Deserialize)]
1003pub struct ModelProfile {
1004 pub name: String,
1005 pub provider: String,
1006 pub supports_tool_calling: bool,
1007 pub supports_structured_output: bool,
1008 pub supports_streaming: bool,
1009 pub max_input_tokens: Option<usize>,
1010 pub max_output_tokens: Option<usize>,
1011}
1012
1013#[async_trait]
1015pub trait ChatModel: Send + Sync {
1016 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError>;
1017
1018 fn profile(&self) -> Option<ModelProfile> {
1020 None
1021 }
1022
1023 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
1024 Box::pin(async_stream::stream! {
1025 match self.chat(request).await {
1026 Ok(response) => {
1027 yield Ok(AIMessageChunk {
1028 content: response.message.content().to_string(),
1029 tool_calls: response.message.tool_calls().to_vec(),
1030 usage: response.usage,
1031 ..Default::default()
1032 });
1033 }
1034 Err(e) => yield Err(e),
1035 }
1036 })
1037 }
1038}
1039
1040#[async_trait]
1042pub trait Tool: Send + Sync {
1043 fn name(&self) -> &'static str;
1044 fn description(&self) -> &'static str;
1045
1046 fn parameters(&self) -> Option<Value> {
1047 None
1048 }
1049
1050 async fn call(&self, args: Value) -> Result<Value, SynapticError>;
1051
1052 fn as_tool_definition(&self) -> ToolDefinition {
1053 ToolDefinition {
1054 name: self.name().to_string(),
1055 description: self.description().to_string(),
1056 parameters: self
1057 .parameters()
1058 .unwrap_or(serde_json::json!({"type": "object", "properties": {}})),
1059 extras: None,
1060 }
1061 }
1062}
1063
1064#[async_trait]
1070pub trait MemoryStore: Send + Sync {
1071 async fn append(&self, session_id: &str, message: Message) -> Result<(), SynapticError>;
1072 async fn load(&self, session_id: &str) -> Result<Vec<Message>, SynapticError>;
1073 async fn clear(&self, session_id: &str) -> Result<(), SynapticError>;
1074}
1075
1076#[async_trait]
1078pub trait CallbackHandler: Send + Sync {
1079 async fn on_event(&self, event: RunEvent) -> Result<(), SynapticError>;
1080}
1081
1082#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1088pub struct RunnableConfig {
1089 #[serde(default)]
1090 pub tags: Vec<String>,
1091 #[serde(default)]
1092 pub metadata: HashMap<String, Value>,
1093 #[serde(default)]
1094 pub max_concurrency: Option<usize>,
1095 #[serde(default)]
1096 pub recursion_limit: Option<usize>,
1097 #[serde(default)]
1098 pub run_id: Option<String>,
1099 #[serde(default)]
1100 pub run_name: Option<String>,
1101}
1102
1103impl RunnableConfig {
1104 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
1105 self.tags = tags;
1106 self
1107 }
1108
1109 pub fn with_run_name(mut self, name: impl Into<String>) -> Self {
1110 self.run_name = Some(name.into());
1111 self
1112 }
1113
1114 pub fn with_run_id(mut self, id: impl Into<String>) -> Self {
1115 self.run_id = Some(id.into());
1116 self
1117 }
1118
1119 pub fn with_max_concurrency(mut self, max: usize) -> Self {
1120 self.max_concurrency = Some(max);
1121 self
1122 }
1123
1124 pub fn with_recursion_limit(mut self, limit: usize) -> Self {
1125 self.recursion_limit = Some(limit);
1126 self
1127 }
1128
1129 pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
1130 self.metadata.insert(key.into(), value);
1131 self
1132 }
1133}
1134
1135#[derive(Debug, Clone, Serialize, Deserialize)]
1141pub struct Item {
1142 pub namespace: Vec<String>,
1143 pub key: String,
1144 pub value: Value,
1145 pub created_at: String,
1146 pub updated_at: String,
1147 #[serde(default, skip_serializing_if = "Option::is_none")]
1149 pub score: Option<f64>,
1150}
1151
1152#[async_trait]
1157pub trait Store: Send + Sync {
1158 async fn get(&self, namespace: &[&str], key: &str) -> Result<Option<Item>, SynapticError>;
1160
1161 async fn search(
1163 &self,
1164 namespace: &[&str],
1165 query: Option<&str>,
1166 limit: usize,
1167 ) -> Result<Vec<Item>, SynapticError>;
1168
1169 async fn put(&self, namespace: &[&str], key: &str, value: Value) -> Result<(), SynapticError>;
1171
1172 async fn delete(&self, namespace: &[&str], key: &str) -> Result<(), SynapticError>;
1174
1175 async fn list_namespaces(&self, prefix: &[&str]) -> Result<Vec<Vec<String>>, SynapticError>;
1177}
1178
1179#[async_trait]
1185pub trait Embeddings: Send + Sync {
1186 async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError>;
1188
1189 async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError>;
1191}
1192
1193pub type StreamWriter = Arc<dyn Fn(Value) + Send + Sync>;
1199
1200#[derive(Clone)]
1206pub struct Runtime {
1207 pub store: Option<Arc<dyn Store>>,
1208 pub stream_writer: Option<StreamWriter>,
1209}
1210
1211#[derive(Clone)]
1213pub struct ToolRuntime {
1214 pub store: Option<Arc<dyn Store>>,
1215 pub stream_writer: Option<StreamWriter>,
1216 pub state: Option<Value>,
1217 pub tool_call_id: String,
1218 pub config: Option<RunnableConfig>,
1219}
1220
1221#[async_trait]
1231pub trait RuntimeAwareTool: Send + Sync {
1232 fn name(&self) -> &'static str;
1233 fn description(&self) -> &'static str;
1234
1235 fn parameters(&self) -> Option<Value> {
1236 None
1237 }
1238
1239 async fn call_with_runtime(
1240 &self,
1241 args: Value,
1242 runtime: ToolRuntime,
1243 ) -> Result<Value, SynapticError>;
1244
1245 fn as_tool_definition(&self) -> ToolDefinition {
1246 ToolDefinition {
1247 name: self.name().to_string(),
1248 description: self.description().to_string(),
1249 parameters: self
1250 .parameters()
1251 .unwrap_or(serde_json::json!({"type": "object", "properties": {}})),
1252 extras: None,
1253 }
1254 }
1255}
1256
1257pub struct RuntimeAwareToolAdapter {
1262 inner: Arc<dyn RuntimeAwareTool>,
1263 runtime: Arc<tokio::sync::RwLock<Option<ToolRuntime>>>,
1264}
1265
1266impl RuntimeAwareToolAdapter {
1267 pub fn new(tool: Arc<dyn RuntimeAwareTool>) -> Self {
1268 Self {
1269 inner: tool,
1270 runtime: Arc::new(tokio::sync::RwLock::new(None)),
1271 }
1272 }
1273
1274 pub async fn set_runtime(&self, runtime: ToolRuntime) {
1275 *self.runtime.write().await = Some(runtime);
1276 }
1277}
1278
1279#[async_trait]
1280impl Tool for RuntimeAwareToolAdapter {
1281 fn name(&self) -> &'static str {
1282 self.inner.name()
1283 }
1284
1285 fn description(&self) -> &'static str {
1286 self.inner.description()
1287 }
1288
1289 fn parameters(&self) -> Option<Value> {
1290 self.inner.parameters()
1291 }
1292
1293 async fn call(&self, args: Value) -> Result<Value, SynapticError> {
1294 let runtime = self.runtime.read().await.clone().unwrap_or(ToolRuntime {
1295 store: None,
1296 stream_writer: None,
1297 state: None,
1298 tool_call_id: String::new(),
1299 config: None,
1300 });
1301 self.inner.call_with_runtime(args, runtime).await
1302 }
1303}
1304
1305#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1311pub struct Document {
1312 pub id: String,
1313 pub content: String,
1314 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
1315 pub metadata: HashMap<String, Value>,
1316}
1317
1318impl Document {
1319 pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
1320 Self {
1321 id: id.into(),
1322 content: content.into(),
1323 metadata: HashMap::new(),
1324 }
1325 }
1326
1327 pub fn with_metadata(
1328 id: impl Into<String>,
1329 content: impl Into<String>,
1330 metadata: HashMap<String, Value>,
1331 ) -> Self {
1332 Self {
1333 id: id.into(),
1334 content: content.into(),
1335 metadata,
1336 }
1337 }
1338}
1339
1340#[async_trait]
1346pub trait Retriever: Send + Sync {
1347 async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapticError>;
1348}
1349
1350#[async_trait]
1356pub trait VectorStore: Send + Sync {
1357 async fn add_documents(
1359 &self,
1360 docs: Vec<Document>,
1361 embeddings: &dyn Embeddings,
1362 ) -> Result<Vec<String>, SynapticError>;
1363
1364 async fn similarity_search(
1366 &self,
1367 query: &str,
1368 k: usize,
1369 embeddings: &dyn Embeddings,
1370 ) -> Result<Vec<Document>, SynapticError>;
1371
1372 async fn similarity_search_with_score(
1374 &self,
1375 query: &str,
1376 k: usize,
1377 embeddings: &dyn Embeddings,
1378 ) -> Result<Vec<(Document, f32)>, SynapticError>;
1379
1380 async fn similarity_search_by_vector(
1382 &self,
1383 embedding: &[f32],
1384 k: usize,
1385 ) -> Result<Vec<Document>, SynapticError>;
1386
1387 async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError>;
1389}
1390
1391#[async_trait]
1397pub trait Loader: Send + Sync {
1398 async fn load(&self) -> Result<Vec<Document>, SynapticError>;
1400
1401 fn lazy_load(
1403 &self,
1404 ) -> Pin<Box<dyn Stream<Item = Result<Document, SynapticError>> + Send + '_>> {
1405 Box::pin(async_stream::stream! {
1406 match self.load().await {
1407 Ok(docs) => {
1408 for doc in docs {
1409 yield Ok(doc);
1410 }
1411 }
1412 Err(e) => yield Err(e),
1413 }
1414 })
1415 }
1416}
1417
1418#[async_trait]
1424pub trait LlmCache: Send + Sync {
1425 async fn get(&self, key: &str) -> Result<Option<ChatResponse>, SynapticError>;
1427 async fn put(&self, key: &str, response: &ChatResponse) -> Result<(), SynapticError>;
1429 async fn clear(&self) -> Result<(), SynapticError>;
1431}
1432
1433#[derive(Debug, Clone)]
1439pub struct EntrypointConfig {
1440 pub name: &'static str,
1441 pub checkpointer: Option<&'static str>,
1442}
1443
1444pub type EntrypointFn = dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<Value, SynapticError>> + Send>>
1449 + Send
1450 + Sync;
1451
1452pub struct Entrypoint {
1453 pub config: EntrypointConfig,
1454 pub invoke_fn: Box<EntrypointFn>,
1455}
1456
1457impl Entrypoint {
1458 pub async fn invoke(&self, input: Value) -> Result<Value, SynapticError> {
1459 (self.invoke_fn)(input).await
1460 }
1461}
1462
1463pub fn now_iso() -> String {
1469 format!("{:?}", std::time::SystemTime::now())
1470}
1471
1472pub fn encode_namespace(namespace: &[&str]) -> String {
1474 namespace.join(":")
1475}
1476
1477pub fn validate_table_name(name: &str) -> Result<(), SynapticError> {
1479 if name.is_empty() {
1480 return Err(SynapticError::Store(
1481 "table name must not be empty".to_string(),
1482 ));
1483 }
1484 if !name
1485 .chars()
1486 .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.')
1487 {
1488 return Err(SynapticError::Store(format!(
1489 "invalid table name '{name}': only alphanumeric, underscore, and dot characters are allowed",
1490 )));
1491 }
1492 Ok(())
1493}