1use schemars::JsonSchema;
5use serde::de::DeserializeOwned;
6
7use crate::LlmError;
8use crate::provider::{LlmProvider, Message, Role};
9
10pub struct Extractor<'a, P: LlmProvider> {
11 provider: &'a P,
12 preamble: Option<String>,
13}
14
15impl<'a, P: LlmProvider> Extractor<'a, P> {
16 pub fn new(provider: &'a P) -> Self {
17 Self {
18 provider,
19 preamble: None,
20 }
21 }
22
23 #[must_use]
24 pub fn with_preamble(mut self, preamble: impl Into<String>) -> Self {
25 self.preamble = Some(preamble.into());
26 self
27 }
28
29 pub async fn extract<T>(&self, input: &str) -> Result<T, LlmError>
33 where
34 T: DeserializeOwned + JsonSchema + 'static,
35 {
36 let mut messages = Vec::new();
37 if let Some(ref preamble) = self.preamble {
38 messages.push(Message::from_legacy(Role::System, preamble.clone()));
39 }
40 messages.push(Message::from_legacy(Role::User, input));
41 self.provider.chat_typed::<T>(&messages).await
42 }
43}
44
45#[cfg(test)]
46mod tests {
47 use super::*;
48 use crate::provider::{ChatStream, LlmProvider, Message};
49
50 struct StubProvider {
51 response: String,
52 }
53
54 impl LlmProvider for StubProvider {
55 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
56 Ok(self.response.clone())
57 }
58
59 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
60 let response = self.chat(messages).await?;
61 Ok(Box::pin(tokio_stream::once(Ok(
62 crate::StreamChunk::Content(response),
63 ))))
64 }
65
66 fn supports_streaming(&self) -> bool {
67 false
68 }
69
70 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
71 Err(LlmError::EmbedUnsupported {
72 provider: "stub".into(),
73 })
74 }
75
76 fn supports_embeddings(&self) -> bool {
77 false
78 }
79
80 fn name(&self) -> &'static str {
81 "stub"
82 }
83 }
84
85 #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)]
86 struct TestOutput {
87 value: String,
88 }
89
90 #[tokio::test]
91 async fn extract_without_preamble() {
92 let provider = StubProvider {
93 response: r#"{"value": "result"}"#.into(),
94 };
95 let extractor = Extractor::new(&provider);
96 let result: TestOutput = extractor.extract("test input").await.unwrap();
97 assert_eq!(
98 result,
99 TestOutput {
100 value: "result".into()
101 }
102 );
103 }
104
105 #[tokio::test]
106 async fn extract_with_preamble() {
107 let provider = StubProvider {
108 response: r#"{"value": "with_preamble"}"#.into(),
109 };
110 let extractor = Extractor::new(&provider).with_preamble("Analyze this");
111 let result: TestOutput = extractor.extract("test input").await.unwrap();
112 assert_eq!(
113 result,
114 TestOutput {
115 value: "with_preamble".into()
116 }
117 );
118 }
119
120 #[tokio::test]
121 async fn extract_error_propagation() {
122 struct FailProvider;
123
124 impl LlmProvider for FailProvider {
125 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
126 Err(LlmError::Unavailable)
127 }
128
129 async fn chat_stream(&self, _messages: &[Message]) -> Result<ChatStream, LlmError> {
130 Err(LlmError::Unavailable)
131 }
132
133 fn supports_streaming(&self) -> bool {
134 false
135 }
136
137 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
138 Err(LlmError::Unavailable)
139 }
140
141 fn supports_embeddings(&self) -> bool {
142 false
143 }
144
145 fn name(&self) -> &'static str {
146 "fail"
147 }
148 }
149
150 let provider = FailProvider;
151 let extractor = Extractor::new(&provider);
152 let result = extractor.extract::<TestOutput>("test").await;
153 assert!(matches!(result, Err(LlmError::Unavailable)));
154 }
155}