[RLlib]: Off-Policy Evaluation fixes. (#25899)

This commit is contained in:
Rohan Potdar 2022-06-21 04:24:24 -07:00 committed by GitHub
parent e10876604d
commit 28df3f34f5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 97 additions and 73 deletions

View file

@ -147,8 +147,8 @@ def summarize_episodes(
if new_episodes is None:
new_episodes = episodes
episodes, estimates = _partition(episodes)
new_episodes, _ = _partition(new_episodes)
episodes, _ = _partition(episodes)
new_episodes, estimates = _partition(new_episodes)
episode_rewards = []
episode_lengths = []
@ -223,9 +223,11 @@ def summarize_episodes(
for k, v in e.metrics.items():
acc[k].append(v)
for name, metrics in estimators.items():
out = {}
for k, v_list in metrics.items():
metrics[k] = np.mean(v_list)
estimators[name] = dict(metrics)
out[k + "_mean"] = np.mean(v_list)
out[k + "_std"] = np.std(v_list)
estimators[name] = out
return dict(
episode_reward_max=max_reward,

View file

@ -1,11 +1,12 @@
from typing import Tuple, List, Generator
import logging
from typing import Tuple, Generator, List
from ray.rllib.offline.estimators.off_policy_estimator import (
OffPolicyEstimator,
OffPolicyEstimate,
)
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.utils.annotations import ExperimentalAPI, override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import SampleBatchType
@ -16,44 +17,58 @@ import numpy as np
torch, nn = try_import_torch()
logger = logging.getLogger()
# TODO (rohan): replace with AIR/parallel workers
# (And find a better name than `should_train`)
@DeveloperAPI
def k_fold_cv(
batch: SampleBatchType, k: int, should_train: bool = True
@ExperimentalAPI
def train_test_split(
batch: SampleBatchType,
train_test_split_val: float = 0.0,
k: int = 0,
) -> Generator[Tuple[List[SampleBatch]], None, None]:
"""Utility function that returns a k-fold cross validation generator
over episodes from the given batch. If the number of episodes in the
batch is less than `k` or `should_train` is set to False, yields an empty
list for train_episodes and all the episodes in test_episodes.
"""Utility function that returns either a train/test split or
a k-fold cross validation generator over episodes from the given batch.
By default, `k` is set to 0.0, which sets eval_batch = batch
and train_batch to an empty SampleBatch.
Args:
batch: A SampleBatch of episodes to split
k: Number of cross-validation splits
should_train: True by default. If False, yield [], [episodes].
train_test_split_val: Split the batch into a training batch with
`train_test_split_val * n_episodes` episodes and an evaluation batch
with `(1 - train_test_split_val) * n_episodes` episodes. If not
specified, use `k` for k-fold cross validation instead.
k: k-fold cross validation for training model and evaluating OPE.
Returns:
A tuple with two lists of SampleBatches (train_episodes, test_episodes)
A tuple with two SampleBatches (eval_batch, train_batch)
"""
if not train_test_split_val and not k:
logger.log(
"`train_test_split_val` and `k` are both 0;" "not generating training batch"
)
yield [batch], [SampleBatch()]
return
episodes = batch.split_by_episode()
n_episodes = len(episodes)
if n_episodes < k or not should_train:
yield [], episodes
# Train-test split
if train_test_split_val:
train_episodes = episodes[: int(n_episodes * train_test_split_val)]
eval_episodes = episodes[int(n_episodes * train_test_split_val) :]
yield eval_episodes, train_episodes
return
# k-fold cv
assert n_episodes >= k, f"Not enough eval episodes in batch for {k}-fold cv!"
n_fold = n_episodes // k
for i in range(k):
train_episodes = episodes[: i * n_fold] + episodes[(i + 1) * n_fold :]
if i != k - 1:
test_episodes = episodes[i * n_fold : (i + 1) * n_fold]
eval_episodes = episodes[i * n_fold : (i + 1) * n_fold]
else:
# Append remaining episodes onto the last test_episodes
test_episodes = episodes[i * n_fold :]
yield train_episodes, test_episodes
# Append remaining episodes onto the last eval_episodes
eval_episodes = episodes[i * n_fold :]
yield eval_episodes, train_episodes
return
@DeveloperAPI
@ExperimentalAPI
class DirectMethod(OffPolicyEstimator):
"""The Direct Method estimator.
@ -66,7 +81,8 @@ class DirectMethod(OffPolicyEstimator):
policy: Policy,
gamma: float,
q_model_type: str = "fqe",
k: int = 5,
train_test_split_val: float = 0.0,
k: int = 0,
**kwargs,
):
"""
@ -80,8 +96,12 @@ class DirectMethod(OffPolicyEstimator):
or "qreg" for Q-Regression, or a custom model that implements:
- `estimate_q(states,actions)`
- `estimate_v(states, action_probs)`
k: k-fold cross validation for training model and evaluating OPE
kwargs: Optional arguments for the specified Q model
train_test_split_val: Split the batch into a training batch with
`train_test_split_val * n_episodes` episodes and an evaluation batch
with `(1 - train_test_split_val) * n_episodes` episodes. If not
specified, use `k` for k-fold cross validation instead.
k: k-fold cross validation for training model and evaluating OPE.
kwargs: Optional arguments for the specified Q model.
"""
super().__init__(name, policy, gamma)
@ -117,17 +137,20 @@ class DirectMethod(OffPolicyEstimator):
gamma=gamma,
**kwargs,
)
self.train_test_split_val = train_test_split_val
self.k = k
self.losses = []
@override(OffPolicyEstimator)
def estimate(
self, batch: SampleBatchType, should_train: bool = True
) -> OffPolicyEstimate:
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)
estimates = []
# Split data into train and test using k-fold cross validation
for train_episodes, test_episodes in k_fold_cv(batch, self.k, should_train):
# Split data into train and test batches
for train_episodes, test_episodes in train_test_split(
batch,
self.train_test_split_val,
self.k,
):
# Train Q-function
if train_episodes:

View file

@ -1,13 +1,13 @@
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimate
from ray.rllib.offline.estimators.direct_method import DirectMethod, k_fold_cv
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.offline.estimators.direct_method import DirectMethod, train_test_split
from ray.rllib.utils.annotations import ExperimentalAPI, override
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.numpy import convert_to_numpy
import numpy as np
@DeveloperAPI
@ExperimentalAPI
class DoublyRobust(DirectMethod):
"""The Doubly Robust (DR) estimator.
@ -15,12 +15,17 @@ class DoublyRobust(DirectMethod):
@override(DirectMethod)
def estimate(
self, batch: SampleBatchType, should_train: bool = True
self,
batch: SampleBatchType,
) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)
estimates = []
# Split data into train and test using k-fold cross validation
for train_episodes, test_episodes in k_fold_cv(batch, self.k, should_train):
# Split data into train and test batches
for train_episodes, test_episodes in train_test_split(
batch,
self.train_test_split_val,
self.k,
):
# Train Q-function
if train_episodes:

View file

@ -5,14 +5,14 @@ from typing import List, Union
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModelConfigDict, TensorType
torch, nn = try_import_torch()
@DeveloperAPI
@ExperimentalAPI
class FQETorchModel:
"""Pytorch implementation of the Fitted Q-Evaluation (FQE) model from
https://arxiv.org/pdf/1911.06854.pdf
@ -153,7 +153,7 @@ class FQETorchModel:
)
q_values, _ = self.q_model({"obs": obs}, [], None)
q_acts = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
q_acts = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze(-1)
with torch.no_grad():
next_q_values, _ = self.target_q_model({"obs": next_obs}, [], None)
next_v = torch.sum(next_q_values * next_action_prob, axis=-1)
@ -188,7 +188,7 @@ class FQETorchModel:
q_values, _ = self.q_model({"obs": obs}, [], None)
if actions is not None:
actions = torch.tensor(actions, device=self.device, dtype=int)
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze(-1)
return q_values.detach()
def estimate_v(

View file

@ -2,13 +2,13 @@ from ray.rllib.offline.estimators.off_policy_estimator import (
OffPolicyEstimator,
OffPolicyEstimate,
)
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.annotations import override, ExperimentalAPI
from ray.rllib.utils.typing import SampleBatchType
from typing import List
import numpy as np
@DeveloperAPI
@ExperimentalAPI
class ImportanceSampling(OffPolicyEstimator):
"""The step-wise IS estimator.

View file

@ -2,7 +2,7 @@ from collections import namedtuple
import logging
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.offline.io_context import IOContext
from ray.rllib.utils.annotations import Deprecated
from ray.rllib.utils.numpy import convert_to_numpy
@ -11,16 +11,16 @@ from typing import List
logger = logging.getLogger(__name__)
OffPolicyEstimate = DeveloperAPI(
OffPolicyEstimate = ExperimentalAPI(
namedtuple("OffPolicyEstimate", ["estimator_name", "metrics"])
)
@DeveloperAPI
@ExperimentalAPI
class OffPolicyEstimator:
"""Interface for an off policy reward estimator."""
@DeveloperAPI
@ExperimentalAPI
def __init__(self, name: str, policy: Policy, gamma: float):
"""Initializes an OffPolicyEstimator instance.
@ -34,7 +34,7 @@ class OffPolicyEstimator:
self.gamma = gamma
self.new_estimates = []
@DeveloperAPI
@ExperimentalAPI
def estimate(self, batch: SampleBatchType) -> List[OffPolicyEstimate]:
"""Returns a list of off policy estimates for the given batch of episodes.
@ -46,7 +46,7 @@ class OffPolicyEstimator:
"""
raise NotImplementedError
@DeveloperAPI
@ExperimentalAPI
def train(self, batch: SampleBatchType) -> TensorType:
"""Trains an Off-Policy Estimator on a batch of experiences.
A model-based estimator should override this and train
@ -60,7 +60,7 @@ class OffPolicyEstimator:
"""
pass
@DeveloperAPI
@ExperimentalAPI
def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
"""Returns log likelihood for actions in given batch for policy.
@ -92,7 +92,7 @@ class OffPolicyEstimator:
log_likelihoods = convert_to_numpy(log_likelihoods)
return log_likelihoods
@DeveloperAPI
@ExperimentalAPI
def check_can_estimate_for(self, batch: SampleBatchType) -> None:
"""Checks if we support off policy estimation (OPE) on given batch.
@ -119,7 +119,7 @@ class OffPolicyEstimator:
"`off_policy_estimation_methods: {}` to disable estimation."
)
@DeveloperAPI
@ExperimentalAPI
def process(self, batch: SampleBatchType) -> None:
"""Computes off policy estimates (OPE) on batch and stores results.
Thus-far collected results can be retrieved then by calling
@ -130,7 +130,7 @@ class OffPolicyEstimator:
"""
self.new_estimates.extend(self.estimate(batch))
@DeveloperAPI
@ExperimentalAPI
def get_metrics(self, get_losses: bool = False) -> List[OffPolicyEstimate]:
"""Returns list of new episode metric estimates since the last call.
@ -154,7 +154,7 @@ class OffPolicyEstimator:
@Deprecated(help="OffPolicyEstimator.__init__(policy, gamma, config)", error=False)
@classmethod
@DeveloperAPI
@ExperimentalAPI
def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator":
"""Creates an off-policy estimator from an IOContext object.
Extracts Policy and gamma (discount factor) information from the
@ -178,11 +178,11 @@ class OffPolicyEstimator:
return cls(policy, gamma, config)
@Deprecated(new="OffPolicyEstimator.create_from_io_context", error=True)
@DeveloperAPI
@ExperimentalAPI
def create(self, *args, **kwargs):
return self.create_from_io_context(*args, **kwargs)
@Deprecated(new="OffPolicyEstimator.compute_log_likelihoods", error=False)
@DeveloperAPI
@ExperimentalAPI
def action_prob(self, *args, **kwargs):
return self.compute_log_likelihoods(*args, **kwargs)

View file

@ -6,14 +6,14 @@ import numpy as np
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType, ModelConfigDict
torch, nn = try_import_torch()
@DeveloperAPI
@ExperimentalAPI
class QRegTorchModel:
"""Pytorch implementation of the Q-Reg model from
https://arxiv.org/pdf/1911.06854.pdf
@ -175,7 +175,7 @@ class QRegTorchModel:
q_values, _ = self.q_model({"obs": obs[idxs]}, [], None)
q_acts = torch.gather(
q_values, -1, actions[idxs].unsqueeze(-1)
).squeeze()
).squeeze(-1)
loss = discounts[idxs] * ps[idxs] * (returns[idxs] - q_acts) ** 2
loss = torch.mean(loss)
self.optimizer.zero_grad()
@ -205,7 +205,7 @@ class QRegTorchModel:
q_values, _ = self.q_model({"obs": obs}, [], None)
if actions is not None:
actions = torch.tensor(actions, device=self.device, dtype=int)
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze(-1)
return q_values.detach()
def estimate_v(

View file

@ -92,7 +92,7 @@ class TestOPE(unittest.TestCase):
cls.std_ret["simulation"] = np.std(mc_ret)
# Optional configs for the model-based estimators
cls.model_config = {"k": 2, "n_iters": 10}
cls.model_config = {"train_test_split_val": 0.0, "k": 2, "n_iters": 10}
ray.shutdown()
@classmethod
@ -110,7 +110,6 @@ class TestOPE(unittest.TestCase):
)
estimator.process(self.batch)
estimates = estimator.get_metrics()
assert len(estimates) == self.n_episodes
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
@ -123,7 +122,6 @@ class TestOPE(unittest.TestCase):
)
estimator.process(self.batch)
estimates = estimator.get_metrics()
assert len(estimates) == self.n_episodes
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
@ -138,7 +136,6 @@ class TestOPE(unittest.TestCase):
)
estimator.process(self.batch)
estimates = estimator.get_metrics()
assert len(estimates) == self.n_episodes
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
@ -153,7 +150,6 @@ class TestOPE(unittest.TestCase):
)
estimator.process(self.batch)
estimates = estimator.get_metrics()
assert len(estimates) == self.n_episodes
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
@ -168,7 +164,6 @@ class TestOPE(unittest.TestCase):
)
estimator.process(self.batch)
estimates = estimator.get_metrics()
assert len(estimates) == self.n_episodes
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
@ -183,7 +178,6 @@ class TestOPE(unittest.TestCase):
)
estimator.process(self.batch)
estimates = estimator.get_metrics()
assert len(estimates) == self.n_episodes
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])

View file

@ -3,12 +3,12 @@ from ray.rllib.offline.estimators.off_policy_estimator import (
OffPolicyEstimate,
)
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.annotations import override, ExperimentalAPI
from ray.rllib.utils.typing import SampleBatchType
import numpy as np
@DeveloperAPI
@ExperimentalAPI
class WeightedImportanceSampling(OffPolicyEstimator):
"""The weighted step-wise IS estimator.