mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
56 lines
1.6 KiB
Python
56 lines
1.6 KiB
Python
"""Example of a custom experiment wrapped around an RLlib trainer."""
|
|
import argparse
|
|
|
|
import ray
|
|
from ray import tune
|
|
from ray.rllib.agents import ppo
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--train-iterations", type=int, default=10)
|
|
|
|
|
|
def experiment(config):
|
|
iterations = config.pop("train-iterations")
|
|
train_agent = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
|
checkpoint = None
|
|
train_results = {}
|
|
|
|
# Train
|
|
for i in range(iterations):
|
|
train_results = train_agent.train()
|
|
if i % 2 == 0 or i == iterations - 1:
|
|
checkpoint = train_agent.save(tune.get_trial_dir())
|
|
tune.report(**train_results)
|
|
train_agent.stop()
|
|
|
|
# Manual Eval
|
|
config["num_workers"] = 0
|
|
eval_agent = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
|
eval_agent.restore(checkpoint)
|
|
env = eval_agent.workers.local_worker().env
|
|
|
|
obs = env.reset()
|
|
done = False
|
|
eval_results = {"eval_reward": 0, "eval_eps_length": 0}
|
|
while not done:
|
|
action = eval_agent.compute_action(obs)
|
|
next_obs, reward, done, info = env.step(action)
|
|
eval_results["eval_reward"] += reward
|
|
eval_results["eval_eps_length"] += 1
|
|
results = {**train_results, **eval_results}
|
|
tune.report(results)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
|
|
ray.init(num_cpus=3)
|
|
config = ppo.DEFAULT_CONFIG.copy()
|
|
config["train-iterations"] = args.train_iterations
|
|
|
|
config["env"] = "CartPole-v0"
|
|
|
|
tune.run(
|
|
experiment,
|
|
config=config,
|
|
resources_per_trial=ppo.PPOTrainer.default_resource_request(config))
|