ray/rllib/examples/serving/cartpole_server.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

66 lines
2.1 KiB
Python
Executable file

"""Example of running a policy server. Copy this file for your use case.
To try this out, in two separate shells run:
$ python cartpole_server.py
$ python cartpole_client.py
"""
import os
from gym import spaces
import numpy as np
import ray
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.utils.policy_server import PolicyServer
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env
SERVER_ADDRESS = "localhost"
SERVER_PORT = 9900
CHECKPOINT_FILE = "last_checkpoint.out"
class CartpoleServing(ExternalEnv):
def __init__(self):
ExternalEnv.__init__(
self, spaces.Discrete(2),
spaces.Box(low=-10, high=10, shape=(4, ), dtype=np.float32))
def run(self):
print("Starting policy server at {}:{}".format(SERVER_ADDRESS,
SERVER_PORT))
server = PolicyServer(self, SERVER_ADDRESS, SERVER_PORT)
server.serve_forever()
if __name__ == "__main__":
ray.init()
register_env("srv", lambda _: CartpoleServing())
# We use DQN since it supports off-policy actions, but you can choose and
# configure any agent.
dqn = DQNTrainer(
env="srv",
config={
# Use a single process to avoid needing to set up a load balancer
"num_workers": 0,
# Configure the agent to run short iterations for debugging
"exploration_fraction": 0.01,
"learning_starts": 100,
"timesteps_per_iteration": 200,
})
# Attempt to restore from checkpoint if possible.
if os.path.exists(CHECKPOINT_FILE):
checkpoint_path = open(CHECKPOINT_FILE).read()
print("Restoring from checkpoint path", checkpoint_path)
dqn.restore(checkpoint_path)
# Serving and training loop
while True:
print(pretty_print(dqn.train()))
checkpoint_path = dqn.save()
print("Last checkpoint", checkpoint_path)
with open(CHECKPOINT_FILE, "w") as f:
f.write(checkpoint_path)