ray/rllib/offline/estimators/off_policy_estimator.py
Rohan Potdar 38c9e1d52a
[RLlib]: Fix OPE trainables (#26279)
Co-authored-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
2022-07-17 14:25:53 -07:00

113 lines
3.9 KiB
Python

import logging
from ray.rllib.policy.sample_batch import (
MultiAgentBatch,
DEFAULT_POLICY_ID,
SampleBatch,
)
from ray.rllib.policy import Policy
from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import TensorType, SampleBatchType
from typing import Dict, Any
logger = logging.getLogger(__name__)
@DeveloperAPI
class OffPolicyEstimator:
"""Interface for an off policy reward estimator."""
@DeveloperAPI
def __init__(self, policy: Policy, gamma: float):
"""Initializes an OffPolicyEstimator instance.
Args:
policy: Policy to evaluate.
gamma: Discount factor of the environment.
"""
self.policy = policy
self.gamma = gamma
@DeveloperAPI
def estimate(self, batch: SampleBatchType) -> Dict[str, Any]:
"""Returns off-policy estimates for the given batch of episodes.
Args:
batch: The batch to calculate the off-policy estimates (OPE) on. The
batch must contain the fields "obs", "actions", and "action_prob".
Returns:
The off-policy estimates (OPE) calculated on the given batch. The returned
dict can be any arbitrary mapping of strings to metrics.
"""
raise NotImplementedError
@DeveloperAPI
def convert_ma_batch_to_sample_batch(self, batch: SampleBatchType) -> SampleBatch:
"""Converts a MultiAgentBatch to a SampleBatch if neccessary.
Args:
batch: The SampleBatchType to convert.
Returns:
batch: the converted SampleBatch
Raises:
ValueError if the MultiAgentBatch has more than one policy_id
or if the policy_id is not `DEFAULT_POLICY_ID`
"""
# TODO: Make this a util to sample_batch.py
if isinstance(batch, MultiAgentBatch):
policy_keys = batch.policy_batches.keys()
if len(policy_keys) == 1 and DEFAULT_POLICY_ID in policy_keys:
batch = batch.policy_batches[DEFAULT_POLICY_ID]
else:
raise ValueError(
"Off-Policy Estimation is not implemented for "
"multi-agent batches. You can set "
"`off_policy_estimation_methods: {}` to resolve this."
)
return batch
@DeveloperAPI
def check_action_prob_in_batch(self, batch: SampleBatchType) -> None:
"""Checks if we support off policy estimation (OPE) on given batch.
Args:
batch: The batch to check.
Raises:
ValueError: In case `action_prob` key is not in batch
"""
if "action_prob" not in batch:
raise ValueError(
"Off-policy estimation is not possible unless the inputs "
"include action probabilities (i.e., the policy is stochastic "
"and emits the 'action_prob' key). For DQN this means using "
"`exploration_config: {type: 'SoftQ'}`. You can also set "
"`off_policy_estimation_methods: {}` to disable estimation."
)
@DeveloperAPI
def train(self, batch: SampleBatchType) -> Dict[str, Any]:
"""Train a model for Off-Policy Estimation.
Args:
batch: SampleBatch to train on
Returns:
Any optional metrics to return from the estimator
"""
return {}
@Deprecated(
old="OffPolicyEstimator.action_log_likelihood",
new="ray.rllib.utils.policy.compute_log_likelihoods_from_input_dict",
error=False,
)
def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, batch)
return convert_to_numpy(log_likelihoods)