import hashlib
import json
import logging
import os
import random
import time
import msgpack
import msgpack_numpy as m
import numpy as np
import redis
import flatland
from flatland.envs.malfunction_generators import malfunction_from_file
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
from flatland.evaluators import messages
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
m.patch()
[docs]class FlatlandRemoteClient(object):
"""
Redis client to interface with flatland-rl remote-evaluation-service
The Docker container hosts a redis-server inside the container.
This client connects to the same redis-server,
and communicates with the service.
The service eventually will reside outside the docker container,
and will communicate
with the client only via the redis-server of the docker container.
On the instantiation of the docker container, one service will be
instantiated parallely.
The service will accepts commands at "`service_id`::commands"
where `service_id` is either provided as an `env` variable or is
instantiated to "flatland_rl_redis_service_id"
"""
def __init__(self,
remote_host='127.0.0.1',
remote_port=6379,
remote_db=0,
remote_password=None,
test_envs_root=None,
verbose=False):
self.remote_host = remote_host
self.remote_port = remote_port
self.remote_db = remote_db
self.remote_password = remote_password
self.redis_pool = redis.ConnectionPool(
host=remote_host,
port=remote_port,
db=remote_db,
password=remote_password)
self.redis_conn = redis.Redis(connection_pool=self.redis_pool)
self.namespace = "flatland-rl"
self.service_id = os.getenv(
'FLATLAND_RL_SERVICE_ID',
'FLATLAND_RL_SERVICE_ID'
)
self.command_channel = "{}::{}::commands".format(
self.namespace,
self.service_id
)
if test_envs_root:
self.test_envs_root = test_envs_root
else:
self.test_envs_root = os.getenv(
'AICROWD_TESTS_FOLDER',
'/tmp/flatland_envs'
)
self.current_env_path = None
self.verbose = verbose
self.env = None
self.ping_pong()
self.env_step_times = []
self.stats = {}
[docs] def update_running_mean_stats(self, key, scalar):
"""
Computes the running mean for certain params
"""
mean_key = "{}_mean".format(key)
counter_key = "{}_counter".format(key)
try:
self.stats[mean_key] = \
((self.stats[mean_key] * self.stats[counter_key]) + scalar) / (self.stats[counter_key] + 1)
self.stats[counter_key] += 1
except KeyError:
self.stats[mean_key] = 0
self.stats[counter_key] = 0
[docs] def get_redis_connection(self):
return self.redis_conn
def _generate_response_channel(self):
random_hash = hashlib.md5(
"{}".format(
random.randint(0, 10 ** 10)
).encode('utf-8')).hexdigest()
response_channel = "{}::{}::response::{}".format(self.namespace,
self.service_id,
random_hash)
return response_channel
def _remote_request(self, _request, blocking=True):
"""
request:
-command_type
-payload
-response_channel
response: (on response_channel)
- RESULT
* Send the payload on command_channel (self.namespace+"::command")
** redis-left-push (LPUSH)
* Keep listening on response_channel (BLPOP)
"""
assert isinstance(_request, dict)
_request['response_channel'] = self._generate_response_channel()
_request['timestamp'] = time.time()
_redis = self.get_redis_connection()
"""
The client always pushes in the left
and the service always pushes in the right
"""
if self.verbose:
print("Request : ", _request)
# Push request in command_channels
# Note: The patched msgpack supports numpy arrays
payload = msgpack.packb(_request, default=m.encode, use_bin_type=True)
_redis.lpush(self.command_channel, payload)
if blocking:
# Wait with a blocking pop for the response
_response = _redis.blpop(_request['response_channel'])[1]
if self.verbose:
print("Response : ", _response)
_response = msgpack.unpackb(
_response,
object_hook=m.decode,
encoding="utf8")
if _response['type'] == messages.FLATLAND_RL.ERROR:
raise Exception(str(_response["payload"]))
else:
return _response
[docs] def ping_pong(self):
"""
Official Handshake with the evaluation service
Send a PING
and wait for PONG
If not PONG, raise error
"""
_request = {}
_request['type'] = messages.FLATLAND_RL.PING
_request['payload'] = {
"version": flatland.__version__
}
_response = self._remote_request(_request)
if _response['type'] != messages.FLATLAND_RL.PONG:
raise Exception(
"Unable to perform handshake with the evaluation service. \
Expected PONG; received {}".format(json.dumps(_response)))
else:
return True
[docs] def env_create(self, obs_builder_object):
"""
Create a local env and remote env on which the
local agent can operate.
The observation builder is only used in the local env
and the remote env uses a DummyObservationBuilder
"""
time_start = time.time()
_request = {}
_request['type'] = messages.FLATLAND_RL.ENV_CREATE
_request['payload'] = {}
_response = self._remote_request(_request)
observation = _response['payload']['observation']
info = _response['payload']['info']
random_seed = _response['payload']['random_seed']
test_env_file_path = _response['payload']['env_file_path']
time_diff = time.time() - time_start
self.update_running_mean_stats("env_creation_wait_time", time_diff)
if not observation:
# If the observation is False,
# then the evaluations are complete
# hence return false
return observation, info
if self.verbose:
print("Received Env : ", test_env_file_path)
test_env_file_path = os.path.join(
self.test_envs_root,
test_env_file_path
)
if not os.path.exists(test_env_file_path):
raise Exception(
"\nWe cannot seem to find the env file paths at the required location.\n"
"Did you remember to set the AICROWD_TESTS_FOLDER environment variable "
"to point to the location of the Tests folder ? \n"
"We are currently looking at `{}` for the tests".format(self.test_envs_root)
)
if self.verbose:
print("Current env path : ", test_env_file_path)
self.current_env_path = test_env_file_path
self.env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path),
schedule_generator=schedule_from_file(test_env_file_path),
malfunction_generator_and_process_data=malfunction_from_file(test_env_file_path),
obs_builder_object=obs_builder_object)
time_start = time.time()
local_observation, info = self.env.reset(
regenerate_rail=True,
regenerate_schedule=True,
activate_agents=False,
random_seed=random_seed
)
time_diff = time.time() - time_start
self.update_running_mean_stats("internal_env_reset_time", time_diff)
# Use the local observation
# as the remote server uses a dummy observation builder
return local_observation, info
[docs] def env_step(self, action, render=False):
"""
Respond with [observation, reward, done, info]
"""
_request = {}
_request['type'] = messages.FLATLAND_RL.ENV_STEP
_request['payload'] = {}
_request['payload']['action'] = action
# Relay the action in a non-blocking way to the server
# so that it can start doing an env.step on it in ~ parallel
self._remote_request(_request, blocking=False)
# Apply the action in the local env
time_start = time.time()
local_observation, local_reward, local_done, local_info = \
self.env.step(action)
time_diff = time.time() - time_start
# Compute a running mean of env step times
self.update_running_mean_stats("internal_env_step_time", time_diff)
return [local_observation, local_reward, local_done, local_info]
[docs] def submit(self):
_request = {}
_request['type'] = messages.FLATLAND_RL.ENV_SUBMIT
_request['payload'] = {}
_response = self._remote_request(_request)
######################################################################
# Print Local Stats
######################################################################
print("=" * 100)
print("=" * 100)
print("## Client Performance Stats")
print("=" * 100)
for _key in self.stats:
if _key.endswith("_mean"):
print("\t - {}\t:{}".format(_key, self.stats[_key]))
print("=" * 100)
if os.getenv("AICROWD_BLOCKING_SUBMIT"):
"""
If the submission is supposed to happen as a blocking submit,
then wait indefinitely for the evaluator to decide what to
do with the container.
"""
while True:
time.sleep(10)
return _response['payload']
if __name__ == "__main__":
remote_client = FlatlandRemoteClient()
def my_controller(obs, _env):
_action = {}
for _idx, _ in enumerate(_env.agents):
_action[_idx] = np.random.randint(0, 5)
return _action
my_observation_builder = DummyObservationBuilder()
episode = 0
obs = True
while obs:
obs, info = remote_client.env_create(
obs_builder_object=my_observation_builder
)
if not obs:
"""
The remote env returns False as the first obs
when it is done evaluating all the individual episodes
"""
break
print("Episode : {}".format(episode))
episode += 1
print(remote_client.env.dones['__all__'])
while True:
action = my_controller(obs, remote_client.env)
time_start = time.time()
observation, all_rewards, done, info = remote_client.env_step(action)
time_diff = time.time() - time_start
print("Step Time : ", time_diff)
if done['__all__']:
print("Current Episode : ", episode)
print("Episode Done")
print("Reward : ", sum(list(all_rewards.values())))
break
print("Evaluation Complete...")
print(remote_client.submit())