ray/rllib/offline/estimators/feature_importance.py
2022-07-11 18:12:50 -07:00

112 lines
4.2 KiB
Python

# TODO (@Kourosh) move this to a better location and consolidate the parent class with
# OPE
from typing import Callable, List
from ray.rllib.policy import Policy
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.offline.estimators.off_policy_estimator import (
OffPolicyEstimator,
OffPolicyEstimate,
)
import numpy as np
import copy
def perturb_fn(batch: np.ndarray, index: int):
# shuffle the indexth column features
random_inds = np.random.permutation(batch.shape[0])
batch[:, index] = batch[random_inds, index]
class FeatureImportance(OffPolicyEstimator):
def __init__(
self,
name: str,
policy: Policy,
gamma: float,
repeat: int = 1,
perturb_fn: Callable[[np.ndarray, int], None] = perturb_fn,
):
"""Feature importance in a model inspection technique that can be used for any
fitted predictor when the data is tablular.
This implementation is also known as permutation importance that is defined to
be the variation of the model's prediction when a single feature value is
randomly shuffled. In RLlib it is implemented as a custom OffPolicyEstimator
which is used to evaluate RLlib policies without performing environment
interactions.
Example usage: In the example below the feature importance module is used to
evaluate the policy and the each feature's importance is computed after each
training iteration. The permutation are repeated `self.repeat` times and the
results are averages across repeats.
```python
config = (
AlgorithmConfig()
.offline_data(
off_policy_estimation_methods=
{
"feature_importance": {
"type": FeatureImportance,
"repeat": 10
}
}
)
)
algorithm = DQN(config=config)
results = algorithm.train()
```
Args:
name: string to save the feature importance results under.
policy: the policy to use for feature importance.
repeat: number of times to repeat the perturbation.
gamma: dummy discount factor to be passed to the super class.
perturb_fn: function to perturb the features. By default reshuffle the
features within the batch.
"""
super().__init__(name, policy, gamma=gamma)
self.repeat = repeat
self.perturb_fn = perturb_fn
def estimate(self, batch: SampleBatchType) -> List[OffPolicyEstimate]:
"""Estimate the feature importance of the policy.
Given a batch of tabular observations, the importance of each feature is
computed by perturbing each feature and computing the difference between the
perturbed policy and the reference policy. The importance is computed for each
feature and each perturbation is repeated `self.repeat` times.
Args:
batch: the batch of data to use for feature importance.
Returns:
A list of OffPolicyEstimate objects. Each OffPolicyEstimate object contains
a metics name and a dictionary of metrics mapping feature index to its
importance.
"""
obs_batch = batch["obs"]
n_features = obs_batch.shape[-1]
importance = np.zeros((self.repeat, n_features))
ref_actions, _, _ = self.policy.compute_actions(obs_batch, explore=False)
for r in range(self.repeat):
for i in range(n_features):
copy_obs_batch = copy.deepcopy(obs_batch)
perturb_fn(copy_obs_batch, index=i)
perturbed_actions, _, _ = self.policy.compute_actions(
copy_obs_batch, explore=False
)
importance[r, i] = np.mean(np.abs(perturbed_actions - ref_actions))
# take an average across repeats
importance = importance.mean(0)
metrics = {f"feature_{i}": importance[i] for i in range(len(importance))}
ret = [OffPolicyEstimate(self.name, metrics)]
return ret