semantic_commands/embedders/
openai.rs

1use std::collections::HashMap;
2
3use anyhow::{Result, bail};
4use log::debug;
5use serde::Deserialize;
6
7use crate::Embedder;
8
9const BASE_URL: &str = "https://api.openai.com/v1/";
10
11#[derive(Deserialize, Debug)]
12pub struct ErrorDescription {
13	pub message: String,
14	// pub r#type: String,
15	// pub code: String,
16}
17
18#[derive(Deserialize, Debug)]
19pub struct ErrorResponse {
20	pub error: ErrorDescription,
21}
22
23#[derive(Deserialize, Debug)]
24pub struct SuccessResponse<D> {
25	// object: String,
26	data: Vec<D>,
27	// model: String,
28	// usage: ResponseUsage,
29}
30
31#[derive(Deserialize, Debug)]
32pub struct EmbeddingResponse {
33	// object: String,
34	embedding: Vec<f32>,
35	// index: u32,
36}
37
38#[derive(Deserialize, Debug)]
39struct ResponseUsage {
40	// prompt_tokens: u32,
41	// total_tokens: u32,
42}
43
44#[derive(Deserialize, Debug)]
45#[serde(untagged)]
46pub enum OpenAIResponse<D> {
47	Success(SuccessResponse<D>),
48	Error(ErrorResponse),
49}
50
51pub struct OpenAIEmbedder {
52	pub token: String,
53}
54
55#[async_trait::async_trait]
56impl Embedder for OpenAIEmbedder {
57	async fn embed(&self, input: &str) -> Result<Vec<f32>> {
58		let client = reqwest::Client::new();
59		let map: HashMap<&str, &str> = HashMap::from_iter(vec![("input", input), ("model", "text-embedding-ada-002")]);
60
61		debug!("fetching embedding from openai for phrase: {input}...");
62		let response = client
63			.post(format!("{BASE_URL}/embeddings"))
64			.bearer_auth(self.token.clone())
65			.json(&map)
66			.send()
67			.await?
68			.json::<OpenAIResponse<EmbeddingResponse>>()
69			.await?;
70		match response {
71			OpenAIResponse::Error(error_response) => {
72				bail!(error_response.error.message)
73			}
74			OpenAIResponse::Success(embdding_response) => {
75				let asd = embdding_response.data.first().unwrap();
76				Ok(asd.embedding.clone())
77			}
78		}
79	}
80}