import pprint
from typing import Dict, List, Optional, NamedTuple
import numpy as np
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_action_for_move
from flatland.envs.rail_trainrun_data_structures import Waypoint, Trainrun, TrainrunWaypoint
# ---- ActionPlan ---------------
# an action plan element represents the actions to be taken by an agent at the given time step
ActionPlanElement = NamedTuple('ActionPlanElement', [
('scheduled_at', int),
('action', RailEnvActions)
])
# an action plan gathers all the the actions to be taken by a single agent at the corresponding time steps
ActionPlan = List[ActionPlanElement]
# An action plan dict gathers all the actions for every agent identified by the dictionary key = agent_handle
ActionPlanDict = Dict[int, ActionPlan]
[docs]class ControllerFromTrainruns():
"""Takes train runs, derives the actions from it and re-acts them."""
pp = pprint.PrettyPrinter(indent=4)
def __init__(self,
env: RailEnv,
trainrun_dict: Dict[int, Trainrun]):
self.env: RailEnv = env
self.trainrun_dict: Dict[int, Trainrun] = trainrun_dict
self.action_plan: ActionPlanDict = [self._create_action_plan_for_agent(agent_id, chosen_path)
for agent_id, chosen_path in trainrun_dict.items()]
[docs] def get_waypoint_before_or_at_step(self, agent_id: int, step: int) -> Waypoint:
"""
Get the way point point from which the current position can be extracted.
Parameters
----------
agent_id
step
Returns
-------
WalkingElement
"""
trainrun = self.trainrun_dict[agent_id]
entry_time_step = trainrun[0].scheduled_at
# the agent has no position before and at choosing to enter the grid (one tick elapses before the agent enters the grid)
if step <= entry_time_step:
return Waypoint(position=None, direction=self.env.agents[agent_id].initial_direction)
# the agent has no position as soon as the target is reached
exit_time_step = trainrun[-1].scheduled_at
if step >= exit_time_step:
# agent loses position as soon as target cell is reached
return Waypoint(position=None, direction=trainrun[-1].waypoint.direction)
waypoint = None
for trainrun_waypoint in trainrun:
if step < trainrun_waypoint.scheduled_at:
return waypoint
if step >= trainrun_waypoint.scheduled_at:
waypoint = trainrun_waypoint.waypoint
assert waypoint is not None
return waypoint
[docs] def get_action_at_step(self, agent_id: int, current_step: int) -> Optional[RailEnvActions]:
"""
Get the current action if any is defined in the `ActionPlan`.
ASSUMPTION we assume the env has `remove_agents_at_target=True` and `activate_agents=False`!!
Parameters
----------
agent_id
current_step
Returns
-------
WalkingElement, optional
"""
for action_plan_element in self.action_plan[agent_id]:
scheduled_at = action_plan_element.scheduled_at
if scheduled_at > current_step:
return None
elif current_step == scheduled_at:
return action_plan_element.action
return None
[docs] def act(self, current_step: int) -> Dict[int, RailEnvActions]:
"""
Get the action dictionary to be replayed at the current step.
Returns only action where required (no action for done agents or those not at the beginning of the cell).
ASSUMPTION we assume the env has `remove_agents_at_target=True` and `activate_agents=False`!!
Parameters
----------
current_step: int
Returns
-------
Dict[int, RailEnvActions]
"""
action_dict = {}
for agent_id in range(len(self.env.agents)):
action: Optional[RailEnvActions] = self.get_action_at_step(agent_id, current_step)
if action is not None:
action_dict[agent_id] = action
return action_dict
[docs] def print_action_plan(self):
"""Pretty-prints `ActionPlanDict` of this `ControllerFromTrainruns` to stdout."""
self.__class__.print_action_plan_dict(self.action_plan)
[docs] @staticmethod
def print_action_plan_dict(action_plan: ActionPlanDict):
"""Pretty-prints `ActionPlanDict` to stdout."""
for agent_id, plan in enumerate(action_plan):
print("{}: ".format(agent_id))
for step in plan:
print(" {}".format(step))
[docs] @staticmethod
def assert_actions_plans_equal(expected_action_plan: ActionPlanDict, actual_action_plan: ActionPlanDict):
assert len(expected_action_plan) == len(actual_action_plan)
for k in range(len(expected_action_plan)):
assert len(expected_action_plan[k]) == len(actual_action_plan[k]), \
"len for agent {} should be the same.\n\n expected ({}) = {}\n\n actual ({}) = {}".format(
k,
len(expected_action_plan[k]),
ControllerFromTrainruns.pp.pformat(expected_action_plan[k]),
len(actual_action_plan[k]),
ControllerFromTrainruns.pp.pformat(actual_action_plan[k]))
for i in range(len(expected_action_plan[k])):
assert expected_action_plan[k][i] == actual_action_plan[k][i], \
"not the same at agent {} at step {}\n\n expected = {}\n\n actual = {}".format(
k, i,
ControllerFromTrainruns.pp.pformat(expected_action_plan[k][i]),
ControllerFromTrainruns.pp.pformat(actual_action_plan[k][i]))
assert expected_action_plan == actual_action_plan, \
"expected {}, found {}".format(expected_action_plan, actual_action_plan)
def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan:
action_plan = []
agent = self.env.agents[agent_id]
minimum_cell_time = int(np.ceil(1.0 / agent.speed_data['speed']))
for path_loop, trainrun_waypoint in enumerate(trainrun):
trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint
position = trainrun_waypoint.waypoint.position
if Vec2d.is_equal(agent.target, position):
break
next_trainrun_waypoint: TrainrunWaypoint = trainrun[path_loop + 1]
next_position = next_trainrun_waypoint.waypoint.position
if path_loop == 0:
self._add_action_plan_elements_for_first_path_element_of_agent(
action_plan,
trainrun_waypoint,
next_trainrun_waypoint,
minimum_cell_time
)
continue
just_before_target = Vec2d.is_equal(agent.target, next_position)
self._add_action_plan_elements_for_current_path_element(
action_plan,
minimum_cell_time,
trainrun_waypoint,
next_trainrun_waypoint)
# add a final element
if just_before_target:
self._add_action_plan_elements_for_target_at_path_element_just_before_target(
action_plan,
minimum_cell_time,
trainrun_waypoint,
next_trainrun_waypoint)
return action_plan
def _add_action_plan_elements_for_current_path_element(self,
action_plan: ActionPlan,
minimum_cell_time: int,
trainrun_waypoint: TrainrunWaypoint,
next_trainrun_waypoint: TrainrunWaypoint):
scheduled_at = trainrun_waypoint.scheduled_at
next_entry_value = next_trainrun_waypoint.scheduled_at
position = trainrun_waypoint.waypoint.position
direction = trainrun_waypoint.waypoint.direction
next_position = next_trainrun_waypoint.waypoint.position
next_direction = next_trainrun_waypoint.waypoint.direction
next_action = get_action_for_move(position,
direction,
next_position,
next_direction,
self.env.rail)
# if the next entry is later than minimum_cell_time, then stop here and
# move minimum_cell_time before the exit
# we have to do this since agents in the RailEnv are processed in the step() in the order of their handle
if next_entry_value > scheduled_at + minimum_cell_time:
action = ActionPlanElement(scheduled_at, RailEnvActions.STOP_MOVING)
action_plan.append(action)
action = ActionPlanElement(next_entry_value - minimum_cell_time, next_action)
action_plan.append(action)
else:
action = ActionPlanElement(scheduled_at, next_action)
action_plan.append(action)
def _add_action_plan_elements_for_target_at_path_element_just_before_target(self,
action_plan: ActionPlan,
minimum_cell_time: int,
trainrun_waypoint: TrainrunWaypoint,
next_trainrun_waypoint: TrainrunWaypoint):
scheduled_at = trainrun_waypoint.scheduled_at
action = ActionPlanElement(scheduled_at + minimum_cell_time, RailEnvActions.STOP_MOVING)
action_plan.append(action)
def _add_action_plan_elements_for_first_path_element_of_agent(self,
action_plan: ActionPlan,
trainrun_waypoint: TrainrunWaypoint,
next_trainrun_waypoint: TrainrunWaypoint,
minimum_cell_time: int):
scheduled_at = trainrun_waypoint.scheduled_at
position = trainrun_waypoint.waypoint.position
direction = trainrun_waypoint.waypoint.direction
next_position = next_trainrun_waypoint.waypoint.position
next_direction = next_trainrun_waypoint.waypoint.direction
# add intial do nothing if we do not enter immediately, actually not necessary
if scheduled_at > 0:
action = ActionPlanElement(0, RailEnvActions.DO_NOTHING)
action_plan.append(action)
# add action to enter the grid
action = ActionPlanElement(scheduled_at, RailEnvActions.MOVE_FORWARD)
action_plan.append(action)
next_action = get_action_for_move(position,
direction,
next_position,
next_direction,
self.env.rail)
# if the agent is blocked in the cell, we have to call stop upon entering!
if next_trainrun_waypoint.scheduled_at > scheduled_at + 1 + minimum_cell_time:
action = ActionPlanElement(scheduled_at + 1, RailEnvActions.STOP_MOVING)
action_plan.append(action)
# execute the action exactly minimum_cell_time before the entry into the next cell
action = ActionPlanElement(next_trainrun_waypoint.scheduled_at - minimum_cell_time, next_action)
action_plan.append(action)