From 8e613652af436514df67af035987f1df920617e1 Mon Sep 17 00:00:00 2001 From: Michael Luo Date: Wed, 9 Sep 2020 00:34:34 -0700 Subject: [PATCH] [RLLib] MBMPO Fixes (#10296) --- doc/source/rllib-algorithms.rst | 30 ++++++++++ doc/source/rllib-toc.rst | 2 + rllib/agents/maml/maml_torch_policy.py | 17 +++--- rllib/agents/mbmpo/mbmpo.py | 12 ++-- rllib/agents/mbmpo/model_ensemble.py | 15 ++--- rllib/agents/mbmpo/model_vector_env.py | 5 +- rllib/agents/mbmpo/utils.py | 58 +++++++++++++++++++ .../env/{halfcheetah.py => mbmpo_env.py} | 40 +++++++------ .../mbmpo/halfcheetah-mbmpo.yaml | 8 ++- rllib/tuned_examples/mbmpo/hopper-mbmpo.yaml | 27 +++++++++ 10 files changed, 168 insertions(+), 46 deletions(-) rename rllib/examples/env/{halfcheetah.py => mbmpo_env.py} (57%) create mode 100644 rllib/tuned_examples/mbmpo/hopper-mbmpo.yaml diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index 3ff0c76cd..7e6553e29 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -22,6 +22,7 @@ Algorithm Frameworks Discrete Actions Continuous Actions Multi- `IMPALA`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Transformer`_, `+autoreg`_ `MAML`_ tf + torch No **Yes** No `MARWIL`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_ +`MBMPO`_ torch No **Yes** No `PG`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Transformer`_, `+autoreg`_ `PPO`_, `APPO`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Transformer`_, `+autoreg`_ `SAC`_ tf + torch **Yes** **Yes** **Yes** @@ -442,6 +443,35 @@ Tuned examples: HalfCheetahRandDirecEnv (`Env `__ `[implementation] `__ + +RLlib's MBMPO implementation is a Dyna-styled model-based RL method that learns based on the predictions of an ensemble of transition-dynamics models. Similar to MAML, MBMPO metalearns an optimial policy by treating each dynamics model as a different task. Code here is adapted from https://github.com/jonasrothfuss/model_ensemble_meta_learning. Similar to the original paper, MBMPO is evaluated on MuJoCo, with the horizon set to 200 instead of the default 1000. + +Additional statistics are logged in MBMPO. Each MBMPO iteration corresponds to multiple MAML iterations, and ``MAMLIter$i$_DynaTrajInner_$j$_episode_reward_mean`` measures the agent's returns across the dynamics models at iteration ``i`` of MAML and step ``j`` of inner adaptation. Examples can be seen `here `__. + +Tuned examples: `HalfCheetah `__, `Hopper `__ + +**MuJoCo results @100K steps:** `more details `__ + +============= ============ ==================== +MuJoCo env RLlib MBMPO Clavera et al MBMPO +============= ============ ==================== +HalfCheetah 520 ~550 +Hopper 620 ~650 +============= ============ ==================== + +**MBMPO-specific configs** (see also `common configs `__): + +.. literalinclude:: ../../rllib/agents/mbmpo/mbmpo.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + .. _dreamer: Dreamer diff --git a/doc/source/rllib-toc.rst b/doc/source/rllib-toc.rst index d9e7896cc..bc6b8b630 100644 --- a/doc/source/rllib-toc.rst +++ b/doc/source/rllib-toc.rst @@ -110,6 +110,8 @@ Algorithms - |pytorch| |tensorflow| :ref:`Model-Agnostic Meta-Learning (MAML) ` + - |pytorch| :ref:`Model-Based Meta-Policy-Optimization (MBMPO) ` + - |pytorch| |tensorflow| :ref:`Policy Gradients ` - |pytorch| |tensorflow| :ref:`Proximal Policy Optimization (PPO) ` diff --git a/rllib/agents/maml/maml_torch_policy.py b/rllib/agents/maml/maml_torch_policy.py index cf378a4ba..8a143b455 100644 --- a/rllib/agents/maml/maml_torch_policy.py +++ b/rllib/agents/maml/maml_torch_policy.py @@ -199,10 +199,9 @@ class MAMLLoss(object): current_policy_vars[i] = adapted_policy_vars kls.append(kl_loss) inner_ppo_loss.append(ppo_loss) - inner_kls.append(kls) + inner_kls.extend(kls) - mean_inner_kl = [torch.mean(torch.stack(kls)) for kls in inner_kls] - self.mean_inner_kl = mean_inner_kl + self.mean_inner_kl = inner_kls ppo_obj = [] for i in range(self.num_tasks): @@ -230,10 +229,10 @@ class MAMLLoss(object): self.mean_entropy = entropy_loss self.inner_kl_loss = torch.mean( - torch.stack( - [a * b for a, b in zip(self.cur_kl_coeff, mean_inner_kl)])) + torch.stack([ + a * b for a, b in zip(self.cur_kl_coeff, self.mean_inner_kl) + ])) self.loss = torch.mean(torch.stack(ppo_obj)) + self.inner_kl_loss - print("Meta-Loss: ", self.loss, ", Inner KL:", self.inner_kl_loss) def feed_forward(self, obs, policy_vars, policy_config): # Hacky for now, reconstruct FC network with adapted weights @@ -298,7 +297,6 @@ class MAMLLoss(object): return pi_new_logits, torch.squeeze(value_fn) def compute_updated_variables(self, loss, network_vars, model): - grad = torch.autograd.grad( loss, inputs=model.parameters(), @@ -389,8 +387,9 @@ def maml_stats(policy, train_batch): class KLCoeffMixin: def __init__(self, config): - self.kl_coeff_val = [config["kl_coeff"] - ] * config["inner_adaptation_steps"] + self.kl_coeff_val = [ + config["kl_coeff"] + ] * config["inner_adaptation_steps"] * config["num_workers"] self.kl_target = self.config["kl_target"] def update_kls(self, sampled_kls): diff --git a/rllib/agents/mbmpo/mbmpo.py b/rllib/agents/mbmpo/mbmpo.py index cf24f8a78..eebbb1dcd 100644 --- a/rllib/agents/mbmpo/mbmpo.py +++ b/rllib/agents/mbmpo/mbmpo.py @@ -18,7 +18,8 @@ from ray.rllib.utils.torch_ops import convert_to_torch_tensor from ray.rllib.evaluation.metrics import collect_episodes from ray.rllib.agents.mbmpo.model_vector_env import custom_model_vector_env from ray.rllib.evaluation.metrics import collect_metrics -from ray.rllib.agents.mbmpo.utils import calculate_gae_advantages +from ray.rllib.agents.mbmpo.utils import calculate_gae_advantages, \ + MBMPOExploration logger = logging.getLogger(__name__) @@ -69,7 +70,7 @@ DEFAULT_CONFIG = with_common_config({ # Number of Transition-Dynamics Models for Ensemble "ensemble_size": 5, # Hidden Layers for Model Ensemble - "fcnet_hiddens": [512, 512], + "fcnet_hiddens": [512, 512, 512], # Model Learning Rate "lr": 1e-3, # Max number of training epochs per MBMPO iter @@ -81,10 +82,11 @@ DEFAULT_CONFIG = with_common_config({ # Normalize Data (obs, action, and deltas) "normalize_data": True, }, + "exploration_config": { + "type": MBMPOExploration, + }, # Workers sample from dynamics models "custom_vector_env": custom_model_vector_env, - # How many enviornments there are per worker (vectorized) - "num_worker_envs": 20, # How many iterations through MAML per MBMPO iteration "num_maml_steps": 10, }) @@ -152,7 +154,7 @@ class MetaUpdate: metrics.info[LEARNER_INFO] = fetches metrics.counters[STEPS_TRAINED_COUNTER] += samples.count - if self.step_counter == self.num_steps: + if self.step_counter == self.num_steps - 1: td_metric = self.workers.local_worker().foreach_policy( fit_dynamics)[0] diff --git a/rllib/agents/mbmpo/model_ensemble.py b/rllib/agents/mbmpo/model_ensemble.py index c252e0464..1c8d03562 100644 --- a/rllib/agents/mbmpo/model_ensemble.py +++ b/rllib/agents/mbmpo/model_ensemble.py @@ -158,7 +158,7 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module): for i in range(self.num_models): self.add_module("TD-model-" + str(i), self.dynamics_ensemble[i]) - self.replay_buffer_max = 100000 + self.replay_buffer_max = 10000 self.replay_buffer = None self.optimizers = [ torch.optim.Adam( @@ -170,7 +170,8 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module): self.metrics[STEPS_SAMPLED_COUNTER] = 0 # For each worker, choose a random model to choose trajectories from - self.sample_index = np.random.randint(self.num_models) + worker_index = get_global_worker().worker_index + self.sample_index = int((worker_index - 1) / self.num_models) self.global_itr = 0 self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) @@ -195,9 +196,10 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module): # Add env samples to Replay Buffer local_worker = get_global_worker() new_samples = local_worker.sample() + # Initial Exploration of 8000 timesteps if not self.global_itr: - tmp = local_worker.sample() - new_samples.concat(tmp) + extra = local_worker.sample() + new_samples.concat(extra) # Process Samples new_samples = process_samples(new_samples) @@ -257,9 +259,6 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module): train_losses[ind] = train_losses[ ind].detach().cpu().numpy() - del x - del y - # Validation val_lists = [] for data in zip(*val_loaders): @@ -273,8 +272,6 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module): for ind in range(self.num_models): val_losses[ind] = val_losses[ind].detach().cpu().numpy() - del x - del y val_lists = np.array(val_lists) avg_val_losses = np.mean(val_lists, axis=0) diff --git a/rllib/agents/mbmpo/model_vector_env.py b/rllib/agents/mbmpo/model_vector_env.py index 655169e06..4a0b56836 100644 --- a/rllib/agents/mbmpo/model_vector_env.py +++ b/rllib/agents/mbmpo/model_vector_env.py @@ -81,7 +81,7 @@ class _VectorizedModelGymEnv(VectorEnv): next_obs_batch = self.model.predict_model_batches( obs_batch, action_batch, device=self.device) - next_obs_batch = np.clip(next_obs_batch, -50, 50) + next_obs_batch = np.clip(next_obs_batch, -1000, 1000) rew_batch = self.envs[0].reward(obs_batch, action_batch, next_obs_batch) @@ -95,7 +95,8 @@ class _VectorizedModelGymEnv(VectorEnv): self.cur_obs = next_obs_batch - return list(obs_batch), list(rew_batch), list(dones_batch), info_batch + return list(next_obs_batch), list(rew_batch), list( + dones_batch), info_batch @override(VectorEnv) def get_unwrapped(self): diff --git a/rllib/agents/mbmpo/utils.py b/rllib/agents/mbmpo/utils.py index 16bb922da..b6efcdb47 100644 --- a/rllib/agents/mbmpo/utils.py +++ b/rllib/agents/mbmpo/utils.py @@ -1,5 +1,16 @@ import numpy as np import scipy +from typing import Union + +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.utils.annotations import override +from ray.rllib.utils.exploration.exploration import Exploration +from ray.rllib.utils.framework import try_import_tf, try_import_torch, \ + TensorType + +tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() class LinearFeatureBaseline(): @@ -66,3 +77,50 @@ def discount_cumsum(x, discount): """ return scipy.signal.lfilter( [1], [1, float(-discount)], x[::-1], axis=0)[::-1] + + +class MBMPOExploration(Exploration): + """An exploration that simply samples from a distribution. + + The sampling can be made deterministic by passing explore=False into + the call to `get_exploration_action`. + Also allows for scheduled parameters for the distributions, such as + lowering stddev, temperature, etc.. over time. + """ + + def __init__(self, action_space, *, framework: str, model: ModelV2, + **kwargs): + """Initializes a StochasticSampling Exploration object. + + Args: + action_space (Space): The gym action space used by the environment. + framework (str): One of None, "tf", "torch". + """ + assert framework is not None + self.timestep = 0 + self.worker_index = kwargs["worker_index"] + super().__init__( + action_space, model=model, framework=framework, **kwargs) + + @override(Exploration) + def get_exploration_action(self, + *, + action_distribution: ActionDistribution, + timestep: Union[int, TensorType], + explore: bool = True): + assert self.framework == "torch" + return self._get_torch_exploration_action(action_distribution, explore) + + def _get_torch_exploration_action(self, action_dist, explore): + action = action_dist.sample() + logp = action_dist.sampled_action_logp() + + batch_size = action.size()[0] + + # Initial Random Exploration for Real Env Interaction + if self.worker_index == 0 and self.timestep < 8000: + print("Using Random") + action = [self.action_space.sample() for _ in range(batch_size)] + logp = [0.0 for _ in range(batch_size)] + self.timestep += batch_size + return action, logp diff --git a/rllib/examples/env/halfcheetah.py b/rllib/examples/env/mbmpo_env.py similarity index 57% rename from rllib/examples/env/halfcheetah.py rename to rllib/examples/env/mbmpo_env.py index 70f946468..22315e547 100644 --- a/rllib/examples/env/halfcheetah.py +++ b/rllib/examples/env/mbmpo_env.py @@ -1,21 +1,5 @@ import numpy as np -from gym.envs.mujoco import HalfCheetahEnv -import inspect - - -def get_all_function_arguments(function, locals): - kwargs_dict = {} - for arg in inspect.getfullargspec(function).kwonlyargs: - if arg not in ["args", "kwargs"]: - kwargs_dict[arg] = locals[arg] - args = [locals[arg] for arg in inspect.getfullargspec(function).args] - - if "args" in locals: - args += locals["args"] - - if "kwargs" in locals: - kwargs_dict.update(locals["kwargs"]) - return args, kwargs_dict +from gym.envs.mujoco import HalfCheetahEnv, HopperEnv class HalfCheetahWrapper(HalfCheetahEnv): @@ -42,8 +26,28 @@ class HalfCheetahWrapper(HalfCheetahEnv): return np.minimum(np.maximum(-1000.0, reward), 1000.0) +class HopperWrapper(HopperEnv): + """Hopper Wrapper that wraps Mujoco Hopper-v2 env + with an additional defined reward function for model-based RL. + + This is currently used for MBMPO. + """ + + def __init__(self, *args, **kwargs): + HopperEnv.__init__(self, *args, **kwargs) + + def reward(self, obs, action, obs_next): + alive_bonus = 1.0 + assert obs.ndim == 2 and action.ndim == 2 + assert obs.shape == obs_next.shape and action.shape[0] == obs.shape[0] + vel = obs_next[:, 5] + ctrl_cost = 1e-3 * np.sum(np.square(action), axis=1) + reward = vel + alive_bonus - ctrl_cost + return np.minimum(np.maximum(-1000.0, reward), 1000.0) + + if __name__ == "__main__": - env = HalfCheetahWrapper() + env = HopperWrapper() env.reset() for _ in range(1000): env.step(env.action_space.sample()) diff --git a/rllib/tuned_examples/mbmpo/halfcheetah-mbmpo.yaml b/rllib/tuned_examples/mbmpo/halfcheetah-mbmpo.yaml index 9e69fde03..7980894af 100644 --- a/rllib/tuned_examples/mbmpo/halfcheetah-mbmpo.yaml +++ b/rllib/tuned_examples/mbmpo/halfcheetah-mbmpo.yaml @@ -1,11 +1,12 @@ -halfcheetah-mb-mpo: - env: ray.rllib.examples.env.halfcheetah.HalfCheetahWrapper +halfcheetah-mbmpo: + env: ray.rllib.examples.env.mbmpo_env.HalfCheetahWrapper run: MBMPO stop: training_iteration: 500 config: # Only supported in torch right now framework: torch + # 200 in paper, 1000 will take forever horizon: 200 num_envs_per_worker: 20 inner_adaptation_steps: 1 @@ -14,12 +15,13 @@ halfcheetah-mb-mpo: lambda: 1.0 lr: 0.001 clip_param: 0.5 - kl_target: 0.01 + kl_target: 0.003 kl_coeff: 0.0000000001 num_workers: 20 num_gpus: 1 inner_lr: 0.001 clip_actions: False + num_maml_steps: 15 model: fcnet_hiddens: [32, 32] free_log_std: True diff --git a/rllib/tuned_examples/mbmpo/hopper-mbmpo.yaml b/rllib/tuned_examples/mbmpo/hopper-mbmpo.yaml new file mode 100644 index 000000000..28d6a0b54 --- /dev/null +++ b/rllib/tuned_examples/mbmpo/hopper-mbmpo.yaml @@ -0,0 +1,27 @@ +hopper-mbmpo: + env: ray.rllib.examples.env.mbmpo_env.HopperWrapper + run: MBMPO + stop: + training_iteration: 500 + config: + # Only supported in torch right now + framework: torch + # 200 in paper, 1000 will take forever + horizon: 200 + num_envs_per_worker: 20 + inner_adaptation_steps: 1 + maml_optimizer_steps: 8 + gamma: 0.99 + lambda: 1.0 + lr: 0.001 + clip_param: 0.5 + kl_target: 0.003 + kl_coeff: 0.0000000001 + num_workers: 20 + num_gpus: 1 + inner_lr: 0.001 + clip_actions: False + num_maml_steps: 15 + model: + fcnet_hiddens: [32, 32] + free_log_std: True