"""
Collection of environment-specific PredictionBuilder.
"""
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.utils.ordered_set import OrderedSet
[docs]class DummyPredictorForRailEnv(PredictionBuilder):
"""
DummyPredictorForRailEnv object.
This object returns predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
[docs] def get(self, handle: int = None):
"""
Called whenever get_many in the observation build is called.
Parameters
----------
handle : int, optional
Handle of the agent for which to compute the observation vector.
Returns
-------
np.array
Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
prediction_dict = {}
for agent in agents:
if agent.status != RailAgentStatus.ACTIVE:
# TODO make this generic
continue
action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
agent_virtual_position = agent.position
agent_virtual_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
for index in range(1, self.max_depth + 1):
action_done = False
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
continue
for action in action_priorities:
cell_is_free, new_cell_isValid, new_direction, new_position, transition_isValid = \
self.env._check_action_on_agent(action, agent)
if all([new_cell_isValid, transition_isValid]):
# move and change direction to face the new_direction that was
# performed
agent.position = new_position
agent.direction = new_direction
prediction[index] = [index, *new_position, new_direction, action]
action_done = True
break
if not action_done:
raise Exception("Cannot move further. Something is wrong")
prediction_dict[agent.handle] = prediction
agent.position = agent_virtual_position
agent.direction = agent_virtual_direction
return prediction_dict
[docs]class ShortestPathPredictorForRailEnv(PredictionBuilder):
"""
ShortestPathPredictorForRailEnv object.
This object returns shortest-path predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def __init__(self, max_depth: int = 20):
super().__init__(max_depth)
[docs] def get(self, handle: int = None):
"""
Called whenever get_many in the observation build is called.
Requires distance_map to extract the shortest path.
Does not take into account future positions of other agents!
If there is no shortest path, the agent just stands still and stops moving.
Parameters
----------
handle : int, optional
Handle of the agent for which to compute the observation vector.
Returns
-------
np.array
Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here (not implemented yet)
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
distance_map: DistanceMap = self.env.distance_map
shortest_paths = get_shortest_paths(distance_map, max_depth=self.max_depth)
prediction_dict = {}
for agent in agents:
if agent.status == RailAgentStatus.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
agent_virtual_position = agent.target
else:
prediction = np.zeros(shape=(self.max_depth + 1, 5))
for i in range(self.max_depth):
prediction[i] = [i, None, None, None, None]
prediction_dict[agent.handle] = prediction
continue
agent_virtual_direction = agent.direction
agent_speed = agent.speed_data["speed"]
times_per_cell = int(np.reciprocal(agent_speed))
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
shortest_path = shortest_paths[agent.handle]
# if there is a shortest path, remove the initial position
if shortest_path:
shortest_path = shortest_path[1:]
new_direction = agent_virtual_direction
new_position = agent_virtual_position
visited = OrderedSet()
for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving until max_depth is reached
if new_position == agent.target or not shortest_path:
prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING]
visited.add((*new_position, agent.direction))
continue
if index % times_per_cell == 0:
new_position = shortest_path[0].position
new_direction = shortest_path[0].direction
shortest_path = shortest_path[1:]
# prediction is ready
prediction[index] = [index, *new_position, new_direction, 0]
visited.add((*new_position, new_direction))
# TODO: very bady side effects for visualization only: hand the dev_pred_dict back instead of setting on env!
self.env.dev_pred_dict[agent.handle] = visited
prediction_dict[agent.handle] = prediction
return prediction_dict