1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::Stream;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use thiserror::Error;
11
12#[cfg(feature = "schemars")]
13pub use schemars;
14
15#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
21#[serde(tag = "type", rename_all = "snake_case")]
22pub enum ContentBlock {
23 Text {
24 text: String,
25 },
26 Image {
27 url: String,
28 #[serde(default, skip_serializing_if = "Option::is_none")]
29 detail: Option<String>,
30 },
31 Audio {
32 url: String,
33 },
34 Video {
35 url: String,
36 },
37 File {
38 url: String,
39 #[serde(default, skip_serializing_if = "Option::is_none")]
40 mime_type: Option<String>,
41 },
42 Data {
43 data: Value,
44 },
45 Reasoning {
46 content: String,
47 },
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
56#[serde(tag = "role")]
57pub enum Message {
58 #[serde(rename = "system")]
59 System {
60 content: String,
61 #[serde(default, skip_serializing_if = "Option::is_none")]
62 id: Option<String>,
63 #[serde(default, skip_serializing_if = "Option::is_none")]
64 name: Option<String>,
65 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
66 additional_kwargs: HashMap<String, Value>,
67 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
68 response_metadata: HashMap<String, Value>,
69 #[serde(default, skip_serializing_if = "Vec::is_empty")]
70 content_blocks: Vec<ContentBlock>,
71 },
72 #[serde(rename = "human")]
73 Human {
74 content: String,
75 #[serde(default, skip_serializing_if = "Option::is_none")]
76 id: Option<String>,
77 #[serde(default, skip_serializing_if = "Option::is_none")]
78 name: Option<String>,
79 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
80 additional_kwargs: HashMap<String, Value>,
81 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
82 response_metadata: HashMap<String, Value>,
83 #[serde(default, skip_serializing_if = "Vec::is_empty")]
84 content_blocks: Vec<ContentBlock>,
85 },
86 #[serde(rename = "assistant")]
87 AI {
88 content: String,
89 #[serde(default, skip_serializing_if = "Vec::is_empty")]
90 tool_calls: Vec<ToolCall>,
91 #[serde(default, skip_serializing_if = "Option::is_none")]
92 id: Option<String>,
93 #[serde(default, skip_serializing_if = "Option::is_none")]
94 name: Option<String>,
95 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
96 additional_kwargs: HashMap<String, Value>,
97 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
98 response_metadata: HashMap<String, Value>,
99 #[serde(default, skip_serializing_if = "Vec::is_empty")]
100 content_blocks: Vec<ContentBlock>,
101 #[serde(default, skip_serializing_if = "Option::is_none")]
102 usage_metadata: Option<TokenUsage>,
103 #[serde(default, skip_serializing_if = "Vec::is_empty")]
104 invalid_tool_calls: Vec<InvalidToolCall>,
105 },
106 #[serde(rename = "tool")]
107 Tool {
108 content: String,
109 tool_call_id: String,
110 #[serde(default, skip_serializing_if = "Option::is_none")]
111 id: Option<String>,
112 #[serde(default, skip_serializing_if = "Option::is_none")]
113 name: Option<String>,
114 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
115 additional_kwargs: HashMap<String, Value>,
116 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
117 response_metadata: HashMap<String, Value>,
118 #[serde(default, skip_serializing_if = "Vec::is_empty")]
119 content_blocks: Vec<ContentBlock>,
120 },
121 #[serde(rename = "chat")]
122 Chat {
123 custom_role: String,
124 content: String,
125 #[serde(default, skip_serializing_if = "Option::is_none")]
126 id: Option<String>,
127 #[serde(default, skip_serializing_if = "Option::is_none")]
128 name: Option<String>,
129 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
130 additional_kwargs: HashMap<String, Value>,
131 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
132 response_metadata: HashMap<String, Value>,
133 #[serde(default, skip_serializing_if = "Vec::is_empty")]
134 content_blocks: Vec<ContentBlock>,
135 },
136 #[serde(rename = "remove")]
139 Remove {
140 id: String,
142 },
143}
144
145macro_rules! set_message_field {
148 ($self:expr, $field:ident, $value:expr) => {
149 match $self {
150 Message::System { $field, .. } => *$field = $value,
151 Message::Human { $field, .. } => *$field = $value,
152 Message::AI { $field, .. } => *$field = $value,
153 Message::Tool { $field, .. } => *$field = $value,
154 Message::Chat { $field, .. } => *$field = $value,
155 Message::Remove { .. } => { }
156 }
157 };
158}
159
160macro_rules! get_message_field {
163 ($self:expr, $field:ident) => {
164 match $self {
165 Message::System { $field, .. } => $field,
166 Message::Human { $field, .. } => $field,
167 Message::AI { $field, .. } => $field,
168 Message::Tool { $field, .. } => $field,
169 Message::Chat { $field, .. } => $field,
170 Message::Remove { .. } => unreachable!("get_message_field called on Remove variant"),
171 }
172 };
173}
174
175impl Message {
176 pub fn system(content: impl Into<String>) -> Self {
179 Message::System {
180 content: content.into(),
181 id: None,
182 name: None,
183 additional_kwargs: HashMap::new(),
184 response_metadata: HashMap::new(),
185 content_blocks: Vec::new(),
186 }
187 }
188
189 pub fn human(content: impl Into<String>) -> Self {
190 Message::Human {
191 content: content.into(),
192 id: None,
193 name: None,
194 additional_kwargs: HashMap::new(),
195 response_metadata: HashMap::new(),
196 content_blocks: Vec::new(),
197 }
198 }
199
200 pub fn ai(content: impl Into<String>) -> Self {
201 Message::AI {
202 content: content.into(),
203 tool_calls: vec![],
204 id: None,
205 name: None,
206 additional_kwargs: HashMap::new(),
207 response_metadata: HashMap::new(),
208 content_blocks: Vec::new(),
209 usage_metadata: None,
210 invalid_tool_calls: Vec::new(),
211 }
212 }
213
214 pub fn ai_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
215 Message::AI {
216 content: content.into(),
217 tool_calls,
218 id: None,
219 name: None,
220 additional_kwargs: HashMap::new(),
221 response_metadata: HashMap::new(),
222 content_blocks: Vec::new(),
223 usage_metadata: None,
224 invalid_tool_calls: Vec::new(),
225 }
226 }
227
228 pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
229 Message::Tool {
230 content: content.into(),
231 tool_call_id: tool_call_id.into(),
232 id: None,
233 name: None,
234 additional_kwargs: HashMap::new(),
235 response_metadata: HashMap::new(),
236 content_blocks: Vec::new(),
237 }
238 }
239
240 pub fn chat(role: impl Into<String>, content: impl Into<String>) -> Self {
241 Message::Chat {
242 custom_role: role.into(),
243 content: content.into(),
244 id: None,
245 name: None,
246 additional_kwargs: HashMap::new(),
247 response_metadata: HashMap::new(),
248 content_blocks: Vec::new(),
249 }
250 }
251
252 pub fn remove(id: impl Into<String>) -> Self {
254 Message::Remove { id: id.into() }
255 }
256
257 pub fn with_id(mut self, value: impl Into<String>) -> Self {
260 set_message_field!(&mut self, id, Some(value.into()));
261 self
262 }
263
264 pub fn with_name(mut self, value: impl Into<String>) -> Self {
265 set_message_field!(&mut self, name, Some(value.into()));
266 self
267 }
268
269 pub fn with_additional_kwarg(mut self, key: impl Into<String>, value: Value) -> Self {
270 match &mut self {
271 Message::System {
272 additional_kwargs, ..
273 }
274 | Message::Human {
275 additional_kwargs, ..
276 }
277 | Message::AI {
278 additional_kwargs, ..
279 }
280 | Message::Tool {
281 additional_kwargs, ..
282 }
283 | Message::Chat {
284 additional_kwargs, ..
285 } => {
286 additional_kwargs.insert(key.into(), value);
287 }
288 Message::Remove { .. } => { }
289 }
290 self
291 }
292
293 pub fn with_response_metadata_entry(mut self, key: impl Into<String>, value: Value) -> Self {
294 match &mut self {
295 Message::System {
296 response_metadata, ..
297 }
298 | Message::Human {
299 response_metadata, ..
300 }
301 | Message::AI {
302 response_metadata, ..
303 }
304 | Message::Tool {
305 response_metadata, ..
306 }
307 | Message::Chat {
308 response_metadata, ..
309 } => {
310 response_metadata.insert(key.into(), value);
311 }
312 Message::Remove { .. } => { }
313 }
314 self
315 }
316
317 pub fn with_content_blocks(mut self, blocks: Vec<ContentBlock>) -> Self {
318 set_message_field!(&mut self, content_blocks, blocks);
319 self
320 }
321
322 pub fn with_usage_metadata(mut self, usage: TokenUsage) -> Self {
323 if let Message::AI { usage_metadata, .. } = &mut self {
324 *usage_metadata = Some(usage);
325 }
326 self
327 }
328
329 pub fn content(&self) -> &str {
332 match self {
333 Message::Remove { .. } => "",
334 other => get_message_field!(other, content),
335 }
336 }
337
338 pub fn role(&self) -> &str {
339 match self {
340 Message::System { .. } => "system",
341 Message::Human { .. } => "human",
342 Message::AI { .. } => "assistant",
343 Message::Tool { .. } => "tool",
344 Message::Chat { custom_role, .. } => custom_role,
345 Message::Remove { .. } => "remove",
346 }
347 }
348
349 pub fn is_system(&self) -> bool {
350 matches!(self, Message::System { .. })
351 }
352
353 pub fn is_human(&self) -> bool {
354 matches!(self, Message::Human { .. })
355 }
356
357 pub fn is_ai(&self) -> bool {
358 matches!(self, Message::AI { .. })
359 }
360
361 pub fn is_tool(&self) -> bool {
362 matches!(self, Message::Tool { .. })
363 }
364
365 pub fn is_chat(&self) -> bool {
366 matches!(self, Message::Chat { .. })
367 }
368
369 pub fn is_remove(&self) -> bool {
370 matches!(self, Message::Remove { .. })
371 }
372
373 pub fn tool_calls(&self) -> &[ToolCall] {
374 match self {
375 Message::AI { tool_calls, .. } => tool_calls,
376 _ => &[],
377 }
378 }
379
380 pub fn tool_call_id(&self) -> Option<&str> {
381 match self {
382 Message::Tool { tool_call_id, .. } => Some(tool_call_id),
383 _ => None,
384 }
385 }
386
387 pub fn id(&self) -> Option<&str> {
388 match self {
389 Message::Remove { id } => Some(id),
390 other => get_message_field!(other, id).as_deref(),
391 }
392 }
393
394 pub fn name(&self) -> Option<&str> {
395 match self {
396 Message::Remove { .. } => None,
397 other => get_message_field!(other, name).as_deref(),
398 }
399 }
400
401 pub fn additional_kwargs(&self) -> &HashMap<String, Value> {
402 match self {
403 Message::System {
404 additional_kwargs, ..
405 }
406 | Message::Human {
407 additional_kwargs, ..
408 }
409 | Message::AI {
410 additional_kwargs, ..
411 }
412 | Message::Tool {
413 additional_kwargs, ..
414 }
415 | Message::Chat {
416 additional_kwargs, ..
417 } => additional_kwargs,
418 Message::Remove { .. } => {
419 static EMPTY: std::sync::OnceLock<HashMap<String, Value>> =
420 std::sync::OnceLock::new();
421 EMPTY.get_or_init(HashMap::new)
422 }
423 }
424 }
425
426 pub fn response_metadata(&self) -> &HashMap<String, Value> {
427 match self {
428 Message::System {
429 response_metadata, ..
430 }
431 | Message::Human {
432 response_metadata, ..
433 }
434 | Message::AI {
435 response_metadata, ..
436 }
437 | Message::Tool {
438 response_metadata, ..
439 }
440 | Message::Chat {
441 response_metadata, ..
442 } => response_metadata,
443 Message::Remove { .. } => {
444 static EMPTY: std::sync::OnceLock<HashMap<String, Value>> =
445 std::sync::OnceLock::new();
446 EMPTY.get_or_init(HashMap::new)
447 }
448 }
449 }
450
451 pub fn content_blocks(&self) -> &[ContentBlock] {
452 match self {
453 Message::Remove { .. } => &[],
454 other => get_message_field!(other, content_blocks),
455 }
456 }
457
458 pub fn remove_id(&self) -> Option<&str> {
460 match self {
461 Message::Remove { id } => Some(id),
462 _ => None,
463 }
464 }
465
466 pub fn usage_metadata(&self) -> Option<&TokenUsage> {
467 match self {
468 Message::AI { usage_metadata, .. } => usage_metadata.as_ref(),
469 _ => None,
470 }
471 }
472
473 pub fn invalid_tool_calls(&self) -> &[InvalidToolCall] {
474 match self {
475 Message::AI {
476 invalid_tool_calls, ..
477 } => invalid_tool_calls,
478 _ => &[],
479 }
480 }
481}
482
483pub fn filter_messages(
489 messages: &[Message],
490 include_types: Option<&[&str]>,
491 exclude_types: Option<&[&str]>,
492 include_names: Option<&[&str]>,
493 exclude_names: Option<&[&str]>,
494 include_ids: Option<&[&str]>,
495 exclude_ids: Option<&[&str]>,
496) -> Vec<Message> {
497 messages
498 .iter()
499 .filter(|msg| {
500 if let Some(include) = include_types {
501 if !include.contains(&msg.role()) {
502 return false;
503 }
504 }
505 if let Some(exclude) = exclude_types {
506 if exclude.contains(&msg.role()) {
507 return false;
508 }
509 }
510 if let Some(include) = include_names {
511 match msg.name() {
512 Some(name) => {
513 if !include.contains(&name) {
514 return false;
515 }
516 }
517 None => return false,
518 }
519 }
520 if let Some(exclude) = exclude_names {
521 if let Some(name) = msg.name() {
522 if exclude.contains(&name) {
523 return false;
524 }
525 }
526 }
527 if let Some(include) = include_ids {
528 match msg.id() {
529 Some(id) => {
530 if !include.contains(&id) {
531 return false;
532 }
533 }
534 None => return false,
535 }
536 }
537 if let Some(exclude) = exclude_ids {
538 if let Some(id) = msg.id() {
539 if exclude.contains(&id) {
540 return false;
541 }
542 }
543 }
544 true
545 })
546 .cloned()
547 .collect()
548}
549
550#[derive(Debug, Clone, Copy, PartialEq, Eq)]
552pub enum TrimStrategy {
553 First,
555 Last,
557}
558
559pub fn trim_messages(
565 messages: Vec<Message>,
566 max_tokens: usize,
567 token_counter: impl Fn(&Message) -> usize,
568 strategy: TrimStrategy,
569 include_system: bool,
570) -> Vec<Message> {
571 if messages.is_empty() {
572 return messages;
573 }
574
575 match strategy {
576 TrimStrategy::First => {
577 let mut result = Vec::new();
578 let mut total = 0;
579 for msg in messages {
580 let count = token_counter(&msg);
581 if total + count > max_tokens {
582 break;
583 }
584 total += count;
585 result.push(msg);
586 }
587 result
588 }
589 TrimStrategy::Last => {
590 let (system_msg, rest) = if include_system && messages[0].is_system() {
591 (Some(messages[0].clone()), &messages[1..])
592 } else {
593 (None, messages.as_slice())
594 };
595
596 let system_tokens = system_msg.as_ref().map(&token_counter).unwrap_or(0);
597 let budget = max_tokens.saturating_sub(system_tokens);
598
599 let mut selected = Vec::new();
600 let mut total = 0;
601 for msg in rest.iter().rev() {
602 let count = token_counter(msg);
603 if total + count > budget {
604 break;
605 }
606 total += count;
607 selected.push(msg.clone());
608 }
609 selected.reverse();
610
611 let mut result = Vec::new();
612 if let Some(sys) = system_msg {
613 result.push(sys);
614 }
615 result.extend(selected);
616 result
617 }
618 }
619}
620
621pub fn merge_message_runs(messages: Vec<Message>) -> Vec<Message> {
623 if messages.is_empty() {
624 return messages;
625 }
626
627 let mut result: Vec<Message> = Vec::new();
628
629 for msg in messages {
630 let should_merge = result
631 .last()
632 .map(|last| last.role() == msg.role())
633 .unwrap_or(false);
634
635 if should_merge {
636 let last = result.last_mut().unwrap();
637 let merged_content = format!("{}\n{}", last.content(), msg.content());
639 match last {
640 Message::System { content, .. } => *content = merged_content,
641 Message::Human { content, .. } => *content = merged_content,
642 Message::AI {
643 content,
644 tool_calls,
645 invalid_tool_calls,
646 ..
647 } => {
648 *content = merged_content;
649 tool_calls.extend(msg.tool_calls().to_vec());
650 invalid_tool_calls.extend(msg.invalid_tool_calls().to_vec());
651 }
652 Message::Tool { content, .. } => *content = merged_content,
653 Message::Chat { content, .. } => *content = merged_content,
654 Message::Remove { .. } => { }
655 }
656 } else {
657 result.push(msg);
658 }
659 }
660
661 result
662}
663
664pub fn get_buffer_string(messages: &[Message], human_prefix: &str, ai_prefix: &str) -> String {
666 messages
667 .iter()
668 .map(|msg| {
669 let prefix = match msg {
670 Message::System { .. } => "System",
671 Message::Human { .. } => human_prefix,
672 Message::AI { .. } => ai_prefix,
673 Message::Tool { .. } => "Tool",
674 Message::Chat { custom_role, .. } => custom_role.as_str(),
675 Message::Remove { .. } => "Remove",
676 };
677 format!("{prefix}: {}", msg.content())
678 })
679 .collect::<Vec<_>>()
680 .join("\n")
681}
682
683#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
689pub struct AIMessageChunk {
690 pub content: String,
691 #[serde(default, skip_serializing_if = "Vec::is_empty")]
692 pub tool_calls: Vec<ToolCall>,
693 #[serde(default, skip_serializing_if = "Option::is_none")]
694 pub usage: Option<TokenUsage>,
695 #[serde(default, skip_serializing_if = "Option::is_none")]
696 pub id: Option<String>,
697 #[serde(default, skip_serializing_if = "Vec::is_empty")]
698 pub tool_call_chunks: Vec<ToolCallChunk>,
699 #[serde(default, skip_serializing_if = "Vec::is_empty")]
700 pub invalid_tool_calls: Vec<InvalidToolCall>,
701}
702
703impl AIMessageChunk {
704 pub fn into_message(self) -> Message {
705 Message::ai_with_tool_calls(self.content, self.tool_calls)
706 }
707}
708
709impl std::ops::Add for AIMessageChunk {
710 type Output = Self;
711
712 fn add(mut self, rhs: Self) -> Self {
713 self += rhs;
714 self
715 }
716}
717
718impl std::ops::AddAssign for AIMessageChunk {
719 fn add_assign(&mut self, rhs: Self) {
720 self.content.push_str(&rhs.content);
721 self.tool_calls.extend(rhs.tool_calls);
722 self.tool_call_chunks.extend(rhs.tool_call_chunks);
723 self.invalid_tool_calls.extend(rhs.invalid_tool_calls);
724 if self.id.is_none() {
725 self.id = rhs.id;
726 }
727 match (&mut self.usage, rhs.usage) {
728 (Some(u), Some(rhs_u)) => {
729 u.input_tokens += rhs_u.input_tokens;
730 u.output_tokens += rhs_u.output_tokens;
731 u.total_tokens += rhs_u.total_tokens;
732 }
733 (None, Some(rhs_u)) => {
734 self.usage = Some(rhs_u);
735 }
736 _ => {}
737 }
738 }
739}
740
741#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
747pub struct ToolCall {
748 pub id: String,
749 pub name: String,
750 pub arguments: Value,
751}
752
753#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
755pub struct InvalidToolCall {
756 #[serde(default, skip_serializing_if = "Option::is_none")]
757 pub id: Option<String>,
758 #[serde(default, skip_serializing_if = "Option::is_none")]
759 pub name: Option<String>,
760 #[serde(default, skip_serializing_if = "Option::is_none")]
761 pub arguments: Option<String>,
762 pub error: String,
763}
764
765#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
767pub struct ToolCallChunk {
768 #[serde(default, skip_serializing_if = "Option::is_none")]
769 pub id: Option<String>,
770 #[serde(default, skip_serializing_if = "Option::is_none")]
771 pub name: Option<String>,
772 #[serde(default, skip_serializing_if = "Option::is_none")]
773 pub arguments: Option<String>,
774 #[serde(default, skip_serializing_if = "Option::is_none")]
775 pub index: Option<usize>,
776}
777
778#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
780pub struct ToolDefinition {
781 pub name: String,
782 pub description: String,
783 pub parameters: Value,
784 #[serde(default, skip_serializing_if = "Option::is_none")]
786 pub extras: Option<HashMap<String, Value>>,
787}
788
789#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
791#[serde(rename_all = "lowercase")]
792pub enum ToolChoice {
793 Auto,
794 Required,
795 None,
796 Specific(String),
797}
798
799#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
805pub struct ChatRequest {
806 pub messages: Vec<Message>,
807 #[serde(default, skip_serializing_if = "Vec::is_empty")]
808 pub tools: Vec<ToolDefinition>,
809 #[serde(default, skip_serializing_if = "Option::is_none")]
810 pub tool_choice: Option<ToolChoice>,
811}
812
813impl ChatRequest {
814 pub fn new(messages: Vec<Message>) -> Self {
815 Self {
816 messages,
817 tools: vec![],
818 tool_choice: None,
819 }
820 }
821
822 pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
823 self.tools = tools;
824 self
825 }
826
827 pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
828 self.tool_choice = Some(choice);
829 self
830 }
831}
832
833#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
835pub struct ChatResponse {
836 pub message: Message,
837 pub usage: Option<TokenUsage>,
838}
839
840#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
845pub struct TokenUsage {
846 pub input_tokens: u32,
847 pub output_tokens: u32,
848 pub total_tokens: u32,
849 #[serde(default, skip_serializing_if = "Option::is_none")]
850 pub input_details: Option<InputTokenDetails>,
851 #[serde(default, skip_serializing_if = "Option::is_none")]
852 pub output_details: Option<OutputTokenDetails>,
853}
854
855#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
857pub struct InputTokenDetails {
858 #[serde(default)]
859 pub cached: u32,
860 #[serde(default)]
861 pub audio: u32,
862}
863
864#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
866pub struct OutputTokenDetails {
867 #[serde(default)]
868 pub reasoning: u32,
869 #[serde(default)]
870 pub audio: u32,
871}
872
873#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
879pub enum RunEvent {
880 RunStarted {
881 run_id: String,
882 session_id: String,
883 },
884 RunStep {
885 run_id: String,
886 step: usize,
887 },
888 LlmCalled {
889 run_id: String,
890 message_count: usize,
891 },
892 ToolCalled {
893 run_id: String,
894 tool_name: String,
895 },
896 RunFinished {
897 run_id: String,
898 output: String,
899 },
900 RunFailed {
901 run_id: String,
902 error: String,
903 },
904}
905
906#[derive(Debug, Error)]
912pub enum SynapticError {
913 #[error("prompt error: {0}")]
914 Prompt(String),
915 #[error("model error: {0}")]
916 Model(String),
917 #[error("tool error: {0}")]
918 Tool(String),
919 #[error("tool not found: {0}")]
920 ToolNotFound(String),
921 #[error("memory error: {0}")]
922 Memory(String),
923 #[error("rate limit: {0}")]
924 RateLimit(String),
925 #[error("timeout: {0}")]
926 Timeout(String),
927 #[error("validation error: {0}")]
928 Validation(String),
929 #[error("parsing error: {0}")]
930 Parsing(String),
931 #[error("callback error: {0}")]
932 Callback(String),
933 #[error("max steps exceeded: {max_steps}")]
934 MaxStepsExceeded { max_steps: usize },
935 #[error("embedding error: {0}")]
936 Embedding(String),
937 #[error("vector store error: {0}")]
938 VectorStore(String),
939 #[error("retriever error: {0}")]
940 Retriever(String),
941 #[error("loader error: {0}")]
942 Loader(String),
943 #[error("splitter error: {0}")]
944 Splitter(String),
945 #[error("graph error: {0}")]
946 Graph(String),
947 #[error("cache error: {0}")]
948 Cache(String),
949 #[error("config error: {0}")]
950 Config(String),
951 #[error("mcp error: {0}")]
952 Mcp(String),
953}
954
955pub type ChatStream<'a> =
961 Pin<Box<dyn Stream<Item = Result<AIMessageChunk, SynapticError>> + Send + 'a>>;
962
963#[derive(Debug, Clone, Serialize, Deserialize)]
965pub struct ModelProfile {
966 pub name: String,
967 pub provider: String,
968 pub supports_tool_calling: bool,
969 pub supports_structured_output: bool,
970 pub supports_streaming: bool,
971 pub max_input_tokens: Option<usize>,
972 pub max_output_tokens: Option<usize>,
973}
974
975#[async_trait]
977pub trait ChatModel: Send + Sync {
978 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError>;
979
980 fn profile(&self) -> Option<ModelProfile> {
982 None
983 }
984
985 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
986 Box::pin(async_stream::stream! {
987 match self.chat(request).await {
988 Ok(response) => {
989 yield Ok(AIMessageChunk {
990 content: response.message.content().to_string(),
991 tool_calls: response.message.tool_calls().to_vec(),
992 usage: response.usage,
993 ..Default::default()
994 });
995 }
996 Err(e) => yield Err(e),
997 }
998 })
999 }
1000}
1001
1002#[async_trait]
1004pub trait Tool: Send + Sync {
1005 fn name(&self) -> &'static str;
1006 fn description(&self) -> &'static str;
1007
1008 fn parameters(&self) -> Option<Value> {
1009 None
1010 }
1011
1012 async fn call(&self, args: Value) -> Result<Value, SynapticError>;
1013
1014 fn as_tool_definition(&self) -> ToolDefinition {
1015 ToolDefinition {
1016 name: self.name().to_string(),
1017 description: self.description().to_string(),
1018 parameters: self
1019 .parameters()
1020 .unwrap_or(serde_json::json!({"type": "object", "properties": {}})),
1021 extras: None,
1022 }
1023 }
1024}
1025
1026#[derive(Debug, Clone, Default)]
1035pub struct ToolContext {
1036 pub state: Option<Value>,
1038 pub tool_call_id: String,
1040}
1041
1042#[async_trait]
1048pub trait ContextAwareTool: Send + Sync {
1049 fn name(&self) -> &'static str;
1050 fn description(&self) -> &'static str;
1051 async fn call_with_context(
1052 &self,
1053 args: Value,
1054 ctx: ToolContext,
1055 ) -> Result<Value, SynapticError>;
1056}
1057
1058pub struct ContextAwareToolAdapter {
1063 inner: Arc<dyn ContextAwareTool>,
1064}
1065
1066impl ContextAwareToolAdapter {
1067 pub fn new(inner: Arc<dyn ContextAwareTool>) -> Self {
1068 Self { inner }
1069 }
1070}
1071
1072#[async_trait]
1073impl Tool for ContextAwareToolAdapter {
1074 fn name(&self) -> &'static str {
1075 self.inner.name()
1076 }
1077
1078 fn description(&self) -> &'static str {
1079 self.inner.description()
1080 }
1081
1082 async fn call(&self, args: Value) -> Result<Value, SynapticError> {
1083 self.inner
1084 .call_with_context(args, ToolContext::default())
1085 .await
1086 }
1087}
1088
1089#[async_trait]
1095pub trait MemoryStore: Send + Sync {
1096 async fn append(&self, session_id: &str, message: Message) -> Result<(), SynapticError>;
1097 async fn load(&self, session_id: &str) -> Result<Vec<Message>, SynapticError>;
1098 async fn clear(&self, session_id: &str) -> Result<(), SynapticError>;
1099}
1100
1101#[async_trait]
1103pub trait CallbackHandler: Send + Sync {
1104 async fn on_event(&self, event: RunEvent) -> Result<(), SynapticError>;
1105}
1106
1107#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1113pub struct RunnableConfig {
1114 #[serde(default)]
1115 pub tags: Vec<String>,
1116 #[serde(default)]
1117 pub metadata: HashMap<String, Value>,
1118 #[serde(default)]
1119 pub max_concurrency: Option<usize>,
1120 #[serde(default)]
1121 pub recursion_limit: Option<usize>,
1122 #[serde(default)]
1123 pub run_id: Option<String>,
1124 #[serde(default)]
1125 pub run_name: Option<String>,
1126}
1127
1128impl RunnableConfig {
1129 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
1130 self.tags = tags;
1131 self
1132 }
1133
1134 pub fn with_run_name(mut self, name: impl Into<String>) -> Self {
1135 self.run_name = Some(name.into());
1136 self
1137 }
1138
1139 pub fn with_run_id(mut self, id: impl Into<String>) -> Self {
1140 self.run_id = Some(id.into());
1141 self
1142 }
1143
1144 pub fn with_max_concurrency(mut self, max: usize) -> Self {
1145 self.max_concurrency = Some(max);
1146 self
1147 }
1148
1149 pub fn with_recursion_limit(mut self, limit: usize) -> Self {
1150 self.recursion_limit = Some(limit);
1151 self
1152 }
1153
1154 pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
1155 self.metadata.insert(key.into(), value);
1156 self
1157 }
1158}
1159
1160#[derive(Debug, Clone, Serialize, Deserialize)]
1166pub struct Item {
1167 pub namespace: Vec<String>,
1168 pub key: String,
1169 pub value: Value,
1170 pub created_at: String,
1171 pub updated_at: String,
1172 #[serde(default, skip_serializing_if = "Option::is_none")]
1174 pub score: Option<f64>,
1175}
1176
1177#[async_trait]
1182pub trait Store: Send + Sync {
1183 async fn get(&self, namespace: &[&str], key: &str) -> Result<Option<Item>, SynapticError>;
1185
1186 async fn search(
1188 &self,
1189 namespace: &[&str],
1190 query: Option<&str>,
1191 limit: usize,
1192 ) -> Result<Vec<Item>, SynapticError>;
1193
1194 async fn put(&self, namespace: &[&str], key: &str, value: Value) -> Result<(), SynapticError>;
1196
1197 async fn delete(&self, namespace: &[&str], key: &str) -> Result<(), SynapticError>;
1199
1200 async fn list_namespaces(&self, prefix: &[&str]) -> Result<Vec<Vec<String>>, SynapticError>;
1202}
1203
1204#[async_trait]
1210pub trait Embeddings: Send + Sync {
1211 async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError>;
1213
1214 async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError>;
1216}
1217
1218pub type StreamWriter = Arc<dyn Fn(Value) + Send + Sync>;
1224
1225#[derive(Clone)]
1231pub struct Runtime {
1232 pub store: Option<Arc<dyn Store>>,
1233 pub stream_writer: Option<StreamWriter>,
1234}
1235
1236#[derive(Clone)]
1238pub struct ToolRuntime {
1239 pub store: Option<Arc<dyn Store>>,
1240 pub stream_writer: Option<StreamWriter>,
1241 pub state: Option<Value>,
1242 pub tool_call_id: String,
1243 pub config: Option<RunnableConfig>,
1244}
1245
1246#[async_trait]
1256pub trait RuntimeAwareTool: Send + Sync {
1257 fn name(&self) -> &'static str;
1258 fn description(&self) -> &'static str;
1259
1260 fn parameters(&self) -> Option<Value> {
1261 None
1262 }
1263
1264 async fn call_with_runtime(
1265 &self,
1266 args: Value,
1267 runtime: ToolRuntime,
1268 ) -> Result<Value, SynapticError>;
1269
1270 fn as_tool_definition(&self) -> ToolDefinition {
1271 ToolDefinition {
1272 name: self.name().to_string(),
1273 description: self.description().to_string(),
1274 parameters: self
1275 .parameters()
1276 .unwrap_or(serde_json::json!({"type": "object", "properties": {}})),
1277 extras: None,
1278 }
1279 }
1280}
1281
1282pub struct RuntimeAwareToolAdapter {
1287 inner: Arc<dyn RuntimeAwareTool>,
1288 runtime: Arc<tokio::sync::RwLock<Option<ToolRuntime>>>,
1289}
1290
1291impl RuntimeAwareToolAdapter {
1292 pub fn new(tool: Arc<dyn RuntimeAwareTool>) -> Self {
1293 Self {
1294 inner: tool,
1295 runtime: Arc::new(tokio::sync::RwLock::new(None)),
1296 }
1297 }
1298
1299 pub async fn set_runtime(&self, runtime: ToolRuntime) {
1300 *self.runtime.write().await = Some(runtime);
1301 }
1302}
1303
1304#[async_trait]
1305impl Tool for RuntimeAwareToolAdapter {
1306 fn name(&self) -> &'static str {
1307 self.inner.name()
1308 }
1309
1310 fn description(&self) -> &'static str {
1311 self.inner.description()
1312 }
1313
1314 fn parameters(&self) -> Option<Value> {
1315 self.inner.parameters()
1316 }
1317
1318 async fn call(&self, args: Value) -> Result<Value, SynapticError> {
1319 let runtime = self.runtime.read().await.clone().unwrap_or(ToolRuntime {
1320 store: None,
1321 stream_writer: None,
1322 state: None,
1323 tool_call_id: String::new(),
1324 config: None,
1325 });
1326 self.inner.call_with_runtime(args, runtime).await
1327 }
1328}
1329
1330#[derive(Debug, Clone)]
1336pub struct EntrypointConfig {
1337 pub name: &'static str,
1338 pub checkpointer: Option<&'static str>,
1339}
1340
1341pub struct Entrypoint {
1345 pub config: EntrypointConfig,
1346 pub invoke_fn: Box<
1347 dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<Value, SynapticError>> + Send>>
1348 + Send
1349 + Sync,
1350 >,
1351}
1352
1353impl Entrypoint {
1354 pub async fn invoke(&self, input: Value) -> Result<Value, SynapticError> {
1355 (self.invoke_fn)(input).await
1356 }
1357}