skill_runtime/embeddings/
openai.rs1use super::{EmbeddingProvider, OpenAIEmbeddingModel};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use rig::embeddings::EmbeddingModel as RigEmbeddingModel;
10use rig::client::{EmbeddingsClient, ProviderClient};
11use rig::providers::openai::{self, Client as OpenAIClient};
12use std::sync::Arc;
13
14pub struct OpenAIEmbedProvider {
19 client: Arc<OpenAIClient>,
20 model: OpenAIEmbeddingModel,
21 dims: usize,
22}
23
24impl OpenAIEmbedProvider {
25 pub fn new() -> Result<Self> {
30 Self::with_model(OpenAIEmbeddingModel::default())
31 }
32
33 pub fn with_model(model: OpenAIEmbeddingModel) -> Result<Self> {
35 std::env::var("OPENAI_API_KEY").context(
37 "OPENAI_API_KEY environment variable not set. Set it with: export OPENAI_API_KEY=your-key-here"
38 )?;
39
40 let client = Arc::new(OpenAIClient::from_env());
41 let dims = model.dimensions();
42
43 Ok(Self {
44 client,
45 model,
46 dims,
47 })
48 }
49
50 pub fn with_api_key(api_key: &str, model: OpenAIEmbeddingModel) -> Result<Self> {
52 let client = Arc::new(OpenAIClient::new(api_key).context("Failed to create OpenAI client")?);
53 let dims = model.dimensions();
54
55 Ok(Self {
56 client,
57 model,
58 dims,
59 })
60 }
61
62 pub fn from_model_name(name: &str) -> Result<Self> {
64 let model: OpenAIEmbeddingModel = name.parse()?;
65 Self::with_model(model)
66 }
67
68 fn api_model_name(&self) -> &'static str {
70 match self.model {
71 OpenAIEmbeddingModel::Ada002 => openai::TEXT_EMBEDDING_ADA_002,
72 OpenAIEmbeddingModel::TextEmbedding3Small => "text-embedding-3-small",
73 OpenAIEmbeddingModel::TextEmbedding3Large => "text-embedding-3-large",
74 }
75 }
76
77}
78
79#[async_trait]
80impl EmbeddingProvider for OpenAIEmbedProvider {
81 async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
82 if texts.is_empty() {
83 return Ok(Vec::new());
84 }
85
86 let embedding_model = self.client.embedding_model(self.api_model_name());
87
88 let embeddings = embedding_model
90 .embed_texts(texts)
91 .await
92 .context("OpenAI failed to generate embeddings")?;
93
94 let results: Vec<Vec<f32>> = embeddings
96 .into_iter()
97 .map(|emb| emb.vec.into_iter().map(|x| x as f32).collect())
98 .collect();
99
100 Ok(results)
101 }
102
103 fn dimensions(&self) -> usize {
104 self.dims
105 }
106
107 fn model_name(&self) -> &str {
108 self.api_model_name()
109 }
110
111 fn provider_name(&self) -> &str {
112 "openai"
113 }
114
115 fn max_batch_size(&self) -> usize {
116 2048
118 }
119
120 async fn health_check(&self) -> Result<bool> {
121 match self.embed_query("test").await {
123 Ok(emb) => Ok(emb.len() == self.dims),
124 Err(_) => Ok(false),
125 }
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn test_model_dimensions() {
135 assert_eq!(OpenAIEmbeddingModel::Ada002.dimensions(), 1536);
137 assert_eq!(OpenAIEmbeddingModel::TextEmbedding3Small.dimensions(), 1536);
138 assert_eq!(OpenAIEmbeddingModel::TextEmbedding3Large.dimensions(), 3072);
139 }
140
141 #[test]
142 fn test_api_model_names() {
143 assert_eq!(OpenAIEmbeddingModel::Ada002.api_name(), "text-embedding-ada-002");
144 assert_eq!(OpenAIEmbeddingModel::TextEmbedding3Small.api_name(), "text-embedding-3-small");
145 assert_eq!(OpenAIEmbeddingModel::TextEmbedding3Large.api_name(), "text-embedding-3-large");
146 }
147
148 #[tokio::test]
150 #[ignore = "requires OPENAI_API_KEY"]
151 async fn test_embed_documents() {
152 let provider = OpenAIEmbedProvider::new().unwrap();
153 let texts = vec![
154 "Hello world".to_string(),
155 "How are you".to_string(),
156 ];
157
158 let embeddings = provider.embed_documents(texts).await.unwrap();
159 assert_eq!(embeddings.len(), 2);
160 assert_eq!(embeddings[0].len(), provider.dimensions());
161 }
162
163 #[test]
164 fn test_missing_api_key() {
165 let original = std::env::var("OPENAI_API_KEY").ok();
167 std::env::remove_var("OPENAI_API_KEY");
168
169 let result = OpenAIEmbedProvider::new();
170 assert!(result.is_err());
171
172 if let Some(key) = original {
174 std::env::set_var("OPENAI_API_KEY", key);
175 }
176 }
177}