|
1 | 1 | import sys |
2 | | -from typing import Any, NoReturn |
| 2 | +from abc import abstractmethod |
| 3 | +from typing import Any, Callable, NoReturn |
3 | 4 |
|
4 | | -from rcrs_core.agents.agent import Agent as RCRSAgent |
| 5 | +from bitarray import bitarray |
| 6 | +from rcrs_core.commands.AKClear import AKClear |
| 7 | +from rcrs_core.commands.AKClearArea import AKClearArea |
| 8 | +from rcrs_core.commands.AKLoad import AKLoad |
| 9 | +from rcrs_core.commands.AKMove import AKMove |
| 10 | +from rcrs_core.commands.AKRescue import AKRescue |
| 11 | +from rcrs_core.commands.AKRest import AKRest |
| 12 | +from rcrs_core.commands.AKSay import AKSay |
| 13 | +from rcrs_core.commands.AKSpeak import AKSpeak |
| 14 | +from rcrs_core.commands.AKSubscribe import AKSubscribe |
| 15 | +from rcrs_core.commands.AKTell import AKTell |
| 16 | +from rcrs_core.commands.AKUnload import AKUnload |
| 17 | +from rcrs_core.commands.Command import Command |
| 18 | +from rcrs_core.config.config import Config as RCRSConfig |
| 19 | +from rcrs_core.connection.URN import Command as CommandURN |
| 20 | +from rcrs_core.connection.URN import ComponentCommand as ComponentCommandMessageID |
| 21 | +from rcrs_core.connection.URN import ComponentControlMSG as ComponentControlMessageID |
| 22 | +from rcrs_core.connection.URN import Entity as EntityURN |
| 23 | +from rcrs_core.messages.AKAcknowledge import AKAcknowledge |
| 24 | +from rcrs_core.messages.AKConnect import AKConnect |
| 25 | +from rcrs_core.messages.controlMessageFactory import ControlMessageFactory |
| 26 | +from rcrs_core.messages.KAConnectError import KAConnectError |
| 27 | +from rcrs_core.messages.KAConnectOK import KAConnectOK |
| 28 | +from rcrs_core.messages.KASense import KASense |
| 29 | +from rcrs_core.worldmodel.changeSet import ChangeSet |
| 30 | +from rcrs_core.worldmodel.entityID import EntityID |
5 | 31 | from rcrs_core.worldmodel.worldmodel import WorldModel |
6 | 32 |
|
7 | | -from adf_core_python.core.logger.logger import get_logger |
| 33 | +from adf_core_python.core.agent.communication.message_manager import MessageManager |
| 34 | +from adf_core_python.core.agent.communication.standard.bundle.centralized.command_ambulance import ( |
| 35 | + CommandAmbulance, |
| 36 | +) |
| 37 | +from adf_core_python.core.agent.communication.standard.bundle.centralized.command_fire import ( |
| 38 | + CommandFire, |
| 39 | +) |
| 40 | +from adf_core_python.core.agent.communication.standard.bundle.centralized.command_police import ( |
| 41 | + CommandPolice, |
| 42 | +) |
| 43 | +from adf_core_python.core.agent.communication.standard.bundle.centralized.command_scout import ( |
| 44 | + CommandScout, |
| 45 | +) |
| 46 | +from adf_core_python.core.agent.communication.standard.bundle.centralized.message_report import ( |
| 47 | + MessageReport, |
| 48 | +) |
| 49 | +from adf_core_python.core.agent.communication.standard.bundle.information.message_ambulance_team import ( |
| 50 | + MessageAmbulanceTeam, |
| 51 | +) |
| 52 | +from adf_core_python.core.agent.communication.standard.bundle.information.message_building import ( |
| 53 | + MessageBuilding, |
| 54 | +) |
| 55 | +from adf_core_python.core.agent.communication.standard.bundle.information.message_civilian import ( |
| 56 | + MessageCivilian, |
| 57 | +) |
| 58 | +from adf_core_python.core.agent.communication.standard.bundle.information.message_fire_brigade import ( |
| 59 | + MessageFireBrigade, |
| 60 | +) |
| 61 | +from adf_core_python.core.agent.communication.standard.bundle.information.message_police_force import ( |
| 62 | + MessagePoliceForce, |
| 63 | +) |
| 64 | +from adf_core_python.core.agent.communication.standard.bundle.information.message_road import ( |
| 65 | + MessageRoad, |
| 66 | +) |
| 67 | +from adf_core_python.core.agent.communication.standard.standard_communication_module import ( |
| 68 | + StandardCommunicationModule, |
| 69 | +) |
| 70 | +from adf_core_python.core.agent.config.module_config import ModuleConfig |
| 71 | +from adf_core_python.core.agent.develop.develop_data import DevelopData |
| 72 | +from adf_core_python.core.agent.info.agent_info import AgentInfo |
| 73 | +from adf_core_python.core.agent.info.scenario_info import Mode, ScenarioInfo |
| 74 | +from adf_core_python.core.agent.info.world_info import WorldInfo |
| 75 | +from adf_core_python.core.agent.precompute.precompute_data import PrecomputeData |
| 76 | +from adf_core_python.core.component.communication.communication_module import ( |
| 77 | + CommunicationModule, |
| 78 | +) |
| 79 | +from adf_core_python.core.config.config import Config |
| 80 | +from adf_core_python.core.launcher.config_key import ConfigKey |
| 81 | +from adf_core_python.core.logger.logger import get_agent_logger, get_logger |
8 | 82 |
|
9 | 83 |
|
10 | | -class Agent(RCRSAgent): |
11 | | - def __init__(self, is_precompute: bool, name: str) -> None: |
| 84 | +class Agent: |
| 85 | + def __init__( |
| 86 | + self, |
| 87 | + is_precompute: bool, |
| 88 | + name: str, |
| 89 | + is_debug: bool, |
| 90 | + team_name: str, |
| 91 | + data_storage_name: str, |
| 92 | + module_config: ModuleConfig, |
| 93 | + develop_data: DevelopData, |
| 94 | + ) -> None: |
12 | 95 | self.name = name |
13 | 96 | self.connect_request_id = None |
14 | 97 | self.world_model = WorldModel() |
15 | | - self.config = None |
| 98 | + self.config: Config |
16 | 99 | self.random = None |
17 | | - self.agent_id = None |
| 100 | + self.agent_id: EntityID |
18 | 101 | self.precompute_flag = is_precompute |
19 | 102 | self.logger = get_logger( |
20 | 103 | f"{self.__class__.__module__}.{self.__class__.__qualname__}" |
21 | 104 | ) |
22 | 105 |
|
| 106 | + self.team_name = team_name |
| 107 | + self.is_debug = is_debug |
| 108 | + self.is_precompute = is_precompute |
| 109 | + |
| 110 | + if is_precompute: |
| 111 | + # PrecomputeData.remove_date(data_storage_name) |
| 112 | + self.mode = Mode.PRECOMPUTATION |
| 113 | + |
| 114 | + self._module_config = module_config |
| 115 | + self._develop_data = develop_data |
| 116 | + self._precompute_data = PrecomputeData(data_storage_name) |
| 117 | + self._message_manager: MessageManager = MessageManager() |
| 118 | + self._communication_module: CommunicationModule = StandardCommunicationModule() |
| 119 | + |
| 120 | + def get_entity_id(self) -> EntityID: |
| 121 | + return self.agent_id |
| 122 | + |
| 123 | + def set_send_msg(self, connection_send_func: Callable) -> None: |
| 124 | + self.send_msg = connection_send_func |
| 125 | + |
| 126 | + def post_connect(self) -> None: |
| 127 | + if self.is_precompute: |
| 128 | + self._mode = Mode.PRECOMPUTATION |
| 129 | + else: |
| 130 | + # if self._precompute_data.is_ready(): |
| 131 | + # self._mode = Mode.PRECOMPUTED |
| 132 | + # else: |
| 133 | + # self._mode = Mode.NON_PRECOMPUTE |
| 134 | + self._mode = Mode.NON_PRECOMPUTE |
| 135 | + |
| 136 | + self.config.set_value(ConfigKey.KEY_DEBUG_FLAG, self.is_debug) |
| 137 | + self.config.set_value( |
| 138 | + ConfigKey.KEY_DEVELOP_FLAG, self._develop_data.is_develop_mode() |
| 139 | + ) |
| 140 | + self._ignore_time: int = int( |
| 141 | + self.config.get_value("kernel.agents.ignoreuntil", 3) |
| 142 | + ) |
| 143 | + self._scenario_info: ScenarioInfo = ScenarioInfo(self.config, self._mode) |
| 144 | + self._world_info: WorldInfo = WorldInfo(self.world_model) |
| 145 | + self._agent_info = AgentInfo(self, self.world_model) |
| 146 | + self.logger = get_agent_logger( |
| 147 | + f"{self.__class__.__module__}.{self.__class__.__qualname__}", |
| 148 | + self._agent_info, |
| 149 | + ) |
| 150 | + |
| 151 | + self.logger.info(f"config: {self.config}") |
| 152 | + |
| 153 | + def update_step_info( |
| 154 | + self, time: int, change_set: ChangeSet, hear: list[Command] |
| 155 | + ) -> None: |
| 156 | + self._agent_info.record_think_start_time() |
| 157 | + self._agent_info.set_time(time) |
| 158 | + |
| 159 | + if time == 1: |
| 160 | + self._message_manager.register_message_class(0, MessageAmbulanceTeam) |
| 161 | + self._message_manager.register_message_class(1, MessageFireBrigade) |
| 162 | + self._message_manager.register_message_class(2, MessagePoliceForce) |
| 163 | + self._message_manager.register_message_class(3, MessageBuilding) |
| 164 | + self._message_manager.register_message_class(4, MessageCivilian) |
| 165 | + self._message_manager.register_message_class(5, MessageRoad) |
| 166 | + self._message_manager.register_message_class(6, CommandAmbulance) |
| 167 | + self._message_manager.register_message_class(7, CommandFire) |
| 168 | + self._message_manager.register_message_class(8, CommandPolice) |
| 169 | + self._message_manager.register_message_class(9, CommandScout) |
| 170 | + self._message_manager.register_message_class(10, MessageReport) |
| 171 | + |
| 172 | + if time > self._ignore_time: |
| 173 | + self._message_manager.subscribe( |
| 174 | + self._agent_info, self._world_info, self._scenario_info |
| 175 | + ) |
| 176 | + if not self._message_manager.get_is_subscribed(): |
| 177 | + subscribed_channels = self._message_manager.get_subscribed_channels() |
| 178 | + if subscribed_channels: |
| 179 | + self.logger.debug( |
| 180 | + f"Subscribed channels: {subscribed_channels}", |
| 181 | + message_manager=self._message_manager, |
| 182 | + ) |
| 183 | + self.send_subscribe(time, subscribed_channels) |
| 184 | + self._message_manager.set_is_subscribed(True) |
| 185 | + |
| 186 | + self._agent_info.set_heard_commands(hear) |
| 187 | + self._agent_info.set_change_set(change_set) |
| 188 | + self._world_info.set_change_set(change_set) |
| 189 | + |
| 190 | + self._message_manager.refresh() |
| 191 | + self._communication_module.receive(self, self._message_manager) |
| 192 | + |
| 193 | + self.think() |
| 194 | + |
| 195 | + self.logger.debug( |
| 196 | + f"send messages: {self._message_manager.get_send_message_list()}", |
| 197 | + message_manager=self._message_manager, |
| 198 | + ) |
| 199 | + |
| 200 | + self._message_manager.coordinate_message( |
| 201 | + self._agent_info, self._world_info, self._scenario_info |
| 202 | + ) |
| 203 | + self._communication_module.send(self, self._message_manager) |
| 204 | + |
| 205 | + @abstractmethod |
| 206 | + def think(self) -> None: |
| 207 | + pass |
| 208 | + |
| 209 | + @abstractmethod |
| 210 | + def precompute(self) -> None: |
| 211 | + pass |
| 212 | + |
| 213 | + @abstractmethod |
| 214 | + def get_requested_entities(self) -> list[EntityURN]: |
| 215 | + pass |
| 216 | + |
| 217 | + def start_up(self, request_id: int) -> None: |
| 218 | + ak_connect = AKConnect() |
| 219 | + self.send_msg(ak_connect.write(request_id, self)) |
| 220 | + |
| 221 | + def message_received(self, msg: Any) -> None: |
| 222 | + c_msg = ControlMessageFactory().make_message(msg) |
| 223 | + if isinstance(c_msg, KASense): |
| 224 | + self.handler_sense(c_msg) |
| 225 | + elif isinstance(c_msg, KAConnectOK): |
| 226 | + self.handle_connect_ok(c_msg) |
| 227 | + elif isinstance(c_msg, KAConnectError): |
| 228 | + self.handle_connect_error(c_msg) |
| 229 | + |
23 | 230 | def handle_connect_error(self, msg: Any) -> NoReturn: |
24 | 231 | self.logger.error( |
25 | 232 | "Failed to connect agent: %s(request_id: %s)", msg.reason, msg.request_id |
26 | 233 | ) |
27 | 234 | sys.exit(1) |
| 235 | + |
| 236 | + def handle_connect_ok(self, msg: Any) -> None: |
| 237 | + self.agent_id = EntityID(msg.agent_id) |
| 238 | + self.world_model.add_entities(msg.world) |
| 239 | + config: RCRSConfig = msg.config |
| 240 | + self.config = Config() |
| 241 | + if config is not None: |
| 242 | + for key, value in config.data.items(): |
| 243 | + self.config.set_value(key, value) |
| 244 | + for key, value in config.int_data.items(): |
| 245 | + self.config.set_value(key, value) |
| 246 | + for key, value in config.float_data.items(): |
| 247 | + self.config.set_value(key, value) |
| 248 | + for key, value in config.boolean_data.items(): |
| 249 | + self.config.set_value(key, value) |
| 250 | + for key, value in config.array_data.items(): |
| 251 | + self.config.set_value(key, value) |
| 252 | + self.send_acknowledge(msg.request_id) |
| 253 | + self.post_connect() |
| 254 | + if self.precompute_flag: |
| 255 | + print("self.precompute_flag: ", self.precompute_flag) |
| 256 | + self.precompute() |
| 257 | + |
| 258 | + def handler_sense(self, msg: Any) -> None: |
| 259 | + _id = EntityID(msg.agent_id) |
| 260 | + time = msg.time |
| 261 | + change_set = msg.change_set |
| 262 | + heard = msg.hear.commands |
| 263 | + |
| 264 | + if _id != self.get_entity_id(): |
| 265 | + self.logger.error("Agent ID mismatch: %s != %s", _id, self.get_entity_id()) |
| 266 | + return |
| 267 | + |
| 268 | + heard_commands: list[Command] = [] |
| 269 | + for herad_command in heard: |
| 270 | + if herad_command.urn == CommandURN.AK_SPEAK: |
| 271 | + heard_commands.append( |
| 272 | + AKSpeak( |
| 273 | + herad_command.components[ |
| 274 | + ComponentControlMessageID.AgentID |
| 275 | + ].entityID, |
| 276 | + herad_command.components[ |
| 277 | + ComponentControlMessageID.Time |
| 278 | + ].intValue, |
| 279 | + herad_command.components[ |
| 280 | + ComponentCommandMessageID.Message |
| 281 | + ].rawData, |
| 282 | + herad_command.components[ |
| 283 | + ComponentCommandMessageID.Channel |
| 284 | + ].intValue, |
| 285 | + ) |
| 286 | + ) |
| 287 | + |
| 288 | + self.world_model.merge(change_set) |
| 289 | + self.update_step_info(time, change_set, heard_commands) |
| 290 | + |
| 291 | + def send_acknowledge(self, request_id: int) -> None: |
| 292 | + ak_ack = AKAcknowledge() |
| 293 | + self.send_msg(ak_ack.write(request_id, self.agent_id)) |
| 294 | + |
| 295 | + def send_clear(self, time: int, target: EntityID) -> None: |
| 296 | + cmd = AKClear(self.get_entity_id(), time, target) |
| 297 | + msg = cmd.prepare_cmd() |
| 298 | + self.send_msg(msg) |
| 299 | + |
| 300 | + def send_clear_area(self, time: int, x: int = -1, y: int = -1) -> None: |
| 301 | + cmd = AKClearArea(self.get_entity_id(), time, x, y) |
| 302 | + msg = cmd.prepare_cmd() |
| 303 | + self.send_msg(msg) |
| 304 | + |
| 305 | + def send_load(self, time: int, target: EntityID) -> None: |
| 306 | + cmd = AKLoad(self.get_entity_id(), time, target) |
| 307 | + msg = cmd.prepare_cmd() |
| 308 | + self.send_msg(msg) |
| 309 | + |
| 310 | + def send_move(self, time: int, path: list[int], x: int = -1, y: int = -1) -> None: |
| 311 | + cmd = AKMove(self.get_entity_id(), time, path[:], x, y) |
| 312 | + msg = cmd.prepare_cmd() |
| 313 | + self.send_msg(msg) |
| 314 | + |
| 315 | + def send_rescue(self, time: int, target: EntityID) -> None: |
| 316 | + cmd = AKRescue(self.get_entity_id(), time, target) |
| 317 | + msg = cmd.prepare_cmd() |
| 318 | + self.send_msg(msg) |
| 319 | + |
| 320 | + def send_rest(self, time: int) -> None: |
| 321 | + cmd = AKRest(self.get_entity_id(), time) |
| 322 | + msg = cmd.prepare_cmd() |
| 323 | + self.send_msg(msg) |
| 324 | + |
| 325 | + def send_say(self, time_step: int, message: str) -> None: |
| 326 | + cmd = AKSay(self.get_entity_id(), time_step, message) |
| 327 | + msg = cmd.prepare_cmd() |
| 328 | + self.send_msg(msg) |
| 329 | + |
| 330 | + def send_speak(self, time_step: int, message: bitarray, channel: int) -> None: |
| 331 | + cmd = AKSpeak(self.get_entity_id(), time_step, bytes(message), channel) # type: ignore |
| 332 | + msg = cmd.prepare_cmd() |
| 333 | + self.send_msg(msg) |
| 334 | + |
| 335 | + def send_subscribe(self, time: int, channels: list[int]) -> None: |
| 336 | + cmd = AKSubscribe(self.get_entity_id(), time, channels) |
| 337 | + msg = cmd.prepare_cmd() |
| 338 | + self.send_msg(msg) |
| 339 | + |
| 340 | + def send_tell(self, time: int, message: str) -> None: |
| 341 | + cmd = AKTell(self.get_entity_id(), time, message) |
| 342 | + msg = cmd.prepare_cmd() |
| 343 | + self.send_msg(msg) |
| 344 | + |
| 345 | + def send_unload(self, time: int) -> None: |
| 346 | + cmd = AKUnload(self.get_entity_id(), time) |
| 347 | + msg = cmd.prepare_cmd() |
| 348 | + self.send_msg(msg) |
0 commit comments