Source code for flatland.envs.agent_utils

from enum import IntEnum
from itertools import starmap
from typing import Tuple, Optional, NamedTuple

from attr import attrs, attrib, Factory

from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.schedule_utils import Schedule


[docs]class RailAgentStatus(IntEnum): READY_TO_DEPART = 0 # not in grid yet (position is None) -> prediction as if it were at initial position ACTIVE = 1 # in grid (position is not None), not done -> prediction is remaining path DONE = 2 # in grid (position is not None), but done -> prediction is stay at target forever DONE_REMOVED = 3 # removed from grid (position is None) -> prediction is None
Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]), ('initial_direction', Grid4TransitionsEnum), ('direction', Grid4TransitionsEnum), ('target', Tuple[int, int]), ('moving', bool), ('speed_data', dict), ('malfunction_data', dict), ('handle', int), ('status', RailAgentStatus), ('position', Tuple[int, int]), ('old_direction', Grid4TransitionsEnum), ('old_position', Tuple[int, int])])
[docs]@attrs class EnvAgent: initial_position = attrib(type=Tuple[int, int]) initial_direction = attrib(type=Grid4TransitionsEnum) direction = attrib(type=Grid4TransitionsEnum) target = attrib(type=Tuple[int, int]) moving = attrib(default=False, type=bool) # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0, # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous # cell if speed=1, as default) # N.B. we need to use factory since default arguments are not recreated on each call! speed_data = attrib( default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}))) # if broken>0, the agent's actions are ignored for 'broken' steps # number of time the agent had to stop, since the last time it broke down malfunction_data = attrib( default=Factory( lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0, 'moving_before_malfunction': False}))) handle = attrib(default=None) status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus) position = attrib(default=None, type=Optional[Tuple[int, int]]) # used in rendering old_direction = attrib(default=None) old_position = attrib(default=None)
[docs] def reset(self): """ Resets the agents to their initial values of the episode """ self.position = None # TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280 self.direction = self.initial_direction self.status = RailAgentStatus.READY_TO_DEPART self.old_position = None self.old_direction = None self.moving = False # Reset agent values for speed self.speed_data['position_fraction'] = 0. self.speed_data['transition_action_on_cellexit'] = 0. # Reset agent malfunction values self.malfunction_data['malfunction'] = 0 self.malfunction_data['nr_malfunctions'] = 0 self.malfunction_data['moving_before_malfunction'] = False
[docs] def to_agent(self) -> Agent: return Agent(initial_position=self.initial_position, initial_direction=self.initial_direction, direction=self.direction, target=self.target, moving=self.moving, speed_data=self.speed_data, malfunction_data=self.malfunction_data, handle=self.handle, status=self.status, position=self.position, old_direction=self.old_direction, old_position=self.old_position)
[docs] @classmethod def from_schedule(cls, schedule: Schedule): """ Create a list of EnvAgent from lists of positions, directions and targets """ speed_datas = [] for i in range(len(schedule.agent_positions)): speed_datas.append({'position_fraction': 0.0, 'speed': schedule.agent_speeds[i] if schedule.agent_speeds is not None else 1.0, 'transition_action_on_cellexit': 0}) malfunction_datas = [] for i in range(len(schedule.agent_positions)): malfunction_datas.append({'malfunction': 0, 'malfunction_rate': schedule.agent_malfunction_rates[ i] if schedule.agent_malfunction_rates is not None else 0., 'next_malfunction': 0, 'nr_malfunctions': 0}) return list(starmap(EnvAgent, zip(schedule.agent_positions, schedule.agent_directions, schedule.agent_directions, schedule.agent_targets, [False] * len(schedule.agent_positions), speed_datas, malfunction_datas, range(len(schedule.agent_positions)))))
[docs] @classmethod def load_legacy_static_agent(cls, static_agents_data: Tuple): agents = [] for i, static_agent in enumerate(static_agents_data): agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1], direction=static_agent[1], target=static_agent[2], moving=static_agent[3], speed_data=static_agent[4], malfunction_data=static_agent[5], handle=i) agents.append(agent) return agents