mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
160 lines
5.6 KiB
Python
Executable file
160 lines
5.6 KiB
Python
Executable file
"""
|
|
Example of running a Unity3D (MLAgents) Policy server that can learn
|
|
Policies via sampling inside many connected Unity game clients (possibly
|
|
running in the cloud on n nodes).
|
|
For a locally running Unity3D example, see:
|
|
`examples/unity3d_env_local.py`
|
|
|
|
To run this script against one or more possibly cloud-based clients:
|
|
1) Install Unity3D and `pip install mlagents`.
|
|
|
|
2) Compile a Unity3D example game with MLAgents support (e.g. 3DBall or any
|
|
other one that you created yourself) and place the compiled binary
|
|
somewhere, where your RLlib client script (see below) can access it.
|
|
|
|
2.1) To find Unity3D MLAgent examples, first `pip install mlagents`,
|
|
then check out the `.../ml-agents/Project/Assets/ML-Agents/Examples/`
|
|
folder.
|
|
|
|
3) Change this RLlib Policy server code so it knows the observation- and
|
|
action Spaces, the different Policies (called "behaviors" in Unity3D
|
|
MLAgents), and Agent-to-Policy mappings for your particular game.
|
|
Alternatively, use one of the two already existing setups (3DBall or
|
|
SoccerStrikersVsGoalie).
|
|
|
|
4) Then run (two separate shells/machines):
|
|
$ python unity3d_server.py --env 3DBall
|
|
$ python unity3d_client.py --inference-mode=local --game [path to game binary]
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
|
|
import ray
|
|
from ray.rllib.agents.registry import get_trainer_class
|
|
from ray.rllib.env.policy_server_input import PolicyServerInput
|
|
from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv
|
|
|
|
SERVER_ADDRESS = "localhost"
|
|
SERVER_PORT = 9900
|
|
CHECKPOINT_FILE = "last_checkpoint_{}.out"
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--run",
|
|
default="PPO",
|
|
choices=["DQN", "PPO"],
|
|
help="The RLlib-registered algorithm to use.")
|
|
parser.add_argument(
|
|
"--framework",
|
|
choices=["tf", "tf2", "tfe", "torch"],
|
|
default="tf",
|
|
help="The DL framework specifier.")
|
|
parser.add_argument(
|
|
"--num-workers",
|
|
type=int,
|
|
default=2,
|
|
help="The number of workers to use. Each worker will create "
|
|
"its own listening socket for incoming experiences.")
|
|
parser.add_argument(
|
|
"--env",
|
|
type=str,
|
|
default="3DBall",
|
|
choices=[
|
|
"3DBall", "3DBallHard", "FoodCollector", "GridFoodCollector",
|
|
"Pyramids", "SoccerStrikersVsGoalie", "Sorter", "Tennis",
|
|
"VisualHallway", "Walker"
|
|
],
|
|
help="The name of the Env to run in the Unity3D editor "
|
|
"(feel free to add more to this script!)")
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=SERVER_PORT,
|
|
help="The Policy server's port to listen on for ExternalEnv client "
|
|
"conections.")
|
|
parser.add_argument(
|
|
"--checkpoint-freq",
|
|
type=int,
|
|
default=10,
|
|
help="The frequency with which to create checkpoint files of the learnt "
|
|
"Policies.")
|
|
parser.add_argument(
|
|
"--no-restore",
|
|
action="store_true",
|
|
help="Whether to load the Policy "
|
|
"weights from a previous checkpoint")
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
ray.init()
|
|
|
|
# `InputReader` generator (returns None if no input reader is needed on
|
|
# the respective worker).
|
|
def _input(ioctx):
|
|
# We are remote worker or we are local worker with num_workers=0:
|
|
# Create a PolicyServerInput.
|
|
if ioctx.worker_index > 0 or ioctx.worker.num_workers == 0:
|
|
return PolicyServerInput(
|
|
ioctx, SERVER_ADDRESS, args.port + ioctx.worker_index -
|
|
(1 if ioctx.worker_index > 0 else 0))
|
|
# No InputReader (PolicyServerInput) needed.
|
|
else:
|
|
return None
|
|
|
|
# Get the multi-agent policies dict and agent->policy
|
|
# mapping-fn.
|
|
policies, policy_mapping_fn = \
|
|
Unity3DEnv.get_policy_configs_for_game(args.env)
|
|
|
|
# The entire config will be sent to connecting clients so they can
|
|
# build their own samplers (and also Policy objects iff
|
|
# `inference_mode=local` on clients' command line).
|
|
config = {
|
|
# Indicate that the Trainer we setup here doesn't need an actual env.
|
|
# Allow spaces to be determined by user (see below).
|
|
"env": None,
|
|
|
|
# Use the `PolicyServerInput` to generate experiences.
|
|
"input": _input,
|
|
# Use n worker processes to listen on different ports.
|
|
"num_workers": args.num_workers,
|
|
# Disable OPE, since the rollouts are coming from online clients.
|
|
"input_evaluation": [],
|
|
|
|
# Other settings.
|
|
"train_batch_size": 256,
|
|
"rollout_fragment_length": 20,
|
|
# Multi-agent setup for the given env.
|
|
"multiagent": {
|
|
"policies": policies,
|
|
"policy_mapping_fn": policy_mapping_fn,
|
|
},
|
|
# DL framework to use.
|
|
"framework": args.framework,
|
|
}
|
|
|
|
# Create the Trainer used for Policy serving.
|
|
trainer = get_trainer_class(args.run)(config=config)
|
|
|
|
# Attempt to restore from checkpoint if possible.
|
|
checkpoint_path = CHECKPOINT_FILE.format(args.env)
|
|
if not args.no_restore and os.path.exists(checkpoint_path):
|
|
checkpoint_path = open(checkpoint_path).read()
|
|
print("Restoring from checkpoint path", checkpoint_path)
|
|
trainer.restore(checkpoint_path)
|
|
|
|
# Serving and training loop.
|
|
count = 0
|
|
while True:
|
|
# Calls to train() will block on the configured `input` in the Trainer
|
|
# config above (PolicyServerInput).
|
|
print(trainer.train())
|
|
if count % args.checkpoint_freq == 0:
|
|
print("Saving learning progress to checkpoint file.")
|
|
checkpoint = trainer.save()
|
|
# Write the latest checkpoint location to CHECKPOINT_FILE,
|
|
# so we can pick up from the latest one after a server re-start.
|
|
with open(checkpoint_path, "w") as f:
|
|
f.write(checkpoint)
|
|
count += 1
|