mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

update rllib example to use Tuner API. Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>
136 lines
3.6 KiB
Python
136 lines
3.6 KiB
Python
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import ray
|
|
from ray import air, tune
|
|
from ray.rllib.algorithms.registry import get_algorithm_class
|
|
|
|
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
|
|
|
|
|
param_space = {
|
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
|
"framework": "torch",
|
|
"num_workers": 4,
|
|
"num_envs_per_worker": 1,
|
|
"num_cpus_per_worker": 1,
|
|
"log_level": "INFO",
|
|
"env": StatelessCartPole,
|
|
"horizon": 1000,
|
|
"gamma": 0.95,
|
|
"batch_mode": "complete_episodes",
|
|
"replay_buffer_config": {
|
|
"type": "MultiAgentReplayBuffer",
|
|
"storage_unit": "sequences",
|
|
"capacity": 100000,
|
|
"learning_starts": 1000,
|
|
"replay_burn_in": 4,
|
|
},
|
|
"train_batch_size": 480,
|
|
"target_network_update_freq": 480,
|
|
"tau": 0.3,
|
|
"zero_init_states": False,
|
|
"optimization": {
|
|
"actor_learning_rate": 0.005,
|
|
"critic_learning_rate": 0.005,
|
|
"entropy_learning_rate": 0.0001,
|
|
},
|
|
"model": {
|
|
"max_seq_len": 20,
|
|
},
|
|
"policy_model_config": {
|
|
"use_lstm": True,
|
|
"lstm_cell_size": 64,
|
|
"fcnet_hiddens": [64, 64],
|
|
"lstm_use_prev_action": True,
|
|
"lstm_use_prev_reward": True,
|
|
},
|
|
"q_model_config": {
|
|
"use_lstm": True,
|
|
"lstm_cell_size": 64,
|
|
"fcnet_hiddens": [64, 64],
|
|
"lstm_use_prev_action": True,
|
|
"lstm_use_prev_reward": True,
|
|
},
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
# INIT
|
|
ray.init(num_cpus=5)
|
|
|
|
# TRAIN
|
|
results = tune.Tuner(
|
|
"RNNSAC",
|
|
run_config=air.RunConfig(
|
|
name="RNNSAC_example",
|
|
local_dir=str(Path(__file__).parent / "example_out"),
|
|
verbose=2,
|
|
checkpoint_config=air.CheckpointConfig(
|
|
checkpoint_at_end=True,
|
|
num_to_keep=1,
|
|
checkpoint_score_attribute="episode_reward_mean",
|
|
),
|
|
stop={
|
|
"episode_reward_mean": 65.0,
|
|
"timesteps_total": 50000,
|
|
},
|
|
),
|
|
tune_config=tune.TuneConfig(
|
|
metric="episode_reward_mean",
|
|
mode="max",
|
|
),
|
|
param_space=param_space,
|
|
).fit()
|
|
|
|
# TEST
|
|
|
|
checkpoint_config_path = os.path.join(
|
|
results.get_best_result().log_dir, "params.json"
|
|
)
|
|
with open(checkpoint_config_path, "rb") as f:
|
|
checkpoint_config = json.load(f)
|
|
|
|
checkpoint_config["explore"] = False
|
|
|
|
best_checkpoint = results.get_best_result().best_checkpoints[0][0]
|
|
print("Loading checkpoint: {}".format(best_checkpoint))
|
|
|
|
algo = get_algorithm_class("RNNSAC")(
|
|
env=StatelessCartPole, config=checkpoint_config
|
|
)
|
|
algo.restore(best_checkpoint)
|
|
|
|
env = algo.env_creator({})
|
|
state = algo.get_policy().get_initial_state()
|
|
prev_action = 0
|
|
prev_reward = 0
|
|
obs = env.reset()
|
|
|
|
eps = 0
|
|
ep_reward = 0
|
|
while eps < 10:
|
|
action, state, info_algo = algo.compute_single_action(
|
|
obs,
|
|
state=state,
|
|
prev_action=prev_action,
|
|
prev_reward=prev_reward,
|
|
full_fetch=True,
|
|
)
|
|
obs, reward, done, info = env.step(action)
|
|
prev_action = action
|
|
prev_reward = reward
|
|
ep_reward += reward
|
|
try:
|
|
env.render()
|
|
except Exception:
|
|
pass
|
|
if done:
|
|
eps += 1
|
|
print("Episode {}: {}".format(eps, ep_reward))
|
|
ep_reward = 0
|
|
state = algo.get_policy().get_initial_state()
|
|
prev_action = 0
|
|
prev_reward = 0
|
|
obs = env.reset()
|
|
ray.shutdown()
|