Skip to main content

steer_tui/tui/state/
setup.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use steer_grpc::client_api::{AuthProgress, ProviderId, ProviderInfo};
4
5#[derive(Debug, Clone)]
6pub struct SetupState {
7    pub current_step: SetupStep,
8    pub auth_providers: HashMap<ProviderId, AuthStatus>,
9    pub selected_provider: Option<ProviderId>,
10    pub auth_flow_id: Option<String>,
11    pub auth_progress: Option<AuthProgress>,
12    pub auth_input: String,
13    pub error_message: Option<String>,
14    pub provider_cursor: usize,
15    pub skip_setup: bool,
16    pub registry: Arc<RemoteProviderRegistry>,
17}
18
19#[derive(Debug, Clone, PartialEq)]
20pub enum SetupStep {
21    Welcome,
22    ProviderSelection,
23    Authentication(ProviderId),
24    Completion,
25}
26
27#[derive(Debug, Clone, PartialEq)]
28pub enum AuthStatus {
29    NotConfigured,
30    ApiKeySet,
31    OAuthConfigured,
32    InProgress,
33}
34
35/// Minimal provider view built from remote proto ProviderInfo
36#[derive(Debug, Clone)]
37pub struct RemoteProviderConfig {
38    pub id: String,
39    pub name: String,
40}
41
42#[derive(Debug, Clone)]
43pub struct RemoteProviderRegistry {
44    providers: Vec<RemoteProviderConfig>,
45}
46
47impl RemoteProviderRegistry {
48    pub fn from_proto(providers: Vec<ProviderInfo>) -> Self {
49        let providers = providers
50            .into_iter()
51            .map(|p| RemoteProviderConfig {
52                id: p.id,
53                name: p.name,
54            })
55            .collect();
56        Self { providers }
57    }
58
59    pub fn all(&self) -> impl Iterator<Item = &RemoteProviderConfig> {
60        self.providers.iter()
61    }
62
63    pub fn get(&self, id: &ProviderId) -> Option<&RemoteProviderConfig> {
64        self.providers.iter().find(|p| p.id == id.storage_key())
65    }
66}
67
68impl SetupState {
69    pub fn new(
70        registry: Arc<RemoteProviderRegistry>,
71        auth_providers: HashMap<ProviderId, AuthStatus>,
72    ) -> Self {
73        Self {
74            current_step: SetupStep::Welcome,
75            auth_providers,
76            selected_provider: None,
77            auth_flow_id: None,
78            auth_progress: None,
79            auth_input: String::new(),
80            error_message: None,
81            provider_cursor: 0,
82            skip_setup: false,
83            registry,
84        }
85    }
86
87    /// Create a SetupState that skips the welcome page - for /auth command
88    pub fn new_for_auth_command(
89        registry: Arc<RemoteProviderRegistry>,
90        auth_providers: HashMap<ProviderId, AuthStatus>,
91    ) -> Self {
92        let mut state = Self::new(registry, auth_providers);
93        state.current_step = SetupStep::ProviderSelection;
94        state
95    }
96
97    pub fn next_step(&mut self) {
98        self.current_step = match &self.current_step {
99            SetupStep::Welcome => SetupStep::ProviderSelection,
100            SetupStep::ProviderSelection => {
101                if let Some(provider) = &self.selected_provider {
102                    SetupStep::Authentication(provider.clone())
103                } else {
104                    SetupStep::ProviderSelection
105                }
106            }
107            SetupStep::Authentication(_) => SetupStep::Completion,
108            SetupStep::Completion => SetupStep::Completion,
109        };
110        self.error_message = None;
111    }
112
113    pub fn previous_step(&mut self) {
114        self.current_step = match &self.current_step {
115            SetupStep::Welcome => SetupStep::Welcome,
116            SetupStep::ProviderSelection => SetupStep::Welcome,
117            SetupStep::Authentication(_) => SetupStep::ProviderSelection,
118            SetupStep::Completion => SetupStep::ProviderSelection,
119        };
120        self.error_message = None;
121    }
122
123    pub fn available_providers(&self) -> Vec<&RemoteProviderConfig> {
124        let mut providers: Vec<_> = self.registry.all().collect();
125        // Sort by name for consistent ordering
126        providers.sort_by_key(|p| p.name.clone());
127        providers
128    }
129}