ray/rllib/offline/estimators/direct_method.py

194 lines
7.2 KiB
Python
Raw Normal View History

import logging
from typing import Tuple, Generator, List
from ray.rllib.offline.estimators.off_policy_estimator import (
OffPolicyEstimator,
OffPolicyEstimate,
)
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import ExperimentalAPI, override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
from ray.rllib.offline.estimators.qreg_torch_model import QRegTorchModel
from gym.spaces import Discrete
import numpy as np
torch, nn = try_import_torch()
logger = logging.getLogger()
@ExperimentalAPI
def train_test_split(
batch: SampleBatchType,
train_test_split_val: float = 0.0,
k: int = 0,
) -> Generator[Tuple[List[SampleBatch]], None, None]:
"""Utility function that returns either a train/test split or
a k-fold cross validation generator over episodes from the given batch.
By default, `k` is set to 0.0, which sets eval_batch = batch
and train_batch to an empty SampleBatch.
Args:
batch: A SampleBatch of episodes to split
train_test_split_val: Split the batch into a training batch with
`train_test_split_val * n_episodes` episodes and an evaluation batch
with `(1 - train_test_split_val) * n_episodes` episodes. If not
specified, use `k` for k-fold cross validation instead.
k: k-fold cross validation for training model and evaluating OPE.
Returns:
A tuple with two SampleBatches (eval_batch, train_batch)
"""
if not train_test_split_val and not k:
logger.log(
"`train_test_split_val` and `k` are both 0;" "not generating training batch"
)
yield [batch], [SampleBatch()]
return
episodes = batch.split_by_episode()
n_episodes = len(episodes)
# Train-test split
if train_test_split_val:
train_episodes = episodes[: int(n_episodes * train_test_split_val)]
eval_episodes = episodes[int(n_episodes * train_test_split_val) :]
yield eval_episodes, train_episodes
return
# k-fold cv
assert n_episodes >= k, f"Not enough eval episodes in batch for {k}-fold cv!"
n_fold = n_episodes // k
for i in range(k):
train_episodes = episodes[: i * n_fold] + episodes[(i + 1) * n_fold :]
if i != k - 1:
eval_episodes = episodes[i * n_fold : (i + 1) * n_fold]
else:
# Append remaining episodes onto the last eval_episodes
eval_episodes = episodes[i * n_fold :]
yield eval_episodes, train_episodes
return
@ExperimentalAPI
class DirectMethod(OffPolicyEstimator):
"""The Direct Method estimator.
DM estimator described in https://arxiv.org/pdf/1511.03722.pdf"""
@override(OffPolicyEstimator)
def __init__(
self,
name: str,
policy: Policy,
gamma: float,
q_model_type: str = "fqe",
train_test_split_val: float = 0.0,
k: int = 0,
**kwargs,
):
"""
Initializes a Direct Method OPE Estimator.
Args:
name: string to save OPE results under
policy: Policy to evaluate.
gamma: Discount factor of the environment.
q_model_type: Either "fqe" for Fitted Q-Evaluation
or "qreg" for Q-Regression, or a custom model that implements:
- `estimate_q(states,actions)`
- `estimate_v(states, action_probs)`
train_test_split_val: Split the batch into a training batch with
`train_test_split_val * n_episodes` episodes and an evaluation batch
with `(1 - train_test_split_val) * n_episodes` episodes. If not
specified, use `k` for k-fold cross validation instead.
k: k-fold cross validation for training model and evaluating OPE.
kwargs: Optional arguments for the specified Q model.
"""
super().__init__(name, policy, gamma)
# TODO (rohan): Add support for continuous action spaces
assert isinstance(
policy.action_space, Discrete
), "DM Estimator only supports discrete action spaces!"
assert (
policy.config["batch_mode"] == "complete_episodes"
), "DM Estimator only supports `batch_mode`=`complete_episodes`"
# TODO (rohan): Add support for TF!
if policy.framework == "torch":
if q_model_type == "qreg":
model_cls = QRegTorchModel
elif q_model_type == "fqe":
model_cls = FQETorchModel
else:
assert hasattr(
q_model_type, "estimate_q"
), "q_model_type must implement `estimate_q`!"
assert hasattr(
q_model_type, "estimate_v"
), "q_model_type must implement `estimate_v`!"
else:
raise ValueError(
f"{self.__class__.__name__}"
"estimator only supports `policy.framework`=`torch`"
)
self.model = model_cls(
policy=policy,
gamma=gamma,
**kwargs,
)
self.train_test_split_val = train_test_split_val
self.k = k
self.losses = []
@override(OffPolicyEstimator)
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)
estimates = []
# Split data into train and test batches
for train_episodes, test_episodes in train_test_split(
batch,
self.train_test_split_val,
self.k,
):
# Train Q-function
if train_episodes:
# Reinitialize model
self.model.reset()
train_batch = SampleBatch.concat_samples(train_episodes)
losses = self.train(train_batch)
self.losses.append(losses)
# Calculate direct method OPE estimates
for episode in test_episodes:
rewards = episode["rewards"]
v_old = 0.0
v_new = 0.0
for t in range(episode.count):
v_old += rewards[t] * self.gamma ** t
init_step = episode[0:1]
init_obs = np.array([init_step[SampleBatch.OBS]])
all_actions = np.arange(self.policy.action_space.n, dtype=float)
init_step[SampleBatch.ACTIONS] = all_actions
action_probs = np.exp(self.action_log_likelihood(init_step))
v_value = self.model.estimate_v(init_obs, action_probs)
v_new = convert_to_numpy(v_value).item()
estimates.append(
OffPolicyEstimate(
self.name,
{
"v_old": v_old,
"v_new": v_new,
"v_gain": v_new / max(1e-8, v_old),
},
)
)
return estimates
@override(OffPolicyEstimator)
def train(self, batch: SampleBatchType):
return self.model.train_q(batch)