Skip to main content

tibba_model_token/
price.rs

1// Copyright 2025 Tree xie.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::{
16    Error, JsonSnafu, ModelListParams, Schema, SchemaAllowCreate, SchemaAllowEdit, SchemaType,
17    SchemaView, SqlxSnafu, Status, format_datetime,
18};
19use serde::{Deserialize, Serialize};
20use snafu::ResultExt;
21use sqlx::FromRow;
22use sqlx::{Pool, Postgres, QueryBuilder};
23use std::collections::HashMap;
24use tibba_model::Model;
25use time::PrimitiveDateTime;
26
27type Result<T> = std::result::Result<T, Error>;
28
29#[derive(FromRow)]
30struct TokenPriceSchema {
31    id: i64,
32    service: String,
33    model: String,
34    input_price: i64,
35    output_price: i64,
36    fixed_price: i64,
37    unit_size: i32,
38    status: i16,
39    remark: String,
40    created: PrimitiveDateTime,
41    modified: PrimitiveDateTime,
42}
43
44#[derive(Debug, Clone, Deserialize, Serialize)]
45pub struct TokenPrice {
46    pub id: i64,
47    pub service: String,
48    pub model: String,
49    /// 每 unit_size 个输入 token 扣除的积分数
50    pub input_price: i64,
51    /// 每 unit_size 个输出 token 扣除的积分数
52    pub output_price: i64,
53    /// 每次调用固定扣除积分数
54    pub fixed_price: i64,
55    /// 计费基数,默认 1000(per 1K tokens)
56    pub unit_size: i32,
57    pub status: i16,
58    pub remark: String,
59    pub created: String,
60    pub modified: String,
61}
62
63impl From<TokenPriceSchema> for TokenPrice {
64    fn from(s: TokenPriceSchema) -> Self {
65        Self {
66            id: s.id,
67            service: s.service,
68            model: s.model,
69            input_price: s.input_price,
70            output_price: s.output_price,
71            fixed_price: s.fixed_price,
72            unit_size: s.unit_size,
73            status: s.status,
74            remark: s.remark,
75            created: format_datetime(s.created),
76            modified: format_datetime(s.modified),
77        }
78    }
79}
80
81#[derive(Debug, Clone, Deserialize)]
82pub struct TokenPriceInsertParams {
83    pub service: String,
84    pub model: Option<String>,
85    pub input_price: i64,
86    pub output_price: i64,
87    pub fixed_price: Option<i64>,
88    pub unit_size: Option<i32>,
89    pub status: Option<i16>,
90    pub remark: Option<String>,
91}
92
93#[derive(Debug, Clone, Deserialize, Default)]
94pub struct TokenPriceUpdateParams {
95    pub input_price: Option<i64>,
96    pub output_price: Option<i64>,
97    pub fixed_price: Option<i64>,
98    pub unit_size: Option<i32>,
99    pub status: Option<i16>,
100    pub remark: Option<String>,
101}
102
103#[derive(Default)]
104pub struct TokenPriceModel {}
105
106impl TokenPriceModel {
107    /// 按服务类型和模型名查询定价配置。
108    /// 先精确匹配 (service, model),找不到时退回匹配 (service, "")。
109    pub async fn get_by_service_model(
110        &self,
111        pool: &Pool<Postgres>,
112        service: &str,
113        model: &str,
114    ) -> Result<Option<TokenPrice>> {
115        // 精确匹配
116        let result = sqlx::query_as::<_, TokenPriceSchema>(
117            r#"SELECT * FROM token_prices
118               WHERE service = $1 AND model = $2 AND status = 1 AND deleted_at IS NULL
119               LIMIT 1"#,
120        )
121        .bind(service)
122        .bind(model)
123        .fetch_optional(pool)
124        .await
125        .context(SqlxSnafu)?;
126
127        if result.is_some() {
128            return Ok(result.map(Into::into));
129        }
130
131        // 回退:匹配该服务的默认定价(model = '')
132        if !model.is_empty() {
133            let fallback = sqlx::query_as::<_, TokenPriceSchema>(
134                r#"SELECT * FROM token_prices
135                   WHERE service = $1 AND model = '' AND status = 1 AND deleted_at IS NULL
136                   LIMIT 1"#,
137            )
138            .bind(service)
139            .fetch_optional(pool)
140            .await
141            .context(SqlxSnafu)?;
142            return Ok(fallback.map(Into::into));
143        }
144
145        Ok(None)
146    }
147
148    /// 根据定价配置和 token 用量计算本次消耗积分。
149    /// 使用整数向上取整,避免浮点误差。
150    pub fn calculate_cost(price: &TokenPrice, input_tokens: i32, output_tokens: i32) -> i64 {
151        let unit = price.unit_size.max(1) as i64;
152        // 向上取整:(n * p + unit - 1) / unit
153        let input_cost = (input_tokens as i64 * price.input_price + unit - 1) / unit;
154        let output_cost = (output_tokens as i64 * price.output_price + unit - 1) / unit;
155        price.fixed_price + input_cost + output_cost
156    }
157}
158
159impl Model for TokenPriceModel {
160    type Output = TokenPrice;
161    fn new() -> Self {
162        Self::default()
163    }
164
165    async fn schema_view(&self, _pool: &Pool<Postgres>) -> SchemaView {
166        SchemaView {
167            schemas: vec![
168                Schema::new_id(),
169                Schema {
170                    name: "service".to_string(),
171                    category: SchemaType::String,
172                    required: true,
173                    fixed: true,
174                    filterable: true,
175                    ..Default::default()
176                },
177                Schema {
178                    name: "model".to_string(),
179                    category: SchemaType::String,
180                    fixed: true,
181                    filterable: true,
182                    ..Default::default()
183                },
184                Schema {
185                    name: "input_price".to_string(),
186                    category: SchemaType::Number,
187                    required: true,
188                    ..Default::default()
189                },
190                Schema {
191                    name: "output_price".to_string(),
192                    category: SchemaType::Number,
193                    required: true,
194                    ..Default::default()
195                },
196                Schema {
197                    name: "fixed_price".to_string(),
198                    category: SchemaType::Number,
199                    ..Default::default()
200                },
201                Schema {
202                    name: "unit_size".to_string(),
203                    category: SchemaType::Number,
204                    default_value: Some(serde_json::json!(1000)),
205                    ..Default::default()
206                },
207                Schema::new_status(),
208                Schema::new_remark(),
209                Schema::new_created(),
210                Schema::new_modified(),
211            ],
212            allow_edit: SchemaAllowEdit {
213                roles: vec!["su".to_string()],
214                ..Default::default()
215            },
216            allow_create: SchemaAllowCreate {
217                roles: vec!["su".to_string()],
218                ..Default::default()
219            },
220        }
221    }
222
223    async fn insert(&self, pool: &Pool<Postgres>, data: serde_json::Value) -> Result<u64> {
224        let p: TokenPriceInsertParams = serde_json::from_value(data).context(JsonSnafu)?;
225        let row: (i64,) = sqlx::query_as(
226            r#"INSERT INTO token_prices
227               (service, model, input_price, output_price, fixed_price, unit_size, status, remark)
228               VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
229               RETURNING id"#,
230        )
231        .bind(&p.service)
232        .bind(p.model.unwrap_or_default())
233        .bind(p.input_price)
234        .bind(p.output_price)
235        .bind(p.fixed_price.unwrap_or(0))
236        .bind(p.unit_size.unwrap_or(1000))
237        .bind(p.status.unwrap_or(Status::Enabled as i16))
238        .bind(p.remark.unwrap_or_default())
239        .fetch_one(pool)
240        .await
241        .context(SqlxSnafu)?;
242        Ok(row.0 as u64)
243    }
244
245    async fn get_by_id(&self, pool: &Pool<Postgres>, id: u64) -> Result<Option<Self::Output>> {
246        let result = sqlx::query_as::<_, TokenPriceSchema>(
247            r#"SELECT * FROM token_prices WHERE id = $1 AND deleted_at IS NULL"#,
248        )
249        .bind(id as i64)
250        .fetch_optional(pool)
251        .await
252        .context(SqlxSnafu)?;
253        Ok(result.map(Into::into))
254    }
255
256    async fn update_by_id(
257        &self,
258        pool: &Pool<Postgres>,
259        id: u64,
260        data: serde_json::Value,
261    ) -> Result<()> {
262        let p: TokenPriceUpdateParams = serde_json::from_value(data).context(JsonSnafu)?;
263        let mut qb: QueryBuilder<Postgres> =
264            QueryBuilder::new("UPDATE token_prices SET modified = NOW()");
265        if let Some(v) = p.input_price {
266            qb.push(", input_price = ").push_bind(v);
267        }
268        if let Some(v) = p.output_price {
269            qb.push(", output_price = ").push_bind(v);
270        }
271        if let Some(v) = p.fixed_price {
272            qb.push(", fixed_price = ").push_bind(v);
273        }
274        if let Some(v) = p.unit_size {
275            qb.push(", unit_size = ").push_bind(v);
276        }
277        if let Some(v) = p.status {
278            qb.push(", status = ").push_bind(v);
279        }
280        if let Some(v) = p.remark {
281            qb.push(", remark = ").push_bind(v);
282        }
283        qb.push(" WHERE id = ")
284            .push_bind(id as i64)
285            .push(" AND deleted_at IS NULL");
286        qb.build().execute(pool).await.context(SqlxSnafu)?;
287        Ok(())
288    }
289
290    async fn delete_by_id(&self, pool: &Pool<Postgres>, id: u64) -> Result<()> {
291        sqlx::query(
292            r#"UPDATE token_prices SET deleted_at = NOW() WHERE id = $1 AND deleted_at IS NULL"#,
293        )
294        .bind(id as i64)
295        .execute(pool)
296        .await
297        .context(SqlxSnafu)?;
298        Ok(())
299    }
300
301    async fn count(&self, pool: &Pool<Postgres>, params: &ModelListParams) -> Result<i64> {
302        let mut qb: QueryBuilder<Postgres> = QueryBuilder::new("SELECT COUNT(*) FROM token_prices");
303        self.push_conditions(&mut qb, params)?;
304        let row: (i64,) = qb
305            .build_query_as()
306            .fetch_one(pool)
307            .await
308            .context(SqlxSnafu)?;
309        Ok(row.0)
310    }
311
312    async fn list(
313        &self,
314        pool: &Pool<Postgres>,
315        params: &ModelListParams,
316    ) -> Result<Vec<Self::Output>> {
317        let mut qb: QueryBuilder<Postgres> = QueryBuilder::new("SELECT * FROM token_prices");
318        self.push_conditions(&mut qb, params)?;
319        params.push_pagination(&mut qb);
320        let rows = qb
321            .build_query_as::<TokenPriceSchema>()
322            .fetch_all(pool)
323            .await
324            .context(SqlxSnafu)?;
325        Ok(rows.into_iter().map(Into::into).collect())
326    }
327
328    fn push_filter_conditions<'args>(
329        &self,
330        qb: &mut QueryBuilder<'args, Postgres>,
331        filters: &HashMap<String, String>,
332    ) -> Result<()> {
333        if let Some(service) = filters.get("service") {
334            qb.push(" AND service = ").push_bind(service.clone());
335        }
336        if let Some(status) = filters.get("status") {
337            if let Ok(v) = status.parse::<i16>() {
338                qb.push(" AND status = ").push_bind(v);
339            }
340        }
341        Ok(())
342    }
343}