ray/rllib/examples/serving/unity3d_server.py

160 lines
5.6 KiB
Python
Executable file

"""
Example of running a Unity3D (MLAgents) Policy server that can learn
Policies via sampling inside many connected Unity game clients (possibly
running in the cloud on n nodes).
For a locally running Unity3D example, see:
`examples/unity3d_env_local.py`
To run this script against one or more possibly cloud-based clients:
1) Install Unity3D and `pip install mlagents`.
2) Compile a Unity3D example game with MLAgents support (e.g. 3DBall or any
other one that you created yourself) and place the compiled binary
somewhere, where your RLlib client script (see below) can access it.
2.1) To find Unity3D MLAgent examples, first `pip install mlagents`,
then check out the `.../ml-agents/Project/Assets/ML-Agents/Examples/`
folder.
3) Change this RLlib Policy server code so it knows the observation- and
action Spaces, the different Policies (called "behaviors" in Unity3D
MLAgents), and Agent-to-Policy mappings for your particular game.
Alternatively, use one of the two already existing setups (3DBall or
SoccerStrikersVsGoalie).
4) Then run (two separate shells/machines):
$ python unity3d_server.py --env 3DBall
$ python unity3d_client.py --inference-mode=local --game [path to game binary]
"""
import argparse
import os
import ray
from ray.rllib.agents.registry import get_trainer_class
from ray.rllib.env.policy_server_input import PolicyServerInput
from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv
SERVER_ADDRESS = "localhost"
SERVER_PORT = 9900
CHECKPOINT_FILE = "last_checkpoint_{}.out"
parser = argparse.ArgumentParser()
parser.add_argument(
"--run",
default="PPO",
choices=["DQN", "PPO"],
help="The RLlib-registered algorithm to use.")
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.")
parser.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of workers to use. Each worker will create "
"its own listening socket for incoming experiences.")
parser.add_argument(
"--env",
type=str,
default="3DBall",
choices=[
"3DBall", "3DBallHard", "FoodCollector", "GridFoodCollector",
"Pyramids", "SoccerStrikersVsGoalie", "Sorter", "Tennis",
"VisualHallway", "Walker"
],
help="The name of the Env to run in the Unity3D editor "
"(feel free to add more to this script!)")
parser.add_argument(
"--port",
type=int,
default=SERVER_PORT,
help="The Policy server's port to listen on for ExternalEnv client "
"conections.")
parser.add_argument(
"--checkpoint-freq",
type=int,
default=10,
help="The frequency with which to create checkpoint files of the learnt "
"Policies.")
parser.add_argument(
"--no-restore",
action="store_true",
help="Whether to load the Policy "
"weights from a previous checkpoint")
if __name__ == "__main__":
args = parser.parse_args()
ray.init()
# `InputReader` generator (returns None if no input reader is needed on
# the respective worker).
def _input(ioctx):
# We are remote worker or we are local worker with num_workers=0:
# Create a PolicyServerInput.
if ioctx.worker_index > 0 or ioctx.worker.num_workers == 0:
return PolicyServerInput(
ioctx, SERVER_ADDRESS, args.port + ioctx.worker_index -
(1 if ioctx.worker_index > 0 else 0))
# No InputReader (PolicyServerInput) needed.
else:
return None
# Get the multi-agent policies dict and agent->policy
# mapping-fn.
policies, policy_mapping_fn = \
Unity3DEnv.get_policy_configs_for_game(args.env)
# The entire config will be sent to connecting clients so they can
# build their own samplers (and also Policy objects iff
# `inference_mode=local` on clients' command line).
config = {
# Indicate that the Trainer we setup here doesn't need an actual env.
# Allow spaces to be determined by user (see below).
"env": None,
# Use the `PolicyServerInput` to generate experiences.
"input": _input,
# Use n worker processes to listen on different ports.
"num_workers": args.num_workers,
# Disable OPE, since the rollouts are coming from online clients.
"input_evaluation": [],
# Other settings.
"train_batch_size": 256,
"rollout_fragment_length": 20,
# Multi-agent setup for the given env.
"multiagent": {
"policies": policies,
"policy_mapping_fn": policy_mapping_fn,
},
# DL framework to use.
"framework": args.framework,
}
# Create the Trainer used for Policy serving.
trainer = get_trainer_class(args.run)(config=config)
# Attempt to restore from checkpoint if possible.
checkpoint_path = CHECKPOINT_FILE.format(args.env)
if not args.no_restore and os.path.exists(checkpoint_path):
checkpoint_path = open(checkpoint_path).read()
print("Restoring from checkpoint path", checkpoint_path)
trainer.restore(checkpoint_path)
# Serving and training loop.
count = 0
while True:
# Calls to train() will block on the configured `input` in the Trainer
# config above (PolicyServerInput).
print(trainer.train())
if count % args.checkpoint_freq == 0:
print("Saving learning progress to checkpoint file.")
checkpoint = trainer.save()
# Write the latest checkpoint location to CHECKPOINT_FILE,
# so we can pick up from the latest one after a server re-start.
with open(checkpoint_path, "w") as f:
f.write(checkpoint)
count += 1