mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

We currently use our own serialization to ship checkpoints as objects. Instead we should use the Checkpoint class. This PR also adds support to create results from checkpoints pointing to object references. Depends on #26351 Signed-off-by: Kai Fricke <kai@anyscale.com>
126 lines
3.4 KiB
Python
126 lines
3.4 KiB
Python
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import ray
|
|
from ray import tune
|
|
from ray.rllib.algorithms.registry import get_algorithm_class
|
|
|
|
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
|
|
|
|
|
config = {
|
|
"name": "RNNSAC_example",
|
|
"local_dir": str(Path(__file__).parent / "example_out"),
|
|
"checkpoint_at_end": True,
|
|
"keep_checkpoints_num": 1,
|
|
"checkpoint_score_attr": "episode_reward_mean",
|
|
"stop": {
|
|
"episode_reward_mean": 65.0,
|
|
"timesteps_total": 50000,
|
|
},
|
|
"metric": "episode_reward_mean",
|
|
"mode": "max",
|
|
"verbose": 2,
|
|
"config": {
|
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
|
"framework": "torch",
|
|
"num_workers": 4,
|
|
"num_envs_per_worker": 1,
|
|
"num_cpus_per_worker": 1,
|
|
"log_level": "INFO",
|
|
"env": StatelessCartPole,
|
|
"horizon": 1000,
|
|
"gamma": 0.95,
|
|
"batch_mode": "complete_episodes",
|
|
"replay_buffer_config": {
|
|
"type": "MultiAgentReplayBuffer",
|
|
"storage_unit": "sequences",
|
|
"capacity": 100000,
|
|
"learning_starts": 1000,
|
|
"replay_burn_in": 4,
|
|
},
|
|
"train_batch_size": 480,
|
|
"target_network_update_freq": 480,
|
|
"tau": 0.3,
|
|
"zero_init_states": False,
|
|
"optimization": {
|
|
"actor_learning_rate": 0.005,
|
|
"critic_learning_rate": 0.005,
|
|
"entropy_learning_rate": 0.0001,
|
|
},
|
|
"model": {
|
|
"max_seq_len": 20,
|
|
},
|
|
"policy_model_config": {
|
|
"use_lstm": True,
|
|
"lstm_cell_size": 64,
|
|
"fcnet_hiddens": [64, 64],
|
|
"lstm_use_prev_action": True,
|
|
"lstm_use_prev_reward": True,
|
|
},
|
|
"q_model_config": {
|
|
"use_lstm": True,
|
|
"lstm_cell_size": 64,
|
|
"fcnet_hiddens": [64, 64],
|
|
"lstm_use_prev_action": True,
|
|
"lstm_use_prev_reward": True,
|
|
},
|
|
},
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
# INIT
|
|
ray.init(num_cpus=5)
|
|
|
|
# TRAIN
|
|
results = tune.run("RNNSAC", **config)
|
|
|
|
# TEST
|
|
checkpoint_config_path = str(Path(results.best_logdir) / "params.json")
|
|
with open(checkpoint_config_path, "rb") as f:
|
|
checkpoint_config = json.load(f)
|
|
|
|
checkpoint_config["explore"] = False
|
|
|
|
best_checkpoint = results.best_checkpoint
|
|
print("Loading checkpoint: {}".format(best_checkpoint))
|
|
|
|
algo = get_algorithm_class("RNNSAC")(
|
|
env=config["config"]["env"], config=checkpoint_config
|
|
)
|
|
algo.restore(best_checkpoint)
|
|
|
|
env = algo.env_creator({})
|
|
state = algo.get_policy().get_initial_state()
|
|
prev_action = 0
|
|
prev_reward = 0
|
|
obs = env.reset()
|
|
|
|
eps = 0
|
|
ep_reward = 0
|
|
while eps < 10:
|
|
action, state, info_algo = algo.compute_single_action(
|
|
obs,
|
|
state=state,
|
|
prev_action=prev_action,
|
|
prev_reward=prev_reward,
|
|
full_fetch=True,
|
|
)
|
|
obs, reward, done, info = env.step(action)
|
|
prev_action = action
|
|
prev_reward = reward
|
|
ep_reward += reward
|
|
try:
|
|
env.render()
|
|
except Exception:
|
|
pass
|
|
if done:
|
|
eps += 1
|
|
print("Episode {}: {}".format(eps, ep_reward))
|
|
ep_reward = 0
|
|
state = algo.get_policy().get_initial_state()
|
|
prev_action = 0
|
|
prev_reward = 0
|
|
obs = env.reset()
|
|
ray.shutdown()
|