tutti_core/
runner.rs

1use std::{
2    collections::{HashMap, HashSet, VecDeque},
3    time::Duration,
4};
5
6use colored::Colorize;
7use futures::StreamExt;
8use tokio::{
9    sync::mpsc::{self, Receiver, Sender},
10    task::JoinHandle,
11};
12use tutti_config::Project;
13
14use crate::{
15    process::{BoxStream, ProcId},
16    CommandSpec, ProcessManager,
17};
18
19const DEFAULT_KILL_TIMEOUT: u64 = 10;
20
21#[derive(Debug)]
22pub enum LogEvent {
23    Log { service_name: String, line: Vec<u8> },
24    Stop { service_name: String },
25}
26
27async fn follow_output(
28    is_stdout: bool,
29    mut output: BoxStream<Vec<u8>>,
30    service_name: String,
31    rx: Sender<LogEvent>,
32) {
33    while let Some(line) = output.next().await {
34        if rx
35            .send(LogEvent::Log {
36                service_name: service_name.clone(),
37                line,
38            })
39            .await
40            .is_err()
41        {
42            break;
43        }
44    }
45    if is_stdout && rx.send(LogEvent::Stop { service_name }).await.is_err() {
46        eprintln!("Failed to send stop event");
47    }
48}
49
50#[derive(Debug)]
51pub struct RunnerConfig {
52    /// Timeout for killing services (in seconds)
53    pub kill_timeout: Option<u64>,
54}
55
56#[derive(Debug)]
57pub struct Runner<M: ProcessManager> {
58    project: Project,
59    pm: M,
60
61    tasks: Vec<JoinHandle<()>>,
62    processes: Vec<(String, ProcId)>,
63
64    config: RunnerConfig,
65}
66
67impl<M: ProcessManager> Runner<M> {
68    pub fn new(project: Project, pm: M, config: RunnerConfig) -> Self {
69        let tasks = Vec::with_capacity(project.services.len() * 2);
70        let processes = Vec::with_capacity(project.services.len());
71
72        Self {
73            project,
74            pm,
75            tasks,
76            processes,
77            config,
78        }
79    }
80
81    /// Performs topological sort on services considering their dependencies.
82    /// Returns services in order they should be started.
83    ///
84    /// # Errors
85    ///
86    /// Returns an error if a service is not found or if there is a cycle in the dependencies.
87    fn topological_sort(&self, service_names: &HashSet<String>) -> anyhow::Result<Vec<String>> {
88        let mut graph: HashMap<String, Vec<String>> = HashMap::new();
89        let mut in_degree: HashMap<String, usize> = HashMap::new();
90        let mut all_services = HashSet::new();
91
92        let mut to_process = VecDeque::new();
93        for name in service_names {
94            to_process.push_back(name.clone());
95        }
96
97        while let Some(service_name) = to_process.pop_front() {
98            if all_services.contains(&service_name) {
99                continue;
100            }
101
102            let service = self
103                .project
104                .services
105                .get(&service_name)
106                .ok_or_else(|| anyhow::anyhow!("Service '{service_name}' not found"))?;
107
108            all_services.insert(service_name.clone());
109            graph.entry(service_name.clone()).or_default();
110            in_degree.entry(service_name.clone()).or_insert(0);
111
112            for dep in &service.deps {
113                if !self.project.services.contains_key(dep) {
114                    return Err(anyhow::anyhow!(
115                        "Dependency '{dep}' of service '{service_name}' not found",
116                    ));
117                }
118
119                graph
120                    .entry(dep.clone())
121                    .or_default()
122                    .push(service_name.clone());
123                *in_degree.entry(service_name.clone()).or_insert(0) += 1;
124
125                to_process.push_back(dep.clone());
126            }
127        }
128
129        let mut queue = VecDeque::new();
130        for (service, &degree) in &in_degree {
131            if degree == 0 {
132                queue.push_back(service.clone());
133            }
134        }
135
136        let mut result = Vec::new();
137        while let Some(service) = queue.pop_front() {
138            result.push(service.clone());
139
140            if let Some(dependents) = graph.get(&service) {
141                for dependent in dependents {
142                    if let Some(degree) = in_degree.get_mut(dependent) {
143                        *degree -= 1;
144                        if *degree == 0 {
145                            queue.push_back(dependent.clone());
146                        }
147                    }
148                }
149            }
150        }
151
152        if result.len() != all_services.len() {
153            return Err(anyhow::anyhow!("Circular dependency detected"));
154        }
155
156        Ok(result)
157    }
158
159    /// Starts the services defined in the project configuration.
160    ///
161    /// # Errors
162    ///
163    /// Returns an error if any of the services fail to start.
164    pub async fn up(&mut self, services: Vec<String>) -> anyhow::Result<Receiver<LogEvent>> {
165        let (tx, rx) = mpsc::channel(10);
166
167        let service_names = if services.is_empty() {
168            self.project
169                .services
170                .keys()
171                .cloned()
172                .collect::<HashSet<_>>()
173        } else {
174            services
175                .into_iter()
176                .filter(|name| self.project.services.contains_key(name))
177                .collect::<HashSet<_>>()
178        };
179
180        let sorted_services = self.topological_sort(&service_names)?;
181
182        let to_run = sorted_services
183            .into_iter()
184            .map(|name| {
185                let service = self
186                    .project
187                    .services
188                    .get(&name)
189                    .ok_or_else(|| anyhow::anyhow!("unknown service: {name}"))?;
190                Ok((name, service))
191            })
192            .collect::<anyhow::Result<Vec<_>>>()?;
193
194        for (name, service) in &to_run {
195            let service = self
196                .pm
197                .spawn(CommandSpec {
198                    name: name.to_owned(),
199                    cmd: service.cmd.clone(),
200                    cwd: service.cwd.clone(),
201                    env: service
202                        .env
203                        .clone()
204                        .map(|h| h.into_iter().collect())
205                        .unwrap_or_default(),
206                })
207                .await?;
208            let stdout = service.stdout;
209            let stderr = service.stderr;
210
211            self.tasks.push(tokio::spawn(follow_output(
212                true,
213                stdout,
214                name.clone(),
215                tx.clone(),
216            )));
217            self.tasks.push(tokio::spawn(follow_output(
218                false,
219                stderr,
220                name.clone(),
221                tx.clone(),
222            )));
223            self.processes.push((name.to_owned(), service.id));
224        }
225
226        Ok(rx)
227    }
228
229    /// Stops all services.
230    ///
231    /// # Errors
232    ///
233    /// Returns an error if any of the services fail to stop.
234    pub async fn down(&mut self) -> anyhow::Result<()> {
235        let duration =
236            Duration::from_secs(self.config.kill_timeout.unwrap_or(DEFAULT_KILL_TIMEOUT));
237
238        for (name, id) in self.processes.drain(..) {
239            let line = format!("Stopping service {name:?}").yellow();
240            println!("{line}");
241            if self.pm.wait(id, duration).await?.is_some() {
242                let line = format!("service {name:?} already stopped").yellow();
243                println!("{line}");
244                continue;
245            }
246
247            self.pm.shutdown(id).await?;
248            if let Some(exit_code) = self.pm.wait(id, duration).await? {
249                let line = format!("service {name:?} stopped with {exit_code} code").yellow();
250                println!("{line}");
251            } else {
252                self.pm.kill(id).await?;
253                let line = format!("service {name:?} killed").black().on_red();
254                println!("{line}");
255            }
256        }
257
258        Ok(())
259    }
260
261    /// Waits for all services to exit.
262    ///
263    /// # Errors
264    ///
265    /// Returns an error if any of the services fail to exit.
266    pub async fn wait(&mut self) -> anyhow::Result<()> {
267        for task in self.tasks.drain(..) {
268            task.await?;
269        }
270
271        Ok(())
272    }
273}