systemprompt_models/ai/
models.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Hash)]
5pub struct ToolModelConfig {
6    #[serde(skip_serializing_if = "Option::is_none")]
7    pub provider: Option<String>,
8    #[serde(skip_serializing_if = "Option::is_none")]
9    pub model: Option<String>,
10    #[serde(skip_serializing_if = "Option::is_none")]
11    pub max_output_tokens: Option<u32>,
12}
13
14impl ToolModelConfig {
15    pub fn new(provider: impl Into<String>, model: impl Into<String>) -> Self {
16        Self {
17            provider: Some(provider.into()),
18            model: Some(model.into()),
19            max_output_tokens: None,
20        }
21    }
22
23    pub const fn with_max_output_tokens(mut self, tokens: u32) -> Self {
24        self.max_output_tokens = Some(tokens);
25        self
26    }
27
28    pub const fn is_empty(&self) -> bool {
29        self.provider.is_none() && self.model.is_none() && self.max_output_tokens.is_none()
30    }
31
32    pub fn merge_with(&self, other: &Self) -> Self {
33        Self {
34            provider: other.provider.clone().or_else(|| self.provider.clone()),
35            model: other.model.clone().or_else(|| self.model.clone()),
36            max_output_tokens: other.max_output_tokens.or(self.max_output_tokens),
37        }
38    }
39}
40
41pub type ToolModelOverrides = HashMap<String, HashMap<String, ToolModelConfig>>;
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelConfig {
45    pub id: String,
46    pub max_tokens: u32,
47    pub supports_tools: bool,
48    #[serde(default)]
49    pub cost_per_1k_tokens: f32,
50}
51
52impl ModelConfig {
53    pub fn new(id: impl Into<String>, max_tokens: u32, supports_tools: bool) -> Self {
54        Self {
55            id: id.into(),
56            max_tokens,
57            supports_tools,
58            cost_per_1k_tokens: 0.0,
59        }
60    }
61
62    pub const fn with_cost(mut self, cost: f32) -> Self {
63        self.cost_per_1k_tokens = cost;
64        self
65    }
66}