mirror of
https://github.com/vale981/ray
synced 2025-03-10 13:26:39 -04:00
113 lines
4.2 KiB
Python
113 lines
4.2 KiB
Python
![]() |
# TODO (@Kourosh) move this to a better location and consolidate the parent class with
|
||
|
# OPE
|
||
|
|
||
|
from typing import Callable, List
|
||
|
from ray.rllib.policy import Policy
|
||
|
from ray.rllib.utils.typing import SampleBatchType
|
||
|
from ray.rllib.offline.estimators.off_policy_estimator import (
|
||
|
OffPolicyEstimator,
|
||
|
OffPolicyEstimate,
|
||
|
)
|
||
|
|
||
|
import numpy as np
|
||
|
import copy
|
||
|
|
||
|
|
||
|
def perturb_fn(batch: np.ndarray, index: int):
|
||
|
# shuffle the indexth column features
|
||
|
random_inds = np.random.permutation(batch.shape[0])
|
||
|
batch[:, index] = batch[random_inds, index]
|
||
|
|
||
|
|
||
|
class FeatureImportance(OffPolicyEstimator):
|
||
|
def __init__(
|
||
|
self,
|
||
|
name: str,
|
||
|
policy: Policy,
|
||
|
gamma: float,
|
||
|
repeat: int = 1,
|
||
|
perturb_fn: Callable[[np.ndarray, int], None] = perturb_fn,
|
||
|
):
|
||
|
"""Feature importance in a model inspection technique that can be used for any
|
||
|
fitted predictor when the data is tablular.
|
||
|
|
||
|
This implementation is also known as permutation importance that is defined to
|
||
|
be the variation of the model's prediction when a single feature value is
|
||
|
randomly shuffled. In RLlib it is implemented as a custom OffPolicyEstimator
|
||
|
which is used to evaluate RLlib policies without performing environment
|
||
|
interactions.
|
||
|
|
||
|
Example usage: In the example below the feature importance module is used to
|
||
|
evaluate the policy and the each feature's importance is computed after each
|
||
|
training iteration. The permutation are repeated `self.repeat` times and the
|
||
|
results are averages across repeats.
|
||
|
|
||
|
```python
|
||
|
config = (
|
||
|
AlgorithmConfig()
|
||
|
.offline_data(
|
||
|
off_policy_estimation_methods=
|
||
|
{
|
||
|
"feature_importance": {
|
||
|
"type": FeatureImportance,
|
||
|
"repeat": 10
|
||
|
}
|
||
|
}
|
||
|
)
|
||
|
)
|
||
|
|
||
|
algorithm = DQN(config=config)
|
||
|
results = algorithm.train()
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
name: string to save the feature importance results under.
|
||
|
policy: the policy to use for feature importance.
|
||
|
repeat: number of times to repeat the perturbation.
|
||
|
gamma: dummy discount factor to be passed to the super class.
|
||
|
perturb_fn: function to perturb the features. By default reshuffle the
|
||
|
features within the batch.
|
||
|
"""
|
||
|
super().__init__(name, policy, gamma=gamma)
|
||
|
self.repeat = repeat
|
||
|
self.perturb_fn = perturb_fn
|
||
|
|
||
|
def estimate(self, batch: SampleBatchType) -> List[OffPolicyEstimate]:
|
||
|
"""Estimate the feature importance of the policy.
|
||
|
|
||
|
Given a batch of tabular observations, the importance of each feature is
|
||
|
computed by perturbing each feature and computing the difference between the
|
||
|
perturbed policy and the reference policy. The importance is computed for each
|
||
|
feature and each perturbation is repeated `self.repeat` times.
|
||
|
|
||
|
Args:
|
||
|
batch: the batch of data to use for feature importance.
|
||
|
|
||
|
Returns:
|
||
|
A list of OffPolicyEstimate objects. Each OffPolicyEstimate object contains
|
||
|
a metics name and a dictionary of metrics mapping feature index to its
|
||
|
importance.
|
||
|
"""
|
||
|
|
||
|
obs_batch = batch["obs"]
|
||
|
n_features = obs_batch.shape[-1]
|
||
|
importance = np.zeros((self.repeat, n_features))
|
||
|
|
||
|
ref_actions, _, _ = self.policy.compute_actions(obs_batch, explore=False)
|
||
|
for r in range(self.repeat):
|
||
|
for i in range(n_features):
|
||
|
copy_obs_batch = copy.deepcopy(obs_batch)
|
||
|
perturb_fn(copy_obs_batch, index=i)
|
||
|
perturbed_actions, _, _ = self.policy.compute_actions(
|
||
|
copy_obs_batch, explore=False
|
||
|
)
|
||
|
|
||
|
importance[r, i] = np.mean(np.abs(perturbed_actions - ref_actions))
|
||
|
|
||
|
# take an average across repeats
|
||
|
importance = importance.mean(0)
|
||
|
metrics = {f"feature_{i}": importance[i] for i in range(len(importance))}
|
||
|
ret = [OffPolicyEstimate(self.name, metrics)]
|
||
|
|
||
|
return ret
|