systemd_unit_edit/
dropin.rs1use crate::unit::{Error, SystemdUnit};
7use std::path::Path;
8
9impl SystemdUnit {
10 pub fn from_file_with_dropins(path: &Path) -> Result<Self, Error> {
31 let mut unit = Self::from_file(path)?;
33
34 let mut dropin_dir = path.to_path_buf();
36 dropin_dir.set_extension(format!(
37 "{}.d",
38 path.extension().and_then(|e| e.to_str()).unwrap_or("")
39 ));
40
41 if dropin_dir.is_dir() {
43 let mut entries: Vec<_> = std::fs::read_dir(&dropin_dir)?
44 .filter_map(|e| e.ok())
45 .filter(|e| e.path().extension().and_then(|ext| ext.to_str()) == Some("conf"))
46 .collect();
47
48 entries.sort_by_key(|e| e.file_name());
50
51 for entry in entries {
53 let dropin = Self::from_file(&entry.path())?;
54 unit.merge_dropin(&dropin);
55 }
56 }
57
58 Ok(unit)
59 }
60
61 pub fn merge_dropin(&mut self, dropin: &SystemdUnit) {
85 for dropin_section in dropin.sections() {
86 let section_name = match dropin_section.name() {
87 Some(name) => name,
88 None => continue,
89 };
90
91 let mut main_section = match self.get_section(§ion_name) {
93 Some(section) => section,
94 None => {
95 self.add_section(§ion_name);
97 self.get_section(§ion_name).unwrap()
98 }
99 };
100
101 for entry in dropin_section.entries() {
103 let key = match entry.key() {
104 Some(k) => k,
105 None => continue,
106 };
107 let value = match entry.value() {
108 Some(v) => v,
109 None => continue,
110 };
111
112 if crate::systemd_metadata::is_accumulating_directive(&key) {
115 main_section.add(&key, &value);
116 } else {
117 main_section.set(&key, &value);
118 }
119 }
120 }
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use std::str::FromStr;
128
129 #[test]
130 fn test_merge_dropin_basic() {
131 let mut main = SystemdUnit::from_str("[Unit]\nDescription=Main\n").unwrap();
132 let dropin = SystemdUnit::from_str("[Unit]\nAfter=network.target\n").unwrap();
133
134 main.merge_dropin(&dropin);
135
136 let section = main.get_section("Unit").unwrap();
137 assert_eq!(section.get("Description"), Some("Main".to_string()));
138 assert_eq!(section.get("After"), Some("network.target".to_string()));
139 }
140
141 #[test]
142 fn test_merge_dropin_replaces_non_accumulating() {
143 let mut main = SystemdUnit::from_str("[Unit]\nDescription=Main\n").unwrap();
144 let dropin = SystemdUnit::from_str("[Unit]\nDescription=Updated\n").unwrap();
145
146 main.merge_dropin(&dropin);
147
148 let section = main.get_section("Unit").unwrap();
149 assert_eq!(section.get("Description"), Some("Updated".to_string()));
150 }
151
152 #[test]
153 fn test_merge_dropin_accumulates() {
154 let mut main =
155 SystemdUnit::from_str("[Unit]\nWants=foo.service\nAfter=foo.service\n").unwrap();
156 let dropin =
157 SystemdUnit::from_str("[Unit]\nWants=bar.service\nAfter=bar.service\n").unwrap();
158
159 main.merge_dropin(&dropin);
160
161 let section = main.get_section("Unit").unwrap();
162 let wants = section.get_all("Wants");
163 assert_eq!(wants.len(), 2);
164 assert!(wants.contains(&"foo.service".to_string()));
165 assert!(wants.contains(&"bar.service".to_string()));
166
167 let after = section.get_all("After");
168 assert_eq!(after.len(), 2);
169 assert!(after.contains(&"foo.service".to_string()));
170 assert!(after.contains(&"bar.service".to_string()));
171 }
172
173 #[test]
174 fn test_merge_dropin_new_section() {
175 let mut main = SystemdUnit::from_str("[Unit]\nDescription=Main\n").unwrap();
176 let dropin = SystemdUnit::from_str("[Service]\nType=simple\n").unwrap();
177
178 main.merge_dropin(&dropin);
179
180 assert_eq!(main.sections().count(), 2);
181 let service = main.get_section("Service").unwrap();
182 assert_eq!(service.get("Type"), Some("simple".to_string()));
183 }
184
185 #[test]
186 fn test_merge_dropin_mixed() {
187 let mut main = SystemdUnit::from_str(
188 "[Unit]\nDescription=Main\nWants=foo.service\n\n[Service]\nType=simple\n",
189 )
190 .unwrap();
191 let dropin = SystemdUnit::from_str(
192 "[Unit]\nAfter=network.target\nWants=bar.service\n\n[Service]\nRestart=always\n",
193 )
194 .unwrap();
195
196 main.merge_dropin(&dropin);
197
198 let unit_section = main.get_section("Unit").unwrap();
199 assert_eq!(unit_section.get("Description"), Some("Main".to_string()));
200 assert_eq!(
201 unit_section.get("After"),
202 Some("network.target".to_string())
203 );
204 let wants = unit_section.get_all("Wants");
205 assert_eq!(wants.len(), 2);
206
207 let service_section = main.get_section("Service").unwrap();
208 assert_eq!(service_section.get("Type"), Some("simple".to_string()));
209 assert_eq!(service_section.get("Restart"), Some("always".to_string()));
210 }
211}