spring_ai_rs/ai_interface/callback/unit/
mod.rs

1use std::{error::Error, hash::Hash};
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    ai_interface::{
7        callback::{
8            facing::Facing,
9            group::Group,
10            resource::Resource,
11            teams::Team,
12            unit::{
13                command_info::{UnitCurrentCommand, UnitSupportedCommand},
14                weapon::UnitWeapon,
15            },
16            unit_def::UnitDef,
17        },
18        AIInterface,
19    },
20    get_callback,
21};
22
23mod command;
24pub mod command_info;
25pub mod weapon;
26
27#[derive(Copy, Clone, Debug)]
28pub struct UnitInterface {
29    ai_id: i32,
30}
31
32#[derive(Clone, Debug)]
33pub struct UnitInterfaceAll {
34    unit_limit: i32,
35    unit_max: i32,
36    unit_definitions: Vec<UnitDef>,
37}
38
39#[derive(Debug, Copy, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
40pub struct Unit {
41    pub ai_id: i32,
42    pub unit_id: i32,
43}
44
45#[derive(Clone, Debug)]
46pub struct UnitAll {
47    unit_def: UnitDef,
48    team: Team,
49    stockpile: i32,
50    stockpile_queued: i32,
51    max_speed: f32,
52    max_range: f32,
53    max_health: f32,
54    max_experience: f32,
55    health: f32,
56    speed: f32,
57    power: f32,
58    position: [f32; 3],
59    velocity: [f32; 3],
60    is_activated: bool,
61    is_being_built: bool,
62    is_cloaked: bool,
63    is_paralyzed: bool,
64    is_neutral: bool,
65    facing: Facing,
66    last_order_frame: i32,
67    weapons: Vec<UnitWeapon>,
68    group: Option<Group>,
69    command_queue_type: i32,
70    current_commands: Vec<UnitCurrentCommand>,
71    supported_commands: Vec<UnitSupportedCommand>,
72}
73
74impl Unit {
75    pub fn unit_def(&self) -> Result<UnitDef, Box<dyn Error>> {
76        let get_def_func = get_callback!(self.ai_id, Unit_getDef)?;
77        Ok(UnitDef {
78            ai_id: self.ai_id,
79            def_id: unsafe { get_def_func(self.ai_id, self.unit_id) },
80        })
81    }
82
83    pub fn team(&self) -> Result<Team, Box<dyn Error>> {
84        let get_team_func = get_callback!(self.ai_id, Unit_getTeam)?;
85
86        Ok(Team {
87            ai_id: self.ai_id,
88            team_id: unsafe { get_team_func(self.ai_id, self.unit_id) },
89        })
90    }
91
92    pub fn get_stockpile(&self) -> Result<i32, Box<dyn Error>> {
93        let get_stockpile_func = get_callback!(self.ai_id, Unit_getStockpile)?;
94
95        Ok(unsafe { get_stockpile_func(self.ai_id, self.unit_id) })
96    }
97
98    pub fn stockpile_queued(&self) -> Result<i32, Box<dyn Error>> {
99        let get_stpclpile_queue_func = get_callback!(self.ai_id, Unit_getStockpileQueued)?;
100
101        Ok(unsafe { get_stpclpile_queue_func(self.ai_id, self.unit_id) })
102    }
103
104    pub fn max_speed(&self) -> Result<f32, Box<dyn Error>> {
105        let get_speed_func = get_callback!(self.ai_id, Unit_getMaxSpeed)?;
106
107        Ok(unsafe { get_speed_func(self.ai_id, self.unit_id) })
108    }
109
110    pub fn max_range(&self) -> Result<f32, Box<dyn Error>> {
111        let get_range_func = get_callback!(self.ai_id, Unit_getMaxRange)?;
112
113        Ok(unsafe { get_range_func(self.ai_id, self.unit_id) })
114    }
115
116    pub fn max_health(&self) -> Result<f32, Box<dyn Error>> {
117        let get_health_func = get_callback!(self.ai_id, Unit_getMaxHealth)?;
118
119        Ok(unsafe { get_health_func(self.ai_id, self.unit_id) })
120    }
121
122    pub fn max_experience(&self) -> Result<f32, Box<dyn Error>> {
123        let get_experience_func = get_callback!(self.ai_id, Unit_getExperience)?;
124
125        Ok(unsafe { get_experience_func(self.ai_id, self.unit_id) })
126    }
127
128    pub fn health(&self) -> Result<f32, Box<dyn Error>> {
129        let get_health_func = get_callback!(self.ai_id, Unit_getHealth)?;
130
131        Ok(unsafe { get_health_func(self.ai_id, self.unit_id) })
132    }
133
134    pub fn speed(&self) -> Result<f32, Box<dyn Error>> {
135        let get_speed_func = get_callback!(self.ai_id, Unit_getSpeed)?;
136
137        Ok(unsafe { get_speed_func(self.ai_id, self.unit_id) })
138    }
139
140    pub fn power(&self) -> Result<f32, Box<dyn Error>> {
141        let get_power_func = get_callback!(self.ai_id, Unit_getPower)?;
142
143        Ok(unsafe { get_power_func(self.ai_id, self.unit_id) })
144    }
145
146    pub fn resource_use(&self, resource: Resource) -> Result<f32, Box<dyn Error>> {
147        let get_resource_use_func = get_callback!(self.ai_id, Unit_getResourceUse)?;
148
149        Ok(unsafe { get_resource_use_func(self.ai_id, self.unit_id, resource.resource_id) })
150    }
151
152    pub fn resource_make(&self, resource: Resource) -> Result<f32, Box<dyn Error>> {
153        let get_resource_make_func = get_callback!(self.ai_id, Unit_getResourceMake)?;
154
155        Ok(unsafe { get_resource_make_func(self.ai_id, self.unit_id, resource.resource_id) })
156    }
157
158    pub fn position(&self) -> Result<[f32; 3], Box<dyn Error>> {
159        let get_position_func = get_callback!(self.ai_id, Unit_getPos)?;
160
161        let mut ret = [0.0_f32; 3];
162        unsafe { get_position_func(self.ai_id, self.unit_id, ret.as_mut_ptr()) };
163
164        Ok(ret)
165    }
166
167    pub fn velocity(&self) -> Result<[f32; 3], Box<dyn Error>> {
168        let get_velocity_func = get_callback!(self.ai_id, Unit_getVel)?;
169
170        let mut ret = [0.0_f32; 3];
171        unsafe { get_velocity_func(self.ai_id, self.unit_id, ret.as_mut_ptr()) };
172
173        Ok(ret)
174    }
175
176    pub fn is_activated(&self) -> Result<bool, Box<dyn Error>> {
177        let get_is_activated_func = get_callback!(self.ai_id, Unit_isActivated)?;
178
179        Ok(unsafe { get_is_activated_func(self.ai_id, self.unit_id) })
180    }
181
182    pub fn is_being_built(&self) -> Result<bool, Box<dyn Error>> {
183        let get_is_being_built_func = get_callback!(self.ai_id, Unit_isBeingBuilt)?;
184
185        Ok(unsafe { get_is_being_built_func(self.ai_id, self.unit_id) })
186    }
187
188    pub fn is_cloaked(&self) -> Result<bool, Box<dyn Error>> {
189        let get_is_cloaked_func = get_callback!(self.ai_id, Unit_isCloaked)?;
190
191        Ok(unsafe { get_is_cloaked_func(self.ai_id, self.unit_id) })
192    }
193
194    pub fn is_paralyzed(&self) -> Result<bool, Box<dyn Error>> {
195        let get_is_paralyzed_func = get_callback!(self.ai_id, Unit_isParalyzed)?;
196
197        Ok(unsafe { get_is_paralyzed_func(self.ai_id, self.unit_id) })
198    }
199
200    pub fn is_neutral(&self) -> Result<bool, Box<dyn Error>> {
201        let get_is_neutral_func = get_callback!(self.ai_id, Unit_isNeutral)?;
202
203        Ok(unsafe { get_is_neutral_func(self.ai_id, self.unit_id) })
204    }
205
206    pub fn facing(&self) -> Result<Facing, Box<dyn Error>> {
207        let get_facing_func = get_callback!(self.ai_id, Unit_getBuildingFacing)?;
208
209        Ok(unsafe { get_facing_func(self.ai_id, self.unit_id) }.into())
210    }
211
212    pub fn last_order_frame(&self) -> Result<i32, Box<dyn Error>> {
213        let get_last_order_frame_func = get_callback!(self.ai_id, Unit_getLastUserOrderFrame)?;
214
215        Ok(unsafe { get_last_order_frame_func(self.ai_id, self.unit_id) })
216    }
217
218    pub fn weapons(&self) -> Result<Vec<UnitWeapon>, Box<dyn Error>> {
219        let get_weapons_func = get_callback!(self.ai_id, Unit_getWeapons)?;
220        let get_weapon_def_func = get_callback!(self.ai_id, Unit_getWeapon)?;
221
222        let number_of_weapons = unsafe { get_weapons_func(self.ai_id, self.unit_id) };
223
224        Ok((0..number_of_weapons)
225            .map(|i| UnitWeapon {
226                ai_id: self.ai_id,
227                unit_id: self.unit_id,
228                weapon_id: i,
229            })
230            .collect())
231    }
232
233    pub fn group(&self) -> Result<Option<Group>, Box<dyn Error>> {
234        let get_group_func = get_callback!(self.ai_id, Unit_getGroup)?;
235        let group_id = unsafe { get_group_func(self.ai_id, self.unit_id) };
236
237        Ok(if group_id == -1 {
238            None
239        } else {
240            Some(Group {
241                ai_id: self.ai_id,
242                group_id,
243            })
244        })
245    }
246
247    pub fn command_queue_type(&self) -> Result<i32, Box<dyn Error>> {
248        let get_current_command_type_func = get_callback!(self.ai_id, Unit_CurrentCommand_getType)?;
249
250        Ok(unsafe { get_current_command_type_func(self.ai_id, self.unit_id) })
251    }
252
253    pub fn current_commands(&self) -> Result<Vec<UnitCurrentCommand>, Box<dyn Error>> {
254        let get_current_commands_func = get_callback!(self.ai_id, Unit_getCurrentCommands)?;
255
256        let number_of_commands = unsafe { get_current_commands_func(self.ai_id, self.unit_id) };
257
258        Ok((0..number_of_commands)
259            .map(|current_command_index| UnitCurrentCommand {
260                ai_id: self.ai_id,
261                unit_id: self.unit_id,
262                current_command_index,
263            })
264            .collect())
265    }
266
267    pub fn supported_commands(&self) -> Result<Vec<UnitSupportedCommand>, Box<dyn Error>> {
268        let get_supported_commands_func = get_callback!(self.ai_id, Unit_getSupportedCommands)?;
269
270        let number_of_commands = unsafe { get_supported_commands_func(self.ai_id, self.unit_id) };
271
272        Ok((0..number_of_commands)
273            .map(|supported_command_index| UnitSupportedCommand {
274                ai_id: self.ai_id,
275                unit_id: self.unit_id,
276                supported_command_index,
277            })
278            .collect())
279    }
280
281    pub fn get_nearest_enemy(
282        &self,
283        radius: f32,
284        spherical: bool,
285    ) -> Result<Option<Unit>, Box<dyn Error>> {
286        let mut units = AIInterface { ai_id: self.ai_id }
287            .unit_interface()
288            .enemy_units_at(self.position()?, radius, spherical)?;
289
290        units.sort_by(|unit1, unit2| {
291            let self_pos = self.position().unwrap();
292            let unit1_pos = unit1.position().unwrap();
293            let unit2_pos = unit2.position().unwrap();
294
295            let distance1 = (unit1_pos[0] - self_pos[0]).abs()
296                + (unit1_pos[1] - self_pos[1]).abs()
297                + (unit1_pos[2] - self_pos[2]).abs();
298            let distance2 = (unit2_pos[0] - self_pos[0]).abs()
299                + (unit2_pos[1] - self_pos[1]).abs()
300                + (unit2_pos[2] - self_pos[2]).abs();
301
302            distance1.partial_cmp(&distance2).unwrap()
303        });
304
305        Ok(units.first().cloned())
306    }
307
308    pub fn all(&self) -> Result<UnitAll, Box<dyn Error>> {
309        Ok(UnitAll {
310            unit_def: self.unit_def()?,
311            team: self.team()?,
312            stockpile: self.get_stockpile()?,
313            stockpile_queued: self.stockpile_queued()?,
314            max_speed: self.max_speed()?,
315            max_range: self.max_range()?,
316            max_health: self.max_health()?,
317            max_experience: self.max_experience()?,
318            health: self.health()?,
319            speed: self.speed()?,
320            power: self.power()?,
321            position: self.position()?,
322            velocity: self.velocity()?,
323            is_activated: self.is_activated()?,
324            is_being_built: self.is_being_built()?,
325            is_cloaked: self.is_cloaked()?,
326            is_paralyzed: self.is_paralyzed()?,
327            is_neutral: self.is_neutral()?,
328            facing: self.facing()?,
329            last_order_frame: self.last_order_frame()?,
330            weapons: self.weapons()?,
331            group: self.group()?,
332            command_queue_type: self.command_queue_type()?,
333            current_commands: self.current_commands()?,
334            supported_commands: self.supported_commands()?,
335        })
336    }
337}
338
339const UNIT_DEF_MAX: usize = 1024;
340const UNIT_LIST_MAX: usize = 1024;
341
342impl AIInterface {
343    pub fn unit_interface(&self) -> UnitInterface {
344        UnitInterface { ai_id: self.ai_id }
345    }
346}
347
348impl UnitInterface {
349    pub fn unit_limit(&self) -> Result<i32, Box<dyn Error>> {
350        let unit_limit_func = get_callback!(self.ai_id, Unit_getLimit)?;
351        Ok(unsafe { unit_limit_func(self.ai_id) })
352    }
353
354    pub fn unit_max(&self) -> Result<i32, Box<dyn Error>> {
355        let unit_max_func = get_callback!(self.ai_id, Unit_getMax)?;
356        Ok(unsafe { unit_max_func(self.ai_id) })
357    }
358
359    pub fn enemy_units(&self) -> Result<Vec<Unit>, Box<dyn Error>> {
360        let get_enemy_units_func = get_callback!(self.ai_id, getEnemyUnits)?;
361
362        let mut unit_list = [-1_i32; UNIT_LIST_MAX];
363        unsafe { get_enemy_units_func(self.ai_id, unit_list.as_mut_ptr(), UNIT_LIST_MAX as i32) };
364
365        Ok(unit_list
366            .iter()
367            .filter_map(|&unit_id| {
368                if unit_id == -1 {
369                    None
370                } else {
371                    Some(Unit {
372                        ai_id: self.ai_id,
373                        unit_id,
374                    })
375                }
376            })
377            .collect())
378    }
379
380    pub fn enemy_units_at(
381        &self,
382        location: [f32; 3],
383        radius: f32,
384        spherical: bool,
385    ) -> Result<Vec<Unit>, Box<dyn Error>> {
386        let get_enemy_units_func = get_callback!(self.ai_id, getEnemyUnitsIn)?;
387
388        let mut unit_list = [-1_i32; UNIT_LIST_MAX];
389        unsafe {
390            get_enemy_units_func(
391                self.ai_id,
392                location.clone().as_mut_ptr(),
393                radius,
394                spherical,
395                unit_list.as_mut_ptr(),
396                UNIT_LIST_MAX as i32,
397            )
398        };
399
400        Ok(unit_list
401            .iter()
402            .filter_map(|&unit_id| {
403                if unit_id == -1 {
404                    None
405                } else {
406                    Some(Unit {
407                        ai_id: self.ai_id,
408                        unit_id,
409                    })
410                }
411            })
412            .collect())
413    }
414
415    pub fn friendly_units(&self) -> Result<Vec<Unit>, Box<dyn Error>> {
416        let get_friendly_units_func = get_callback!(self.ai_id, getFriendlyUnits)?;
417
418        let mut unit_list = [-1_i32; UNIT_LIST_MAX];
419        unsafe {
420            get_friendly_units_func(self.ai_id, unit_list.as_mut_ptr(), UNIT_LIST_MAX as i32)
421        };
422
423        Ok(unit_list
424            .iter()
425            .filter_map(|&unit_id| {
426                if unit_id == -1 {
427                    None
428                } else {
429                    Some(Unit {
430                        ai_id: self.ai_id,
431                        unit_id,
432                    })
433                }
434            })
435            .collect())
436    }
437
438    pub fn friendly_units_at(
439        &self,
440        location: [f32; 3],
441        radius: f32,
442        spherical: bool,
443    ) -> Result<Vec<Unit>, Box<dyn Error>> {
444        let get_friendly_units_func = get_callback!(self.ai_id, getFriendlyUnitsIn)?;
445
446        let mut unit_list = [-1_i32; UNIT_LIST_MAX];
447        unsafe {
448            get_friendly_units_func(
449                self.ai_id,
450                location.clone().as_mut_ptr(),
451                radius,
452                spherical,
453                unit_list.as_mut_ptr(),
454                UNIT_LIST_MAX as i32,
455            )
456        };
457
458        Ok(unit_list
459            .iter()
460            .filter_map(|&unit_id| {
461                if unit_id == -1 {
462                    None
463                } else {
464                    Some(Unit {
465                        ai_id: self.ai_id,
466                        unit_id,
467                    })
468                }
469            })
470            .collect())
471    }
472
473    pub fn neutral_units(&self) -> Result<Vec<Unit>, Box<dyn Error>> {
474        let get_neutral_units_func = get_callback!(self.ai_id, getNeutralUnits)?;
475
476        let mut unit_list = [-1_i32; UNIT_LIST_MAX];
477        unsafe { get_neutral_units_func(self.ai_id, unit_list.as_mut_ptr(), UNIT_LIST_MAX as i32) };
478
479        Ok(unit_list
480            .iter()
481            .filter_map(|&unit_id| {
482                if unit_id == -1 {
483                    None
484                } else {
485                    Some(Unit {
486                        ai_id: self.ai_id,
487                        unit_id,
488                    })
489                }
490            })
491            .collect())
492    }
493
494    pub fn neutral_units_at(
495        &self,
496        location: [f32; 3],
497        radius: f32,
498        spherical: bool,
499    ) -> Result<Vec<Unit>, Box<dyn Error>> {
500        let get_neutral_units_func = get_callback!(self.ai_id, getNeutralUnitsIn)?;
501
502        let mut unit_list = [-1_i32; UNIT_LIST_MAX];
503        unsafe {
504            get_neutral_units_func(
505                self.ai_id,
506                location.clone().as_mut_ptr(),
507                radius,
508                spherical,
509                unit_list.as_mut_ptr(),
510                UNIT_LIST_MAX as i32,
511            )
512        };
513
514        Ok(unit_list
515            .iter()
516            .filter_map(|&unit_id| {
517                if unit_id == -1 {
518                    None
519                } else {
520                    Some(Unit {
521                        ai_id: self.ai_id,
522                        unit_id,
523                    })
524                }
525            })
526            .collect())
527    }
528
529    pub fn team_units(&self) -> Result<Vec<Unit>, Box<dyn Error>> {
530        let get_team_units_func = get_callback!(self.ai_id, getTeamUnits)?;
531
532        let mut unit_list = [-1_i32; UNIT_LIST_MAX];
533        unsafe { get_team_units_func(self.ai_id, unit_list.as_mut_ptr(), UNIT_LIST_MAX as i32) };
534
535        Ok(unit_list
536            .iter()
537            .filter_map(|&unit_id| {
538                if unit_id == -1 {
539                    None
540                } else {
541                    Some(Unit {
542                        ai_id: self.ai_id,
543                        unit_id,
544                    })
545                }
546            })
547            .collect())
548    }
549
550    pub fn selected_units(&self) -> Result<Vec<Unit>, Box<dyn Error>> {
551        let get_enemy_units_func = get_callback!(self.ai_id, getEnemyUnits)?;
552
553        let mut unit_list = [-1_i32; UNIT_LIST_MAX];
554        unsafe { get_enemy_units_func(self.ai_id, unit_list.as_mut_ptr(), UNIT_LIST_MAX as i32) };
555
556        Ok(unit_list
557            .iter()
558            .filter_map(|&unit_id| {
559                if unit_id == -1 {
560                    None
561                } else {
562                    Some(Unit {
563                        ai_id: self.ai_id,
564                        unit_id,
565                    })
566                }
567            })
568            .collect())
569    }
570
571    pub fn get_unit_definitions(&self) -> Result<Vec<UnitDef>, Box<dyn Error>> {
572        let get_unit_defs_func = get_callback!(self.ai_id, getUnitDefs)?;
573
574        let mut temp = [-1; UNIT_DEF_MAX];
575
576        unsafe { get_unit_defs_func(self.ai_id, temp.as_mut_ptr(), temp.len() as i32) };
577
578        Ok(temp
579            .iter()
580            .filter(|&&def_id| def_id != -1)
581            .map(|&def_id| UnitDef {
582                ai_id: self.ai_id,
583                def_id,
584            })
585            .collect())
586    }
587
588    pub fn all(&self) -> Result<UnitInterfaceAll, Box<dyn Error>> {
589        Ok(UnitInterfaceAll {
590            unit_limit: self.unit_limit()?,
591            unit_max: self.unit_max()?,
592            unit_definitions: self.get_unit_definitions()?,
593        })
594    }
595}