2022-06-21 04:24:24 -07:00
|
|
|
import logging
|
|
|
|
from typing import Tuple, Generator, List
|
2022-06-07 03:52:19 -07:00
|
|
|
from ray.rllib.offline.estimators.off_policy_estimator import (
|
|
|
|
OffPolicyEstimator,
|
|
|
|
OffPolicyEstimate,
|
|
|
|
)
|
|
|
|
from ray.rllib.policy import Policy
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2022-06-21 04:24:24 -07:00
|
|
|
from ray.rllib.utils.annotations import ExperimentalAPI, override
|
2022-06-07 03:52:19 -07:00
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
|
|
from ray.rllib.utils.numpy import convert_to_numpy
|
|
|
|
from ray.rllib.utils.typing import SampleBatchType
|
|
|
|
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
|
|
|
|
from ray.rllib.offline.estimators.qreg_torch_model import QRegTorchModel
|
|
|
|
from gym.spaces import Discrete
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
torch, nn = try_import_torch()
|
|
|
|
|
2022-06-21 04:24:24 -07:00
|
|
|
logger = logging.getLogger()
|
2022-06-07 03:52:19 -07:00
|
|
|
|
|
|
|
|
2022-06-21 04:24:24 -07:00
|
|
|
@ExperimentalAPI
|
|
|
|
def train_test_split(
|
|
|
|
batch: SampleBatchType,
|
|
|
|
train_test_split_val: float = 0.0,
|
|
|
|
k: int = 0,
|
|
|
|
) -> Generator[Tuple[List[SampleBatch]], None, None]:
|
|
|
|
"""Utility function that returns either a train/test split or
|
|
|
|
a k-fold cross validation generator over episodes from the given batch.
|
|
|
|
By default, `k` is set to 0.0, which sets eval_batch = batch
|
|
|
|
and train_batch to an empty SampleBatch.
|
2022-06-07 03:52:19 -07:00
|
|
|
Args:
|
|
|
|
batch: A SampleBatch of episodes to split
|
2022-06-21 04:24:24 -07:00
|
|
|
train_test_split_val: Split the batch into a training batch with
|
|
|
|
`train_test_split_val * n_episodes` episodes and an evaluation batch
|
|
|
|
with `(1 - train_test_split_val) * n_episodes` episodes. If not
|
|
|
|
specified, use `k` for k-fold cross validation instead.
|
|
|
|
k: k-fold cross validation for training model and evaluating OPE.
|
2022-06-07 03:52:19 -07:00
|
|
|
Returns:
|
2022-06-21 04:24:24 -07:00
|
|
|
A tuple with two SampleBatches (eval_batch, train_batch)
|
2022-06-07 03:52:19 -07:00
|
|
|
"""
|
2022-06-21 04:24:24 -07:00
|
|
|
if not train_test_split_val and not k:
|
|
|
|
logger.log(
|
|
|
|
"`train_test_split_val` and `k` are both 0;" "not generating training batch"
|
|
|
|
)
|
|
|
|
yield [batch], [SampleBatch()]
|
|
|
|
return
|
2022-06-07 03:52:19 -07:00
|
|
|
episodes = batch.split_by_episode()
|
|
|
|
n_episodes = len(episodes)
|
2022-06-21 04:24:24 -07:00
|
|
|
# Train-test split
|
|
|
|
if train_test_split_val:
|
|
|
|
train_episodes = episodes[: int(n_episodes * train_test_split_val)]
|
|
|
|
eval_episodes = episodes[int(n_episodes * train_test_split_val) :]
|
|
|
|
yield eval_episodes, train_episodes
|
2022-06-07 03:52:19 -07:00
|
|
|
return
|
2022-06-21 04:24:24 -07:00
|
|
|
# k-fold cv
|
|
|
|
assert n_episodes >= k, f"Not enough eval episodes in batch for {k}-fold cv!"
|
2022-06-07 03:52:19 -07:00
|
|
|
n_fold = n_episodes // k
|
|
|
|
for i in range(k):
|
|
|
|
train_episodes = episodes[: i * n_fold] + episodes[(i + 1) * n_fold :]
|
|
|
|
if i != k - 1:
|
2022-06-21 04:24:24 -07:00
|
|
|
eval_episodes = episodes[i * n_fold : (i + 1) * n_fold]
|
2022-06-07 03:52:19 -07:00
|
|
|
else:
|
2022-06-21 04:24:24 -07:00
|
|
|
# Append remaining episodes onto the last eval_episodes
|
|
|
|
eval_episodes = episodes[i * n_fold :]
|
|
|
|
yield eval_episodes, train_episodes
|
2022-06-07 03:52:19 -07:00
|
|
|
return
|
|
|
|
|
|
|
|
|
2022-06-21 04:24:24 -07:00
|
|
|
@ExperimentalAPI
|
2022-06-07 03:52:19 -07:00
|
|
|
class DirectMethod(OffPolicyEstimator):
|
|
|
|
"""The Direct Method estimator.
|
|
|
|
|
|
|
|
DM estimator described in https://arxiv.org/pdf/1511.03722.pdf"""
|
|
|
|
|
|
|
|
@override(OffPolicyEstimator)
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
name: str,
|
|
|
|
policy: Policy,
|
|
|
|
gamma: float,
|
|
|
|
q_model_type: str = "fqe",
|
2022-06-21 04:24:24 -07:00
|
|
|
train_test_split_val: float = 0.0,
|
|
|
|
k: int = 0,
|
2022-06-07 03:52:19 -07:00
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Initializes a Direct Method OPE Estimator.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name: string to save OPE results under
|
|
|
|
policy: Policy to evaluate.
|
|
|
|
gamma: Discount factor of the environment.
|
|
|
|
q_model_type: Either "fqe" for Fitted Q-Evaluation
|
|
|
|
or "qreg" for Q-Regression, or a custom model that implements:
|
|
|
|
- `estimate_q(states,actions)`
|
|
|
|
- `estimate_v(states, action_probs)`
|
2022-06-21 04:24:24 -07:00
|
|
|
train_test_split_val: Split the batch into a training batch with
|
|
|
|
`train_test_split_val * n_episodes` episodes and an evaluation batch
|
|
|
|
with `(1 - train_test_split_val) * n_episodes` episodes. If not
|
|
|
|
specified, use `k` for k-fold cross validation instead.
|
|
|
|
k: k-fold cross validation for training model and evaluating OPE.
|
|
|
|
kwargs: Optional arguments for the specified Q model.
|
2022-06-07 03:52:19 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
super().__init__(name, policy, gamma)
|
|
|
|
# TODO (rohan): Add support for continuous action spaces
|
|
|
|
assert isinstance(
|
|
|
|
policy.action_space, Discrete
|
|
|
|
), "DM Estimator only supports discrete action spaces!"
|
|
|
|
assert (
|
|
|
|
policy.config["batch_mode"] == "complete_episodes"
|
|
|
|
), "DM Estimator only supports `batch_mode`=`complete_episodes`"
|
|
|
|
|
|
|
|
# TODO (rohan): Add support for TF!
|
|
|
|
if policy.framework == "torch":
|
|
|
|
if q_model_type == "qreg":
|
|
|
|
model_cls = QRegTorchModel
|
|
|
|
elif q_model_type == "fqe":
|
|
|
|
model_cls = FQETorchModel
|
|
|
|
else:
|
|
|
|
assert hasattr(
|
|
|
|
q_model_type, "estimate_q"
|
|
|
|
), "q_model_type must implement `estimate_q`!"
|
|
|
|
assert hasattr(
|
|
|
|
q_model_type, "estimate_v"
|
|
|
|
), "q_model_type must implement `estimate_v`!"
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
f"{self.__class__.__name__}"
|
|
|
|
"estimator only supports `policy.framework`=`torch`"
|
|
|
|
)
|
|
|
|
|
|
|
|
self.model = model_cls(
|
|
|
|
policy=policy,
|
|
|
|
gamma=gamma,
|
|
|
|
**kwargs,
|
|
|
|
)
|
2022-06-21 04:24:24 -07:00
|
|
|
self.train_test_split_val = train_test_split_val
|
2022-06-07 03:52:19 -07:00
|
|
|
self.k = k
|
|
|
|
self.losses = []
|
|
|
|
|
|
|
|
@override(OffPolicyEstimator)
|
2022-06-21 04:24:24 -07:00
|
|
|
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
|
2022-06-07 03:52:19 -07:00
|
|
|
self.check_can_estimate_for(batch)
|
|
|
|
estimates = []
|
2022-06-21 04:24:24 -07:00
|
|
|
# Split data into train and test batches
|
|
|
|
for train_episodes, test_episodes in train_test_split(
|
|
|
|
batch,
|
|
|
|
self.train_test_split_val,
|
|
|
|
self.k,
|
|
|
|
):
|
2022-06-07 03:52:19 -07:00
|
|
|
|
|
|
|
# Train Q-function
|
|
|
|
if train_episodes:
|
|
|
|
# Reinitialize model
|
|
|
|
self.model.reset()
|
|
|
|
train_batch = SampleBatch.concat_samples(train_episodes)
|
|
|
|
losses = self.train(train_batch)
|
|
|
|
self.losses.append(losses)
|
|
|
|
|
|
|
|
# Calculate direct method OPE estimates
|
|
|
|
for episode in test_episodes:
|
|
|
|
rewards = episode["rewards"]
|
|
|
|
v_old = 0.0
|
|
|
|
v_new = 0.0
|
|
|
|
for t in range(episode.count):
|
|
|
|
v_old += rewards[t] * self.gamma ** t
|
|
|
|
|
|
|
|
init_step = episode[0:1]
|
|
|
|
init_obs = np.array([init_step[SampleBatch.OBS]])
|
|
|
|
all_actions = np.arange(self.policy.action_space.n, dtype=float)
|
|
|
|
init_step[SampleBatch.ACTIONS] = all_actions
|
|
|
|
action_probs = np.exp(self.action_log_likelihood(init_step))
|
|
|
|
v_value = self.model.estimate_v(init_obs, action_probs)
|
|
|
|
v_new = convert_to_numpy(v_value).item()
|
|
|
|
|
|
|
|
estimates.append(
|
|
|
|
OffPolicyEstimate(
|
|
|
|
self.name,
|
|
|
|
{
|
|
|
|
"v_old": v_old,
|
|
|
|
"v_new": v_new,
|
|
|
|
"v_gain": v_new / max(1e-8, v_old),
|
|
|
|
},
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return estimates
|
|
|
|
|
|
|
|
@override(OffPolicyEstimator)
|
|
|
|
def train(self, batch: SampleBatchType):
|
|
|
|
return self.model.train_q(batch)
|