Skip to content

Commit 675c110

Browse files
committed
feat: add get_entity_position method and fix typo in get_blockades method
1 parent 0423ae5 commit 675c110

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

adf_core_python/core/agent/info/world_info.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from rcrs_core.entities.area import Area
44
from rcrs_core.entities.blockade import Blockade
55
from rcrs_core.entities.entity import Entity
6+
from rcrs_core.entities.human import Human
67
from rcrs_core.worldmodel.changeSet import ChangeSet
78
from rcrs_core.worldmodel.entityID import EntityID
89
from rcrs_core.worldmodel.worldmodel import WorldModel
@@ -145,6 +146,36 @@ def get_distance(self, entity_id1: EntityID, entity_id2: EntityID) -> float:
145146

146147
return distance
147148

149+
def get_entity_position(self, entity_id: EntityID) -> EntityID:
150+
"""
151+
Get the entity position
152+
153+
Parameters
154+
----------
155+
entity_id : EntityID
156+
Entity ID
157+
158+
Returns
159+
-------
160+
EntityID
161+
Entity position
162+
163+
Raises
164+
------
165+
ValueError
166+
If the entity is invalid
167+
"""
168+
entity = self.get_entity(entity_id)
169+
if entity is None:
170+
raise ValueError(f"Invalid entity: entity_id={entity_id}, entity={entity}")
171+
if isinstance(entity, Area):
172+
return entity.get_id()
173+
if isinstance(entity, Human):
174+
return entity.get_position()
175+
if isinstance(entity, Blockade):
176+
return entity.get_position()
177+
raise ValueError(f"Invalid entity type: entity_id={entity_id}, entity={entity}")
178+
148179
def get_change_set(self) -> ChangeSet:
149180
"""
150181
Get the change set
@@ -156,7 +187,7 @@ def get_change_set(self) -> ChangeSet:
156187
"""
157188
return self._change_set
158189

159-
def get_bloackades(self, area: Area) -> set[Blockade]:
190+
def get_blockades(self, area: Area) -> set[Blockade]:
160191
"""
161192
Get the blockades in the area
162193
@@ -165,12 +196,12 @@ def get_bloackades(self, area: Area) -> set[Blockade]:
165196
ChangeSet
166197
Blockade
167198
"""
168-
bloakcades = set()
199+
blockades = set()
169200
for blockade_entity_id in area.get_blockades():
170-
bloackde_entity = self.get_entity(blockade_entity_id)
171-
if isinstance(bloackde_entity, Blockade):
172-
bloakcades.add(cast(Blockade, bloackde_entity))
173-
return bloakcades
201+
blockades_entity = self.get_entity(blockade_entity_id)
202+
if isinstance(blockades_entity, Blockade):
203+
blockades.add(cast(Blockade, blockades_entity))
204+
return blockades
174205

175206
def add_entity(self, entity: Entity) -> None:
176207
"""

adf_core_python/implement/action/default_extend_action_clear.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ def _get_intersect_edge_action(
432432
start_x = int(agent_x + vector[0])
433433
start_y = int(agent_y + vector[1])
434434

435-
for blockade in self.world_info.get_bloackades(road):
435+
for blockade in self.world_info.get_blockades(road):
436436
if self._is_intersecting_blockade(
437437
start_x, start_y, bp_x, bp_y, blockade
438438
):
@@ -563,7 +563,7 @@ def _get_area_clear_action(
563563
if road.get_blockades() == []:
564564
return None
565565

566-
blockades = set(self.world_info.get_bloackades(road))
566+
blockades = set(self.world_info.get_blockades(road))
567567
min_distance = sys.float_info.max
568568
clear_blockade: Optional[Blockade] = None
569569
for blockade in blockades:
@@ -662,7 +662,7 @@ def _get_neighbour_position_action(
662662
start_x = int(agent_x + vector[0])
663663
start_y = int(agent_y + vector[1])
664664

665-
for blockade in self.world_info.get_bloackades(road):
665+
for blockade in self.world_info.get_blockades(road):
666666
if self._is_intersecting_blockade(
667667
start_x, start_y, mid_x, mid_y, blockade
668668
):
@@ -715,7 +715,7 @@ def _get_neighbour_position_action(
715715
min_point_distance = sys.float_info.max
716716
clear_x = 0
717717
clear_y = 0
718-
for blockade in self.world_info.get_bloackades(road):
718+
for blockade in self.world_info.get_blockades(road):
719719
apexes = blockade.get_apexes()
720720
for i in range(0, len(apexes) - 2, 2):
721721
distance = self._get_distance(

0 commit comments

Comments
 (0)