Skip to main content

synth_ai/
inference_api.rs

1use std::sync::Arc;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use crate::models::JsonMap;
7use crate::openapi_paths;
8use crate::transport::Transport;
9use crate::types::Result;
10
11#[derive(Debug, Clone, Serialize, Deserialize, Default)]
12pub struct ChatCompletionRequest {
13    pub model: String,
14    #[serde(default)]
15    pub messages: Vec<Value>,
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub max_tokens: Option<u32>,
18    #[serde(flatten)]
19    pub extra: JsonMap,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, Default)]
23pub struct ChatCompletionResponse {
24    #[serde(default)]
25    pub id: Option<String>,
26    #[serde(default)]
27    pub choices: Vec<Value>,
28    #[serde(default)]
29    pub usage: Option<Value>,
30    #[serde(flatten)]
31    pub extra: JsonMap,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize, Default)]
35pub struct InferenceJobCreateRequest {
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub model: Option<String>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub input: Option<Value>,
40    #[serde(flatten)]
41    pub extra: JsonMap,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, Default)]
45pub struct InferenceJob {
46    #[serde(default)]
47    pub job_id: Option<String>,
48    #[serde(default)]
49    pub id: Option<String>,
50    #[serde(default)]
51    pub status: Option<String>,
52    #[serde(flatten)]
53    pub extra: JsonMap,
54}
55
56#[derive(Clone)]
57pub struct InferenceClient {
58    transport: Arc<Transport>,
59}
60
61impl InferenceClient {
62    pub(crate) fn new(transport: Arc<Transport>) -> Self {
63        Self { transport }
64    }
65
66    pub fn chat(&self) -> InferenceChatClient {
67        InferenceChatClient {
68            transport: self.transport.clone(),
69        }
70    }
71
72    pub fn jobs(&self) -> InferenceJobsClient {
73        InferenceJobsClient {
74            transport: self.transport.clone(),
75        }
76    }
77}
78
79#[derive(Clone)]
80pub struct InferenceChatClient {
81    transport: Arc<Transport>,
82}
83
84impl InferenceChatClient {
85    pub fn completions(&self) -> InferenceChatCompletionsClient {
86        InferenceChatCompletionsClient {
87            transport: self.transport.clone(),
88        }
89    }
90}
91
92#[derive(Clone)]
93pub struct InferenceChatCompletionsClient {
94    transport: Arc<Transport>,
95}
96
97impl InferenceChatCompletionsClient {
98    pub async fn create(&self, request: &ChatCompletionRequest) -> Result<ChatCompletionResponse> {
99        self.transport
100            .post_json(openapi_paths::API_INFERENCE_CHAT_COMPLETIONS, request)
101            .await
102    }
103}
104
105#[derive(Clone)]
106pub struct InferenceJobsClient {
107    transport: Arc<Transport>,
108}
109
110impl InferenceJobsClient {
111    pub async fn create(&self, request: &InferenceJobCreateRequest) -> Result<InferenceJob> {
112        self.transport
113            .post_json(openapi_paths::API_INFERENCE_JOBS, request)
114            .await
115    }
116
117    pub async fn get(&self, job_id: &str) -> Result<InferenceJob> {
118        self.transport
119            .get_json(&openapi_paths::api_inference_jobs_job(job_id))
120            .await
121    }
122
123    pub async fn artifact(&self, job_id: &str, artifact_id: &str) -> Result<Vec<u8>> {
124        self.transport
125            .get_bytes(&openapi_paths::inference_job_artifact(job_id, artifact_id))
126            .await
127    }
128}