1use schemars::JsonSchema;
35use serde::de::DeserializeOwned;
36
37use crate::LlmError;
38use crate::provider::{LlmProvider, Message, Role};
39
40pub struct Extractor<'a, P: LlmProvider> {
44 provider: &'a P,
45 preamble: Option<String>,
46}
47
48impl<'a, P: LlmProvider> Extractor<'a, P> {
49 pub fn new(provider: &'a P) -> Self {
51 Self {
52 provider,
53 preamble: None,
54 }
55 }
56
57 #[must_use]
59 pub fn with_preamble(mut self, preamble: impl Into<String>) -> Self {
60 self.preamble = Some(preamble.into());
61 self
62 }
63
64 pub async fn extract<T>(&self, input: &str) -> Result<T, LlmError>
73 where
74 T: DeserializeOwned + JsonSchema + 'static,
75 {
76 let mut messages = Vec::new();
77 if let Some(ref preamble) = self.preamble {
78 messages.push(Message::from_legacy(Role::System, preamble.clone()));
79 }
80 messages.push(Message::from_legacy(Role::User, input));
81 self.provider.chat_typed::<T>(&messages).await
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use crate::provider::{ChatStream, LlmProvider, Message};
89
90 struct StubProvider {
91 response: String,
92 }
93
94 impl LlmProvider for StubProvider {
95 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
96 Ok(self.response.clone())
97 }
98
99 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
100 let response = self.chat(messages).await?;
101 Ok(Box::pin(tokio_stream::once(Ok(
102 crate::StreamChunk::Content(response),
103 ))))
104 }
105
106 fn supports_streaming(&self) -> bool {
107 false
108 }
109
110 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
111 Err(LlmError::EmbedUnsupported {
112 provider: "stub".into(),
113 })
114 }
115
116 fn supports_embeddings(&self) -> bool {
117 false
118 }
119
120 fn name(&self) -> &'static str {
121 "stub"
122 }
123 }
124
125 #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)]
126 struct TestOutput {
127 value: String,
128 }
129
130 #[tokio::test]
131 async fn extract_without_preamble() {
132 let provider = StubProvider {
133 response: r#"{"value": "result"}"#.into(),
134 };
135 let extractor = Extractor::new(&provider);
136 let result: TestOutput = extractor.extract("test input").await.unwrap();
137 assert_eq!(
138 result,
139 TestOutput {
140 value: "result".into()
141 }
142 );
143 }
144
145 #[tokio::test]
146 async fn extract_with_preamble() {
147 let provider = StubProvider {
148 response: r#"{"value": "with_preamble"}"#.into(),
149 };
150 let extractor = Extractor::new(&provider).with_preamble("Analyze this");
151 let result: TestOutput = extractor.extract("test input").await.unwrap();
152 assert_eq!(
153 result,
154 TestOutput {
155 value: "with_preamble".into()
156 }
157 );
158 }
159
160 #[tokio::test]
161 async fn extract_error_propagation() {
162 struct FailProvider;
163
164 impl LlmProvider for FailProvider {
165 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
166 Err(LlmError::Unavailable)
167 }
168
169 async fn chat_stream(&self, _messages: &[Message]) -> Result<ChatStream, LlmError> {
170 Err(LlmError::Unavailable)
171 }
172
173 fn supports_streaming(&self) -> bool {
174 false
175 }
176
177 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
178 Err(LlmError::Unavailable)
179 }
180
181 fn supports_embeddings(&self) -> bool {
182 false
183 }
184
185 fn name(&self) -> &'static str {
186 "fail"
187 }
188 }
189
190 let provider = FailProvider;
191 let extractor = Extractor::new(&provider);
192 let result = extractor.extract::<TestOutput>("test").await;
193 assert!(matches!(result, Err(LlmError::Unavailable)));
194 }
195}