ray/rllib/offline/estimators/tests/test_feature_importance.py
2022-07-11 18:12:50 -07:00

37 lines
1.1 KiB
Python

import unittest
import ray
from ray.rllib.algorithms.crr import CRRConfig, CRR
from ray.rllib.execution import synchronous_parallel_sample
from ray.rllib.offline.estimators.feature_importance import FeatureImportance
class TestFeatureImportance(unittest.TestCase):
def setUp(self):
ray.init()
def tearDown(self):
ray.shutdown()
def test_feat_importance_cartpole(self):
config = CRRConfig().framework("torch")
runner = CRR(config, env="CartPole-v0")
policy = runner.workers.local_worker().get_policy()
sample_batch = synchronous_parallel_sample(worker_set=runner.workers)
for repeat in [1, 10]:
evaluator = FeatureImportance(
name="feature_importance", policy=policy, gamma=0.0, repeat=repeat
)
estimate = evaluator.estimate(sample_batch)[0].metrics
# check if the estimate is positive
assert all([val > 0 for val in estimate.values()])
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))