2022-07-11 18:12:50 -07:00
|
|
|
# TODO (@Kourosh) move this to a better location and consolidate the parent class with
|
|
|
|
# OPE
|
|
|
|
|
2022-07-17 14:25:53 -07:00
|
|
|
from typing import Callable, Dict, Any
|
2022-07-11 18:12:50 -07:00
|
|
|
from ray.rllib.policy import Policy
|
|
|
|
from ray.rllib.utils.typing import SampleBatchType
|
2022-07-17 14:25:53 -07:00
|
|
|
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
|
2022-07-11 18:12:50 -07:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import copy
|
|
|
|
|
|
|
|
|
|
|
|
def perturb_fn(batch: np.ndarray, index: int):
|
|
|
|
# shuffle the indexth column features
|
|
|
|
random_inds = np.random.permutation(batch.shape[0])
|
|
|
|
batch[:, index] = batch[random_inds, index]
|
|
|
|
|
|
|
|
|
|
|
|
class FeatureImportance(OffPolicyEstimator):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
policy: Policy,
|
|
|
|
gamma: float,
|
|
|
|
repeat: int = 1,
|
|
|
|
perturb_fn: Callable[[np.ndarray, int], None] = perturb_fn,
|
|
|
|
):
|
|
|
|
"""Feature importance in a model inspection technique that can be used for any
|
|
|
|
fitted predictor when the data is tablular.
|
|
|
|
|
|
|
|
This implementation is also known as permutation importance that is defined to
|
|
|
|
be the variation of the model's prediction when a single feature value is
|
|
|
|
randomly shuffled. In RLlib it is implemented as a custom OffPolicyEstimator
|
|
|
|
which is used to evaluate RLlib policies without performing environment
|
|
|
|
interactions.
|
|
|
|
|
|
|
|
Example usage: In the example below the feature importance module is used to
|
|
|
|
evaluate the policy and the each feature's importance is computed after each
|
|
|
|
training iteration. The permutation are repeated `self.repeat` times and the
|
|
|
|
results are averages across repeats.
|
|
|
|
|
|
|
|
```python
|
|
|
|
config = (
|
|
|
|
AlgorithmConfig()
|
|
|
|
.offline_data(
|
|
|
|
off_policy_estimation_methods=
|
|
|
|
{
|
|
|
|
"feature_importance": {
|
|
|
|
"type": FeatureImportance,
|
|
|
|
"repeat": 10
|
|
|
|
}
|
|
|
|
}
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
algorithm = DQN(config=config)
|
|
|
|
results = algorithm.train()
|
|
|
|
```
|
|
|
|
|
|
|
|
Args:
|
|
|
|
policy: the policy to use for feature importance.
|
|
|
|
repeat: number of times to repeat the perturbation.
|
|
|
|
gamma: dummy discount factor to be passed to the super class.
|
|
|
|
perturb_fn: function to perturb the features. By default reshuffle the
|
|
|
|
features within the batch.
|
|
|
|
"""
|
2022-07-17 14:25:53 -07:00
|
|
|
super().__init__(policy, gamma)
|
2022-07-11 18:12:50 -07:00
|
|
|
self.repeat = repeat
|
|
|
|
self.perturb_fn = perturb_fn
|
|
|
|
|
2022-07-17 14:25:53 -07:00
|
|
|
def estimate(self, batch: SampleBatchType) -> Dict[str, Any]:
|
2022-07-11 18:12:50 -07:00
|
|
|
"""Estimate the feature importance of the policy.
|
|
|
|
|
|
|
|
Given a batch of tabular observations, the importance of each feature is
|
|
|
|
computed by perturbing each feature and computing the difference between the
|
|
|
|
perturbed policy and the reference policy. The importance is computed for each
|
|
|
|
feature and each perturbation is repeated `self.repeat` times.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
batch: the batch of data to use for feature importance.
|
|
|
|
|
|
|
|
Returns:
|
2022-07-17 14:25:53 -07:00
|
|
|
A dict mapping each feature index string to its importance.
|
2022-07-11 18:12:50 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
obs_batch = batch["obs"]
|
|
|
|
n_features = obs_batch.shape[-1]
|
|
|
|
importance = np.zeros((self.repeat, n_features))
|
|
|
|
|
|
|
|
ref_actions, _, _ = self.policy.compute_actions(obs_batch, explore=False)
|
|
|
|
for r in range(self.repeat):
|
|
|
|
for i in range(n_features):
|
|
|
|
copy_obs_batch = copy.deepcopy(obs_batch)
|
|
|
|
perturb_fn(copy_obs_batch, index=i)
|
|
|
|
perturbed_actions, _, _ = self.policy.compute_actions(
|
|
|
|
copy_obs_batch, explore=False
|
|
|
|
)
|
|
|
|
|
|
|
|
importance[r, i] = np.mean(np.abs(perturbed_actions - ref_actions))
|
|
|
|
|
|
|
|
# take an average across repeats
|
|
|
|
importance = importance.mean(0)
|
|
|
|
metrics = {f"feature_{i}": importance[i] for i in range(len(importance))}
|
|
|
|
|
2022-07-17 14:25:53 -07:00
|
|
|
return metrics
|