Source code for flatland.action_plan.action_plan

import pprint
from typing import Dict, List, Optional, NamedTuple

import numpy as np

from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_action_for_move
from flatland.envs.rail_trainrun_data_structures import Waypoint, Trainrun, TrainrunWaypoint

# ---- ActionPlan ---------------
# an action plan element represents the actions to be taken by an agent at the given time step
ActionPlanElement = NamedTuple('ActionPlanElement', [
    ('scheduled_at', int),
    ('action', RailEnvActions)
])
# an action plan gathers all the the actions to be taken by a single agent at the corresponding time steps
ActionPlan = List[ActionPlanElement]

# An action plan dict gathers all the actions for every agent identified by the dictionary key = agent_handle
ActionPlanDict = Dict[int, ActionPlan]


[docs]class ControllerFromTrainruns(): """Takes train runs, derives the actions from it and re-acts them.""" pp = pprint.PrettyPrinter(indent=4) def __init__(self, env: RailEnv, trainrun_dict: Dict[int, Trainrun]): self.env: RailEnv = env self.trainrun_dict: Dict[int, Trainrun] = trainrun_dict self.action_plan: ActionPlanDict = [self._create_action_plan_for_agent(agent_id, chosen_path) for agent_id, chosen_path in trainrun_dict.items()]
[docs] def get_waypoint_before_or_at_step(self, agent_id: int, step: int) -> Waypoint: """ Get the way point point from which the current position can be extracted. Parameters ---------- agent_id step Returns ------- WalkingElement """ trainrun = self.trainrun_dict[agent_id] entry_time_step = trainrun[0].scheduled_at # the agent has no position before and at choosing to enter the grid (one tick elapses before the agent enters the grid) if step <= entry_time_step: return Waypoint(position=None, direction=self.env.agents[agent_id].initial_direction) # the agent has no position as soon as the target is reached exit_time_step = trainrun[-1].scheduled_at if step >= exit_time_step: # agent loses position as soon as target cell is reached return Waypoint(position=None, direction=trainrun[-1].waypoint.direction) waypoint = None for trainrun_waypoint in trainrun: if step < trainrun_waypoint.scheduled_at: return waypoint if step >= trainrun_waypoint.scheduled_at: waypoint = trainrun_waypoint.waypoint assert waypoint is not None return waypoint
[docs] def get_action_at_step(self, agent_id: int, current_step: int) -> Optional[RailEnvActions]: """ Get the current action if any is defined in the `ActionPlan`. ASSUMPTION we assume the env has `remove_agents_at_target=True` and `activate_agents=False`!! Parameters ---------- agent_id current_step Returns ------- WalkingElement, optional """ for action_plan_element in self.action_plan[agent_id]: scheduled_at = action_plan_element.scheduled_at if scheduled_at > current_step: return None elif current_step == scheduled_at: return action_plan_element.action return None
[docs] def act(self, current_step: int) -> Dict[int, RailEnvActions]: """ Get the action dictionary to be replayed at the current step. Returns only action where required (no action for done agents or those not at the beginning of the cell). ASSUMPTION we assume the env has `remove_agents_at_target=True` and `activate_agents=False`!! Parameters ---------- current_step: int Returns ------- Dict[int, RailEnvActions] """ action_dict = {} for agent_id in range(len(self.env.agents)): action: Optional[RailEnvActions] = self.get_action_at_step(agent_id, current_step) if action is not None: action_dict[agent_id] = action return action_dict
[docs] def print_action_plan(self): """Pretty-prints `ActionPlanDict` of this `ControllerFromTrainruns` to stdout.""" self.__class__.print_action_plan_dict(self.action_plan)
[docs] @staticmethod def print_action_plan_dict(action_plan: ActionPlanDict): """Pretty-prints `ActionPlanDict` to stdout.""" for agent_id, plan in enumerate(action_plan): print("{}: ".format(agent_id)) for step in plan: print(" {}".format(step))
[docs] @staticmethod def assert_actions_plans_equal(expected_action_plan: ActionPlanDict, actual_action_plan: ActionPlanDict): assert len(expected_action_plan) == len(actual_action_plan) for k in range(len(expected_action_plan)): assert len(expected_action_plan[k]) == len(actual_action_plan[k]), \ "len for agent {} should be the same.\n\n expected ({}) = {}\n\n actual ({}) = {}".format( k, len(expected_action_plan[k]), ControllerFromTrainruns.pp.pformat(expected_action_plan[k]), len(actual_action_plan[k]), ControllerFromTrainruns.pp.pformat(actual_action_plan[k])) for i in range(len(expected_action_plan[k])): assert expected_action_plan[k][i] == actual_action_plan[k][i], \ "not the same at agent {} at step {}\n\n expected = {}\n\n actual = {}".format( k, i, ControllerFromTrainruns.pp.pformat(expected_action_plan[k][i]), ControllerFromTrainruns.pp.pformat(actual_action_plan[k][i])) assert expected_action_plan == actual_action_plan, \ "expected {}, found {}".format(expected_action_plan, actual_action_plan)
def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan: action_plan = [] agent = self.env.agents[agent_id] minimum_cell_time = int(np.ceil(1.0 / agent.speed_data['speed'])) for path_loop, trainrun_waypoint in enumerate(trainrun): trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint position = trainrun_waypoint.waypoint.position if Vec2d.is_equal(agent.target, position): break next_trainrun_waypoint: TrainrunWaypoint = trainrun[path_loop + 1] next_position = next_trainrun_waypoint.waypoint.position if path_loop == 0: self._add_action_plan_elements_for_first_path_element_of_agent( action_plan, trainrun_waypoint, next_trainrun_waypoint, minimum_cell_time ) continue just_before_target = Vec2d.is_equal(agent.target, next_position) self._add_action_plan_elements_for_current_path_element( action_plan, minimum_cell_time, trainrun_waypoint, next_trainrun_waypoint) # add a final element if just_before_target: self._add_action_plan_elements_for_target_at_path_element_just_before_target( action_plan, minimum_cell_time, trainrun_waypoint, next_trainrun_waypoint) return action_plan def _add_action_plan_elements_for_current_path_element(self, action_plan: ActionPlan, minimum_cell_time: int, trainrun_waypoint: TrainrunWaypoint, next_trainrun_waypoint: TrainrunWaypoint): scheduled_at = trainrun_waypoint.scheduled_at next_entry_value = next_trainrun_waypoint.scheduled_at position = trainrun_waypoint.waypoint.position direction = trainrun_waypoint.waypoint.direction next_position = next_trainrun_waypoint.waypoint.position next_direction = next_trainrun_waypoint.waypoint.direction next_action = get_action_for_move(position, direction, next_position, next_direction, self.env.rail) # if the next entry is later than minimum_cell_time, then stop here and # move minimum_cell_time before the exit # we have to do this since agents in the RailEnv are processed in the step() in the order of their handle if next_entry_value > scheduled_at + minimum_cell_time: action = ActionPlanElement(scheduled_at, RailEnvActions.STOP_MOVING) action_plan.append(action) action = ActionPlanElement(next_entry_value - minimum_cell_time, next_action) action_plan.append(action) else: action = ActionPlanElement(scheduled_at, next_action) action_plan.append(action) def _add_action_plan_elements_for_target_at_path_element_just_before_target(self, action_plan: ActionPlan, minimum_cell_time: int, trainrun_waypoint: TrainrunWaypoint, next_trainrun_waypoint: TrainrunWaypoint): scheduled_at = trainrun_waypoint.scheduled_at action = ActionPlanElement(scheduled_at + minimum_cell_time, RailEnvActions.STOP_MOVING) action_plan.append(action) def _add_action_plan_elements_for_first_path_element_of_agent(self, action_plan: ActionPlan, trainrun_waypoint: TrainrunWaypoint, next_trainrun_waypoint: TrainrunWaypoint, minimum_cell_time: int): scheduled_at = trainrun_waypoint.scheduled_at position = trainrun_waypoint.waypoint.position direction = trainrun_waypoint.waypoint.direction next_position = next_trainrun_waypoint.waypoint.position next_direction = next_trainrun_waypoint.waypoint.direction # add intial do nothing if we do not enter immediately, actually not necessary if scheduled_at > 0: action = ActionPlanElement(0, RailEnvActions.DO_NOTHING) action_plan.append(action) # add action to enter the grid action = ActionPlanElement(scheduled_at, RailEnvActions.MOVE_FORWARD) action_plan.append(action) next_action = get_action_for_move(position, direction, next_position, next_direction, self.env.rail) # if the agent is blocked in the cell, we have to call stop upon entering! if next_trainrun_waypoint.scheduled_at > scheduled_at + 1 + minimum_cell_time: action = ActionPlanElement(scheduled_at + 1, RailEnvActions.STOP_MOVING) action_plan.append(action) # execute the action exactly minimum_cell_time before the entry into the next cell action = ActionPlanElement(next_trainrun_waypoint.scheduled_at - minimum_cell_time, next_action) action_plan.append(action)