mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Examples for training, saving, loading, testing an agent with SB & RLlib (#15897)
This commit is contained in:
parent
474f04e322
commit
55709bac7a
2 changed files with 86 additions and 0 deletions
46
rllib/examples/sb2rllib_rllib_example.py
Normal file
46
rllib/examples/sb2rllib_rllib_example.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
"""
|
||||
Example script on how to train, save, load, and test an RLlib agent.
|
||||
Equivalent script with stable baselines: sb2rllib_sb_example.py
|
||||
"""
|
||||
import gym
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
|
||||
# settings used for both stable baselines and rllib
|
||||
env_name = "CartPole-v1"
|
||||
train_steps = 10000
|
||||
learning_rate = 1e-3
|
||||
save_dir = "saved_models"
|
||||
|
||||
# training and saving
|
||||
analysis = ray.tune.run(
|
||||
"PPO",
|
||||
stop={"timesteps_total": train_steps},
|
||||
config={
|
||||
"env": env_name,
|
||||
"lr": learning_rate
|
||||
},
|
||||
checkpoint_at_end=True,
|
||||
local_dir=save_dir,
|
||||
)
|
||||
# retrieve the checkpoint path
|
||||
analysis.default_metric = "episode_reward_mean"
|
||||
analysis.default_mode = "max"
|
||||
checkpoint_path = analysis.get_best_checkpoint(trial=analysis.get_best_trial())
|
||||
print(f"Trained model saved at {checkpoint_path}")
|
||||
|
||||
# load and restore model
|
||||
agent = ppo.PPOTrainer(env=env_name)
|
||||
agent.restore(checkpoint_path)
|
||||
print(f"Agent loaded from saved model at {checkpoint_path}")
|
||||
|
||||
# inference
|
||||
env = gym.make(env_name)
|
||||
obs = env.reset()
|
||||
for i in range(1000):
|
||||
action = agent.compute_action(obs)
|
||||
obs, reward, done, info = env.step(action)
|
||||
env.render()
|
||||
if done:
|
||||
print(f"Cart pole dropped after {i} steps.")
|
||||
break
|
40
rllib/examples/sb2rllib_sb_example.py
Normal file
40
rllib/examples/sb2rllib_sb_example.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
"""
|
||||
Example script on how to train, save, load, and test a stable baselines 2 agent
|
||||
Code taken and adjusted from SB2 docs:
|
||||
https://stable-baselines.readthedocs.io/en/master/guide/quickstart.html
|
||||
Equivalent script with RLlib: sb2rllib_rllib_example.py
|
||||
"""
|
||||
import gym
|
||||
|
||||
from stable_baselines.common.policies import MlpPolicy
|
||||
from stable_baselines import PPO2
|
||||
|
||||
# settings used for both stable baselines and rllib
|
||||
env_name = "CartPole-v1"
|
||||
train_steps = 10000
|
||||
learning_rate = 1e-3
|
||||
save_dir = "saved_models"
|
||||
|
||||
save_path = f"{save_dir}/sb_model_{train_steps}steps"
|
||||
env = gym.make(env_name)
|
||||
|
||||
# training and saving
|
||||
model = PPO2(MlpPolicy, env, learning_rate=learning_rate, verbose=1)
|
||||
model.learn(total_timesteps=train_steps)
|
||||
model.save(save_path)
|
||||
print(f"Trained model saved at {save_path}")
|
||||
|
||||
# delete and load model (just for illustration)
|
||||
del model
|
||||
model = PPO2.load(save_path)
|
||||
print(f"Agent loaded from saved model at {save_path}")
|
||||
|
||||
# inference
|
||||
obs = env.reset()
|
||||
for i in range(1000):
|
||||
action, _states = model.predict(obs)
|
||||
obs, reward, done, info = env.step(action)
|
||||
env.render()
|
||||
if done:
|
||||
print(f"Cart pole dropped after {i} steps.")
|
||||
break
|
Loading…
Add table
Reference in a new issue