mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLLib] MBMPO Fixes (#10296)
This commit is contained in:
parent
d22980a5c3
commit
8e613652af
10 changed files with 168 additions and 46 deletions
|
@ -22,6 +22,7 @@ Algorithm Frameworks Discrete Actions Continuous Actions Multi-
|
|||
`IMPALA`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Transformer`_, `+autoreg`_
|
||||
`MAML`_ tf + torch No **Yes** No
|
||||
`MARWIL`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
|
||||
`MBMPO`_ torch No **Yes** No
|
||||
`PG`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Transformer`_, `+autoreg`_
|
||||
`PPO`_, `APPO`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Transformer`_, `+autoreg`_
|
||||
`SAC`_ tf + torch **Yes** **Yes** **Yes**
|
||||
|
@ -442,6 +443,35 @@ Tuned examples: HalfCheetahRandDirecEnv (`Env <https://github.com/ray-project/ra
|
|||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
||||
.. _mbmpo:
|
||||
|
||||
Model-Based Meta-Policy-Optimization (MB-MPO)
|
||||
---------------------------------------------
|
||||
|pytorch|
|
||||
`[paper] <https://arxiv.org/pdf/1809.05214.pdf>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/mbmpo/mbmpo.py>`__
|
||||
|
||||
RLlib's MBMPO implementation is a Dyna-styled model-based RL method that learns based on the predictions of an ensemble of transition-dynamics models. Similar to MAML, MBMPO metalearns an optimial policy by treating each dynamics model as a different task. Code here is adapted from https://github.com/jonasrothfuss/model_ensemble_meta_learning. Similar to the original paper, MBMPO is evaluated on MuJoCo, with the horizon set to 200 instead of the default 1000.
|
||||
|
||||
Additional statistics are logged in MBMPO. Each MBMPO iteration corresponds to multiple MAML iterations, and ``MAMLIter$i$_DynaTrajInner_$j$_episode_reward_mean`` measures the agent's returns across the dynamics models at iteration ``i`` of MAML and step ``j`` of inner adaptation. Examples can be seen `here <https://github.com/ray-project/rl-experiments/tree/master/mbmpo>`__.
|
||||
|
||||
Tuned examples: `HalfCheetah <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/mbmpo/halfcheetah-mbmpo.yaml>`__, `Hopper <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/mbmpo/hopper-mbmpo.yaml>`__
|
||||
|
||||
**MuJoCo results @100K steps:** `more details <https://github.com/ray-project/rl-experiments>`__
|
||||
|
||||
============= ============ ====================
|
||||
MuJoCo env RLlib MBMPO Clavera et al MBMPO
|
||||
============= ============ ====================
|
||||
HalfCheetah 520 ~550
|
||||
Hopper 620 ~650
|
||||
============= ============ ====================
|
||||
|
||||
**MBMPO-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
.. literalinclude:: ../../rllib/agents/mbmpo/mbmpo.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
||||
.. _dreamer:
|
||||
|
||||
Dreamer
|
||||
|
|
|
@ -110,6 +110,8 @@ Algorithms
|
|||
|
||||
- |pytorch| |tensorflow| :ref:`Model-Agnostic Meta-Learning (MAML) <maml>`
|
||||
|
||||
- |pytorch| :ref:`Model-Based Meta-Policy-Optimization (MBMPO) <mbmpo>`
|
||||
|
||||
- |pytorch| |tensorflow| :ref:`Policy Gradients <pg>`
|
||||
|
||||
- |pytorch| |tensorflow| :ref:`Proximal Policy Optimization (PPO) <ppo>`
|
||||
|
|
|
@ -199,10 +199,9 @@ class MAMLLoss(object):
|
|||
current_policy_vars[i] = adapted_policy_vars
|
||||
kls.append(kl_loss)
|
||||
inner_ppo_loss.append(ppo_loss)
|
||||
inner_kls.append(kls)
|
||||
inner_kls.extend(kls)
|
||||
|
||||
mean_inner_kl = [torch.mean(torch.stack(kls)) for kls in inner_kls]
|
||||
self.mean_inner_kl = mean_inner_kl
|
||||
self.mean_inner_kl = inner_kls
|
||||
|
||||
ppo_obj = []
|
||||
for i in range(self.num_tasks):
|
||||
|
@ -230,10 +229,10 @@ class MAMLLoss(object):
|
|||
self.mean_entropy = entropy_loss
|
||||
|
||||
self.inner_kl_loss = torch.mean(
|
||||
torch.stack(
|
||||
[a * b for a, b in zip(self.cur_kl_coeff, mean_inner_kl)]))
|
||||
torch.stack([
|
||||
a * b for a, b in zip(self.cur_kl_coeff, self.mean_inner_kl)
|
||||
]))
|
||||
self.loss = torch.mean(torch.stack(ppo_obj)) + self.inner_kl_loss
|
||||
print("Meta-Loss: ", self.loss, ", Inner KL:", self.inner_kl_loss)
|
||||
|
||||
def feed_forward(self, obs, policy_vars, policy_config):
|
||||
# Hacky for now, reconstruct FC network with adapted weights
|
||||
|
@ -298,7 +297,6 @@ class MAMLLoss(object):
|
|||
return pi_new_logits, torch.squeeze(value_fn)
|
||||
|
||||
def compute_updated_variables(self, loss, network_vars, model):
|
||||
|
||||
grad = torch.autograd.grad(
|
||||
loss,
|
||||
inputs=model.parameters(),
|
||||
|
@ -389,8 +387,9 @@ def maml_stats(policy, train_batch):
|
|||
|
||||
class KLCoeffMixin:
|
||||
def __init__(self, config):
|
||||
self.kl_coeff_val = [config["kl_coeff"]
|
||||
] * config["inner_adaptation_steps"]
|
||||
self.kl_coeff_val = [
|
||||
config["kl_coeff"]
|
||||
] * config["inner_adaptation_steps"] * config["num_workers"]
|
||||
self.kl_target = self.config["kl_target"]
|
||||
|
||||
def update_kls(self, sampled_kls):
|
||||
|
|
|
@ -18,7 +18,8 @@ from ray.rllib.utils.torch_ops import convert_to_torch_tensor
|
|||
from ray.rllib.evaluation.metrics import collect_episodes
|
||||
from ray.rllib.agents.mbmpo.model_vector_env import custom_model_vector_env
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.agents.mbmpo.utils import calculate_gae_advantages
|
||||
from ray.rllib.agents.mbmpo.utils import calculate_gae_advantages, \
|
||||
MBMPOExploration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -69,7 +70,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Number of Transition-Dynamics Models for Ensemble
|
||||
"ensemble_size": 5,
|
||||
# Hidden Layers for Model Ensemble
|
||||
"fcnet_hiddens": [512, 512],
|
||||
"fcnet_hiddens": [512, 512, 512],
|
||||
# Model Learning Rate
|
||||
"lr": 1e-3,
|
||||
# Max number of training epochs per MBMPO iter
|
||||
|
@ -81,10 +82,11 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Normalize Data (obs, action, and deltas)
|
||||
"normalize_data": True,
|
||||
},
|
||||
"exploration_config": {
|
||||
"type": MBMPOExploration,
|
||||
},
|
||||
# Workers sample from dynamics models
|
||||
"custom_vector_env": custom_model_vector_env,
|
||||
# How many enviornments there are per worker (vectorized)
|
||||
"num_worker_envs": 20,
|
||||
# How many iterations through MAML per MBMPO iteration
|
||||
"num_maml_steps": 10,
|
||||
})
|
||||
|
@ -152,7 +154,7 @@ class MetaUpdate:
|
|||
metrics.info[LEARNER_INFO] = fetches
|
||||
metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
|
||||
|
||||
if self.step_counter == self.num_steps:
|
||||
if self.step_counter == self.num_steps - 1:
|
||||
td_metric = self.workers.local_worker().foreach_policy(
|
||||
fit_dynamics)[0]
|
||||
|
||||
|
|
|
@ -158,7 +158,7 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module):
|
|||
|
||||
for i in range(self.num_models):
|
||||
self.add_module("TD-model-" + str(i), self.dynamics_ensemble[i])
|
||||
self.replay_buffer_max = 100000
|
||||
self.replay_buffer_max = 10000
|
||||
self.replay_buffer = None
|
||||
self.optimizers = [
|
||||
torch.optim.Adam(
|
||||
|
@ -170,7 +170,8 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module):
|
|||
self.metrics[STEPS_SAMPLED_COUNTER] = 0
|
||||
|
||||
# For each worker, choose a random model to choose trajectories from
|
||||
self.sample_index = np.random.randint(self.num_models)
|
||||
worker_index = get_global_worker().worker_index
|
||||
self.sample_index = int((worker_index - 1) / self.num_models)
|
||||
self.global_itr = 0
|
||||
self.device = (torch.device("cuda")
|
||||
if torch.cuda.is_available() else torch.device("cpu"))
|
||||
|
@ -195,9 +196,10 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module):
|
|||
# Add env samples to Replay Buffer
|
||||
local_worker = get_global_worker()
|
||||
new_samples = local_worker.sample()
|
||||
# Initial Exploration of 8000 timesteps
|
||||
if not self.global_itr:
|
||||
tmp = local_worker.sample()
|
||||
new_samples.concat(tmp)
|
||||
extra = local_worker.sample()
|
||||
new_samples.concat(extra)
|
||||
|
||||
# Process Samples
|
||||
new_samples = process_samples(new_samples)
|
||||
|
@ -257,9 +259,6 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module):
|
|||
train_losses[ind] = train_losses[
|
||||
ind].detach().cpu().numpy()
|
||||
|
||||
del x
|
||||
del y
|
||||
|
||||
# Validation
|
||||
val_lists = []
|
||||
for data in zip(*val_loaders):
|
||||
|
@ -273,8 +272,6 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module):
|
|||
|
||||
for ind in range(self.num_models):
|
||||
val_losses[ind] = val_losses[ind].detach().cpu().numpy()
|
||||
del x
|
||||
del y
|
||||
|
||||
val_lists = np.array(val_lists)
|
||||
avg_val_losses = np.mean(val_lists, axis=0)
|
||||
|
|
|
@ -81,7 +81,7 @@ class _VectorizedModelGymEnv(VectorEnv):
|
|||
next_obs_batch = self.model.predict_model_batches(
|
||||
obs_batch, action_batch, device=self.device)
|
||||
|
||||
next_obs_batch = np.clip(next_obs_batch, -50, 50)
|
||||
next_obs_batch = np.clip(next_obs_batch, -1000, 1000)
|
||||
|
||||
rew_batch = self.envs[0].reward(obs_batch, action_batch,
|
||||
next_obs_batch)
|
||||
|
@ -95,7 +95,8 @@ class _VectorizedModelGymEnv(VectorEnv):
|
|||
|
||||
self.cur_obs = next_obs_batch
|
||||
|
||||
return list(obs_batch), list(rew_batch), list(dones_batch), info_batch
|
||||
return list(next_obs_batch), list(rew_batch), list(
|
||||
dones_batch), info_batch
|
||||
|
||||
@override(VectorEnv)
|
||||
def get_unwrapped(self):
|
||||
|
|
|
@ -1,5 +1,16 @@
|
|||
import numpy as np
|
||||
import scipy
|
||||
from typing import Union
|
||||
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.exploration.exploration import Exploration
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
|
||||
TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
class LinearFeatureBaseline():
|
||||
|
@ -66,3 +77,50 @@ def discount_cumsum(x, discount):
|
|||
"""
|
||||
return scipy.signal.lfilter(
|
||||
[1], [1, float(-discount)], x[::-1], axis=0)[::-1]
|
||||
|
||||
|
||||
class MBMPOExploration(Exploration):
|
||||
"""An exploration that simply samples from a distribution.
|
||||
|
||||
The sampling can be made deterministic by passing explore=False into
|
||||
the call to `get_exploration_action`.
|
||||
Also allows for scheduled parameters for the distributions, such as
|
||||
lowering stddev, temperature, etc.. over time.
|
||||
"""
|
||||
|
||||
def __init__(self, action_space, *, framework: str, model: ModelV2,
|
||||
**kwargs):
|
||||
"""Initializes a StochasticSampling Exploration object.
|
||||
|
||||
Args:
|
||||
action_space (Space): The gym action space used by the environment.
|
||||
framework (str): One of None, "tf", "torch".
|
||||
"""
|
||||
assert framework is not None
|
||||
self.timestep = 0
|
||||
self.worker_index = kwargs["worker_index"]
|
||||
super().__init__(
|
||||
action_space, model=model, framework=framework, **kwargs)
|
||||
|
||||
@override(Exploration)
|
||||
def get_exploration_action(self,
|
||||
*,
|
||||
action_distribution: ActionDistribution,
|
||||
timestep: Union[int, TensorType],
|
||||
explore: bool = True):
|
||||
assert self.framework == "torch"
|
||||
return self._get_torch_exploration_action(action_distribution, explore)
|
||||
|
||||
def _get_torch_exploration_action(self, action_dist, explore):
|
||||
action = action_dist.sample()
|
||||
logp = action_dist.sampled_action_logp()
|
||||
|
||||
batch_size = action.size()[0]
|
||||
|
||||
# Initial Random Exploration for Real Env Interaction
|
||||
if self.worker_index == 0 and self.timestep < 8000:
|
||||
print("Using Random")
|
||||
action = [self.action_space.sample() for _ in range(batch_size)]
|
||||
logp = [0.0 for _ in range(batch_size)]
|
||||
self.timestep += batch_size
|
||||
return action, logp
|
||||
|
|
|
@ -1,21 +1,5 @@
|
|||
import numpy as np
|
||||
from gym.envs.mujoco import HalfCheetahEnv
|
||||
import inspect
|
||||
|
||||
|
||||
def get_all_function_arguments(function, locals):
|
||||
kwargs_dict = {}
|
||||
for arg in inspect.getfullargspec(function).kwonlyargs:
|
||||
if arg not in ["args", "kwargs"]:
|
||||
kwargs_dict[arg] = locals[arg]
|
||||
args = [locals[arg] for arg in inspect.getfullargspec(function).args]
|
||||
|
||||
if "args" in locals:
|
||||
args += locals["args"]
|
||||
|
||||
if "kwargs" in locals:
|
||||
kwargs_dict.update(locals["kwargs"])
|
||||
return args, kwargs_dict
|
||||
from gym.envs.mujoco import HalfCheetahEnv, HopperEnv
|
||||
|
||||
|
||||
class HalfCheetahWrapper(HalfCheetahEnv):
|
||||
|
@ -42,8 +26,28 @@ class HalfCheetahWrapper(HalfCheetahEnv):
|
|||
return np.minimum(np.maximum(-1000.0, reward), 1000.0)
|
||||
|
||||
|
||||
class HopperWrapper(HopperEnv):
|
||||
"""Hopper Wrapper that wraps Mujoco Hopper-v2 env
|
||||
with an additional defined reward function for model-based RL.
|
||||
|
||||
This is currently used for MBMPO.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
HopperEnv.__init__(self, *args, **kwargs)
|
||||
|
||||
def reward(self, obs, action, obs_next):
|
||||
alive_bonus = 1.0
|
||||
assert obs.ndim == 2 and action.ndim == 2
|
||||
assert obs.shape == obs_next.shape and action.shape[0] == obs.shape[0]
|
||||
vel = obs_next[:, 5]
|
||||
ctrl_cost = 1e-3 * np.sum(np.square(action), axis=1)
|
||||
reward = vel + alive_bonus - ctrl_cost
|
||||
return np.minimum(np.maximum(-1000.0, reward), 1000.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = HalfCheetahWrapper()
|
||||
env = HopperWrapper()
|
||||
env.reset()
|
||||
for _ in range(1000):
|
||||
env.step(env.action_space.sample())
|
|
@ -1,11 +1,12 @@
|
|||
halfcheetah-mb-mpo:
|
||||
env: ray.rllib.examples.env.halfcheetah.HalfCheetahWrapper
|
||||
halfcheetah-mbmpo:
|
||||
env: ray.rllib.examples.env.mbmpo_env.HalfCheetahWrapper
|
||||
run: MBMPO
|
||||
stop:
|
||||
training_iteration: 500
|
||||
config:
|
||||
# Only supported in torch right now
|
||||
framework: torch
|
||||
# 200 in paper, 1000 will take forever
|
||||
horizon: 200
|
||||
num_envs_per_worker: 20
|
||||
inner_adaptation_steps: 1
|
||||
|
@ -14,12 +15,13 @@ halfcheetah-mb-mpo:
|
|||
lambda: 1.0
|
||||
lr: 0.001
|
||||
clip_param: 0.5
|
||||
kl_target: 0.01
|
||||
kl_target: 0.003
|
||||
kl_coeff: 0.0000000001
|
||||
num_workers: 20
|
||||
num_gpus: 1
|
||||
inner_lr: 0.001
|
||||
clip_actions: False
|
||||
num_maml_steps: 15
|
||||
model:
|
||||
fcnet_hiddens: [32, 32]
|
||||
free_log_std: True
|
||||
|
|
27
rllib/tuned_examples/mbmpo/hopper-mbmpo.yaml
Normal file
27
rllib/tuned_examples/mbmpo/hopper-mbmpo.yaml
Normal file
|
@ -0,0 +1,27 @@
|
|||
hopper-mbmpo:
|
||||
env: ray.rllib.examples.env.mbmpo_env.HopperWrapper
|
||||
run: MBMPO
|
||||
stop:
|
||||
training_iteration: 500
|
||||
config:
|
||||
# Only supported in torch right now
|
||||
framework: torch
|
||||
# 200 in paper, 1000 will take forever
|
||||
horizon: 200
|
||||
num_envs_per_worker: 20
|
||||
inner_adaptation_steps: 1
|
||||
maml_optimizer_steps: 8
|
||||
gamma: 0.99
|
||||
lambda: 1.0
|
||||
lr: 0.001
|
||||
clip_param: 0.5
|
||||
kl_target: 0.003
|
||||
kl_coeff: 0.0000000001
|
||||
num_workers: 20
|
||||
num_gpus: 1
|
||||
inner_lr: 0.001
|
||||
clip_actions: False
|
||||
num_maml_steps: 15
|
||||
model:
|
||||
fcnet_hiddens: [32, 32]
|
||||
free_log_std: True
|
Loading…
Add table
Reference in a new issue