semantic_commands/embedders/
openai.rs1use std::collections::HashMap;
2
3use anyhow::{Result, bail};
4use log::debug;
5use serde::Deserialize;
6
7use crate::Embedder;
8
9const BASE_URL: &str = "https://api.openai.com/v1/";
10
11#[derive(Deserialize, Debug)]
12pub struct ErrorDescription {
13 pub message: String,
14 }
17
18#[derive(Deserialize, Debug)]
19pub struct ErrorResponse {
20 pub error: ErrorDescription,
21}
22
23#[derive(Deserialize, Debug)]
24pub struct SuccessResponse<D> {
25 data: Vec<D>,
27 }
30
31#[derive(Deserialize, Debug)]
32pub struct EmbeddingResponse {
33 embedding: Vec<f32>,
35 }
37
38#[derive(Deserialize, Debug)]
39struct ResponseUsage {
40 }
43
44#[derive(Deserialize, Debug)]
45#[serde(untagged)]
46pub enum OpenAIResponse<D> {
47 Success(SuccessResponse<D>),
48 Error(ErrorResponse),
49}
50
51pub struct OpenAIEmbedder {
52 pub token: String,
53}
54
55#[async_trait::async_trait]
56impl Embedder for OpenAIEmbedder {
57 async fn embed(&self, input: &str) -> Result<Vec<f32>> {
58 let client = reqwest::Client::new();
59 let map: HashMap<&str, &str> = HashMap::from_iter(vec![("input", input), ("model", "text-embedding-ada-002")]);
60
61 debug!("fetching embedding from openai for phrase: {input}...");
62 let response = client
63 .post(format!("{BASE_URL}/embeddings"))
64 .bearer_auth(self.token.clone())
65 .json(&map)
66 .send()
67 .await?
68 .json::<OpenAIResponse<EmbeddingResponse>>()
69 .await?;
70 match response {
71 OpenAIResponse::Error(error_response) => {
72 bail!(error_response.error.message)
73 }
74 OpenAIResponse::Success(embdding_response) => {
75 let asd = embdding_response.data.first().unwrap();
76 Ok(asd.embedding.clone())
77 }
78 }
79 }
80}