ray/rllib/examples/rnnsac_stateless_cartpole.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

127 lines
3.4 KiB
Python
Raw Normal View History

import json
import os
from pathlib import Path
import ray
from ray import tune
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
config = {
"name": "RNNSAC_example",
"local_dir": str(Path(__file__).parent / "example_out"),
2022-05-23 08:18:44 +02:00
"checkpoint_at_end": True,
"keep_checkpoints_num": 1,
"checkpoint_score_attr": "episode_reward_mean",
"stop": {
"episode_reward_mean": 65.0,
"timesteps_total": 50000,
},
"metric": "episode_reward_mean",
"mode": "max",
"verbose": 2,
"config": {
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": "torch",
"num_workers": 4,
"num_envs_per_worker": 1,
"num_cpus_per_worker": 1,
"log_level": "INFO",
2022-05-23 08:18:44 +02:00
"env": StatelessCartPole,
"horizon": 1000,
"gamma": 0.95,
"batch_mode": "complete_episodes",
"replay_buffer_config": {
"type": "MultiAgentReplayBuffer",
"storage_unit": "sequences",
"capacity": 100000,
"learning_starts": 1000,
"replay_burn_in": 4,
},
"train_batch_size": 480,
"target_network_update_freq": 480,
"tau": 0.3,
"zero_init_states": False,
"optimization": {
"actor_learning_rate": 0.005,
"critic_learning_rate": 0.005,
"entropy_learning_rate": 0.0001,
},
"model": {
"max_seq_len": 20,
},
"policy_model_config": {
"use_lstm": True,
"lstm_cell_size": 64,
"fcnet_hiddens": [64, 64],
"lstm_use_prev_action": True,
"lstm_use_prev_reward": True,
},
"q_model_config": {
"use_lstm": True,
"lstm_cell_size": 64,
"fcnet_hiddens": [64, 64],
"lstm_use_prev_action": True,
"lstm_use_prev_reward": True,
},
},
}
if __name__ == "__main__":
# INIT
ray.init(num_cpus=5)
# TRAIN
results = tune.run("RNNSAC", **config)
# TEST
checkpoint_config_path = str(Path(results.best_logdir) / "params.json")
with open(checkpoint_config_path, "rb") as f:
checkpoint_config = json.load(f)
checkpoint_config["explore"] = False
best_checkpoint = results.best_checkpoint
print("Loading checkpoint: {}".format(best_checkpoint))
algo = get_algorithm_class("RNNSAC")(
env=config["config"]["env"], config=checkpoint_config
)
algo.restore(best_checkpoint)
env = algo.env_creator({})
state = algo.get_policy().get_initial_state()
prev_action = 0
prev_reward = 0
obs = env.reset()
eps = 0
ep_reward = 0
while eps < 10:
action, state, info_algo = algo.compute_single_action(
obs,
state=state,
prev_action=prev_action,
prev_reward=prev_reward,
full_fetch=True,
)
obs, reward, done, info = env.step(action)
prev_action = action
prev_reward = reward
ep_reward += reward
try:
env.render()
2022-05-23 08:18:44 +02:00
except Exception:
pass
if done:
eps += 1
print("Episode {}: {}".format(eps, ep_reward))
ep_reward = 0
state = algo.get_policy().get_initial_state()
prev_action = 0
prev_reward = 0
obs = env.reset()
ray.shutdown()