shuttle_openai/
lib.rs

1use async_openai::config::OpenAIConfig;
2use async_openai::Client;
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use shuttle_service::{CustomError, Error, IntoResource, ResourceFactory, ResourceInputBuilder};
6
7pub use async_openai;
8
9#[derive(Default, Serialize)]
10pub struct OpenAI {
11    api_base: Option<String>,
12    api_key: Option<String>,
13    org_id: Option<String>,
14    project_id: Option<String>,
15}
16
17impl OpenAI {
18    pub fn api_base(mut self, api_base: &str) -> Self {
19        self.api_base = Some(api_base.to_string());
20        self
21    }
22    pub fn api_key(mut self, api_key: &str) -> Self {
23        self.api_key = Some(api_key.to_string());
24        self
25    }
26    pub fn org_id(mut self, org_id: &str) -> Self {
27        self.org_id = Some(org_id.to_string());
28        self
29    }
30    pub fn project_id(mut self, project_id: &str) -> Self {
31        self.project_id = Some(project_id.to_string());
32        self
33    }
34}
35
36#[derive(Serialize, Deserialize)]
37pub struct Config {
38    api_base: Option<String>,
39    api_key: String,
40    org_id: Option<String>,
41    project_id: Option<String>,
42}
43
44#[async_trait]
45impl ResourceInputBuilder for OpenAI {
46    type Input = Config;
47    type Output = Config;
48
49    async fn build(self, _factory: &ResourceFactory) -> Result<Self::Input, Error> {
50        let api_key = self
51            .api_key
52            .ok_or(Error::Custom(CustomError::msg("Open AI API key required")))?;
53        let config = Config {
54            api_base: self.api_base,
55            api_key,
56            org_id: self.org_id,
57            project_id: self.project_id,
58        };
59        Ok(config)
60    }
61}
62
63#[async_trait]
64impl IntoResource<Client<OpenAIConfig>> for Config {
65    async fn into_resource(self) -> Result<Client<OpenAIConfig>, Error> {
66        let mut openai_config = OpenAIConfig::new().with_api_key(self.api_key);
67        if let Some(api_base) = self.api_base {
68            openai_config = openai_config.with_api_base(api_base)
69        }
70        if let Some(org_id) = self.org_id {
71            openai_config = openai_config.with_org_id(org_id)
72        }
73        if let Some(project_id) = self.project_id {
74            openai_config = openai_config.with_project_id(project_id)
75        }
76        Ok(Client::with_config(openai_config))
77    }
78}