systemprompt_models/profile/gateway/
override_rule.rs1use serde::{Deserialize, Serialize};
12use systemprompt_identifiers::ProviderId;
13
14use super::error::{GatewayProfileError, GatewayResult};
15use super::route::match_pattern;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, schemars::JsonSchema)]
18#[serde(rename_all = "snake_case")]
19pub enum OverrideRuleAction {
20 Replace,
21 Strip,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
25#[serde(deny_unknown_fields)]
26pub struct SystemPromptRule {
27 #[serde(default, skip_serializing_if = "Option::is_none")]
28 pub provider: Option<ProviderId>,
29 #[serde(default, skip_serializing_if = "Option::is_none")]
30 pub model_pattern: Option<String>,
31 pub action: OverrideRuleAction,
32 #[serde(default, skip_serializing_if = "Option::is_none")]
33 pub prompt: Option<String>,
34}
35
36impl SystemPromptRule {
37 #[must_use]
38 pub fn matches(&self, provider: &ProviderId, model: &str) -> bool {
39 let provider_ok = self
40 .provider
41 .as_ref()
42 .is_none_or(|p| p.as_str() == provider.as_str());
43 let model_ok = self
44 .model_pattern
45 .as_deref()
46 .is_none_or(|pat| match_pattern(pat, model));
47 provider_ok && model_ok
48 }
49
50 pub const fn validate(&self) -> GatewayResult<()> {
51 match self.action {
52 OverrideRuleAction::Replace if self.prompt.is_none() => {
53 Err(GatewayProfileError::OverrideReplaceMissingPrompt)
54 },
55 OverrideRuleAction::Strip if self.prompt.is_some() => {
56 Err(GatewayProfileError::OverrideStripWithPrompt)
57 },
58 _ => Ok(()),
59 }
60 }
61}