swiftide_core/chat_completion/
chat_completion_response.rs

1use std::collections::HashMap;
2
3use derive_builder::Builder;
4use serde::{Deserialize, Serialize};
5use uuid::Uuid;
6
7use super::{ToolCallBuilder, tools::ToolCall};
8
9/// A generic response from chat completions
10///
11/// When streaming, the delta is available. Every response will have the accumulated message if
12/// present. The final message will also have the final tool calls.
13#[derive(Clone, Builder, Debug, Serialize, Deserialize, PartialEq)]
14#[builder(setter(strip_option, into), build_fn(error = anyhow::Error))]
15pub struct ChatCompletionResponse {
16    /// An identifier for the response
17    ///
18    /// Useful when streaming to make sure chunks can be mapped to the right response
19    #[builder(private, default = Uuid::new_v4())]
20    pub id: Uuid,
21
22    #[builder(default)]
23    pub message: Option<String>,
24
25    #[builder(default)]
26    pub tool_calls: Option<Vec<ToolCall>>,
27
28    #[builder(default)]
29    pub usage: Option<Usage>,
30
31    /// Streaming response
32    #[builder(default)]
33    pub delta: Option<ChatCompletionResponseDelta>,
34}
35
36impl Default for ChatCompletionResponse {
37    fn default() -> Self {
38        Self {
39            id: Uuid::new_v4(),
40            message: None,
41            tool_calls: None,
42            delta: None,
43            usage: None,
44        }
45    }
46}
47
48#[derive(Clone, Builder, Debug, Serialize, Deserialize, PartialEq)]
49pub struct Usage {
50    pub prompt_tokens: u32,
51    pub completion_tokens: u32,
52    pub total_tokens: u32,
53}
54
55impl Usage {
56    pub fn builder() -> UsageBuilder {
57        UsageBuilder::default()
58    }
59}
60
61#[derive(Clone, Builder, Debug, Serialize, Deserialize, PartialEq)]
62pub struct ChatCompletionResponseDelta {
63    pub message_chunk: Option<String>,
64
65    // These are not public as the assumption is they are not usable
66    // until the tool calls are valid
67    tool_calls_chunk: Option<HashMap<usize, ToolCallAccum>>,
68}
69
70// Accumulator for streamed tool calls
71#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
72struct ToolCallAccum {
73    id: Option<String>,
74    name: Option<String>,
75    arguments: Option<String>,
76}
77
78impl ChatCompletionResponse {
79    pub fn builder() -> ChatCompletionResponseBuilder {
80        ChatCompletionResponseBuilder::default()
81    }
82
83    pub fn message(&self) -> Option<&str> {
84        self.message.as_deref()
85    }
86
87    pub fn tool_calls(&self) -> Option<&[ToolCall]> {
88        self.tool_calls.as_deref()
89    }
90
91    /// Adds a streaming chunk to the message and also the delta
92    pub fn append_message_delta(&mut self, message_delta: Option<&str>) -> &mut Self {
93        // let message: Option<String> = message;
94        let Some(message_delta) = message_delta else {
95            return self;
96        };
97
98        if let Some(delta) = &mut self.delta {
99            delta.message_chunk = Some(message_delta.to_string());
100        } else {
101            self.delta = Some(ChatCompletionResponseDelta {
102                message_chunk: Some(message_delta.to_string()),
103                tool_calls_chunk: None,
104            });
105        }
106
107        self.message
108            .as_mut()
109            .map(|m| {
110                m.push_str(message_delta);
111            })
112            .unwrap_or_else(|| {
113                self.message = Some(message_delta.to_string());
114            });
115        self
116    }
117
118    /// Adds a streaming chunk to the tool calls, if it can be build, the tool call will be build,
119    /// otherwise it will remain in the delta and retried on the next call
120    pub fn append_tool_call_delta(
121        &mut self,
122        index: usize,
123        id: Option<&str>,
124        name: Option<&str>,
125        arguments: Option<&str>,
126    ) -> &mut Self {
127        if let Some(delta) = &mut self.delta {
128            let map = delta.tool_calls_chunk.get_or_insert_with(HashMap::new);
129            map.entry(index)
130                .and_modify(|v| {
131                    if v.id.is_none() {
132                        v.id = id.map(Into::into);
133                    }
134                    if v.name.is_none() {
135                        v.name = name.map(Into::into);
136                    }
137                    if let Some(v) = v.arguments.as_mut() {
138                        if let Some(arguments) = arguments {
139                            v.push_str(arguments);
140                        }
141                    } else {
142                        v.arguments = arguments.map(Into::into);
143                    }
144                })
145                .or_insert(ToolCallAccum {
146                    id: id.map(Into::into),
147                    name: name.map(Into::into),
148                    arguments: arguments.map(Into::into),
149                });
150        } else {
151            self.delta = Some(ChatCompletionResponseDelta {
152                message_chunk: None,
153                tool_calls_chunk: Some(HashMap::from([(
154                    index,
155                    ToolCallAccum {
156                        id: id.map(Into::into),
157                        name: name.map(Into::into),
158                        arguments: arguments.map(Into::into),
159                    },
160                )])),
161            });
162        }
163
164        // Now let's try to rebuild _every_ tool call and overwrite
165        // Performance wise very meh but it works, in reality it's only a couple of tool calls most
166        self.finalize_tools_from_stream();
167
168        self
169    }
170
171    pub fn append_usage_delta(
172        &mut self,
173        prompt_tokens: u32,
174        completion_tokens: u32,
175        total_tokens: u32,
176    ) -> &mut Self {
177        debug_assert!(prompt_tokens + completion_tokens == total_tokens);
178
179        if let Some(usage) = &mut self.usage {
180            usage.prompt_tokens += prompt_tokens;
181            usage.completion_tokens += completion_tokens;
182            usage.total_tokens += total_tokens;
183        } else {
184            self.usage = Some(Usage {
185                prompt_tokens,
186                completion_tokens,
187                total_tokens,
188            });
189        }
190        self
191    }
192
193    fn finalize_tools_from_stream(&mut self) {
194        if let Some(values) = self
195            .delta
196            .as_ref()
197            .and_then(|d| d.tool_calls_chunk.as_ref().map(|t| t.values()))
198        {
199            let maybe_tool_calls = values
200                .filter_map(|maybe_tool_call| {
201                    ToolCallBuilder::default()
202                        .maybe_id(maybe_tool_call.id.clone())
203                        .maybe_name(maybe_tool_call.name.clone())
204                        .maybe_args(maybe_tool_call.arguments.clone())
205                        .build()
206                        .ok()
207                })
208                .collect::<Vec<_>>();
209
210            if !maybe_tool_calls.is_empty() {
211                self.tool_calls = Some(maybe_tool_calls);
212            }
213        }
214    }
215}
216
217impl ChatCompletionResponseBuilder {
218    pub fn maybe_message<T: Into<Option<String>>>(&mut self, message: T) -> &mut Self {
219        self.message = Some(message.into());
220        self
221    }
222
223    pub fn maybe_tool_calls<T: Into<Option<Vec<ToolCall>>>>(&mut self, tool_calls: T) -> &mut Self {
224        self.tool_calls = Some(tool_calls.into());
225        self
226    }
227}