Skip to main content

steer_auth_plugin/
directive.rs

1use crate::error::Result;
2use crate::identifiers::ModelId;
3use async_trait::async_trait;
4use std::fmt;
5use std::sync::Arc;
6
7#[derive(Debug, Clone)]
8pub struct HeaderPair {
9    pub name: String,
10    pub value: String,
11}
12
13#[derive(Debug, Clone)]
14pub struct QueryParam {
15    pub name: String,
16    pub value: String,
17}
18
19#[derive(Debug, Clone, Copy)]
20pub enum RequestKind {
21    Complete,
22    Stream,
23}
24
25#[derive(Debug, Clone)]
26pub struct AuthHeaderContext {
27    pub model_id: Option<ModelId>,
28    pub request_kind: RequestKind,
29}
30
31#[derive(Debug, Clone)]
32pub struct AuthErrorContext {
33    pub status: Option<u16>,
34    pub body_snippet: Option<String>,
35    pub request_kind: RequestKind,
36}
37
38#[derive(Debug, Clone, Copy)]
39pub enum AuthErrorAction {
40    RetryOnce,
41    ReauthRequired,
42    NoAction,
43}
44
45#[async_trait]
46pub trait AuthHeaderProvider: Send + Sync {
47    async fn headers(&self, ctx: AuthHeaderContext) -> Result<Vec<HeaderPair>>;
48
49    async fn on_auth_error(&self, ctx: AuthErrorContext) -> Result<AuthErrorAction>;
50}
51
52#[derive(Debug, Clone)]
53pub enum InstructionPolicy {
54    Prefix(String),
55    DefaultIfEmpty(String),
56    Override(String),
57}
58
59#[derive(Debug, Clone)]
60pub enum AuthDirective {
61    OpenAiResponses(OpenAiResponsesAuth),
62    Anthropic(AnthropicAuth),
63}
64
65#[derive(Clone)]
66pub struct OpenAiResponsesAuth {
67    pub headers: Arc<dyn AuthHeaderProvider>,
68    pub base_url_override: Option<String>,
69    pub require_streaming: Option<bool>,
70    pub instruction_policy: Option<InstructionPolicy>,
71    pub include: Option<Vec<String>>,
72}
73
74impl fmt::Debug for OpenAiResponsesAuth {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        f.debug_struct("OpenAiResponsesAuth")
77            .field("headers", &"<AuthHeaderProvider>")
78            .field("base_url_override", &self.base_url_override)
79            .field("require_streaming", &self.require_streaming)
80            .field("instruction_policy", &self.instruction_policy)
81            .field("include", &self.include)
82            .finish()
83    }
84}
85
86#[derive(Clone)]
87pub struct AnthropicAuth {
88    pub headers: Arc<dyn AuthHeaderProvider>,
89    pub instruction_policy: Option<InstructionPolicy>,
90    pub query_params: Option<Vec<QueryParam>>,
91}
92
93impl fmt::Debug for AnthropicAuth {
94    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        f.debug_struct("AnthropicAuth")
96            .field("headers", &"<AuthHeaderProvider>")
97            .field("instruction_policy", &self.instruction_policy)
98            .field("query_params", &self.query_params)
99            .finish()
100    }
101}