1use super::{
16 Error, JsonSnafu, ModelListParams, Schema, SchemaAllowCreate, SchemaAllowEdit, SchemaType,
17 SchemaView, SqlxSnafu, Status, format_datetime,
18};
19use serde::{Deserialize, Serialize};
20use snafu::ResultExt;
21use sqlx::FromRow;
22use sqlx::{Pool, Postgres, QueryBuilder};
23use std::collections::HashMap;
24use tibba_model::Model;
25use time::PrimitiveDateTime;
26
27type Result<T> = std::result::Result<T, Error>;
28
29#[derive(FromRow)]
30struct TokenPriceSchema {
31 id: i64,
32 service: String,
33 model: String,
34 input_price: i64,
35 output_price: i64,
36 fixed_price: i64,
37 unit_size: i32,
38 status: i16,
39 remark: String,
40 created: PrimitiveDateTime,
41 modified: PrimitiveDateTime,
42}
43
44#[derive(Debug, Clone, Deserialize, Serialize)]
45pub struct TokenPrice {
46 pub id: i64,
47 pub service: String,
48 pub model: String,
49 pub input_price: i64,
51 pub output_price: i64,
53 pub fixed_price: i64,
55 pub unit_size: i32,
57 pub status: i16,
58 pub remark: String,
59 pub created: String,
60 pub modified: String,
61}
62
63impl From<TokenPriceSchema> for TokenPrice {
64 fn from(s: TokenPriceSchema) -> Self {
65 Self {
66 id: s.id,
67 service: s.service,
68 model: s.model,
69 input_price: s.input_price,
70 output_price: s.output_price,
71 fixed_price: s.fixed_price,
72 unit_size: s.unit_size,
73 status: s.status,
74 remark: s.remark,
75 created: format_datetime(s.created),
76 modified: format_datetime(s.modified),
77 }
78 }
79}
80
81#[derive(Debug, Clone, Deserialize)]
82pub struct TokenPriceInsertParams {
83 pub service: String,
84 pub model: Option<String>,
85 pub input_price: i64,
86 pub output_price: i64,
87 pub fixed_price: Option<i64>,
88 pub unit_size: Option<i32>,
89 pub status: Option<i16>,
90 pub remark: Option<String>,
91}
92
93#[derive(Debug, Clone, Deserialize, Default)]
94pub struct TokenPriceUpdateParams {
95 pub input_price: Option<i64>,
96 pub output_price: Option<i64>,
97 pub fixed_price: Option<i64>,
98 pub unit_size: Option<i32>,
99 pub status: Option<i16>,
100 pub remark: Option<String>,
101}
102
103#[derive(Default)]
104pub struct TokenPriceModel {}
105
106impl TokenPriceModel {
107 pub async fn get_by_service_model(
110 &self,
111 pool: &Pool<Postgres>,
112 service: &str,
113 model: &str,
114 ) -> Result<Option<TokenPrice>> {
115 let result = sqlx::query_as::<_, TokenPriceSchema>(
117 r#"SELECT * FROM token_prices
118 WHERE service = $1 AND model = $2 AND status = 1 AND deleted_at IS NULL
119 LIMIT 1"#,
120 )
121 .bind(service)
122 .bind(model)
123 .fetch_optional(pool)
124 .await
125 .context(SqlxSnafu)?;
126
127 if result.is_some() {
128 return Ok(result.map(Into::into));
129 }
130
131 if !model.is_empty() {
133 let fallback = sqlx::query_as::<_, TokenPriceSchema>(
134 r#"SELECT * FROM token_prices
135 WHERE service = $1 AND model = '' AND status = 1 AND deleted_at IS NULL
136 LIMIT 1"#,
137 )
138 .bind(service)
139 .fetch_optional(pool)
140 .await
141 .context(SqlxSnafu)?;
142 return Ok(fallback.map(Into::into));
143 }
144
145 Ok(None)
146 }
147
148 pub fn calculate_cost(price: &TokenPrice, input_tokens: i32, output_tokens: i32) -> i64 {
151 let unit = price.unit_size.max(1) as i64;
152 let input_cost = (input_tokens as i64 * price.input_price + unit - 1) / unit;
154 let output_cost = (output_tokens as i64 * price.output_price + unit - 1) / unit;
155 price.fixed_price + input_cost + output_cost
156 }
157}
158
159impl Model for TokenPriceModel {
160 type Output = TokenPrice;
161 fn new() -> Self {
162 Self::default()
163 }
164
165 async fn schema_view(&self, _pool: &Pool<Postgres>) -> SchemaView {
166 SchemaView {
167 schemas: vec![
168 Schema::new_id(),
169 Schema {
170 name: "service".to_string(),
171 category: SchemaType::String,
172 required: true,
173 fixed: true,
174 filterable: true,
175 ..Default::default()
176 },
177 Schema {
178 name: "model".to_string(),
179 category: SchemaType::String,
180 fixed: true,
181 filterable: true,
182 ..Default::default()
183 },
184 Schema {
185 name: "input_price".to_string(),
186 category: SchemaType::Number,
187 required: true,
188 ..Default::default()
189 },
190 Schema {
191 name: "output_price".to_string(),
192 category: SchemaType::Number,
193 required: true,
194 ..Default::default()
195 },
196 Schema {
197 name: "fixed_price".to_string(),
198 category: SchemaType::Number,
199 ..Default::default()
200 },
201 Schema {
202 name: "unit_size".to_string(),
203 category: SchemaType::Number,
204 default_value: Some(serde_json::json!(1000)),
205 ..Default::default()
206 },
207 Schema::new_status(),
208 Schema::new_remark(),
209 Schema::new_created(),
210 Schema::new_modified(),
211 ],
212 allow_edit: SchemaAllowEdit {
213 roles: vec!["su".to_string()],
214 ..Default::default()
215 },
216 allow_create: SchemaAllowCreate {
217 roles: vec!["su".to_string()],
218 ..Default::default()
219 },
220 }
221 }
222
223 async fn insert(&self, pool: &Pool<Postgres>, data: serde_json::Value) -> Result<u64> {
224 let p: TokenPriceInsertParams = serde_json::from_value(data).context(JsonSnafu)?;
225 let row: (i64,) = sqlx::query_as(
226 r#"INSERT INTO token_prices
227 (service, model, input_price, output_price, fixed_price, unit_size, status, remark)
228 VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
229 RETURNING id"#,
230 )
231 .bind(&p.service)
232 .bind(p.model.unwrap_or_default())
233 .bind(p.input_price)
234 .bind(p.output_price)
235 .bind(p.fixed_price.unwrap_or(0))
236 .bind(p.unit_size.unwrap_or(1000))
237 .bind(p.status.unwrap_or(Status::Enabled as i16))
238 .bind(p.remark.unwrap_or_default())
239 .fetch_one(pool)
240 .await
241 .context(SqlxSnafu)?;
242 Ok(row.0 as u64)
243 }
244
245 async fn get_by_id(&self, pool: &Pool<Postgres>, id: u64) -> Result<Option<Self::Output>> {
246 let result = sqlx::query_as::<_, TokenPriceSchema>(
247 r#"SELECT * FROM token_prices WHERE id = $1 AND deleted_at IS NULL"#,
248 )
249 .bind(id as i64)
250 .fetch_optional(pool)
251 .await
252 .context(SqlxSnafu)?;
253 Ok(result.map(Into::into))
254 }
255
256 async fn update_by_id(
257 &self,
258 pool: &Pool<Postgres>,
259 id: u64,
260 data: serde_json::Value,
261 ) -> Result<()> {
262 let p: TokenPriceUpdateParams = serde_json::from_value(data).context(JsonSnafu)?;
263 let mut qb: QueryBuilder<Postgres> =
264 QueryBuilder::new("UPDATE token_prices SET modified = NOW()");
265 if let Some(v) = p.input_price {
266 qb.push(", input_price = ").push_bind(v);
267 }
268 if let Some(v) = p.output_price {
269 qb.push(", output_price = ").push_bind(v);
270 }
271 if let Some(v) = p.fixed_price {
272 qb.push(", fixed_price = ").push_bind(v);
273 }
274 if let Some(v) = p.unit_size {
275 qb.push(", unit_size = ").push_bind(v);
276 }
277 if let Some(v) = p.status {
278 qb.push(", status = ").push_bind(v);
279 }
280 if let Some(v) = p.remark {
281 qb.push(", remark = ").push_bind(v);
282 }
283 qb.push(" WHERE id = ")
284 .push_bind(id as i64)
285 .push(" AND deleted_at IS NULL");
286 qb.build().execute(pool).await.context(SqlxSnafu)?;
287 Ok(())
288 }
289
290 async fn delete_by_id(&self, pool: &Pool<Postgres>, id: u64) -> Result<()> {
291 sqlx::query(
292 r#"UPDATE token_prices SET deleted_at = NOW() WHERE id = $1 AND deleted_at IS NULL"#,
293 )
294 .bind(id as i64)
295 .execute(pool)
296 .await
297 .context(SqlxSnafu)?;
298 Ok(())
299 }
300
301 async fn count(&self, pool: &Pool<Postgres>, params: &ModelListParams) -> Result<i64> {
302 let mut qb: QueryBuilder<Postgres> = QueryBuilder::new("SELECT COUNT(*) FROM token_prices");
303 self.push_conditions(&mut qb, params)?;
304 let row: (i64,) = qb
305 .build_query_as()
306 .fetch_one(pool)
307 .await
308 .context(SqlxSnafu)?;
309 Ok(row.0)
310 }
311
312 async fn list(
313 &self,
314 pool: &Pool<Postgres>,
315 params: &ModelListParams,
316 ) -> Result<Vec<Self::Output>> {
317 let mut qb: QueryBuilder<Postgres> = QueryBuilder::new("SELECT * FROM token_prices");
318 self.push_conditions(&mut qb, params)?;
319 params.push_pagination(&mut qb);
320 let rows = qb
321 .build_query_as::<TokenPriceSchema>()
322 .fetch_all(pool)
323 .await
324 .context(SqlxSnafu)?;
325 Ok(rows.into_iter().map(Into::into).collect())
326 }
327
328 fn push_filter_conditions<'args>(
329 &self,
330 qb: &mut QueryBuilder<'args, Postgres>,
331 filters: &HashMap<String, String>,
332 ) -> Result<()> {
333 if let Some(service) = filters.get("service") {
334 qb.push(" AND service = ").push_bind(service.clone());
335 }
336 if let Some(status) = filters.get("status") {
337 if let Ok(v) = status.parse::<i16>() {
338 qb.push(" AND status = ").push_bind(v);
339 }
340 }
341 Ok(())
342 }
343}