ray/rllib/offline/estimators/tests/test_feature_importance.py
Rohan Potdar 38c9e1d52a
[RLlib]: Fix OPE trainables (#26279)
Co-authored-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
2022-07-17 14:25:53 -07:00

35 lines
1 KiB
Python

import unittest
import ray
from ray.rllib.algorithms.crr import CRRConfig, CRR
from ray.rllib.execution import synchronous_parallel_sample
from ray.rllib.offline.estimators.feature_importance import FeatureImportance
class TestFeatureImportance(unittest.TestCase):
def setUp(self):
ray.init()
def tearDown(self):
ray.shutdown()
def test_feat_importance_cartpole(self):
config = CRRConfig().framework("torch")
runner = CRR(config, env="CartPole-v0")
policy = runner.workers.local_worker().get_policy()
sample_batch = synchronous_parallel_sample(worker_set=runner.workers)
for repeat in [1, 10]:
evaluator = FeatureImportance(policy=policy, gamma=0.0, repeat=repeat)
estimate = evaluator.estimate(sample_batch)
# check if the estimate is positive
assert all([val > 0 for val in estimate.values()])
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))