[RLlib] Prototype of a DynaTrainer (for env dynamics learning in upcoming MBMPO algo). (#8860)

This commit is contained in:
Sven Mika 2020-06-16 09:01:20 +02:00 committed by GitHub
parent 7008902cff
commit 14405b90d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 387 additions and 26 deletions

View file

@ -433,6 +433,14 @@ py_test(
srcs = ["agents/dqn/tests/test_simple_q.py"]
)
# DYNATrainer
py_test(
name = "test_dyna",
tags = ["agents_dir"],
size = "small",
srcs = ["agents/dyna/tests/test_dyna.py"]
)
# ES
py_test(
name = "test_es",

View file

@ -0,0 +1,10 @@
from ray.rllib.agents.dyna.dyna import DYNATrainer, DEFAULT_CONFIG
from ray.rllib.agents.dyna.dyna_torch_policy import dyna_torch_loss, \
DYNATorchPolicy
__all__ = [
"dyna_torch_loss",
"DEFAULT_CONFIG",
"DYNATorchPolicy",
"DYNATrainer",
]

101
rllib/agents/dyna/dyna.py Normal file
View file

@ -0,0 +1,101 @@
import logging
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
logger = logging.getLogger(__name__)
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# Default Trainer setting overrides.
"num_workers": 1,
"num_envs_per_worker": 1,
# The size of an entire epoch (for supervised learning the dynamics).
# The train-batch will be split into training and validation sets according
# to `training_set_ratio`, then n epochs (with minibatch
# size=`sgd_minibatch_size`) will be trained until the sliding average
# of the validation performance decreases.
"train_batch_size": 10000,
"sgd_minibatch_size": 500,
"rollout_fragment_length": 200,
# Learning rate for the dynamics optimizer.
"lr": 0.0003,
# Fraction of the entire data that should be used for training the dynamics
# model. The validation fraction is 1.0 - `training_set_ratio`. Training of
# a dynamics model over n some epochs (1 epoch = entire training set) stops
# when the validation set's performance starts to decrease.
"train_set_ratio": 0.8,
# The exploration strategy to apply on top of the (acting) policy.
# TODO: (sven) Use random for testing purposes for now.
"exploration_config": {"type": "Random"},
# Whether to predict the action that lead from obs(t) to obs(t+1), instead
# of predicting obs(t+1).
"predict_action": False,
# Whether the dynamics model should predict the reward, given obs(t)+a(t).
# NOTE: Only supported if `predict_action`=False.
"predict_reward": False,
# Whether to use the same network for predicting rewards than for
# predicting the next observation.
"reward_share_layers": True,
# TODO: (sven) figure out API to query the latent space vector given
# some observation (not needed for MBMPO).
"learn_latent_space": False,
# Whether to predict `obs(t+1) - obs(t)` instead of `obs(t+1)` directly.
# NOTE: This only works for 1D Box observation spaces, e.g. Box(5,) and
# if `predict_action`=False.
"predict_obs_delta": True,
# TODO: loss function types: neg_log_llh, etc..?
"loss_function": "l2",
# Config for the dynamics learning model architecture.
"dynamics_model": {
"fcnet_hiddens": [512, 512],
"fcnet_activation": "relu",
},
# TODO: (sven) allow for having a default model config over many
# sub-models: e.g. "model": {"ModelA": {[default_config]},
# "ModelB": [default_config]}
})
# __sphinx_doc_end__
# yapf: enable
def validate_config(config):
if config["train_set_ratio"] <= 0.0 or \
config["train_set_ratio"] >= 1.0:
raise ValueError("`train_set_ratio` must be within (0.0, 1.0)!")
if config["predict_action"] or config["predict_reward"]:
raise ValueError(
"`predict_action`=True or `predict_reward`=True not supported "
"yet!")
if config["learn_latent_space"]:
raise ValueError("`learn_latent_space` not supported yet!")
if config["loss_function"] != "l2":
raise ValueError("`loss_function` other than 'l2' not supported yet!")
def get_policy_class(config):
if config["framework"] == "torch":
from ray.rllib.agents.dyna.dyna_torch_policy import DYNATorchPolicy
return DYNATorchPolicy
else:
raise ValueError("tf not supported yet!")
DYNATrainer = build_trainer(
name="DYNA",
default_policy=None,
get_policy_class=get_policy_class,
default_config=DEFAULT_CONFIG,
validate_config=validate_config,
)

View file

@ -0,0 +1,58 @@
import gym
from gym.spaces import Discrete
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_torch
torch, nn = try_import_torch()
class DYNATorchModel(TorchModelV2, nn.Module):
"""Extension of standard TorchModelV2 for Env dynamics learning.
Data flow:
obs.cat(action) -> forward() -> next_obs|next_obs_delta
get_next_state(obs, action) -> next_obs|next_obs_delta
Note that this class by itself is not a valid model unless you
implement forward() in a subclass.
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
"""Initializes a DYNATorchModel object.
"""
nn.Module.__init__(self)
# Construct the wrapped model handing it a concat'd observation and
# action space as "input_space" and our obs_space as "output_space".
# TODO: (sven) get rid of these restrictions on obs/action spaces.
assert isinstance(action_space, Discrete)
input_space = gym.spaces.Box(
obs_space.low[0],
obs_space.high[0],
shape=(obs_space.shape[0] + action_space.n, ))
super(DYNATorchModel, self).__init__(input_space, action_space,
num_outputs, model_config, name)
def get_next_observation(self, observations, actions):
"""Returns a next obs prediction given current observation and action.
This implements p^(s'|s, a). With p being the environment dynamics.
Arguments:
observations (Tensor): The current observation Tensor.
actions (Tensor): The actions taken in `observations`.
Returns:
TensorType: The predicted next observations.
"""
# One-hot the actions.
actions_flat = nn.functional.one_hot(
actions, num_classes=self.action_space.n).float()
# Push through our underlying Model.
next_obs, _ = self.forward({
"obs_flat": torch.cat([observations, actions_flat], -1)
}, [], None)
return next_obs

View file

@ -0,0 +1,94 @@
import gym
import logging
import ray
from ray.rllib.agents.dyna.dyna_torch_model import DYNATorchModel
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils import try_import_torch
torch, nn = try_import_torch()
logger = logging.getLogger(__name__)
def make_model_and_dist(policy, obs_space, action_space, config):
# Get the output distribution class for predicting rewards and next-obs.
policy.distr_cls_next_obs, num_outputs = ModelCatalog.get_action_dist(
obs_space, config, dist_type="deterministic", framework="torch")
if config["predict_reward"]:
# TODO: (sven) implement reward prediction.
_ = ModelCatalog.get_action_dist(
gym.spaces.Box(float("-inf"), float("inf"), ()),
config,
dist_type="")
# Build one dynamics model if we are a Worker.
# If we are the main MAML learner, build n (num_workers) dynamics Models
# for being able to create checkpoints for the current state of training.
policy.dynamics_model = ModelCatalog.get_model_v2(
obs_space,
action_space,
num_outputs=num_outputs,
model_config=config["dynamics_model"],
framework="torch",
name="dynamics_model",
model_interface=DYNATorchModel,
)
action_dist, num_outputs = ModelCatalog.get_action_dist(
action_space, config, dist_type="deterministic", framework="torch")
# Create the pi-model and register it with the Policy.
policy.pi = ModelCatalog.get_model_v2(
obs_space,
action_space,
num_outputs=num_outputs,
model_config=config["model"],
framework="torch",
name="policy_model",
)
return policy.pi, action_dist
def dyna_torch_loss(policy, model, dist_class, train_batch):
# Split batch into train and validation sets according to
# `train_set_ratio`.
predicted_next_state_deltas = \
policy.dynamics_model.get_next_observation(
train_batch[SampleBatch.CUR_OBS], train_batch[SampleBatch.ACTIONS])
labels = train_batch[SampleBatch.NEXT_OBS] - train_batch[SampleBatch.
CUR_OBS]
loss = torch.pow(
torch.sum(
torch.pow(labels - predicted_next_state_deltas, 2.0), dim=-1), 0.5)
batch_size = int(loss.shape[0])
train_set_size = int(batch_size * policy.config["train_set_ratio"])
train_loss, validation_loss = \
torch.split(loss, (train_set_size, batch_size - train_set_size), dim=0)
policy.dynamics_train_loss = torch.mean(train_loss)
policy.dynamics_validation_loss = torch.mean(validation_loss)
return policy.dynamics_train_loss
def stats_fn(policy, train_batch):
return {
"dynamics_train_loss": policy.dynamics_train_loss,
"dynamics_validation_loss": policy.dynamics_validation_loss,
}
def torch_optimizer(policy, config):
return torch.optim.Adam(
policy.dynamics_model.parameters(), lr=config["lr"])
DYNATorchPolicy = build_torch_policy(
name="DYNATorchPolicy",
loss_fn=dyna_torch_loss,
get_default_config=lambda: ray.rllib.agents.dyna.dyna.DEFAULT_CONFIG,
stats_fn=stats_fn,
optimizer_fn=torch_optimizer,
make_model_and_action_dist=make_model_and_dist,
)

View file

@ -0,0 +1,72 @@
import copy
import gym
import numpy as np
import unittest
import ray
import ray.rllib.agents.dyna as dyna
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.test_utils import check_compute_single_action, \
framework_iterator
torch, _ = try_import_torch()
class TestDYNA(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(local_mode=True)
@classmethod
def tearDownClass(cls):
ray.shutdown()
def test_dyna_compilation(self):
"""Test whether a DYNATrainer can be built with both frameworks."""
config = copy.deepcopy(dyna.DEFAULT_CONFIG)
config["num_workers"] = 1
config["train_batch_size"] = 1000
num_iterations = 30
env = "CartPole-v0"
test_env = gym.make(env)
for _ in framework_iterator(config, frameworks="torch"):
trainer = dyna.DYNATrainer(config=config, env=env)
policy = trainer.get_policy()
# Do n supervised epochs, each over `train_batch_size`.
# Ignore validation loss here as a stopping criteria.
for i in range(num_iterations):
info = trainer.train()["info"]["learner"]["default_policy"]
print("SL iteration: {}".format(i))
print("train loss {}".format(info["dynamics_train_loss"]))
print("validation loss {}".format(
info["dynamics_validation_loss"]))
# Check, whether normal action stepping works with DYNA's policy.
# Note that DYNA does not train its Policy. It must be pushed
# down from the main model-based algo from time to time.
check_compute_single_action(trainer)
# Check, whether env dynamics were actually learnt - more or less.
obs = test_env.reset()
for _ in range(10):
action = trainer.compute_action(obs)
obs = torch.from_numpy(np.array([obs])).float()
# Make the prediction over the next state (deterministic delta
# like in MBMPO).
predicted_next_obs_delta = \
policy.dynamics_model.get_next_observation(
obs,
torch.from_numpy(np.array([action])))
predicted_next_obs = obs + predicted_next_obs_delta
obs, _, done, _ = test_env.step(action)
self.assertLess(
np.sum(obs - predicted_next_obs.detach().numpy()), 0.05)
# Reset if done.
if done:
obs = test_env.reset()
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -1,7 +1,7 @@
import math
import gym
from gym import spaces
from gym.utils import seeding
import math
import numpy as np

View file

@ -125,14 +125,18 @@ class ModelCatalog:
Args:
action_space (Space): Action space of the target gym env.
config (Optional[dict]): Optional model config.
dist_type (Optional[str]): Identifier of the action distribution.
dist_type (Optional[str]): Identifier of the action distribution
interpreted as a hint.
framework (str): One of "tf", "tfe", or "torch".
kwargs (dict): Optional kwargs to pass on to the Distribution's
constructor.
Returns:
dist_class (ActionDistribution): Python class of the distribution.
dist_dim (int): The size of the input vector to the distribution.
Tuple:
- dist_class (ActionDistribution): Python class of the
distribution.
- dist_dim (int): The size of the input vector to the
distribution.
"""
dist = None
@ -523,6 +527,30 @@ class ModelCatalog:
return wrapper
@staticmethod
def _get_v2_model_class(input_space, model_config, framework="tf"):
if framework == "torch":
from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
FCNet)
from ray.rllib.models.torch.visionnet import (VisionNetwork as
VisionNet)
else:
from ray.rllib.models.tf.fcnet import \
FullyConnectedNetwork as FCNet
from ray.rllib.models.tf.visionnet import \
VisionNetwork as VisionNet
# Discrete/1D obs-spaces.
if isinstance(input_space, gym.spaces.Discrete) or \
len(input_space.shape) <= 2:
return FCNet
# Default Conv2D net.
else:
return VisionNet
# -------------------
# DEPRECATED METHODS.
# -------------------
@staticmethod
def get_model(input_dict,
obs_space,
@ -581,27 +609,6 @@ class ModelCatalog:
return FullyConnectedNetwork(input_dict, obs_space, action_space,
num_outputs, options)
@staticmethod
def _get_v2_model_class(obs_space, model_config, framework="tf"):
if framework == "torch":
from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
FCNet)
from ray.rllib.models.torch.visionnet import (VisionNetwork as
VisionNet)
else:
from ray.rllib.models.tf.fcnet import \
FullyConnectedNetwork as FCNet
from ray.rllib.models.tf.visionnet import \
VisionNetwork as VisionNet
# Discrete/1D obs-spaces.
if isinstance(obs_space, gym.spaces.Discrete) or \
len(obs_space.shape) <= 2:
return FCNet
# Default Conv2D net.
else:
return VisionNet
@staticmethod
def get_torch_model(obs_space,
num_outputs,

View file

@ -1,6 +1,8 @@
import functools
import numpy as np
import time
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
@ -149,6 +151,13 @@ class TorchPolicy(Policy):
dist_class = self.dist_class
dist_inputs, state_out = self.model(
input_dict, state_batches, seq_lens)
if not (isinstance(dist_class, functools.partial)
or issubclass(dist_class, TorchDistributionWrapper)):
raise ValueError(
"`dist_class` ({}) not a TorchDistributionWrapper "
"subclass! Make sure your `action_distribution_fn` or "
"`make_model_and_action_dist` return a correct "
"distribution class.".format(dist_class.__name__))
action_dist = dist_class(dist_inputs, self.model)
# Get the exploration action from the forward results.

View file

@ -99,7 +99,9 @@ def build_torch_policy(name,
# Model is customized (use default action dist class).
if make_model:
assert make_model_and_action_dist is None
assert make_model_and_action_dist is None, \
"Either `make_model` or `make_model_and_action_dist`" \
" must be None!"
self.model = make_model(self, obs_space, action_space, config)
dist_class, _ = ModelCatalog.get_action_dist(
action_space, self.config["model"], framework="torch")