"""
TransitionMap and derived classes.
"""
import numpy as np
from importlib_resources import path
from numpy import array
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position, get_direction
from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transitions import Transitions
from flatland.utils.ordered_set import OrderedSet
# TODO are these general classes or for grid4 only?
[docs]class TransitionMap:
"""
Base TransitionMap class.
Generic class that implements a collection of transitions over a set of
cells.
"""
[docs] def get_transitions(self, cell_id):
"""
Return a tuple of transitions available in a cell specified by
`cell_id` (e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between,
for stochastic transitions).
Parameters
----------
cell_id : [cell identifier]
The cell_id object depends on the specific implementation.
It generally is an int (e.g., an index) or a tuple of indices.
Returns
-------
tuple
List of the validity of transitions in the cell.
"""
raise NotImplementedError()
[docs] def set_transitions(self, cell_id, new_transitions):
"""
Replaces the available transitions in cell `cell_id` with the tuple
`new_transitions'. `new_transitions` must have
one element for each possible transition.
Parameters
----------
cell_id : [cell identifier]
The cell_id object depends on the specific implementation.
It generally is an int (e.g., an index) or a tuple of indices.
new_transitions : tuple
Tuple of new transitions validitiy for the cell.
"""
raise NotImplementedError()
[docs] def get_transition(self, cell_id, transition_index):
"""
Return the status of whether an agent in cell `cell_id` can perform a
movement along transition `transition_index` (e.g., the NESW direction
of movement, for agents on a grid).
Parameters
----------
cell_id : [cell identifier]
The cell_id object depends on the specific implementation.
It generally is an int (e.g., an index) or a tuple of indices.
transition_index : int
Index of the transition to probe, as index in the tuple returned by
get_transitions(). e.g., the NESW direction of movement, for agents
on a grid.
Returns
-------
int or float (depending on Transitions used)
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
"""
raise NotImplementedError()
[docs] def set_transition(self, cell_id, transition_index, new_transition):
"""
Replaces the validity of transition to `transition_index` in cell
`cell_id' with the new `new_transition`.
Parameters
----------
cell_id : [cell identifier]
The cell_id object depends on the specific implementation.
It generally is an int (e.g., an index) or a tuple of indices.
transition_index : int
Index of the transition to probe, as index in the tuple returned by
get_transitions(). e.g., the NESW direction of movement, for agents
on a grid.
new_transition : int or float (depending on Transitions used)
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
"""
raise NotImplementedError()
[docs]class GridTransitionMap(TransitionMap):
"""
Implements a TransitionMap over a 2D grid.
GridTransitionMap implements utility functions.
"""
def __init__(self, width, height, transitions: Transitions = Grid4Transitions([]), random_seed=None):
"""
Builder for GridTransitionMap object.
Parameters
----------
width : int
Width of the grid.
height : int
Height of the grid.
transitions : Transitions object
The Transitions object to use to encode/decode transitions over the
grid.
"""
self.width = width
self.height = height
self.transitions = transitions
self.random_generator = np.random.RandomState()
if random_seed is None:
self.random_generator.seed(12)
else:
self.random_generator.seed(random_seed)
self.grid = np.zeros((height, width), dtype=self.transitions.get_type())
[docs] def get_full_transitions(self, row, column):
"""
Returns the full transitions for the cell at (row, column) in the format transition_map's transitions.
Parameters
----------
row: int
column: int
(row,column) specifies the cell in this transition map.
Returns
-------
self.transitions.get_type()
The cell content int the format of this map's Transitions.
"""
return self.grid[row][column]
[docs] def get_transitions(self, row, column, orientation):
"""
Return a tuple of transitions available in a cell specified by
`cell_id` (e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between,
for stochastic transitions).
Parameters
----------
cell_id : tuple
The cell_id indices a cell as (column, row, orientation),
where orientation is the direction an agent is facing within a cell.
Alternatively, it can be accessed as (column, row) to return the
full cell content.
Returns
-------
tuple
List of the validity of transitions in the cell as given by the maps transitions.
"""
return self.transitions.get_transitions(self.grid[row][column], orientation)
[docs] def set_transitions(self, cell_id, new_transitions):
"""
Replaces the available transitions in cell `cell_id` with the tuple
`new_transitions'. `new_transitions` must have
one element for each possible transition.
Parameters
----------
cell_id : tuple
The cell_id indices a cell as (column, row, orientation),
where orientation is the direction an agent is facing within a cell.
Alternatively, it can be accessed as (column, row) to replace the
full cell content.
new_transitions : tuple
Tuple of new transitions validitiy for the cell.
"""
assert len(cell_id) in (2, 3), \
'GridTransitionMap.set_transitions() ERROR: cell_id tuple must have length 2 or 3.'
if len(cell_id) == 3:
self.grid[cell_id[0]][cell_id[1]] = self.transitions.set_transitions(self.grid[cell_id[0]][cell_id[1]],
cell_id[2],
new_transitions)
elif len(cell_id) == 2:
self.grid[cell_id[0]][cell_id[1]] = new_transitions
[docs] def get_transition(self, cell_id, transition_index):
"""
Return the status of whether an agent in cell `cell_id` can perform a
movement along transition `transition_index` (e.g., the NESW direction
of movement, for agents on a grid).
Parameters
----------
cell_id : tuple
The cell_id indices a cell as (column, row, orientation),
where orientation is the direction an agent is facing within a cell.
transition_index : int
Index of the transition to probe, as index in the tuple returned by
get_transitions(). e.g., the NESW direction of movement, for agents
on a grid.
Returns
-------
int or float (depending on Transitions used in the )
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
"""
assert len(cell_id) == 3, \
'GridTransitionMap.get_transition() ERROR: cell_id tuple must have length 2 or 3.'
return self.transitions.get_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index)
[docs] def set_transition(self, cell_id, transition_index, new_transition, remove_deadends=False):
"""
Replaces the validity of transition to `transition_index` in cell
`cell_id' with the new `new_transition`.
Parameters
----------
cell_id : tuple
The cell_id indices a cell as (column, row, orientation),
where orientation is the direction an agent is facing within a cell.
transition_index : int
Index of the transition to probe, as index in the tuple returned by
get_transitions(). e.g., the NESW direction of movement, for agents
on a grid.
new_transition : int or float (depending on Transitions used in the map.)
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
"""
assert len(cell_id) == 3, \
'GridTransitionMap.set_transition() ERROR: cell_id tuple must have length 3.'
self.grid[cell_id[0]][cell_id[1]] = self.transitions.set_transition(
self.grid[cell_id[0]][cell_id[1]],
cell_id[2],
transition_index,
new_transition,
remove_deadends)
[docs] def save_transition_map(self, filename):
"""
Save the transitions grid as `filename`, in npy format.
Parameters
----------
filename : string
Name of the file to which to save the transitions grid.
"""
np.save(filename, self.grid)
[docs] def load_transition_map(self, package, resource):
"""
Load the transitions grid from `filename` (npy format).
The load function only updates the transitions grid, and possibly width and height, but the object has to be
initialized with the correct `transitions` object anyway.
Parameters
----------
package : string
Name of the package from which to load the transitions grid.
resource : string
Name of the file from which to load the transitions grid within the package.
override_gridsize : bool
If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size
of the map loaded from `filename`. If override_gridsize=False, the transitions grid is either cropped (if
the grid size is larger than (height,width) ) or padded with zeros (if the grid size is smaller than
(height,width) )
"""
with path(package, resource) as file_in:
new_grid = np.load(file_in)
new_height = new_grid.shape[0]
new_width = new_grid.shape[1]
self.width = new_width
self.height = new_height
self.grid = new_grid
[docs] def is_dead_end(self, rcPos: IntVector2DArray):
"""
Check if the cell is a dead-end.
Parameters
----------
rcPos: Tuple[int,int]
tuple(row, column) with grid coordinate
Returns
-------
boolean
True if and only if the cell is a dead-end.
"""
nbits = 0
tmp = self.get_full_transitions(rcPos[0], rcPos[1])
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
return nbits == 1
[docs] def is_simple_turn(self, rcPos: IntVector2DArray):
"""
Check if the cell is a left/right simple turn
Parameters
----------
rcPos: Tuple[int,int]
tuple(row, column) with grid coordinate
Returns
-------
boolean
True if and only if the cell is a left/right simple turn.
"""
tmp = self.get_full_transitions(rcPos[0], rcPos[1])
def is_simple_turn(trans):
all_simple_turns = OrderedSet()
for trans in [int('0100000000000010', 2), # Case 1b (8) - simple turn right
int('0001001000000000', 2) # Case 1c (9) - simple turn left]:
]:
for _ in range(3):
trans = self.transitions.rotate_transition(trans, rotation=90)
all_simple_turns.add(trans)
return trans in all_simple_turns
return is_simple_turn(tmp)
[docs] def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray):
"""
Breath first search for a possible path from one node with a certain orientation to a target node.
:param start: Start cell rom where we want to check the path
:param direction: Start direction for the path we are testing
:param end: Cell that we try to reach from the start cell
:return: True if a path exists, False otherwise
"""
visited = OrderedSet()
stack = [(start, direction)]
while stack:
node = stack.pop()
node_position = node[0]
node_direction = node[1]
if Vec2d.is_equal(node_position, end):
return True
if node not in visited:
visited.add(node)
moves = self.get_transitions(node_position[0], node_position[1], node_direction)
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node_position, move_index),
move_index))
return False
[docs] def cell_neighbours_valid(self, rcPos: IntVector2DArray, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
- surrounding cells have inbound transitions for all the outbound transitions of this cell.
These are NOT checked - see transition.is_valid:
- all transitions have the mirror transitions (N->E <=> W->S)
- Reverse transitions (N -> S) only exist for a dead-end
- a cell contains either no dead-ends or exactly one
Returns: True (valid) or False (invalid)
"""
cell_transition = self.grid[tuple(rcPos)]
if check_this_cell:
if not self.transitions.is_valid(cell_transition):
return False
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out
lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8
g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1)
gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4)
giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int
# loop over available outbound directions (indices) for rcPos
for iDirOut in giDirOut:
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then this transition is invalid!
if np.any(gPos2 < 0):
return False
if np.any(gPos2 >= grcMax):
return False
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
t4Trans2 = self.get_transitions(*gPos2, iDirOut)
if any(t4Trans2):
continue
else:
return False
# If the cell is empty but has incoming connections we return false
if binTrans < 1:
connected = 0
for iDirOut in np.arange(4):
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then ignore it for the count of incoming connections
if np.any(gPos2 < 0):
continue
if np.any(gPos2 >= grcMax):
continue
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
for orientation in range(4):
connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut))
if connected > 0:
return False
return True
[docs] def fix_neighbours(self, rcPos: IntVector2DArray, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
- surrounding cells have inbound transitions for all the outbound transitions of this cell.
These are NOT checked - see transition.is_valid:
- all transitions have the mirror transitions (N->E <=> W->S)
- Reverse transitions (N -> S) only exist for a dead-end
- a cell contains either no dead-ends or exactly one
Returns: True (valid) or False (invalid)
"""
cell_transition = self.grid[tuple(rcPos)]
if check_this_cell:
if not self.transitions.is_valid(cell_transition):
return False
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out
lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8
g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1)
gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4)
giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int
# loop over available outbound directions (indices) for rcPos
for iDirOut in giDirOut:
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then this transition is invalid!
if np.any(gPos2 < 0):
return False
if np.any(gPos2 >= grcMax):
return False
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
t4Trans2 = self.get_transitions(*gPos2, iDirOut)
if any(t4Trans2):
continue
else:
self.set_transition((gPos2[0], gPos2[1], iDirOut), mirror(iDirOut), 1)
return False
return True
[docs] def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1):
"""
Fixes broken transitions
"""
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
# Transition elements
transitions = RailEnvTransitions()
cells = transitions.transition_list
simple_switch_east_south = transitions.rotate_transition(cells[10], 90)
simple_switch_west_south = transitions.rotate_transition(cells[2], 270)
symmetrical = cells[6]
double_slip = cells[5]
three_way_transitions = [simple_switch_east_south, simple_switch_west_south]
# loop over available outbound directions (indices) for rcPos
incoming_connections = np.zeros(4)
for iDirOut in np.arange(4):
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then ignore it for the count of incoming connections
if np.any(gPos2 < 0):
continue
if np.any(gPos2 >= grcMax):
continue
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
connected = 0
for orientation in range(4):
connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut))
if connected > 0:
incoming_connections[iDirOut] = 1
number_of_incoming = np.sum(incoming_connections)
# Only one incoming direction --> Straight line set deadend
if number_of_incoming == 1:
if self.get_full_transitions(*rcPos) == 0:
self.set_transitions(rcPos, 0)
else:
self.set_transitions(rcPos, 0)
for direction in range(4):
if incoming_connections[direction] > 0:
self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1)
# Connect all incoming connections
if number_of_incoming == 2:
self.set_transitions(rcPos, 0)
connect_directions = np.argwhere(incoming_connections > 0)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
# Find feasible connection for three entries
if number_of_incoming == 3:
self.set_transitions(rcPos, 0)
hole = np.argwhere(incoming_connections < 1)[0][0]
if direction >= 0:
switch_type_idx = (direction - hole + 3) % 4
if switch_type_idx == 0:
transition = simple_switch_west_south
elif switch_type_idx == 2:
transition = simple_switch_east_south
else:
transition = self.random_generator.choice(three_way_transitions, 1)
else:
transition = self.random_generator.choice(three_way_transitions, 1)
transition = transitions.rotate_transition(transition, int(hole * 90))
self.set_transitions((rcPos[0], rcPos[1]), transition)
# Make a double slip switch
if number_of_incoming == 4:
rotation = self.random_generator.randint(2)
transition = transitions.rotate_transition(double_slip, int(rotation * 90))
self.set_transitions((rcPos[0], rcPos[1]), transition)
return True
[docs] def validate_new_transition(self, prev_pos: IntVector2D, current_pos: IntVector2D,
new_pos: IntVector2D, end_pos: IntVector2D):
"""
Utility function to test that a path drawn by a-start algorithm uses valid transition objects.
We us this to quide a-star as there are many transition elements that are not allowed in RailEnv
:param prev_pos: The previous position we were checking
:param current_pos: The current position we are checking
:param new_pos: Possible child position we move into
:param end_pos: End cell of path we are drawing
:return: True if the transition is valid, False if transition element is illegal
"""
# start by getting direction used to get to current node
# and direction from current node to possible child node
new_dir = get_direction(current_pos, new_pos)
if prev_pos is not None:
current_dir = get_direction(prev_pos, current_pos)
else:
current_dir = new_dir
# create new transition that would go to child
new_trans = self.grid[current_pos]
if prev_pos is None:
if new_trans == 0:
# need to flip direction because of how end points are defined
new_trans = self.transitions.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
# check if matches existing layout
new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1)
else:
# set the forward path
new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = self.transitions.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
if Vec2d.is_equal(new_pos, end_pos):
# need to validate end pos setup as well
new_trans_e = self.grid[end_pos]
if new_trans_e == 0:
# need to flip direction because of how end points are defined
new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
# check if matches existing layout
new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, new_dir, 1)
if not self.transitions.is_valid(new_trans_e):
return False
# is transition is valid?
return self.transitions.is_valid(new_trans)
[docs]def mirror(dir):
return (dir + 2) % 4
# TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?)