1use serde::{de::DeserializeOwned, Deserialize, Serialize};
2
3use crate::{TokenUsage, Value, WesichainError};
4use async_trait::async_trait;
5
6#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
7#[serde(rename_all = "lowercase")]
8pub enum Role {
9 System,
10 User,
11 Assistant,
12 Tool,
13}
14
15#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
16#[serde(untagged)]
17pub enum MessageContent {
18 Text(String),
19 Parts(Vec<ContentPart>),
20}
21
22#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
23#[serde(tag = "type", rename_all = "snake_case")]
24pub enum ContentPart {
25 Text { text: String },
26 ImageUrl { url: String, detail: Option<String> },
27 ImageData { data: String, media_type: String },
28}
29
30impl From<String> for MessageContent {
31 fn from(s: String) -> Self {
32 Self::Text(s)
33 }
34}
35
36impl From<&str> for MessageContent {
37 fn from(s: &str) -> Self {
38 Self::Text(s.to_string())
39 }
40}
41
42impl Default for MessageContent {
43 fn default() -> Self {
44 Self::Text(String::new())
45 }
46}
47
48impl std::fmt::Display for MessageContent {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 write!(f, "{}", self.to_text_lossy())
51 }
52}
53
54impl MessageContent {
55 pub fn as_text(&self) -> Option<&str> {
56 match self {
57 MessageContent::Text(s) => Some(s.as_str()),
58 MessageContent::Parts(_) => None,
59 }
60 }
61
62 pub fn to_text_lossy(&self) -> String {
63 match self {
64 MessageContent::Text(s) => s.clone(),
65 MessageContent::Parts(parts) => parts
66 .iter()
67 .filter_map(|p| {
68 if let ContentPart::Text { text } = p {
69 Some(text.as_str())
70 } else {
71 None
72 }
73 })
74 .collect::<Vec<_>>()
75 .join(""),
76 }
77 }
78
79 pub fn is_empty(&self) -> bool {
80 match self {
81 MessageContent::Text(s) => s.is_empty(),
82 MessageContent::Parts(parts) => parts.is_empty(),
83 }
84 }
85}
86
87#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
88pub struct Message {
89 pub role: Role,
90 pub content: MessageContent,
91 #[serde(skip_serializing_if = "Option::is_none")]
92 pub tool_call_id: Option<String>,
93 #[serde(default, skip_serializing_if = "Vec::is_empty")]
94 pub tool_calls: Vec<ToolCall>,
95}
96
97impl Message {
98 pub fn user(content: impl Into<MessageContent>) -> Self {
99 Self { role: Role::User, content: content.into(), tool_call_id: None, tool_calls: vec![] }
100 }
101
102 pub fn system(content: impl Into<MessageContent>) -> Self {
103 Self { role: Role::System, content: content.into(), tool_call_id: None, tool_calls: vec![] }
104 }
105
106 pub fn assistant(content: impl Into<MessageContent>) -> Self {
107 Self { role: Role::Assistant, content: content.into(), tool_call_id: None, tool_calls: vec![] }
108 }
109
110 pub fn with_image_url(mut self, url: impl Into<String>, detail: Option<String>) -> Self {
111 let parts = match self.content {
112 MessageContent::Text(t) if !t.is_empty() => vec![
113 ContentPart::Text { text: t },
114 ContentPart::ImageUrl { url: url.into(), detail },
115 ],
116 MessageContent::Text(_) => vec![ContentPart::ImageUrl { url: url.into(), detail }],
117 MessageContent::Parts(mut parts) => {
118 parts.push(ContentPart::ImageUrl { url: url.into(), detail });
119 parts
120 }
121 };
122 self.content = MessageContent::Parts(parts);
123 self
124 }
125
126 pub fn with_image_data(mut self, data: impl Into<String>, media_type: impl Into<String>) -> Self {
127 let parts = match self.content {
128 MessageContent::Text(t) if !t.is_empty() => vec![
129 ContentPart::Text { text: t },
130 ContentPart::ImageData { data: data.into(), media_type: media_type.into() },
131 ],
132 MessageContent::Text(_) => vec![ContentPart::ImageData { data: data.into(), media_type: media_type.into() }],
133 MessageContent::Parts(mut parts) => {
134 parts.push(ContentPart::ImageData { data: data.into(), media_type: media_type.into() });
135 parts
136 }
137 };
138 self.content = MessageContent::Parts(parts);
139 self
140 }
141}
142
143#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
144pub struct ToolSpec {
145 pub name: String,
146 pub description: String,
147 pub parameters: Value,
148}
149
150#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
151pub struct ToolCall {
152 pub id: String,
153 pub name: String,
154 pub args: Value,
155}
156
157#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
158pub struct LlmRequest {
159 pub model: String,
160 pub messages: Vec<Message>,
161 #[serde(default, skip_serializing_if = "Vec::is_empty")]
162 pub tools: Vec<ToolSpec>,
163 #[serde(skip_serializing_if = "Option::is_none")]
164 pub temperature: Option<f32>,
165 #[serde(skip_serializing_if = "Option::is_none")]
166 pub max_tokens: Option<u32>,
167 #[serde(default, skip_serializing_if = "Vec::is_empty")]
168 pub stop_sequences: Vec<String>,
169}
170
171#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
172pub struct LlmResponse {
173 pub content: String,
174 #[serde(default, skip_serializing_if = "Vec::is_empty")]
175 pub tool_calls: Vec<ToolCall>,
176 #[serde(skip_serializing_if = "Option::is_none")]
177 pub usage: Option<TokenUsage>,
178 #[serde(default, skip_serializing_if = "String::is_empty")]
181 pub model: String,
182}
183
184#[async_trait]
185pub trait ToolCallingLlm: crate::Runnable<LlmRequest, LlmResponse> + Send + Sync + 'static {}
186
187impl crate::Bindable for LlmRequest {
188 fn bind(&mut self, args: Value) -> Result<(), WesichainError> {
189 if let Some(obj) = args.as_object() {
190 if let Some(tools_val) = obj.get("tools") {
191 let tools: Vec<ToolSpec> =
192 serde_json::from_value(tools_val.clone()).map_err(WesichainError::Serde)?;
193 self.tools.extend(tools);
194 }
195 }
196 Ok(())
197 }
198}
199
200pub trait ToolCallingLlmExt: ToolCallingLlm {
201 fn with_structured_output<T>(self) -> impl crate::Runnable<LlmRequest, T>
202 where
203 T: schemars::JsonSchema + DeserializeOwned + Serialize + Send + Sync + 'static,
204 Self: Sized,
205 {
206 use crate::{RunnableExt, StructuredOutputParser};
207
208 let schema = schemars::schema_for!(T);
209 let as_value = serde_json::to_value(schema).unwrap_or(Value::Null);
210
211 let tool_spec = ToolSpec {
212 name: "output_formatter".to_string(),
213 description: "Output the result in this format".to_string(),
214 parameters: as_value,
215 };
216
217 let bound = self.bind(serde_json::json!({
218 "tools": [tool_spec]
219 }));
220
221 bound.then(StructuredOutputParser::<T>::new())
222 }
223}
224
225impl<L> ToolCallingLlmExt for L where L: ToolCallingLlm {}