mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib]: Off-Policy Evaluation fixes. (#25899)
This commit is contained in:
parent
e10876604d
commit
28df3f34f5
9 changed files with 97 additions and 73 deletions
|
@ -147,8 +147,8 @@ def summarize_episodes(
|
|||
if new_episodes is None:
|
||||
new_episodes = episodes
|
||||
|
||||
episodes, estimates = _partition(episodes)
|
||||
new_episodes, _ = _partition(new_episodes)
|
||||
episodes, _ = _partition(episodes)
|
||||
new_episodes, estimates = _partition(new_episodes)
|
||||
|
||||
episode_rewards = []
|
||||
episode_lengths = []
|
||||
|
@ -223,9 +223,11 @@ def summarize_episodes(
|
|||
for k, v in e.metrics.items():
|
||||
acc[k].append(v)
|
||||
for name, metrics in estimators.items():
|
||||
out = {}
|
||||
for k, v_list in metrics.items():
|
||||
metrics[k] = np.mean(v_list)
|
||||
estimators[name] = dict(metrics)
|
||||
out[k + "_mean"] = np.mean(v_list)
|
||||
out[k + "_std"] = np.std(v_list)
|
||||
estimators[name] = out
|
||||
|
||||
return dict(
|
||||
episode_reward_max=max_reward,
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
from typing import Tuple, List, Generator
|
||||
import logging
|
||||
from typing import Tuple, Generator, List
|
||||
from ray.rllib.offline.estimators.off_policy_estimator import (
|
||||
OffPolicyEstimator,
|
||||
OffPolicyEstimate,
|
||||
)
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, override
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI, override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
|
@ -16,44 +17,58 @@ import numpy as np
|
|||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
# TODO (rohan): replace with AIR/parallel workers
|
||||
# (And find a better name than `should_train`)
|
||||
@DeveloperAPI
|
||||
def k_fold_cv(
|
||||
batch: SampleBatchType, k: int, should_train: bool = True
|
||||
|
||||
@ExperimentalAPI
|
||||
def train_test_split(
|
||||
batch: SampleBatchType,
|
||||
train_test_split_val: float = 0.0,
|
||||
k: int = 0,
|
||||
) -> Generator[Tuple[List[SampleBatch]], None, None]:
|
||||
"""Utility function that returns a k-fold cross validation generator
|
||||
over episodes from the given batch. If the number of episodes in the
|
||||
batch is less than `k` or `should_train` is set to False, yields an empty
|
||||
list for train_episodes and all the episodes in test_episodes.
|
||||
|
||||
"""Utility function that returns either a train/test split or
|
||||
a k-fold cross validation generator over episodes from the given batch.
|
||||
By default, `k` is set to 0.0, which sets eval_batch = batch
|
||||
and train_batch to an empty SampleBatch.
|
||||
Args:
|
||||
batch: A SampleBatch of episodes to split
|
||||
k: Number of cross-validation splits
|
||||
should_train: True by default. If False, yield [], [episodes].
|
||||
|
||||
train_test_split_val: Split the batch into a training batch with
|
||||
`train_test_split_val * n_episodes` episodes and an evaluation batch
|
||||
with `(1 - train_test_split_val) * n_episodes` episodes. If not
|
||||
specified, use `k` for k-fold cross validation instead.
|
||||
k: k-fold cross validation for training model and evaluating OPE.
|
||||
Returns:
|
||||
A tuple with two lists of SampleBatches (train_episodes, test_episodes)
|
||||
A tuple with two SampleBatches (eval_batch, train_batch)
|
||||
"""
|
||||
if not train_test_split_val and not k:
|
||||
logger.log(
|
||||
"`train_test_split_val` and `k` are both 0;" "not generating training batch"
|
||||
)
|
||||
yield [batch], [SampleBatch()]
|
||||
return
|
||||
episodes = batch.split_by_episode()
|
||||
n_episodes = len(episodes)
|
||||
if n_episodes < k or not should_train:
|
||||
yield [], episodes
|
||||
# Train-test split
|
||||
if train_test_split_val:
|
||||
train_episodes = episodes[: int(n_episodes * train_test_split_val)]
|
||||
eval_episodes = episodes[int(n_episodes * train_test_split_val) :]
|
||||
yield eval_episodes, train_episodes
|
||||
return
|
||||
# k-fold cv
|
||||
assert n_episodes >= k, f"Not enough eval episodes in batch for {k}-fold cv!"
|
||||
n_fold = n_episodes // k
|
||||
for i in range(k):
|
||||
train_episodes = episodes[: i * n_fold] + episodes[(i + 1) * n_fold :]
|
||||
if i != k - 1:
|
||||
test_episodes = episodes[i * n_fold : (i + 1) * n_fold]
|
||||
eval_episodes = episodes[i * n_fold : (i + 1) * n_fold]
|
||||
else:
|
||||
# Append remaining episodes onto the last test_episodes
|
||||
test_episodes = episodes[i * n_fold :]
|
||||
yield train_episodes, test_episodes
|
||||
# Append remaining episodes onto the last eval_episodes
|
||||
eval_episodes = episodes[i * n_fold :]
|
||||
yield eval_episodes, train_episodes
|
||||
return
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
class DirectMethod(OffPolicyEstimator):
|
||||
"""The Direct Method estimator.
|
||||
|
||||
|
@ -66,7 +81,8 @@ class DirectMethod(OffPolicyEstimator):
|
|||
policy: Policy,
|
||||
gamma: float,
|
||||
q_model_type: str = "fqe",
|
||||
k: int = 5,
|
||||
train_test_split_val: float = 0.0,
|
||||
k: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -80,8 +96,12 @@ class DirectMethod(OffPolicyEstimator):
|
|||
or "qreg" for Q-Regression, or a custom model that implements:
|
||||
- `estimate_q(states,actions)`
|
||||
- `estimate_v(states, action_probs)`
|
||||
k: k-fold cross validation for training model and evaluating OPE
|
||||
kwargs: Optional arguments for the specified Q model
|
||||
train_test_split_val: Split the batch into a training batch with
|
||||
`train_test_split_val * n_episodes` episodes and an evaluation batch
|
||||
with `(1 - train_test_split_val) * n_episodes` episodes. If not
|
||||
specified, use `k` for k-fold cross validation instead.
|
||||
k: k-fold cross validation for training model and evaluating OPE.
|
||||
kwargs: Optional arguments for the specified Q model.
|
||||
"""
|
||||
|
||||
super().__init__(name, policy, gamma)
|
||||
|
@ -117,17 +137,20 @@ class DirectMethod(OffPolicyEstimator):
|
|||
gamma=gamma,
|
||||
**kwargs,
|
||||
)
|
||||
self.train_test_split_val = train_test_split_val
|
||||
self.k = k
|
||||
self.losses = []
|
||||
|
||||
@override(OffPolicyEstimator)
|
||||
def estimate(
|
||||
self, batch: SampleBatchType, should_train: bool = True
|
||||
) -> OffPolicyEstimate:
|
||||
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
|
||||
self.check_can_estimate_for(batch)
|
||||
estimates = []
|
||||
# Split data into train and test using k-fold cross validation
|
||||
for train_episodes, test_episodes in k_fold_cv(batch, self.k, should_train):
|
||||
# Split data into train and test batches
|
||||
for train_episodes, test_episodes in train_test_split(
|
||||
batch,
|
||||
self.train_test_split_val,
|
||||
self.k,
|
||||
):
|
||||
|
||||
# Train Q-function
|
||||
if train_episodes:
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimate
|
||||
from ray.rllib.offline.estimators.direct_method import DirectMethod, k_fold_cv
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, override
|
||||
from ray.rllib.offline.estimators.direct_method import DirectMethod, train_test_split
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI, override
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
import numpy as np
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
class DoublyRobust(DirectMethod):
|
||||
"""The Doubly Robust (DR) estimator.
|
||||
|
||||
|
@ -15,12 +15,17 @@ class DoublyRobust(DirectMethod):
|
|||
|
||||
@override(DirectMethod)
|
||||
def estimate(
|
||||
self, batch: SampleBatchType, should_train: bool = True
|
||||
self,
|
||||
batch: SampleBatchType,
|
||||
) -> OffPolicyEstimate:
|
||||
self.check_can_estimate_for(batch)
|
||||
estimates = []
|
||||
# Split data into train and test using k-fold cross validation
|
||||
for train_episodes, test_episodes in k_fold_cv(batch, self.k, should_train):
|
||||
# Split data into train and test batches
|
||||
for train_episodes, test_episodes in train_test_split(
|
||||
batch,
|
||||
self.train_test_split_val,
|
||||
self.k,
|
||||
):
|
||||
|
||||
# Train Q-function
|
||||
if train_episodes:
|
||||
|
|
|
@ -5,14 +5,14 @@ from typing import List, Union
|
|||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
class FQETorchModel:
|
||||
"""Pytorch implementation of the Fitted Q-Evaluation (FQE) model from
|
||||
https://arxiv.org/pdf/1911.06854.pdf
|
||||
|
@ -153,7 +153,7 @@ class FQETorchModel:
|
|||
)
|
||||
|
||||
q_values, _ = self.q_model({"obs": obs}, [], None)
|
||||
q_acts = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
|
||||
q_acts = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze(-1)
|
||||
with torch.no_grad():
|
||||
next_q_values, _ = self.target_q_model({"obs": next_obs}, [], None)
|
||||
next_v = torch.sum(next_q_values * next_action_prob, axis=-1)
|
||||
|
@ -188,7 +188,7 @@ class FQETorchModel:
|
|||
q_values, _ = self.q_model({"obs": obs}, [], None)
|
||||
if actions is not None:
|
||||
actions = torch.tensor(actions, device=self.device, dtype=int)
|
||||
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
|
||||
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze(-1)
|
||||
return q_values.detach()
|
||||
|
||||
def estimate_v(
|
||||
|
|
|
@ -2,13 +2,13 @@ from ray.rllib.offline.estimators.off_policy_estimator import (
|
|||
OffPolicyEstimator,
|
||||
OffPolicyEstimate,
|
||||
)
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.annotations import override, ExperimentalAPI
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
from typing import List
|
||||
import numpy as np
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
class ImportanceSampling(OffPolicyEstimator):
|
||||
"""The step-wise IS estimator.
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from collections import namedtuple
|
|||
import logging
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.utils.annotations import Deprecated
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
|
@ -11,16 +11,16 @@ from typing import List
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OffPolicyEstimate = DeveloperAPI(
|
||||
OffPolicyEstimate = ExperimentalAPI(
|
||||
namedtuple("OffPolicyEstimate", ["estimator_name", "metrics"])
|
||||
)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
class OffPolicyEstimator:
|
||||
"""Interface for an off policy reward estimator."""
|
||||
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
def __init__(self, name: str, policy: Policy, gamma: float):
|
||||
"""Initializes an OffPolicyEstimator instance.
|
||||
|
||||
|
@ -34,7 +34,7 @@ class OffPolicyEstimator:
|
|||
self.gamma = gamma
|
||||
self.new_estimates = []
|
||||
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
def estimate(self, batch: SampleBatchType) -> List[OffPolicyEstimate]:
|
||||
"""Returns a list of off policy estimates for the given batch of episodes.
|
||||
|
||||
|
@ -46,7 +46,7 @@ class OffPolicyEstimator:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
def train(self, batch: SampleBatchType) -> TensorType:
|
||||
"""Trains an Off-Policy Estimator on a batch of experiences.
|
||||
A model-based estimator should override this and train
|
||||
|
@ -60,7 +60,7 @@ class OffPolicyEstimator:
|
|||
"""
|
||||
pass
|
||||
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
|
||||
"""Returns log likelihood for actions in given batch for policy.
|
||||
|
||||
|
@ -92,7 +92,7 @@ class OffPolicyEstimator:
|
|||
log_likelihoods = convert_to_numpy(log_likelihoods)
|
||||
return log_likelihoods
|
||||
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
def check_can_estimate_for(self, batch: SampleBatchType) -> None:
|
||||
"""Checks if we support off policy estimation (OPE) on given batch.
|
||||
|
||||
|
@ -119,7 +119,7 @@ class OffPolicyEstimator:
|
|||
"`off_policy_estimation_methods: {}` to disable estimation."
|
||||
)
|
||||
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
def process(self, batch: SampleBatchType) -> None:
|
||||
"""Computes off policy estimates (OPE) on batch and stores results.
|
||||
Thus-far collected results can be retrieved then by calling
|
||||
|
@ -130,7 +130,7 @@ class OffPolicyEstimator:
|
|||
"""
|
||||
self.new_estimates.extend(self.estimate(batch))
|
||||
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
def get_metrics(self, get_losses: bool = False) -> List[OffPolicyEstimate]:
|
||||
"""Returns list of new episode metric estimates since the last call.
|
||||
|
||||
|
@ -154,7 +154,7 @@ class OffPolicyEstimator:
|
|||
|
||||
@Deprecated(help="OffPolicyEstimator.__init__(policy, gamma, config)", error=False)
|
||||
@classmethod
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator":
|
||||
"""Creates an off-policy estimator from an IOContext object.
|
||||
Extracts Policy and gamma (discount factor) information from the
|
||||
|
@ -178,11 +178,11 @@ class OffPolicyEstimator:
|
|||
return cls(policy, gamma, config)
|
||||
|
||||
@Deprecated(new="OffPolicyEstimator.create_from_io_context", error=True)
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
def create(self, *args, **kwargs):
|
||||
return self.create_from_io_context(*args, **kwargs)
|
||||
|
||||
@Deprecated(new="OffPolicyEstimator.compute_log_likelihoods", error=False)
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
def action_prob(self, *args, **kwargs):
|
||||
return self.compute_log_likelihoods(*args, **kwargs)
|
||||
|
|
|
@ -6,14 +6,14 @@ import numpy as np
|
|||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import TensorType, ModelConfigDict
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
class QRegTorchModel:
|
||||
"""Pytorch implementation of the Q-Reg model from
|
||||
https://arxiv.org/pdf/1911.06854.pdf
|
||||
|
@ -175,7 +175,7 @@ class QRegTorchModel:
|
|||
q_values, _ = self.q_model({"obs": obs[idxs]}, [], None)
|
||||
q_acts = torch.gather(
|
||||
q_values, -1, actions[idxs].unsqueeze(-1)
|
||||
).squeeze()
|
||||
).squeeze(-1)
|
||||
loss = discounts[idxs] * ps[idxs] * (returns[idxs] - q_acts) ** 2
|
||||
loss = torch.mean(loss)
|
||||
self.optimizer.zero_grad()
|
||||
|
@ -205,7 +205,7 @@ class QRegTorchModel:
|
|||
q_values, _ = self.q_model({"obs": obs}, [], None)
|
||||
if actions is not None:
|
||||
actions = torch.tensor(actions, device=self.device, dtype=int)
|
||||
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
|
||||
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze(-1)
|
||||
return q_values.detach()
|
||||
|
||||
def estimate_v(
|
||||
|
|
|
@ -92,7 +92,7 @@ class TestOPE(unittest.TestCase):
|
|||
cls.std_ret["simulation"] = np.std(mc_ret)
|
||||
|
||||
# Optional configs for the model-based estimators
|
||||
cls.model_config = {"k": 2, "n_iters": 10}
|
||||
cls.model_config = {"train_test_split_val": 0.0, "k": 2, "n_iters": 10}
|
||||
ray.shutdown()
|
||||
|
||||
@classmethod
|
||||
|
@ -110,7 +110,6 @@ class TestOPE(unittest.TestCase):
|
|||
)
|
||||
estimator.process(self.batch)
|
||||
estimates = estimator.get_metrics()
|
||||
assert len(estimates) == self.n_episodes
|
||||
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
|
||||
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
|
||||
|
||||
|
@ -123,7 +122,6 @@ class TestOPE(unittest.TestCase):
|
|||
)
|
||||
estimator.process(self.batch)
|
||||
estimates = estimator.get_metrics()
|
||||
assert len(estimates) == self.n_episodes
|
||||
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
|
||||
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
|
||||
|
||||
|
@ -138,7 +136,6 @@ class TestOPE(unittest.TestCase):
|
|||
)
|
||||
estimator.process(self.batch)
|
||||
estimates = estimator.get_metrics()
|
||||
assert len(estimates) == self.n_episodes
|
||||
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
|
||||
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
|
||||
|
||||
|
@ -153,7 +150,6 @@ class TestOPE(unittest.TestCase):
|
|||
)
|
||||
estimator.process(self.batch)
|
||||
estimates = estimator.get_metrics()
|
||||
assert len(estimates) == self.n_episodes
|
||||
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
|
||||
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
|
||||
|
||||
|
@ -168,7 +164,6 @@ class TestOPE(unittest.TestCase):
|
|||
)
|
||||
estimator.process(self.batch)
|
||||
estimates = estimator.get_metrics()
|
||||
assert len(estimates) == self.n_episodes
|
||||
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
|
||||
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
|
||||
|
||||
|
@ -183,7 +178,6 @@ class TestOPE(unittest.TestCase):
|
|||
)
|
||||
estimator.process(self.batch)
|
||||
estimates = estimator.get_metrics()
|
||||
assert len(estimates) == self.n_episodes
|
||||
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
|
||||
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
|
||||
|
||||
|
|
|
@ -3,12 +3,12 @@ from ray.rllib.offline.estimators.off_policy_estimator import (
|
|||
OffPolicyEstimate,
|
||||
)
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.annotations import override, ExperimentalAPI
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
import numpy as np
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
@ExperimentalAPI
|
||||
class WeightedImportanceSampling(OffPolicyEstimator):
|
||||
"""The weighted step-wise IS estimator.
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue