mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Add a simple REST policy server and client example (#2232)
* wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * policy serve * spaces * checkpoint * no train * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * fix race condition * update * com * updat * add test * Update run_multi_node_tests.sh * use curl * curl * kill * Update run_multi_node_tests.sh * Update run_multi_node_tests.sh * fix import * update
This commit is contained in:
parent
418cd6804a
commit
e5724a9cfe
15 changed files with 384 additions and 78 deletions
|
@ -1,5 +0,0 @@
|
|||
# flake8: noqa
|
||||
from ray.rllib.examples.multiagent_mountaincar_env \
|
||||
import MultiAgentMountainCarEnv
|
||||
from ray.rllib.examples.multiagent_pendulum_env \
|
||||
import MultiAgentPendulumEnv
|
0
python/ray/rllib/examples/legacy_multiagent/__init__.py
Normal file
0
python/ray/rllib/examples/legacy_multiagent/__init__.py
Normal file
|
@ -21,7 +21,9 @@ def pass_params_to_gym(env_name):
|
|||
|
||||
register(
|
||||
id=env_name,
|
||||
entry_point='ray.rllib.examples:' + "MultiAgentMountainCarEnv",
|
||||
entry_point=(
|
||||
"ray.rllib.examples.legacy_multiagent.multiagent_mountaincar_env:"
|
||||
"MultiAgentMountainCarEnv"),
|
||||
max_episode_steps=200,
|
||||
kwargs={}
|
||||
)
|
|
@ -21,7 +21,9 @@ def pass_params_to_gym(env_name):
|
|||
|
||||
register(
|
||||
id=env_name,
|
||||
entry_point='ray.rllib.examples:' + "MultiAgentPendulumEnv",
|
||||
entry_point=(
|
||||
"ray.rllib.examples.legacy_multiagent.multiagent_pendulum_env:"
|
||||
"MultiAgentPendulumEnv"),
|
||||
max_episode_steps=100,
|
||||
kwargs={}
|
||||
)
|
55
python/ray/rllib/examples/serving/cartpole_client.py
Executable file
55
python/ray/rllib/examples/serving/cartpole_client.py
Executable file
|
@ -0,0 +1,55 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
"""Example of querying a policy server. Copy this file for your use case.
|
||||
|
||||
To try this out, in two separate shells run:
|
||||
$ python cartpole_server.py
|
||||
$ python cartpole_client.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gym
|
||||
|
||||
from ray.rllib.utils.policy_client import PolicyClient
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--no-train", action="store_true", help="Whether to disable training.")
|
||||
parser.add_argument(
|
||||
"--off-policy", action="store_true",
|
||||
help="Whether to take random instead of on-policy actions.")
|
||||
parser.add_argument(
|
||||
"--stop-at-reward", type=int, default=9999,
|
||||
help="Stop once the specified reward is reached.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
env = gym.make("CartPole-v0")
|
||||
client = PolicyClient("http://localhost:8900")
|
||||
|
||||
eid = client.start_episode(training_enabled=not args.no_train)
|
||||
obs = env.reset()
|
||||
rewards = 0
|
||||
|
||||
while True:
|
||||
if args.off_policy:
|
||||
action = env.action_space.sample()
|
||||
client.log_action(eid, obs, action)
|
||||
else:
|
||||
action = client.get_action(eid, obs)
|
||||
obs, reward, done, info = env.step(action)
|
||||
rewards += reward
|
||||
client.log_returns(eid, reward, info=info)
|
||||
if done:
|
||||
print("Total reward:", rewards)
|
||||
if rewards >= args.stop_at_reward:
|
||||
print("Target reward achieved, exiting")
|
||||
exit(0)
|
||||
rewards = 0
|
||||
client.end_episode(eid, obs)
|
||||
obs = env.reset()
|
||||
eid = client.start_episode(training_enabled=not args.no_train)
|
66
python/ray/rllib/examples/serving/cartpole_server.py
Executable file
66
python/ray/rllib/examples/serving/cartpole_server.py
Executable file
|
@ -0,0 +1,66 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
"""Example of running a policy server. Copy this file for your use case.
|
||||
|
||||
To try this out, in two separate shells run:
|
||||
$ python cartpole_server.py
|
||||
$ python cartpole_client.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from gym import spaces
|
||||
|
||||
import ray
|
||||
from ray.rllib.dqn import DQNAgent
|
||||
from ray.rllib.utils.serving_env import ServingEnv
|
||||
from ray.rllib.utils.policy_server import PolicyServer
|
||||
from ray.tune.logger import pretty_print
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
SERVER_ADDRESS = "localhost"
|
||||
SERVER_PORT = 8900
|
||||
CHECKPOINT_FILE = "last_checkpoint.out"
|
||||
|
||||
|
||||
class CartpoleServing(ServingEnv):
|
||||
def __init__(self):
|
||||
ServingEnv.__init__(
|
||||
self, spaces.Discrete(2), spaces.Box(low=-10, high=10, shape=(4,)))
|
||||
|
||||
def run(self):
|
||||
print("Starting policy server at {}:{}".format(
|
||||
SERVER_ADDRESS, SERVER_PORT))
|
||||
server = PolicyServer(self, SERVER_ADDRESS, SERVER_PORT)
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
register_env("srv", lambda _: CartpoleServing())
|
||||
|
||||
# We use DQN since it supports off-policy actions, but you can choose and
|
||||
# configure any agent.
|
||||
dqn = DQNAgent(env="srv", config={
|
||||
# Use a single process to avoid needing to set up a load balancer
|
||||
"num_workers": 0,
|
||||
# Configure the agent to run short iterations for debugging
|
||||
"exploration_fraction": 0.01,
|
||||
"learning_starts": 100,
|
||||
"timesteps_per_iteration": 200,
|
||||
})
|
||||
|
||||
# Attempt to restore from checkpoint if possible.
|
||||
if os.path.exists(CHECKPOINT_FILE):
|
||||
checkpoint_path = open(CHECKPOINT_FILE).read()
|
||||
print("Restoring from checkpoint path", checkpoint_path)
|
||||
dqn.restore(checkpoint_path)
|
||||
|
||||
# Serving and training loop
|
||||
while True:
|
||||
print(pretty_print(dqn.train()))
|
||||
checkpoint_path = dqn.save()
|
||||
print("Last checkpoint", checkpoint_path)
|
||||
with open(CHECKPOINT_FILE, "w") as f:
|
||||
f.write(checkpoint_path)
|
12
python/ray/rllib/examples/serving/test.sh
Executable file
12
python/ray/rllib/examples/serving/test.sh
Executable file
|
@ -0,0 +1,12 @@
|
|||
#!/bin/bash
|
||||
|
||||
pkill -f cartpole_server.py
|
||||
(python cartpole_server.py 2>&1 | grep -v 200) &
|
||||
pid=$!
|
||||
|
||||
while ! curl localhost:8900; do
|
||||
sleep 1
|
||||
done
|
||||
|
||||
python cartpole_client.py --stop-at-reward=100
|
||||
kill $pid
|
|
@ -24,16 +24,16 @@ class SimpleServing(ServingEnv):
|
|||
self.env = env
|
||||
|
||||
def run(self):
|
||||
self.start_episode()
|
||||
eid = self.start_episode()
|
||||
obs = self.env.reset()
|
||||
while True:
|
||||
action = self.get_action(obs)
|
||||
action = self.get_action(eid, obs)
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.log_returns(reward, info=info)
|
||||
self.log_returns(eid, reward, info=info)
|
||||
if done:
|
||||
self.end_episode(obs)
|
||||
self.end_episode(eid, obs)
|
||||
obs = self.env.reset()
|
||||
self.start_episode()
|
||||
eid = self.start_episode()
|
||||
|
||||
|
||||
class PartOffPolicyServing(ServingEnv):
|
||||
|
@ -43,20 +43,20 @@ class PartOffPolicyServing(ServingEnv):
|
|||
self.off_pol_frac = off_pol_frac
|
||||
|
||||
def run(self):
|
||||
self.start_episode()
|
||||
eid = self.start_episode()
|
||||
obs = self.env.reset()
|
||||
while True:
|
||||
if random.random() < self.off_pol_frac:
|
||||
action = self.env.action_space.sample()
|
||||
self.log_action(obs, action)
|
||||
self.log_action(eid, obs, action)
|
||||
else:
|
||||
action = self.get_action(obs)
|
||||
action = self.get_action(eid, obs)
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.log_returns(reward, info=info)
|
||||
self.log_returns(eid, reward, info=info)
|
||||
if done:
|
||||
self.end_episode(obs)
|
||||
self.end_episode(eid, obs)
|
||||
obs = self.env.reset()
|
||||
self.start_episode()
|
||||
eid = self.start_episode()
|
||||
|
||||
|
||||
class SimpleOffPolicyServing(ServingEnv):
|
||||
|
@ -65,18 +65,18 @@ class SimpleOffPolicyServing(ServingEnv):
|
|||
self.env = env
|
||||
|
||||
def run(self):
|
||||
self.start_episode()
|
||||
eid = self.start_episode()
|
||||
obs = self.env.reset()
|
||||
while True:
|
||||
# Take random actions
|
||||
action = self.env.action_space.sample()
|
||||
self.log_action(obs, action)
|
||||
self.log_action(eid, obs, action)
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.log_returns(reward, info=info)
|
||||
self.log_returns(eid, reward, info=info)
|
||||
if done:
|
||||
self.end_episode(obs)
|
||||
self.end_episode(eid, obs)
|
||||
obs = self.env.reset()
|
||||
self.start_episode()
|
||||
eid = self.start_episode()
|
||||
|
||||
|
||||
class MultiServing(ServingEnv):
|
||||
|
@ -98,14 +98,13 @@ class MultiServing(ServingEnv):
|
|||
self.start_episode(episode_id=eids[i])
|
||||
cur_obs[i] = envs[i].reset()
|
||||
actions = [
|
||||
self.get_action(
|
||||
cur_obs[i], episode_id=eids[i]) for i in active]
|
||||
self.get_action(eids[i], cur_obs[i]) for i in active]
|
||||
for i, action in zip(active, actions):
|
||||
obs, reward, done, _ = envs[i].step(action)
|
||||
cur_obs[i] = obs
|
||||
self.log_returns(reward, episode_id=eids[i])
|
||||
self.log_returns(eids[i], reward)
|
||||
if done:
|
||||
self.end_episode(obs, episode_id=eids[i])
|
||||
self.end_episode(eids[i], obs)
|
||||
del cur_obs[i]
|
||||
|
||||
|
||||
|
|
116
python/ray/rllib/utils/policy_client.py
Normal file
116
python/ray/rllib/utils/policy_client.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import pickle
|
||||
|
||||
try:
|
||||
import requests # `requests` is not part of stdlib.
|
||||
except ImportError:
|
||||
requests = None
|
||||
print("Couldn't import `requests` library. Be sure to install it on"
|
||||
" the client side.")
|
||||
|
||||
|
||||
class PolicyClient(object):
|
||||
"""Client to interact with a RLlib policy server."""
|
||||
|
||||
START_EPISODE = "START_EPISODE"
|
||||
GET_ACTION = "GET_ACTION"
|
||||
LOG_ACTION = "LOG_ACTION"
|
||||
LOG_RETURNS = "LOG_RETURNS"
|
||||
END_EPISODE = "END_EPISODE"
|
||||
|
||||
def __init__(self, address):
|
||||
self._address = address
|
||||
|
||||
def start_episode(self, episode_id=None, training_enabled=True):
|
||||
"""Record the start of an episode.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Unique string id for the episode or None for
|
||||
it to be auto-assigned.
|
||||
training_enabled (bool): Whether to use experiences for this
|
||||
episode to improve the policy.
|
||||
|
||||
Returns:
|
||||
episode_id (str): Unique string id for the episode.
|
||||
"""
|
||||
|
||||
return self._send({
|
||||
"episode_id": episode_id,
|
||||
"command": PolicyClient.START_EPISODE,
|
||||
"training_enabled": training_enabled,
|
||||
})["episode_id"]
|
||||
|
||||
def get_action(self, episode_id, observation):
|
||||
"""Record an observation and get the on-policy action.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
observation (obj): Current environment observation.
|
||||
|
||||
Returns:
|
||||
action (obj): Action from the env action space.
|
||||
"""
|
||||
return self._send({
|
||||
"command": PolicyClient.GET_ACTION,
|
||||
"observation": observation,
|
||||
"episode_id": episode_id,
|
||||
})["action"]
|
||||
|
||||
def log_action(self, episode_id, observation, action):
|
||||
"""Record an observation and (off-policy) action taken.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
observation (obj): Current environment observation.
|
||||
action (obj): Action for the observation.
|
||||
"""
|
||||
self._send({
|
||||
"command": PolicyClient.LOG_ACTION,
|
||||
"observation": observation,
|
||||
"action": action,
|
||||
"episode_id": episode_id,
|
||||
})
|
||||
|
||||
def log_returns(self, episode_id, reward, info=None):
|
||||
"""Record returns from the environment.
|
||||
|
||||
The reward will be attributed to the previous action taken by the
|
||||
episode. Rewards accumulate until the next action. If no reward is
|
||||
logged before the next action, a reward of 0.0 is assumed.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
reward (float): Reward from the environment.
|
||||
"""
|
||||
self._send({
|
||||
"command": PolicyClient.LOG_RETURNS,
|
||||
"reward": reward,
|
||||
"info": info,
|
||||
"episode_id": episode_id,
|
||||
})
|
||||
|
||||
def end_episode(self, episode_id, observation):
|
||||
"""Record the end of an episode.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
observation (obj): Current environment observation.
|
||||
"""
|
||||
self._send({
|
||||
"command": PolicyClient.END_EPISODE,
|
||||
"observation": observation,
|
||||
"episode_id": episode_id,
|
||||
})
|
||||
|
||||
def _send(self, data):
|
||||
payload = pickle.dumps(data)
|
||||
response = requests.post(self._address, data=payload)
|
||||
if response.status_code != 200:
|
||||
print("Request failed", data)
|
||||
print(response.text)
|
||||
response.raise_for_status()
|
||||
parsed = pickle.loads(response.content)
|
||||
return parsed
|
62
python/ray/rllib/utils/policy_server.py
Normal file
62
python/ray/rllib/utils/policy_server.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import pickle
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from ray.rllib.utils.policy_client import PolicyClient
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
from SimpleHTTPServer import SimpleHTTPRequestHandler
|
||||
from SocketServer import TCPServer as HTTPServer
|
||||
from SocketServer import ThreadingMixIn
|
||||
elif sys.version_info[0] == 3:
|
||||
from http.server import SimpleHTTPRequestHandler, HTTPServer
|
||||
from socketserver import ThreadingMixIn
|
||||
|
||||
|
||||
class PolicyServer(ThreadingMixIn, HTTPServer):
|
||||
def __init__(self, serving_env, address, port):
|
||||
handler = _make_handler(serving_env)
|
||||
HTTPServer.__init__(self, (address, port), handler)
|
||||
|
||||
|
||||
def _make_handler(serving_env):
|
||||
class Handler(SimpleHTTPRequestHandler):
|
||||
def do_POST(self):
|
||||
content_len = int(self.headers.get('Content-Length'), 0)
|
||||
raw_body = self.rfile.read(content_len)
|
||||
parsed_input = pickle.loads(raw_body)
|
||||
try:
|
||||
response = self.execute_command(parsed_input)
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
self.wfile.write(pickle.dumps(response))
|
||||
except Exception:
|
||||
self.send_error(500, traceback.format_exc())
|
||||
|
||||
def execute_command(self, args):
|
||||
command = args["command"]
|
||||
response = {}
|
||||
if command == PolicyClient.START_EPISODE:
|
||||
response["episode_id"] = serving_env.start_episode(
|
||||
args["episode_id"], args["training_enabled"])
|
||||
elif command == PolicyClient.GET_ACTION:
|
||||
response["action"] = serving_env.get_action(
|
||||
args["episode_id"], args["observation"])
|
||||
elif command == PolicyClient.LOG_ACTION:
|
||||
serving_env.log_action(
|
||||
args["episode_id"], args["observation"], args["action"])
|
||||
elif command == PolicyClient.LOG_RETURNS:
|
||||
serving_env.log_returns(
|
||||
args["episode_id"], args["reward"], args["info"])
|
||||
elif command == PolicyClient.END_EPISODE:
|
||||
serving_env.end_episode(
|
||||
args["episode_id"], args["observation"])
|
||||
else:
|
||||
raise Exception("Unknown command: {}".format(command))
|
||||
return response
|
||||
|
||||
return Handler
|
|
@ -188,7 +188,7 @@ def _env_runner(
|
|||
|
||||
while True:
|
||||
# Get observations from ready envs
|
||||
unfiltered_obs, rewards, dones, _, off_policy_actions = \
|
||||
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
|
||||
async_vector_env.poll()
|
||||
ready_eids = []
|
||||
ready_obs = []
|
||||
|
@ -216,24 +216,25 @@ def _env_runner(
|
|||
else:
|
||||
done = False
|
||||
|
||||
episode.batch_builder.add_values(
|
||||
obs=episode.last_observation,
|
||||
actions=episode.last_action_flat(),
|
||||
rewards=rewards[eid],
|
||||
dones=done,
|
||||
new_obs=filtered_obs,
|
||||
**episode.last_pi_info)
|
||||
if infos[eid].get("training_enabled", True):
|
||||
episode.batch_builder.add_values(
|
||||
obs=episode.last_observation,
|
||||
actions=episode.last_action_flat(),
|
||||
rewards=rewards[eid],
|
||||
dones=done,
|
||||
new_obs=filtered_obs,
|
||||
**episode.last_pi_info)
|
||||
|
||||
# Cut the batch if we're not packing multiple episodes into one,
|
||||
# or if we've exceeded the requested batch size.
|
||||
if (done and not pack) or \
|
||||
episode.batch_builder.count >= num_local_steps:
|
||||
yield episode.batch_builder.build_and_reset(
|
||||
policy.postprocess_trajectory)
|
||||
elif done:
|
||||
# Make sure postprocessor never goes across episode boundaries
|
||||
episode.batch_builder.postprocess_batch_so_far(
|
||||
policy.postprocess_trajectory)
|
||||
# Cut the batch if we're not packing multiple episodes into
|
||||
# one, or if we've exceeded the requested batch size.
|
||||
if (done and not pack) or \
|
||||
episode.batch_builder.count >= num_local_steps:
|
||||
yield episode.batch_builder.build_and_reset(
|
||||
policy.postprocess_trajectory)
|
||||
elif done:
|
||||
# Make sure postprocessor never crosses episode boundaries
|
||||
episode.batch_builder.postprocess_batch_so_far(
|
||||
policy.postprocess_trajectory)
|
||||
|
||||
if done:
|
||||
# Handle episode termination
|
||||
|
|
|
@ -4,6 +4,7 @@ from __future__ import print_function
|
|||
|
||||
from six.moves import queue
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
|
||||
|
||||
|
@ -26,8 +27,6 @@ class ServingEnv(threading.Thread):
|
|||
|
||||
This env is thread-safe, but individual episodes must be executed serially.
|
||||
|
||||
TODO: Provide a HTTP server/client example based on ServingEnv.
|
||||
|
||||
Examples:
|
||||
>>> register_env("my_env", lambda config: YourServingEnv(config))
|
||||
>>> agent = DQNAgent(env="my_env")
|
||||
|
@ -51,8 +50,6 @@ class ServingEnv(threading.Thread):
|
|||
self.observation_space = observation_space
|
||||
self._episodes = {}
|
||||
self._finished = set()
|
||||
self._num_episodes = 0
|
||||
self._cur_default_episode_id = None
|
||||
self._results_avail_condition = threading.Condition()
|
||||
self._max_concurrent_episodes = max_concurrent
|
||||
|
||||
|
@ -70,24 +67,21 @@ class ServingEnv(threading.Thread):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def start_episode(self, episode_id=None):
|
||||
def start_episode(self, episode_id=None, training_enabled=True):
|
||||
"""Record the start of an episode.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Unique string id for the episode or None for
|
||||
it to be auto-assigned. Auto-assignment only works if there
|
||||
is at most one active episode at a time.
|
||||
it to be auto-assigned.
|
||||
training_enabled (bool): Whether to use experiences for this
|
||||
episode to improve the policy.
|
||||
|
||||
Returns:
|
||||
episode_id (str): Unique string id for the episode.
|
||||
"""
|
||||
|
||||
if episode_id is None:
|
||||
if self._cur_default_episode_id:
|
||||
raise ValueError(
|
||||
"An existing episode is still active. You must pass "
|
||||
"`episode_id` if there are going to be multiple active "
|
||||
"episodes at once.")
|
||||
episode_id = "default_{}".format(self._num_episodes)
|
||||
self._cur_default_episode_id = episode_id
|
||||
self._num_episodes += 1
|
||||
episode_id = uuid.uuid4().hex
|
||||
|
||||
if episode_id in self._finished:
|
||||
raise ValueError(
|
||||
|
@ -98,14 +92,16 @@ class ServingEnv(threading.Thread):
|
|||
"Episode {} is already started".format(episode_id))
|
||||
|
||||
self._episodes[episode_id] = _Episode(
|
||||
episode_id, self._results_avail_condition)
|
||||
episode_id, self._results_avail_condition, training_enabled)
|
||||
|
||||
def get_action(self, observation, episode_id=None):
|
||||
return episode_id
|
||||
|
||||
def get_action(self, episode_id, observation):
|
||||
"""Record an observation and get the on-policy action.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
observation (obj): Current environment observation.
|
||||
episode_id (str): Episode id passed to start_episode() or None.
|
||||
|
||||
Returns:
|
||||
action (obj): Action from the env action space.
|
||||
|
@ -114,19 +110,19 @@ class ServingEnv(threading.Thread):
|
|||
episode = self._get(episode_id)
|
||||
return episode.wait_for_action(observation)
|
||||
|
||||
def log_action(self, observation, action, episode_id=None):
|
||||
def log_action(self, episode_id, observation, action):
|
||||
"""Record an observation and (off-policy) action taken.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
observation (obj): Current environment observation.
|
||||
action (obj): Action for the observation.
|
||||
episode_id (str): Episode id passed to start_episode() or None.
|
||||
"""
|
||||
|
||||
episode = self._get(episode_id)
|
||||
episode.log_action(observation, action)
|
||||
|
||||
def log_returns(self, reward, info=None, episode_id=None):
|
||||
def log_returns(self, episode_id, reward, info=None):
|
||||
"""Record returns from the environment.
|
||||
|
||||
The reward will be attributed to the previous action taken by the
|
||||
|
@ -134,34 +130,31 @@ class ServingEnv(threading.Thread):
|
|||
logged before the next action, a reward of 0.0 is assumed.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id passed to start_episode() or None.
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
reward (float): Reward from the environment.
|
||||
info (dict): Optional info dict.
|
||||
"""
|
||||
|
||||
episode = self._get(episode_id)
|
||||
episode.cur_reward += reward
|
||||
if info:
|
||||
episode.cur_info = info
|
||||
episode.cur_info = info or {}
|
||||
|
||||
def end_episode(self, observation, episode_id=None):
|
||||
def end_episode(self, episode_id, observation):
|
||||
"""Record the end of an episode.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id passed by start_episode() or None.
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
observation (obj): Current environment observation.
|
||||
"""
|
||||
|
||||
episode = self._get(episode_id)
|
||||
self._finished.add(episode.episode_id)
|
||||
self._cur_default_episode_id = None
|
||||
episode.done(observation)
|
||||
|
||||
def _get(self, episode_id=None):
|
||||
def _get(self, episode_id):
|
||||
"""Get a started episode or raise an error."""
|
||||
|
||||
if episode_id is None:
|
||||
episode_id = self._cur_default_episode_id
|
||||
|
||||
if episode_id in self._finished:
|
||||
raise ValueError(
|
||||
"Episode {} has already completed.".format(episode_id))
|
||||
|
@ -217,9 +210,10 @@ class _ServingEnvToAsync(AsyncVectorEnv):
|
|||
class _Episode(object):
|
||||
"""Tracked state for each active episode."""
|
||||
|
||||
def __init__(self, episode_id, results_avail_condition):
|
||||
def __init__(self, episode_id, results_avail_condition, training_enabled):
|
||||
self.episode_id = episode_id
|
||||
self.results_avail_condition = results_avail_condition
|
||||
self.training_enabled = training_enabled
|
||||
self.data_queue = queue.Queue()
|
||||
self.action_queue = queue.Queue()
|
||||
self.new_observation = None
|
||||
|
@ -258,6 +252,8 @@ class _Episode(object):
|
|||
}
|
||||
if self.new_action is not None:
|
||||
item["off_policy_action"] = self.new_action
|
||||
if not self.training_enabled:
|
||||
item["info"]["training_enabled"] = False
|
||||
self.new_observation = None
|
||||
self.new_action = None
|
||||
self.cur_reward = 0.0
|
||||
|
|
|
@ -253,7 +253,7 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
|||
--smoke-test
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/examples/multiagent_mountaincar.py
|
||||
python /ray/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/examples/multiagent_pendulum.py
|
||||
python /ray/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py
|
||||
|
|
Loading…
Add table
Reference in a new issue