mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
113 lines
3.9 KiB
Python
113 lines
3.9 KiB
Python
import logging
|
|
from ray.rllib.policy.sample_batch import (
|
|
MultiAgentBatch,
|
|
DEFAULT_POLICY_ID,
|
|
SampleBatch,
|
|
)
|
|
from ray.rllib.policy import Policy
|
|
from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict
|
|
from ray.rllib.utils.annotations import DeveloperAPI
|
|
from ray.rllib.utils.deprecation import Deprecated
|
|
from ray.rllib.utils.numpy import convert_to_numpy
|
|
from ray.rllib.utils.typing import TensorType, SampleBatchType
|
|
from typing import Dict, Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@DeveloperAPI
|
|
class OffPolicyEstimator:
|
|
"""Interface for an off policy reward estimator."""
|
|
|
|
@DeveloperAPI
|
|
def __init__(self, policy: Policy, gamma: float):
|
|
"""Initializes an OffPolicyEstimator instance.
|
|
|
|
Args:
|
|
policy: Policy to evaluate.
|
|
gamma: Discount factor of the environment.
|
|
"""
|
|
self.policy = policy
|
|
self.gamma = gamma
|
|
|
|
@DeveloperAPI
|
|
def estimate(self, batch: SampleBatchType) -> Dict[str, Any]:
|
|
"""Returns off-policy estimates for the given batch of episodes.
|
|
|
|
Args:
|
|
batch: The batch to calculate the off-policy estimates (OPE) on. The
|
|
batch must contain the fields "obs", "actions", and "action_prob".
|
|
|
|
Returns:
|
|
The off-policy estimates (OPE) calculated on the given batch. The returned
|
|
dict can be any arbitrary mapping of strings to metrics.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@DeveloperAPI
|
|
def convert_ma_batch_to_sample_batch(self, batch: SampleBatchType) -> SampleBatch:
|
|
"""Converts a MultiAgentBatch to a SampleBatch if neccessary.
|
|
|
|
Args:
|
|
batch: The SampleBatchType to convert.
|
|
|
|
Returns:
|
|
batch: the converted SampleBatch
|
|
|
|
Raises:
|
|
ValueError if the MultiAgentBatch has more than one policy_id
|
|
or if the policy_id is not `DEFAULT_POLICY_ID`
|
|
"""
|
|
# TODO: Make this a util to sample_batch.py
|
|
if isinstance(batch, MultiAgentBatch):
|
|
policy_keys = batch.policy_batches.keys()
|
|
if len(policy_keys) == 1 and DEFAULT_POLICY_ID in policy_keys:
|
|
batch = batch.policy_batches[DEFAULT_POLICY_ID]
|
|
else:
|
|
raise ValueError(
|
|
"Off-Policy Estimation is not implemented for "
|
|
"multi-agent batches. You can set "
|
|
"`off_policy_estimation_methods: {}` to resolve this."
|
|
)
|
|
return batch
|
|
|
|
@DeveloperAPI
|
|
def check_action_prob_in_batch(self, batch: SampleBatchType) -> None:
|
|
"""Checks if we support off policy estimation (OPE) on given batch.
|
|
|
|
Args:
|
|
batch: The batch to check.
|
|
|
|
Raises:
|
|
ValueError: In case `action_prob` key is not in batch
|
|
"""
|
|
|
|
if "action_prob" not in batch:
|
|
raise ValueError(
|
|
"Off-policy estimation is not possible unless the inputs "
|
|
"include action probabilities (i.e., the policy is stochastic "
|
|
"and emits the 'action_prob' key). For DQN this means using "
|
|
"`exploration_config: {type: 'SoftQ'}`. You can also set "
|
|
"`off_policy_estimation_methods: {}` to disable estimation."
|
|
)
|
|
|
|
@DeveloperAPI
|
|
def train(self, batch: SampleBatchType) -> Dict[str, Any]:
|
|
"""Train a model for Off-Policy Estimation.
|
|
|
|
Args:
|
|
batch: SampleBatch to train on
|
|
|
|
Returns:
|
|
Any optional metrics to return from the estimator
|
|
"""
|
|
return {}
|
|
|
|
@Deprecated(
|
|
old="OffPolicyEstimator.action_log_likelihood",
|
|
new="ray.rllib.utils.policy.compute_log_likelihoods_from_input_dict",
|
|
error=False,
|
|
)
|
|
def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
|
|
log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, batch)
|
|
return convert_to_numpy(log_likelihoods)
|