Skip to main content

wesichain_core/
llm.rs

1use serde::{de::DeserializeOwned, Deserialize, Serialize};
2
3use crate::{TokenUsage, Value, WesichainError};
4use async_trait::async_trait;
5
6#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
7#[serde(rename_all = "lowercase")]
8pub enum Role {
9    System,
10    User,
11    Assistant,
12    Tool,
13}
14
15#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
16#[serde(untagged)]
17pub enum MessageContent {
18    Text(String),
19    Parts(Vec<ContentPart>),
20}
21
22#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
23#[serde(tag = "type", rename_all = "snake_case")]
24pub enum ContentPart {
25    Text { text: String },
26    ImageUrl { url: String, detail: Option<String> },
27    ImageData { data: String, media_type: String },
28}
29
30impl From<String> for MessageContent {
31    fn from(s: String) -> Self {
32        Self::Text(s)
33    }
34}
35
36impl From<&str> for MessageContent {
37    fn from(s: &str) -> Self {
38        Self::Text(s.to_string())
39    }
40}
41
42impl Default for MessageContent {
43    fn default() -> Self {
44        Self::Text(String::new())
45    }
46}
47
48impl std::fmt::Display for MessageContent {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        write!(f, "{}", self.to_text_lossy())
51    }
52}
53
54impl MessageContent {
55    pub fn as_text(&self) -> Option<&str> {
56        match self {
57            MessageContent::Text(s) => Some(s.as_str()),
58            MessageContent::Parts(_) => None,
59        }
60    }
61
62    pub fn to_text_lossy(&self) -> String {
63        match self {
64            MessageContent::Text(s) => s.clone(),
65            MessageContent::Parts(parts) => parts
66                .iter()
67                .filter_map(|p| {
68                    if let ContentPart::Text { text } = p {
69                        Some(text.as_str())
70                    } else {
71                        None
72                    }
73                })
74                .collect::<Vec<_>>()
75                .join(""),
76        }
77    }
78
79    pub fn is_empty(&self) -> bool {
80        match self {
81            MessageContent::Text(s) => s.is_empty(),
82            MessageContent::Parts(parts) => parts.is_empty(),
83        }
84    }
85}
86
87#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
88pub struct Message {
89    pub role: Role,
90    pub content: MessageContent,
91    #[serde(skip_serializing_if = "Option::is_none")]
92    pub tool_call_id: Option<String>,
93    #[serde(default, skip_serializing_if = "Vec::is_empty")]
94    pub tool_calls: Vec<ToolCall>,
95}
96
97impl Message {
98    pub fn user(content: impl Into<MessageContent>) -> Self {
99        Self { role: Role::User, content: content.into(), tool_call_id: None, tool_calls: vec![] }
100    }
101
102    pub fn system(content: impl Into<MessageContent>) -> Self {
103        Self { role: Role::System, content: content.into(), tool_call_id: None, tool_calls: vec![] }
104    }
105
106    pub fn assistant(content: impl Into<MessageContent>) -> Self {
107        Self { role: Role::Assistant, content: content.into(), tool_call_id: None, tool_calls: vec![] }
108    }
109
110    pub fn with_image_url(mut self, url: impl Into<String>, detail: Option<String>) -> Self {
111        let parts = match self.content {
112            MessageContent::Text(t) if !t.is_empty() => vec![
113                ContentPart::Text { text: t },
114                ContentPart::ImageUrl { url: url.into(), detail },
115            ],
116            MessageContent::Text(_) => vec![ContentPart::ImageUrl { url: url.into(), detail }],
117            MessageContent::Parts(mut parts) => {
118                parts.push(ContentPart::ImageUrl { url: url.into(), detail });
119                parts
120            }
121        };
122        self.content = MessageContent::Parts(parts);
123        self
124    }
125
126    pub fn with_image_data(mut self, data: impl Into<String>, media_type: impl Into<String>) -> Self {
127        let parts = match self.content {
128            MessageContent::Text(t) if !t.is_empty() => vec![
129                ContentPart::Text { text: t },
130                ContentPart::ImageData { data: data.into(), media_type: media_type.into() },
131            ],
132            MessageContent::Text(_) => vec![ContentPart::ImageData { data: data.into(), media_type: media_type.into() }],
133            MessageContent::Parts(mut parts) => {
134                parts.push(ContentPart::ImageData { data: data.into(), media_type: media_type.into() });
135                parts
136            }
137        };
138        self.content = MessageContent::Parts(parts);
139        self
140    }
141}
142
143#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
144pub struct ToolSpec {
145    pub name: String,
146    pub description: String,
147    pub parameters: Value,
148}
149
150#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
151pub struct ToolCall {
152    pub id: String,
153    pub name: String,
154    pub args: Value,
155}
156
157#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
158pub struct LlmRequest {
159    pub model: String,
160    pub messages: Vec<Message>,
161    #[serde(default, skip_serializing_if = "Vec::is_empty")]
162    pub tools: Vec<ToolSpec>,
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub temperature: Option<f32>,
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub max_tokens: Option<u32>,
167    #[serde(default, skip_serializing_if = "Vec::is_empty")]
168    pub stop_sequences: Vec<String>,
169}
170
171#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
172pub struct LlmResponse {
173    pub content: String,
174    #[serde(default, skip_serializing_if = "Vec::is_empty")]
175    pub tool_calls: Vec<ToolCall>,
176    #[serde(skip_serializing_if = "Option::is_none")]
177    pub usage: Option<TokenUsage>,
178    /// Model identifier returned by the provider (e.g. `"claude-3-5-sonnet-20241022"`).
179    /// Empty string when the provider did not return a model name.
180    #[serde(default, skip_serializing_if = "String::is_empty")]
181    pub model: String,
182}
183
184#[async_trait]
185pub trait ToolCallingLlm: crate::Runnable<LlmRequest, LlmResponse> + Send + Sync + 'static {}
186
187impl crate::Bindable for LlmRequest {
188    fn bind(&mut self, args: Value) -> Result<(), WesichainError> {
189        if let Some(obj) = args.as_object() {
190            if let Some(tools_val) = obj.get("tools") {
191                let tools: Vec<ToolSpec> =
192                    serde_json::from_value(tools_val.clone()).map_err(WesichainError::Serde)?;
193                self.tools.extend(tools);
194            }
195        }
196        Ok(())
197    }
198}
199
200pub trait ToolCallingLlmExt: ToolCallingLlm {
201    fn with_structured_output<T>(self) -> impl crate::Runnable<LlmRequest, T>
202    where
203        T: schemars::JsonSchema + DeserializeOwned + Serialize + Send + Sync + 'static,
204        Self: Sized,
205    {
206        use crate::{RunnableExt, StructuredOutputParser};
207
208        let schema = schemars::schema_for!(T);
209        let as_value = serde_json::to_value(schema).unwrap_or(Value::Null);
210
211        let tool_spec = ToolSpec {
212            name: "output_formatter".to_string(),
213            description: "Output the result in this format".to_string(),
214            parameters: as_value,
215        };
216
217        let bound = self.bind(serde_json::json!({
218            "tools": [tool_spec]
219        }));
220
221        bound.then(StructuredOutputParser::<T>::new())
222    }
223}
224
225impl<L> ToolCallingLlmExt for L where L: ToolCallingLlm {}