ray/rllib/examples/rnnsac_stateless_cartpole.py

136 lines
3.6 KiB
Python

import json
import os
from pathlib import Path
import ray
from ray import air, tune
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
param_space = {
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": "torch",
"num_workers": 4,
"num_envs_per_worker": 1,
"num_cpus_per_worker": 1,
"log_level": "INFO",
"env": StatelessCartPole,
"horizon": 1000,
"gamma": 0.95,
"batch_mode": "complete_episodes",
"replay_buffer_config": {
"type": "MultiAgentReplayBuffer",
"storage_unit": "sequences",
"capacity": 100000,
"replay_burn_in": 4,
},
"num_steps_sampled_before_learning_starts": 1000,
"train_batch_size": 480,
"target_network_update_freq": 480,
"tau": 0.3,
"zero_init_states": False,
"optimization": {
"actor_learning_rate": 0.005,
"critic_learning_rate": 0.005,
"entropy_learning_rate": 0.0001,
},
"model": {
"max_seq_len": 20,
},
"policy_model_config": {
"use_lstm": True,
"lstm_cell_size": 64,
"fcnet_hiddens": [64, 64],
"lstm_use_prev_action": True,
"lstm_use_prev_reward": True,
},
"q_model_config": {
"use_lstm": True,
"lstm_cell_size": 64,
"fcnet_hiddens": [64, 64],
"lstm_use_prev_action": True,
"lstm_use_prev_reward": True,
},
}
if __name__ == "__main__":
# INIT
ray.init(num_cpus=5)
# TRAIN
results = tune.Tuner(
"RNNSAC",
run_config=air.RunConfig(
name="RNNSAC_example",
local_dir=str(Path(__file__).parent / "example_out"),
verbose=2,
checkpoint_config=air.CheckpointConfig(
checkpoint_at_end=True,
num_to_keep=1,
checkpoint_score_attribute="episode_reward_mean",
),
stop={
"episode_reward_mean": 65.0,
"timesteps_total": 50000,
},
),
tune_config=tune.TuneConfig(
metric="episode_reward_mean",
mode="max",
),
param_space=param_space,
).fit()
# TEST
checkpoint_config_path = os.path.join(
results.get_best_result().log_dir, "params.json"
)
with open(checkpoint_config_path, "rb") as f:
checkpoint_config = json.load(f)
checkpoint_config["explore"] = False
best_checkpoint = results.get_best_result().best_checkpoints[0][0]
print("Loading checkpoint: {}".format(best_checkpoint))
algo = get_algorithm_class("RNNSAC")(
env=StatelessCartPole, config=checkpoint_config
)
algo.restore(best_checkpoint)
env = algo.env_creator({})
state = algo.get_policy().get_initial_state()
prev_action = 0
prev_reward = 0
obs = env.reset()
eps = 0
ep_reward = 0
while eps < 10:
action, state, info_algo = algo.compute_single_action(
obs,
state=state,
prev_action=prev_action,
prev_reward=prev_reward,
full_fetch=True,
)
obs, reward, done, info = env.step(action)
prev_action = action
prev_reward = reward
ep_reward += reward
try:
env.render()
except Exception:
pass
if done:
eps += 1
print("Episode {}: {}".format(eps, ep_reward))
ep_reward = 0
state = algo.get_policy().get_initial_state()
prev_action = 0
prev_reward = 0
obs = env.reset()
ray.shutdown()