mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Prototype of a DynaTrainer (for env dynamics learning in upcoming MBMPO algo). (#8860)
This commit is contained in:
parent
7008902cff
commit
14405b90d5
10 changed files with 387 additions and 26 deletions
|
@ -433,6 +433,14 @@ py_test(
|
|||
srcs = ["agents/dqn/tests/test_simple_q.py"]
|
||||
)
|
||||
|
||||
# DYNATrainer
|
||||
py_test(
|
||||
name = "test_dyna",
|
||||
tags = ["agents_dir"],
|
||||
size = "small",
|
||||
srcs = ["agents/dyna/tests/test_dyna.py"]
|
||||
)
|
||||
|
||||
# ES
|
||||
py_test(
|
||||
name = "test_es",
|
||||
|
|
10
rllib/agents/dyna/__init__.py
Normal file
10
rllib/agents/dyna/__init__.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
from ray.rllib.agents.dyna.dyna import DYNATrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.dyna.dyna_torch_policy import dyna_torch_loss, \
|
||||
DYNATorchPolicy
|
||||
|
||||
__all__ = [
|
||||
"dyna_torch_loss",
|
||||
"DEFAULT_CONFIG",
|
||||
"DYNATorchPolicy",
|
||||
"DYNATrainer",
|
||||
]
|
101
rllib/agents/dyna/dyna.py
Normal file
101
rllib/agents/dyna/dyna.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
import logging
|
||||
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# Default Trainer setting overrides.
|
||||
"num_workers": 1,
|
||||
"num_envs_per_worker": 1,
|
||||
|
||||
# The size of an entire epoch (for supervised learning the dynamics).
|
||||
# The train-batch will be split into training and validation sets according
|
||||
# to `training_set_ratio`, then n epochs (with minibatch
|
||||
# size=`sgd_minibatch_size`) will be trained until the sliding average
|
||||
# of the validation performance decreases.
|
||||
"train_batch_size": 10000,
|
||||
"sgd_minibatch_size": 500,
|
||||
"rollout_fragment_length": 200,
|
||||
# Learning rate for the dynamics optimizer.
|
||||
"lr": 0.0003,
|
||||
|
||||
# Fraction of the entire data that should be used for training the dynamics
|
||||
# model. The validation fraction is 1.0 - `training_set_ratio`. Training of
|
||||
# a dynamics model over n some epochs (1 epoch = entire training set) stops
|
||||
# when the validation set's performance starts to decrease.
|
||||
"train_set_ratio": 0.8,
|
||||
|
||||
# The exploration strategy to apply on top of the (acting) policy.
|
||||
# TODO: (sven) Use random for testing purposes for now.
|
||||
"exploration_config": {"type": "Random"},
|
||||
|
||||
# Whether to predict the action that lead from obs(t) to obs(t+1), instead
|
||||
# of predicting obs(t+1).
|
||||
"predict_action": False,
|
||||
|
||||
# Whether the dynamics model should predict the reward, given obs(t)+a(t).
|
||||
# NOTE: Only supported if `predict_action`=False.
|
||||
"predict_reward": False,
|
||||
|
||||
# Whether to use the same network for predicting rewards than for
|
||||
# predicting the next observation.
|
||||
"reward_share_layers": True,
|
||||
|
||||
# TODO: (sven) figure out API to query the latent space vector given
|
||||
# some observation (not needed for MBMPO).
|
||||
"learn_latent_space": False,
|
||||
|
||||
# Whether to predict `obs(t+1) - obs(t)` instead of `obs(t+1)` directly.
|
||||
# NOTE: This only works for 1D Box observation spaces, e.g. Box(5,) and
|
||||
# if `predict_action`=False.
|
||||
"predict_obs_delta": True,
|
||||
# TODO: loss function types: neg_log_llh, etc..?
|
||||
"loss_function": "l2",
|
||||
|
||||
# Config for the dynamics learning model architecture.
|
||||
"dynamics_model": {
|
||||
"fcnet_hiddens": [512, 512],
|
||||
"fcnet_activation": "relu",
|
||||
},
|
||||
|
||||
# TODO: (sven) allow for having a default model config over many
|
||||
# sub-models: e.g. "model": {"ModelA": {[default_config]},
|
||||
# "ModelB": [default_config]}
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
if config["train_set_ratio"] <= 0.0 or \
|
||||
config["train_set_ratio"] >= 1.0:
|
||||
raise ValueError("`train_set_ratio` must be within (0.0, 1.0)!")
|
||||
if config["predict_action"] or config["predict_reward"]:
|
||||
raise ValueError(
|
||||
"`predict_action`=True or `predict_reward`=True not supported "
|
||||
"yet!")
|
||||
if config["learn_latent_space"]:
|
||||
raise ValueError("`learn_latent_space` not supported yet!")
|
||||
if config["loss_function"] != "l2":
|
||||
raise ValueError("`loss_function` other than 'l2' not supported yet!")
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.dyna.dyna_torch_policy import DYNATorchPolicy
|
||||
return DYNATorchPolicy
|
||||
else:
|
||||
raise ValueError("tf not supported yet!")
|
||||
|
||||
|
||||
DYNATrainer = build_trainer(
|
||||
name="DYNA",
|
||||
default_policy=None,
|
||||
get_policy_class=get_policy_class,
|
||||
default_config=DEFAULT_CONFIG,
|
||||
validate_config=validate_config,
|
||||
)
|
58
rllib/agents/dyna/dyna_torch_model.py
Normal file
58
rllib/agents/dyna/dyna_torch_model.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
import gym
|
||||
from gym.spaces import Discrete
|
||||
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
class DYNATorchModel(TorchModelV2, nn.Module):
|
||||
"""Extension of standard TorchModelV2 for Env dynamics learning.
|
||||
|
||||
Data flow:
|
||||
obs.cat(action) -> forward() -> next_obs|next_obs_delta
|
||||
get_next_state(obs, action) -> next_obs|next_obs_delta
|
||||
|
||||
Note that this class by itself is not a valid model unless you
|
||||
implement forward() in a subclass.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
"""Initializes a DYNATorchModel object.
|
||||
"""
|
||||
|
||||
nn.Module.__init__(self)
|
||||
# Construct the wrapped model handing it a concat'd observation and
|
||||
# action space as "input_space" and our obs_space as "output_space".
|
||||
# TODO: (sven) get rid of these restrictions on obs/action spaces.
|
||||
assert isinstance(action_space, Discrete)
|
||||
input_space = gym.spaces.Box(
|
||||
obs_space.low[0],
|
||||
obs_space.high[0],
|
||||
shape=(obs_space.shape[0] + action_space.n, ))
|
||||
super(DYNATorchModel, self).__init__(input_space, action_space,
|
||||
num_outputs, model_config, name)
|
||||
|
||||
def get_next_observation(self, observations, actions):
|
||||
"""Returns a next obs prediction given current observation and action.
|
||||
|
||||
This implements p^(s'|s, a). With p being the environment dynamics.
|
||||
|
||||
Arguments:
|
||||
observations (Tensor): The current observation Tensor.
|
||||
actions (Tensor): The actions taken in `observations`.
|
||||
|
||||
Returns:
|
||||
TensorType: The predicted next observations.
|
||||
"""
|
||||
|
||||
# One-hot the actions.
|
||||
actions_flat = nn.functional.one_hot(
|
||||
actions, num_classes=self.action_space.n).float()
|
||||
# Push through our underlying Model.
|
||||
next_obs, _ = self.forward({
|
||||
"obs_flat": torch.cat([observations, actions_flat], -1)
|
||||
}, [], None)
|
||||
return next_obs
|
94
rllib/agents/dyna/dyna_torch_policy.py
Normal file
94
rllib/agents/dyna/dyna_torch_policy.py
Normal file
|
@ -0,0 +1,94 @@
|
|||
import gym
|
||||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.dyna.dyna_torch_model import DYNATorchModel
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils import try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_model_and_dist(policy, obs_space, action_space, config):
|
||||
# Get the output distribution class for predicting rewards and next-obs.
|
||||
policy.distr_cls_next_obs, num_outputs = ModelCatalog.get_action_dist(
|
||||
obs_space, config, dist_type="deterministic", framework="torch")
|
||||
if config["predict_reward"]:
|
||||
# TODO: (sven) implement reward prediction.
|
||||
_ = ModelCatalog.get_action_dist(
|
||||
gym.spaces.Box(float("-inf"), float("inf"), ()),
|
||||
config,
|
||||
dist_type="")
|
||||
|
||||
# Build one dynamics model if we are a Worker.
|
||||
# If we are the main MAML learner, build n (num_workers) dynamics Models
|
||||
# for being able to create checkpoints for the current state of training.
|
||||
policy.dynamics_model = ModelCatalog.get_model_v2(
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=config["dynamics_model"],
|
||||
framework="torch",
|
||||
name="dynamics_model",
|
||||
model_interface=DYNATorchModel,
|
||||
)
|
||||
|
||||
action_dist, num_outputs = ModelCatalog.get_action_dist(
|
||||
action_space, config, dist_type="deterministic", framework="torch")
|
||||
# Create the pi-model and register it with the Policy.
|
||||
policy.pi = ModelCatalog.get_model_v2(
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=config["model"],
|
||||
framework="torch",
|
||||
name="policy_model",
|
||||
)
|
||||
|
||||
return policy.pi, action_dist
|
||||
|
||||
|
||||
def dyna_torch_loss(policy, model, dist_class, train_batch):
|
||||
# Split batch into train and validation sets according to
|
||||
# `train_set_ratio`.
|
||||
predicted_next_state_deltas = \
|
||||
policy.dynamics_model.get_next_observation(
|
||||
train_batch[SampleBatch.CUR_OBS], train_batch[SampleBatch.ACTIONS])
|
||||
labels = train_batch[SampleBatch.NEXT_OBS] - train_batch[SampleBatch.
|
||||
CUR_OBS]
|
||||
loss = torch.pow(
|
||||
torch.sum(
|
||||
torch.pow(labels - predicted_next_state_deltas, 2.0), dim=-1), 0.5)
|
||||
batch_size = int(loss.shape[0])
|
||||
train_set_size = int(batch_size * policy.config["train_set_ratio"])
|
||||
train_loss, validation_loss = \
|
||||
torch.split(loss, (train_set_size, batch_size - train_set_size), dim=0)
|
||||
policy.dynamics_train_loss = torch.mean(train_loss)
|
||||
policy.dynamics_validation_loss = torch.mean(validation_loss)
|
||||
return policy.dynamics_train_loss
|
||||
|
||||
|
||||
def stats_fn(policy, train_batch):
|
||||
return {
|
||||
"dynamics_train_loss": policy.dynamics_train_loss,
|
||||
"dynamics_validation_loss": policy.dynamics_validation_loss,
|
||||
}
|
||||
|
||||
|
||||
def torch_optimizer(policy, config):
|
||||
return torch.optim.Adam(
|
||||
policy.dynamics_model.parameters(), lr=config["lr"])
|
||||
|
||||
|
||||
DYNATorchPolicy = build_torch_policy(
|
||||
name="DYNATorchPolicy",
|
||||
loss_fn=dyna_torch_loss,
|
||||
get_default_config=lambda: ray.rllib.agents.dyna.dyna.DEFAULT_CONFIG,
|
||||
stats_fn=stats_fn,
|
||||
optimizer_fn=torch_optimizer,
|
||||
make_model_and_action_dist=make_model_and_dist,
|
||||
)
|
72
rllib/agents/dyna/tests/test_dyna.py
Normal file
72
rllib/agents/dyna/tests/test_dyna.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
import copy
|
||||
import gym
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.dyna as dyna
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.test_utils import check_compute_single_action, \
|
||||
framework_iterator
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
class TestDYNA(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init(local_mode=True)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def test_dyna_compilation(self):
|
||||
"""Test whether a DYNATrainer can be built with both frameworks."""
|
||||
config = copy.deepcopy(dyna.DEFAULT_CONFIG)
|
||||
config["num_workers"] = 1
|
||||
config["train_batch_size"] = 1000
|
||||
num_iterations = 30
|
||||
env = "CartPole-v0"
|
||||
test_env = gym.make(env)
|
||||
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
trainer = dyna.DYNATrainer(config=config, env=env)
|
||||
policy = trainer.get_policy()
|
||||
# Do n supervised epochs, each over `train_batch_size`.
|
||||
# Ignore validation loss here as a stopping criteria.
|
||||
for i in range(num_iterations):
|
||||
info = trainer.train()["info"]["learner"]["default_policy"]
|
||||
print("SL iteration: {}".format(i))
|
||||
print("train loss {}".format(info["dynamics_train_loss"]))
|
||||
print("validation loss {}".format(
|
||||
info["dynamics_validation_loss"]))
|
||||
# Check, whether normal action stepping works with DYNA's policy.
|
||||
# Note that DYNA does not train its Policy. It must be pushed
|
||||
# down from the main model-based algo from time to time.
|
||||
check_compute_single_action(trainer)
|
||||
|
||||
# Check, whether env dynamics were actually learnt - more or less.
|
||||
obs = test_env.reset()
|
||||
for _ in range(10):
|
||||
action = trainer.compute_action(obs)
|
||||
obs = torch.from_numpy(np.array([obs])).float()
|
||||
# Make the prediction over the next state (deterministic delta
|
||||
# like in MBMPO).
|
||||
predicted_next_obs_delta = \
|
||||
policy.dynamics_model.get_next_observation(
|
||||
obs,
|
||||
torch.from_numpy(np.array([action])))
|
||||
predicted_next_obs = obs + predicted_next_obs_delta
|
||||
obs, _, done, _ = test_env.step(action)
|
||||
self.assertLess(
|
||||
np.sum(obs - predicted_next_obs.detach().numpy()), 0.05)
|
||||
# Reset if done.
|
||||
if done:
|
||||
obs = test_env.reset()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
2
rllib/examples/env/stateless_cartpole.py
vendored
2
rllib/examples/env/stateless_cartpole.py
vendored
|
@ -1,7 +1,7 @@
|
|||
import math
|
||||
import gym
|
||||
from gym import spaces
|
||||
from gym.utils import seeding
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
|
|
@ -125,14 +125,18 @@ class ModelCatalog:
|
|||
Args:
|
||||
action_space (Space): Action space of the target gym env.
|
||||
config (Optional[dict]): Optional model config.
|
||||
dist_type (Optional[str]): Identifier of the action distribution.
|
||||
dist_type (Optional[str]): Identifier of the action distribution
|
||||
interpreted as a hint.
|
||||
framework (str): One of "tf", "tfe", or "torch".
|
||||
kwargs (dict): Optional kwargs to pass on to the Distribution's
|
||||
constructor.
|
||||
|
||||
Returns:
|
||||
dist_class (ActionDistribution): Python class of the distribution.
|
||||
dist_dim (int): The size of the input vector to the distribution.
|
||||
Tuple:
|
||||
- dist_class (ActionDistribution): Python class of the
|
||||
distribution.
|
||||
- dist_dim (int): The size of the input vector to the
|
||||
distribution.
|
||||
"""
|
||||
|
||||
dist = None
|
||||
|
@ -523,6 +527,30 @@ class ModelCatalog:
|
|||
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
def _get_v2_model_class(input_space, model_config, framework="tf"):
|
||||
if framework == "torch":
|
||||
from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
|
||||
FCNet)
|
||||
from ray.rllib.models.torch.visionnet import (VisionNetwork as
|
||||
VisionNet)
|
||||
else:
|
||||
from ray.rllib.models.tf.fcnet import \
|
||||
FullyConnectedNetwork as FCNet
|
||||
from ray.rllib.models.tf.visionnet import \
|
||||
VisionNetwork as VisionNet
|
||||
|
||||
# Discrete/1D obs-spaces.
|
||||
if isinstance(input_space, gym.spaces.Discrete) or \
|
||||
len(input_space.shape) <= 2:
|
||||
return FCNet
|
||||
# Default Conv2D net.
|
||||
else:
|
||||
return VisionNet
|
||||
|
||||
# -------------------
|
||||
# DEPRECATED METHODS.
|
||||
# -------------------
|
||||
@staticmethod
|
||||
def get_model(input_dict,
|
||||
obs_space,
|
||||
|
@ -581,27 +609,6 @@ class ModelCatalog:
|
|||
return FullyConnectedNetwork(input_dict, obs_space, action_space,
|
||||
num_outputs, options)
|
||||
|
||||
@staticmethod
|
||||
def _get_v2_model_class(obs_space, model_config, framework="tf"):
|
||||
if framework == "torch":
|
||||
from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
|
||||
FCNet)
|
||||
from ray.rllib.models.torch.visionnet import (VisionNetwork as
|
||||
VisionNet)
|
||||
else:
|
||||
from ray.rllib.models.tf.fcnet import \
|
||||
FullyConnectedNetwork as FCNet
|
||||
from ray.rllib.models.tf.visionnet import \
|
||||
VisionNetwork as VisionNet
|
||||
|
||||
# Discrete/1D obs-spaces.
|
||||
if isinstance(obs_space, gym.spaces.Discrete) or \
|
||||
len(obs_space.shape) <= 2:
|
||||
return FCNet
|
||||
# Default Conv2D net.
|
||||
else:
|
||||
return VisionNet
|
||||
|
||||
@staticmethod
|
||||
def get_torch_model(obs_space,
|
||||
num_outputs,
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import functools
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
||||
|
@ -149,6 +151,13 @@ class TorchPolicy(Policy):
|
|||
dist_class = self.dist_class
|
||||
dist_inputs, state_out = self.model(
|
||||
input_dict, state_batches, seq_lens)
|
||||
if not (isinstance(dist_class, functools.partial)
|
||||
or issubclass(dist_class, TorchDistributionWrapper)):
|
||||
raise ValueError(
|
||||
"`dist_class` ({}) not a TorchDistributionWrapper "
|
||||
"subclass! Make sure your `action_distribution_fn` or "
|
||||
"`make_model_and_action_dist` return a correct "
|
||||
"distribution class.".format(dist_class.__name__))
|
||||
action_dist = dist_class(dist_inputs, self.model)
|
||||
|
||||
# Get the exploration action from the forward results.
|
||||
|
|
|
@ -99,7 +99,9 @@ def build_torch_policy(name,
|
|||
|
||||
# Model is customized (use default action dist class).
|
||||
if make_model:
|
||||
assert make_model_and_action_dist is None
|
||||
assert make_model_and_action_dist is None, \
|
||||
"Either `make_model` or `make_model_and_action_dist`" \
|
||||
" must be None!"
|
||||
self.model = make_model(self, obs_space, action_space, config)
|
||||
dist_class, _ = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"], framework="torch")
|
||||
|
|
Loading…
Add table
Reference in a new issue