1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::Stream;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use thiserror::Error;
11
12#[cfg(feature = "schemars")]
13pub use schemars;
14
15pub mod context_budget;
16pub mod token_counter;
17
18pub use context_budget::{ContextBudget, ContextSlot, Priority};
19pub use token_counter::{HeuristicTokenCounter, TokenCounter};
20
21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
27#[serde(tag = "type", rename_all = "snake_case")]
28pub enum ContentBlock {
29 Text {
30 text: String,
31 },
32 Image {
33 url: String,
34 #[serde(default, skip_serializing_if = "Option::is_none")]
35 detail: Option<String>,
36 },
37 Audio {
38 url: String,
39 },
40 Video {
41 url: String,
42 },
43 File {
44 url: String,
45 #[serde(default, skip_serializing_if = "Option::is_none")]
46 mime_type: Option<String>,
47 },
48 Data {
49 data: Value,
50 },
51 Reasoning {
52 content: String,
53 },
54}
55
56#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
62#[serde(tag = "role")]
63pub enum Message {
64 #[serde(rename = "system")]
65 System {
66 content: String,
67 #[serde(default, skip_serializing_if = "Option::is_none")]
68 id: Option<String>,
69 #[serde(default, skip_serializing_if = "Option::is_none")]
70 name: Option<String>,
71 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
72 additional_kwargs: HashMap<String, Value>,
73 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
74 response_metadata: HashMap<String, Value>,
75 #[serde(default, skip_serializing_if = "Vec::is_empty")]
76 content_blocks: Vec<ContentBlock>,
77 },
78 #[serde(rename = "human")]
79 Human {
80 content: String,
81 #[serde(default, skip_serializing_if = "Option::is_none")]
82 id: Option<String>,
83 #[serde(default, skip_serializing_if = "Option::is_none")]
84 name: Option<String>,
85 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
86 additional_kwargs: HashMap<String, Value>,
87 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
88 response_metadata: HashMap<String, Value>,
89 #[serde(default, skip_serializing_if = "Vec::is_empty")]
90 content_blocks: Vec<ContentBlock>,
91 },
92 #[serde(rename = "assistant")]
93 AI {
94 content: String,
95 #[serde(default, skip_serializing_if = "Vec::is_empty")]
96 tool_calls: Vec<ToolCall>,
97 #[serde(default, skip_serializing_if = "Option::is_none")]
98 id: Option<String>,
99 #[serde(default, skip_serializing_if = "Option::is_none")]
100 name: Option<String>,
101 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
102 additional_kwargs: HashMap<String, Value>,
103 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
104 response_metadata: HashMap<String, Value>,
105 #[serde(default, skip_serializing_if = "Vec::is_empty")]
106 content_blocks: Vec<ContentBlock>,
107 #[serde(default, skip_serializing_if = "Option::is_none")]
108 usage_metadata: Option<TokenUsage>,
109 #[serde(default, skip_serializing_if = "Vec::is_empty")]
110 invalid_tool_calls: Vec<InvalidToolCall>,
111 },
112 #[serde(rename = "tool")]
113 Tool {
114 content: String,
115 tool_call_id: String,
116 #[serde(default, skip_serializing_if = "Option::is_none")]
117 id: Option<String>,
118 #[serde(default, skip_serializing_if = "Option::is_none")]
119 name: Option<String>,
120 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
121 additional_kwargs: HashMap<String, Value>,
122 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
123 response_metadata: HashMap<String, Value>,
124 #[serde(default, skip_serializing_if = "Vec::is_empty")]
125 content_blocks: Vec<ContentBlock>,
126 },
127 #[serde(rename = "chat")]
128 Chat {
129 custom_role: String,
130 content: String,
131 #[serde(default, skip_serializing_if = "Option::is_none")]
132 id: Option<String>,
133 #[serde(default, skip_serializing_if = "Option::is_none")]
134 name: Option<String>,
135 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
136 additional_kwargs: HashMap<String, Value>,
137 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
138 response_metadata: HashMap<String, Value>,
139 #[serde(default, skip_serializing_if = "Vec::is_empty")]
140 content_blocks: Vec<ContentBlock>,
141 },
142 #[serde(rename = "remove")]
145 Remove {
146 id: String,
148 },
149}
150
151macro_rules! set_message_field {
154 ($self:expr, $field:ident, $value:expr) => {
155 match $self {
156 Message::System { $field, .. } => *$field = $value,
157 Message::Human { $field, .. } => *$field = $value,
158 Message::AI { $field, .. } => *$field = $value,
159 Message::Tool { $field, .. } => *$field = $value,
160 Message::Chat { $field, .. } => *$field = $value,
161 Message::Remove { .. } => { }
162 }
163 };
164}
165
166macro_rules! get_message_field {
169 ($self:expr, $field:ident) => {
170 match $self {
171 Message::System { $field, .. } => $field,
172 Message::Human { $field, .. } => $field,
173 Message::AI { $field, .. } => $field,
174 Message::Tool { $field, .. } => $field,
175 Message::Chat { $field, .. } => $field,
176 Message::Remove { .. } => unreachable!("get_message_field called on Remove variant"),
177 }
178 };
179}
180
181impl Message {
182 pub fn system(content: impl Into<String>) -> Self {
185 Message::System {
186 content: content.into(),
187 id: None,
188 name: None,
189 additional_kwargs: HashMap::new(),
190 response_metadata: HashMap::new(),
191 content_blocks: Vec::new(),
192 }
193 }
194
195 pub fn human(content: impl Into<String>) -> Self {
196 Message::Human {
197 content: content.into(),
198 id: None,
199 name: None,
200 additional_kwargs: HashMap::new(),
201 response_metadata: HashMap::new(),
202 content_blocks: Vec::new(),
203 }
204 }
205
206 pub fn ai(content: impl Into<String>) -> Self {
207 Message::AI {
208 content: content.into(),
209 tool_calls: vec![],
210 id: None,
211 name: None,
212 additional_kwargs: HashMap::new(),
213 response_metadata: HashMap::new(),
214 content_blocks: Vec::new(),
215 usage_metadata: None,
216 invalid_tool_calls: Vec::new(),
217 }
218 }
219
220 pub fn ai_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
221 Message::AI {
222 content: content.into(),
223 tool_calls,
224 id: None,
225 name: None,
226 additional_kwargs: HashMap::new(),
227 response_metadata: HashMap::new(),
228 content_blocks: Vec::new(),
229 usage_metadata: None,
230 invalid_tool_calls: Vec::new(),
231 }
232 }
233
234 pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
235 Message::Tool {
236 content: content.into(),
237 tool_call_id: tool_call_id.into(),
238 id: None,
239 name: None,
240 additional_kwargs: HashMap::new(),
241 response_metadata: HashMap::new(),
242 content_blocks: Vec::new(),
243 }
244 }
245
246 pub fn chat(role: impl Into<String>, content: impl Into<String>) -> Self {
247 Message::Chat {
248 custom_role: role.into(),
249 content: content.into(),
250 id: None,
251 name: None,
252 additional_kwargs: HashMap::new(),
253 response_metadata: HashMap::new(),
254 content_blocks: Vec::new(),
255 }
256 }
257
258 pub fn remove(id: impl Into<String>) -> Self {
260 Message::Remove { id: id.into() }
261 }
262
263 pub fn with_id(mut self, value: impl Into<String>) -> Self {
266 set_message_field!(&mut self, id, Some(value.into()));
267 self
268 }
269
270 pub fn with_name(mut self, value: impl Into<String>) -> Self {
271 set_message_field!(&mut self, name, Some(value.into()));
272 self
273 }
274
275 pub fn with_additional_kwarg(mut self, key: impl Into<String>, value: Value) -> Self {
276 match &mut self {
277 Message::System {
278 additional_kwargs, ..
279 }
280 | Message::Human {
281 additional_kwargs, ..
282 }
283 | Message::AI {
284 additional_kwargs, ..
285 }
286 | Message::Tool {
287 additional_kwargs, ..
288 }
289 | Message::Chat {
290 additional_kwargs, ..
291 } => {
292 additional_kwargs.insert(key.into(), value);
293 }
294 Message::Remove { .. } => { }
295 }
296 self
297 }
298
299 pub fn with_response_metadata_entry(mut self, key: impl Into<String>, value: Value) -> Self {
300 match &mut self {
301 Message::System {
302 response_metadata, ..
303 }
304 | Message::Human {
305 response_metadata, ..
306 }
307 | Message::AI {
308 response_metadata, ..
309 }
310 | Message::Tool {
311 response_metadata, ..
312 }
313 | Message::Chat {
314 response_metadata, ..
315 } => {
316 response_metadata.insert(key.into(), value);
317 }
318 Message::Remove { .. } => { }
319 }
320 self
321 }
322
323 pub fn with_content_blocks(mut self, blocks: Vec<ContentBlock>) -> Self {
324 set_message_field!(&mut self, content_blocks, blocks);
325 self
326 }
327
328 pub fn with_usage_metadata(mut self, usage: TokenUsage) -> Self {
329 if let Message::AI { usage_metadata, .. } = &mut self {
330 *usage_metadata = Some(usage);
331 }
332 self
333 }
334
335 pub fn content(&self) -> &str {
338 match self {
339 Message::Remove { .. } => "",
340 other => get_message_field!(other, content),
341 }
342 }
343
344 pub fn role(&self) -> &str {
345 match self {
346 Message::System { .. } => "system",
347 Message::Human { .. } => "human",
348 Message::AI { .. } => "assistant",
349 Message::Tool { .. } => "tool",
350 Message::Chat { custom_role, .. } => custom_role,
351 Message::Remove { .. } => "remove",
352 }
353 }
354
355 pub fn is_system(&self) -> bool {
356 matches!(self, Message::System { .. })
357 }
358
359 pub fn is_human(&self) -> bool {
360 matches!(self, Message::Human { .. })
361 }
362
363 pub fn is_ai(&self) -> bool {
364 matches!(self, Message::AI { .. })
365 }
366
367 pub fn is_tool(&self) -> bool {
368 matches!(self, Message::Tool { .. })
369 }
370
371 pub fn is_chat(&self) -> bool {
372 matches!(self, Message::Chat { .. })
373 }
374
375 pub fn is_remove(&self) -> bool {
376 matches!(self, Message::Remove { .. })
377 }
378
379 pub fn tool_calls(&self) -> &[ToolCall] {
380 match self {
381 Message::AI { tool_calls, .. } => tool_calls,
382 _ => &[],
383 }
384 }
385
386 pub fn tool_call_id(&self) -> Option<&str> {
387 match self {
388 Message::Tool { tool_call_id, .. } => Some(tool_call_id),
389 _ => None,
390 }
391 }
392
393 pub fn id(&self) -> Option<&str> {
394 match self {
395 Message::Remove { id } => Some(id),
396 other => get_message_field!(other, id).as_deref(),
397 }
398 }
399
400 pub fn name(&self) -> Option<&str> {
401 match self {
402 Message::Remove { .. } => None,
403 other => get_message_field!(other, name).as_deref(),
404 }
405 }
406
407 pub fn additional_kwargs(&self) -> &HashMap<String, Value> {
408 match self {
409 Message::System {
410 additional_kwargs, ..
411 }
412 | Message::Human {
413 additional_kwargs, ..
414 }
415 | Message::AI {
416 additional_kwargs, ..
417 }
418 | Message::Tool {
419 additional_kwargs, ..
420 }
421 | Message::Chat {
422 additional_kwargs, ..
423 } => additional_kwargs,
424 Message::Remove { .. } => {
425 static EMPTY: std::sync::OnceLock<HashMap<String, Value>> =
426 std::sync::OnceLock::new();
427 EMPTY.get_or_init(HashMap::new)
428 }
429 }
430 }
431
432 pub fn response_metadata(&self) -> &HashMap<String, Value> {
433 match self {
434 Message::System {
435 response_metadata, ..
436 }
437 | Message::Human {
438 response_metadata, ..
439 }
440 | Message::AI {
441 response_metadata, ..
442 }
443 | Message::Tool {
444 response_metadata, ..
445 }
446 | Message::Chat {
447 response_metadata, ..
448 } => response_metadata,
449 Message::Remove { .. } => {
450 static EMPTY: std::sync::OnceLock<HashMap<String, Value>> =
451 std::sync::OnceLock::new();
452 EMPTY.get_or_init(HashMap::new)
453 }
454 }
455 }
456
457 pub fn content_blocks(&self) -> &[ContentBlock] {
458 match self {
459 Message::Remove { .. } => &[],
460 other => get_message_field!(other, content_blocks),
461 }
462 }
463
464 pub fn remove_id(&self) -> Option<&str> {
466 match self {
467 Message::Remove { id } => Some(id),
468 _ => None,
469 }
470 }
471
472 pub fn usage_metadata(&self) -> Option<&TokenUsage> {
473 match self {
474 Message::AI { usage_metadata, .. } => usage_metadata.as_ref(),
475 _ => None,
476 }
477 }
478
479 pub fn invalid_tool_calls(&self) -> &[InvalidToolCall] {
480 match self {
481 Message::AI {
482 invalid_tool_calls, ..
483 } => invalid_tool_calls,
484 _ => &[],
485 }
486 }
487
488 pub fn set_content(&mut self, new_content: impl Into<String>) {
490 let new_content = new_content.into();
491 set_message_field!(self, content, new_content);
492 }
493}
494
495pub fn filter_messages(
501 messages: &[Message],
502 include_types: Option<&[&str]>,
503 exclude_types: Option<&[&str]>,
504 include_names: Option<&[&str]>,
505 exclude_names: Option<&[&str]>,
506 include_ids: Option<&[&str]>,
507 exclude_ids: Option<&[&str]>,
508) -> Vec<Message> {
509 messages
510 .iter()
511 .filter(|msg| {
512 if let Some(include) = include_types {
513 if !include.contains(&msg.role()) {
514 return false;
515 }
516 }
517 if let Some(exclude) = exclude_types {
518 if exclude.contains(&msg.role()) {
519 return false;
520 }
521 }
522 if let Some(include) = include_names {
523 match msg.name() {
524 Some(name) => {
525 if !include.contains(&name) {
526 return false;
527 }
528 }
529 None => return false,
530 }
531 }
532 if let Some(exclude) = exclude_names {
533 if let Some(name) = msg.name() {
534 if exclude.contains(&name) {
535 return false;
536 }
537 }
538 }
539 if let Some(include) = include_ids {
540 match msg.id() {
541 Some(id) => {
542 if !include.contains(&id) {
543 return false;
544 }
545 }
546 None => return false,
547 }
548 }
549 if let Some(exclude) = exclude_ids {
550 if let Some(id) = msg.id() {
551 if exclude.contains(&id) {
552 return false;
553 }
554 }
555 }
556 true
557 })
558 .cloned()
559 .collect()
560}
561
562#[derive(Debug, Clone, Copy, PartialEq, Eq)]
564pub enum TrimStrategy {
565 First,
567 Last,
569}
570
571pub fn trim_messages(
577 messages: Vec<Message>,
578 max_tokens: usize,
579 token_counter: impl Fn(&Message) -> usize,
580 strategy: TrimStrategy,
581 include_system: bool,
582) -> Vec<Message> {
583 if messages.is_empty() {
584 return messages;
585 }
586
587 match strategy {
588 TrimStrategy::First => {
589 let mut result = Vec::new();
590 let mut total = 0;
591 for msg in messages {
592 let count = token_counter(&msg);
593 if total + count > max_tokens {
594 break;
595 }
596 total += count;
597 result.push(msg);
598 }
599 result
600 }
601 TrimStrategy::Last => {
602 let (system_msg, rest) = if include_system && messages[0].is_system() {
603 (Some(messages[0].clone()), &messages[1..])
604 } else {
605 (None, messages.as_slice())
606 };
607
608 let system_tokens = system_msg.as_ref().map(&token_counter).unwrap_or(0);
609 let budget = max_tokens.saturating_sub(system_tokens);
610
611 let mut selected = Vec::new();
612 let mut total = 0;
613 for msg in rest.iter().rev() {
614 let count = token_counter(msg);
615 if total + count > budget {
616 break;
617 }
618 total += count;
619 selected.push(msg.clone());
620 }
621 selected.reverse();
622
623 let mut result = Vec::new();
624 if let Some(sys) = system_msg {
625 result.push(sys);
626 }
627 result.extend(selected);
628 result
629 }
630 }
631}
632
633pub fn merge_message_runs(messages: Vec<Message>) -> Vec<Message> {
635 if messages.is_empty() {
636 return messages;
637 }
638
639 let mut result: Vec<Message> = Vec::new();
640
641 for msg in messages {
642 let should_merge = result
643 .last()
644 .map(|last| last.role() == msg.role())
645 .unwrap_or(false);
646
647 if should_merge {
648 let last = result.last_mut().unwrap();
649 let merged_content = format!("{}\n{}", last.content(), msg.content());
651 match last {
652 Message::System { content, .. } => *content = merged_content,
653 Message::Human { content, .. } => *content = merged_content,
654 Message::AI {
655 content,
656 tool_calls,
657 invalid_tool_calls,
658 ..
659 } => {
660 *content = merged_content;
661 tool_calls.extend(msg.tool_calls().to_vec());
662 invalid_tool_calls.extend(msg.invalid_tool_calls().to_vec());
663 }
664 Message::Tool { content, .. } => *content = merged_content,
665 Message::Chat { content, .. } => *content = merged_content,
666 Message::Remove { .. } => { }
667 }
668 } else {
669 result.push(msg);
670 }
671 }
672
673 result
674}
675
676pub fn get_buffer_string(messages: &[Message], human_prefix: &str, ai_prefix: &str) -> String {
678 messages
679 .iter()
680 .map(|msg| {
681 let prefix = match msg {
682 Message::System { .. } => "System",
683 Message::Human { .. } => human_prefix,
684 Message::AI { .. } => ai_prefix,
685 Message::Tool { .. } => "Tool",
686 Message::Chat { custom_role, .. } => custom_role.as_str(),
687 Message::Remove { .. } => "Remove",
688 };
689 format!("{prefix}: {}", msg.content())
690 })
691 .collect::<Vec<_>>()
692 .join("\n")
693}
694
695#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
701pub struct AIMessageChunk {
702 pub content: String,
703 #[serde(default, skip_serializing_if = "Vec::is_empty")]
704 pub tool_calls: Vec<ToolCall>,
705 #[serde(default, skip_serializing_if = "Option::is_none")]
706 pub usage: Option<TokenUsage>,
707 #[serde(default, skip_serializing_if = "Option::is_none")]
708 pub id: Option<String>,
709 #[serde(default, skip_serializing_if = "Vec::is_empty")]
710 pub tool_call_chunks: Vec<ToolCallChunk>,
711 #[serde(default, skip_serializing_if = "Vec::is_empty")]
712 pub invalid_tool_calls: Vec<InvalidToolCall>,
713}
714
715impl AIMessageChunk {
716 pub fn into_message(self) -> Message {
717 Message::ai_with_tool_calls(self.content, self.tool_calls)
718 }
719}
720
721impl std::ops::Add for AIMessageChunk {
722 type Output = Self;
723
724 fn add(mut self, rhs: Self) -> Self {
725 self += rhs;
726 self
727 }
728}
729
730impl std::ops::AddAssign for AIMessageChunk {
731 fn add_assign(&mut self, rhs: Self) {
732 self.content.push_str(&rhs.content);
733 self.tool_calls.extend(rhs.tool_calls);
734 self.tool_call_chunks.extend(rhs.tool_call_chunks);
735 self.invalid_tool_calls.extend(rhs.invalid_tool_calls);
736 if self.id.is_none() {
737 self.id = rhs.id;
738 }
739 match (&mut self.usage, rhs.usage) {
740 (Some(u), Some(rhs_u)) => {
741 u.input_tokens += rhs_u.input_tokens;
742 u.output_tokens += rhs_u.output_tokens;
743 u.total_tokens += rhs_u.total_tokens;
744 }
745 (None, Some(rhs_u)) => {
746 self.usage = Some(rhs_u);
747 }
748 _ => {}
749 }
750 }
751}
752
753#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
759pub struct ToolCall {
760 pub id: String,
761 pub name: String,
762 pub arguments: Value,
763}
764
765#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
767pub struct InvalidToolCall {
768 #[serde(default, skip_serializing_if = "Option::is_none")]
769 pub id: Option<String>,
770 #[serde(default, skip_serializing_if = "Option::is_none")]
771 pub name: Option<String>,
772 #[serde(default, skip_serializing_if = "Option::is_none")]
773 pub arguments: Option<String>,
774 pub error: String,
775}
776
777#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
779pub struct ToolCallChunk {
780 #[serde(default, skip_serializing_if = "Option::is_none")]
781 pub id: Option<String>,
782 #[serde(default, skip_serializing_if = "Option::is_none")]
783 pub name: Option<String>,
784 #[serde(default, skip_serializing_if = "Option::is_none")]
785 pub arguments: Option<String>,
786 #[serde(default, skip_serializing_if = "Option::is_none")]
787 pub index: Option<usize>,
788}
789
790#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
792pub struct ToolDefinition {
793 pub name: String,
794 pub description: String,
795 pub parameters: Value,
796 #[serde(default, skip_serializing_if = "Option::is_none")]
798 pub extras: Option<HashMap<String, Value>>,
799}
800
801#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
803#[serde(rename_all = "lowercase")]
804pub enum ToolChoice {
805 Auto,
806 Required,
807 None,
808 Specific(String),
809}
810
811#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
817pub struct ChatRequest {
818 pub messages: Vec<Message>,
819 #[serde(default, skip_serializing_if = "Vec::is_empty")]
820 pub tools: Vec<ToolDefinition>,
821 #[serde(default, skip_serializing_if = "Option::is_none")]
822 pub tool_choice: Option<ToolChoice>,
823}
824
825impl ChatRequest {
826 pub fn new(messages: Vec<Message>) -> Self {
827 Self {
828 messages,
829 tools: vec![],
830 tool_choice: None,
831 }
832 }
833
834 pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
835 self.tools = tools;
836 self
837 }
838
839 pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
840 self.tool_choice = Some(choice);
841 self
842 }
843}
844
845#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
847pub struct ChatResponse {
848 pub message: Message,
849 pub usage: Option<TokenUsage>,
850}
851
852#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
857pub struct TokenUsage {
858 pub input_tokens: u32,
859 pub output_tokens: u32,
860 pub total_tokens: u32,
861 #[serde(default, skip_serializing_if = "Option::is_none")]
862 pub input_details: Option<InputTokenDetails>,
863 #[serde(default, skip_serializing_if = "Option::is_none")]
864 pub output_details: Option<OutputTokenDetails>,
865}
866
867#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
869pub struct InputTokenDetails {
870 #[serde(default)]
871 pub cached: u32,
872 #[serde(default)]
873 pub audio: u32,
874}
875
876#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
878pub struct OutputTokenDetails {
879 #[serde(default)]
880 pub reasoning: u32,
881 #[serde(default)]
882 pub audio: u32,
883}
884
885#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
891pub enum RunEvent {
892 RunStarted {
893 run_id: String,
894 session_id: String,
895 },
896 RunStep {
897 run_id: String,
898 step: usize,
899 },
900 LlmCalled {
901 run_id: String,
902 message_count: usize,
903 },
904 ToolCalled {
905 run_id: String,
906 tool_name: String,
907 },
908 RunFinished {
909 run_id: String,
910 output: String,
911 },
912 RunFailed {
913 run_id: String,
914 error: String,
915 },
916}
917
918#[derive(Debug, Error)]
924pub enum SynapticError {
925 #[error("prompt error: {0}")]
926 Prompt(String),
927 #[error("model error: {0}")]
928 Model(String),
929 #[error("tool error: {0}")]
930 Tool(String),
931 #[error("tool not found: {0}")]
932 ToolNotFound(String),
933 #[error("memory error: {0}")]
934 Memory(String),
935 #[error("rate limit: {0}")]
936 RateLimit(String),
937 #[error("timeout: {0}")]
938 Timeout(String),
939 #[error("validation error: {0}")]
940 Validation(String),
941 #[error("parsing error: {0}")]
942 Parsing(String),
943 #[error("callback error: {0}")]
944 Callback(String),
945 #[error("max steps exceeded: {max_steps}")]
946 MaxStepsExceeded { max_steps: usize },
947 #[error("embedding error: {0}")]
948 Embedding(String),
949 #[error("vector store error: {0}")]
950 VectorStore(String),
951 #[error("retriever error: {0}")]
952 Retriever(String),
953 #[error("loader error: {0}")]
954 Loader(String),
955 #[error("splitter error: {0}")]
956 Splitter(String),
957 #[error("graph error: {0}")]
958 Graph(String),
959 #[error("cache error: {0}")]
960 Cache(String),
961 #[error("store error: {0}")]
962 Store(String),
963 #[error("config error: {0}")]
964 Config(String),
965 #[error("mcp error: {0}")]
966 Mcp(String),
967}
968
969pub type ChatStream<'a> =
975 Pin<Box<dyn Stream<Item = Result<AIMessageChunk, SynapticError>> + Send + 'a>>;
976
977#[derive(Debug, Clone, Serialize, Deserialize)]
979pub struct ModelProfile {
980 pub name: String,
981 pub provider: String,
982 pub supports_tool_calling: bool,
983 pub supports_structured_output: bool,
984 pub supports_streaming: bool,
985 pub max_input_tokens: Option<usize>,
986 pub max_output_tokens: Option<usize>,
987}
988
989#[async_trait]
991pub trait ChatModel: Send + Sync {
992 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError>;
993
994 fn profile(&self) -> Option<ModelProfile> {
996 None
997 }
998
999 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
1000 Box::pin(async_stream::stream! {
1001 match self.chat(request).await {
1002 Ok(response) => {
1003 yield Ok(AIMessageChunk {
1004 content: response.message.content().to_string(),
1005 tool_calls: response.message.tool_calls().to_vec(),
1006 usage: response.usage,
1007 ..Default::default()
1008 });
1009 }
1010 Err(e) => yield Err(e),
1011 }
1012 })
1013 }
1014}
1015
1016#[async_trait]
1018pub trait Tool: Send + Sync {
1019 fn name(&self) -> &'static str;
1020 fn description(&self) -> &'static str;
1021
1022 fn parameters(&self) -> Option<Value> {
1023 None
1024 }
1025
1026 async fn call(&self, args: Value) -> Result<Value, SynapticError>;
1027
1028 fn as_tool_definition(&self) -> ToolDefinition {
1029 ToolDefinition {
1030 name: self.name().to_string(),
1031 description: self.description().to_string(),
1032 parameters: self
1033 .parameters()
1034 .unwrap_or(serde_json::json!({"type": "object", "properties": {}})),
1035 extras: None,
1036 }
1037 }
1038}
1039
1040#[async_trait]
1046pub trait MemoryStore: Send + Sync {
1047 async fn append(&self, session_id: &str, message: Message) -> Result<(), SynapticError>;
1048 async fn load(&self, session_id: &str) -> Result<Vec<Message>, SynapticError>;
1049 async fn clear(&self, session_id: &str) -> Result<(), SynapticError>;
1050}
1051
1052#[async_trait]
1054pub trait CallbackHandler: Send + Sync {
1055 async fn on_event(&self, event: RunEvent) -> Result<(), SynapticError>;
1056}
1057
1058#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1064pub struct RunnableConfig {
1065 #[serde(default)]
1066 pub tags: Vec<String>,
1067 #[serde(default)]
1068 pub metadata: HashMap<String, Value>,
1069 #[serde(default)]
1070 pub max_concurrency: Option<usize>,
1071 #[serde(default)]
1072 pub recursion_limit: Option<usize>,
1073 #[serde(default)]
1074 pub run_id: Option<String>,
1075 #[serde(default)]
1076 pub run_name: Option<String>,
1077}
1078
1079impl RunnableConfig {
1080 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
1081 self.tags = tags;
1082 self
1083 }
1084
1085 pub fn with_run_name(mut self, name: impl Into<String>) -> Self {
1086 self.run_name = Some(name.into());
1087 self
1088 }
1089
1090 pub fn with_run_id(mut self, id: impl Into<String>) -> Self {
1091 self.run_id = Some(id.into());
1092 self
1093 }
1094
1095 pub fn with_max_concurrency(mut self, max: usize) -> Self {
1096 self.max_concurrency = Some(max);
1097 self
1098 }
1099
1100 pub fn with_recursion_limit(mut self, limit: usize) -> Self {
1101 self.recursion_limit = Some(limit);
1102 self
1103 }
1104
1105 pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
1106 self.metadata.insert(key.into(), value);
1107 self
1108 }
1109}
1110
1111#[derive(Debug, Clone, Serialize, Deserialize)]
1117pub struct Item {
1118 pub namespace: Vec<String>,
1119 pub key: String,
1120 pub value: Value,
1121 pub created_at: String,
1122 pub updated_at: String,
1123 #[serde(default, skip_serializing_if = "Option::is_none")]
1125 pub score: Option<f64>,
1126}
1127
1128#[async_trait]
1133pub trait Store: Send + Sync {
1134 async fn get(&self, namespace: &[&str], key: &str) -> Result<Option<Item>, SynapticError>;
1136
1137 async fn search(
1139 &self,
1140 namespace: &[&str],
1141 query: Option<&str>,
1142 limit: usize,
1143 ) -> Result<Vec<Item>, SynapticError>;
1144
1145 async fn put(&self, namespace: &[&str], key: &str, value: Value) -> Result<(), SynapticError>;
1147
1148 async fn delete(&self, namespace: &[&str], key: &str) -> Result<(), SynapticError>;
1150
1151 async fn list_namespaces(&self, prefix: &[&str]) -> Result<Vec<Vec<String>>, SynapticError>;
1153}
1154
1155#[async_trait]
1161pub trait Embeddings: Send + Sync {
1162 async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError>;
1164
1165 async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError>;
1167}
1168
1169pub type StreamWriter = Arc<dyn Fn(Value) + Send + Sync>;
1175
1176#[derive(Clone)]
1182pub struct Runtime {
1183 pub store: Option<Arc<dyn Store>>,
1184 pub stream_writer: Option<StreamWriter>,
1185}
1186
1187#[derive(Clone)]
1189pub struct ToolRuntime {
1190 pub store: Option<Arc<dyn Store>>,
1191 pub stream_writer: Option<StreamWriter>,
1192 pub state: Option<Value>,
1193 pub tool_call_id: String,
1194 pub config: Option<RunnableConfig>,
1195}
1196
1197#[async_trait]
1207pub trait RuntimeAwareTool: Send + Sync {
1208 fn name(&self) -> &'static str;
1209 fn description(&self) -> &'static str;
1210
1211 fn parameters(&self) -> Option<Value> {
1212 None
1213 }
1214
1215 async fn call_with_runtime(
1216 &self,
1217 args: Value,
1218 runtime: ToolRuntime,
1219 ) -> Result<Value, SynapticError>;
1220
1221 fn as_tool_definition(&self) -> ToolDefinition {
1222 ToolDefinition {
1223 name: self.name().to_string(),
1224 description: self.description().to_string(),
1225 parameters: self
1226 .parameters()
1227 .unwrap_or(serde_json::json!({"type": "object", "properties": {}})),
1228 extras: None,
1229 }
1230 }
1231}
1232
1233pub struct RuntimeAwareToolAdapter {
1238 inner: Arc<dyn RuntimeAwareTool>,
1239 runtime: Arc<tokio::sync::RwLock<Option<ToolRuntime>>>,
1240}
1241
1242impl RuntimeAwareToolAdapter {
1243 pub fn new(tool: Arc<dyn RuntimeAwareTool>) -> Self {
1244 Self {
1245 inner: tool,
1246 runtime: Arc::new(tokio::sync::RwLock::new(None)),
1247 }
1248 }
1249
1250 pub async fn set_runtime(&self, runtime: ToolRuntime) {
1251 *self.runtime.write().await = Some(runtime);
1252 }
1253}
1254
1255#[async_trait]
1256impl Tool for RuntimeAwareToolAdapter {
1257 fn name(&self) -> &'static str {
1258 self.inner.name()
1259 }
1260
1261 fn description(&self) -> &'static str {
1262 self.inner.description()
1263 }
1264
1265 fn parameters(&self) -> Option<Value> {
1266 self.inner.parameters()
1267 }
1268
1269 async fn call(&self, args: Value) -> Result<Value, SynapticError> {
1270 let runtime = self.runtime.read().await.clone().unwrap_or(ToolRuntime {
1271 store: None,
1272 stream_writer: None,
1273 state: None,
1274 tool_call_id: String::new(),
1275 config: None,
1276 });
1277 self.inner.call_with_runtime(args, runtime).await
1278 }
1279}
1280
1281#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1287pub struct Document {
1288 pub id: String,
1289 pub content: String,
1290 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
1291 pub metadata: HashMap<String, Value>,
1292}
1293
1294impl Document {
1295 pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
1296 Self {
1297 id: id.into(),
1298 content: content.into(),
1299 metadata: HashMap::new(),
1300 }
1301 }
1302
1303 pub fn with_metadata(
1304 id: impl Into<String>,
1305 content: impl Into<String>,
1306 metadata: HashMap<String, Value>,
1307 ) -> Self {
1308 Self {
1309 id: id.into(),
1310 content: content.into(),
1311 metadata,
1312 }
1313 }
1314}
1315
1316#[async_trait]
1322pub trait Retriever: Send + Sync {
1323 async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapticError>;
1324}
1325
1326#[async_trait]
1332pub trait VectorStore: Send + Sync {
1333 async fn add_documents(
1335 &self,
1336 docs: Vec<Document>,
1337 embeddings: &dyn Embeddings,
1338 ) -> Result<Vec<String>, SynapticError>;
1339
1340 async fn similarity_search(
1342 &self,
1343 query: &str,
1344 k: usize,
1345 embeddings: &dyn Embeddings,
1346 ) -> Result<Vec<Document>, SynapticError>;
1347
1348 async fn similarity_search_with_score(
1350 &self,
1351 query: &str,
1352 k: usize,
1353 embeddings: &dyn Embeddings,
1354 ) -> Result<Vec<(Document, f32)>, SynapticError>;
1355
1356 async fn similarity_search_by_vector(
1358 &self,
1359 embedding: &[f32],
1360 k: usize,
1361 ) -> Result<Vec<Document>, SynapticError>;
1362
1363 async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError>;
1365}
1366
1367#[async_trait]
1373pub trait Loader: Send + Sync {
1374 async fn load(&self) -> Result<Vec<Document>, SynapticError>;
1376
1377 fn lazy_load(
1379 &self,
1380 ) -> Pin<Box<dyn Stream<Item = Result<Document, SynapticError>> + Send + '_>> {
1381 Box::pin(async_stream::stream! {
1382 match self.load().await {
1383 Ok(docs) => {
1384 for doc in docs {
1385 yield Ok(doc);
1386 }
1387 }
1388 Err(e) => yield Err(e),
1389 }
1390 })
1391 }
1392}
1393
1394#[async_trait]
1400pub trait LlmCache: Send + Sync {
1401 async fn get(&self, key: &str) -> Result<Option<ChatResponse>, SynapticError>;
1403 async fn put(&self, key: &str, response: &ChatResponse) -> Result<(), SynapticError>;
1405 async fn clear(&self) -> Result<(), SynapticError>;
1407}
1408
1409#[derive(Debug, Clone)]
1415pub struct EntrypointConfig {
1416 pub name: &'static str,
1417 pub checkpointer: Option<&'static str>,
1418}
1419
1420pub type EntrypointFn = dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<Value, SynapticError>> + Send>>
1425 + Send
1426 + Sync;
1427
1428pub struct Entrypoint {
1429 pub config: EntrypointConfig,
1430 pub invoke_fn: Box<EntrypointFn>,
1431}
1432
1433impl Entrypoint {
1434 pub async fn invoke(&self, input: Value) -> Result<Value, SynapticError> {
1435 (self.invoke_fn)(input).await
1436 }
1437}
1438
1439pub fn now_iso() -> String {
1445 format!("{:?}", std::time::SystemTime::now())
1446}
1447
1448pub fn encode_namespace(namespace: &[&str]) -> String {
1450 namespace.join(":")
1451}
1452
1453pub fn validate_table_name(name: &str) -> Result<(), SynapticError> {
1455 if name.is_empty() {
1456 return Err(SynapticError::Store(
1457 "table name must not be empty".to_string(),
1458 ));
1459 }
1460 if !name
1461 .chars()
1462 .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.')
1463 {
1464 return Err(SynapticError::Store(format!(
1465 "invalid table name '{name}': only alphanumeric, underscore, and dot characters are allowed",
1466 )));
1467 }
1468 Ok(())
1469}