mirror of
https://github.com/vale981/ray
synced 2025-03-10 05:16:49 -04:00
35 lines
1 KiB
Python
35 lines
1 KiB
Python
import unittest
|
|
import ray
|
|
|
|
from ray.rllib.algorithms.crr import CRRConfig, CRR
|
|
from ray.rllib.execution import synchronous_parallel_sample
|
|
from ray.rllib.offline.estimators.feature_importance import FeatureImportance
|
|
|
|
|
|
class TestFeatureImportance(unittest.TestCase):
|
|
def setUp(self):
|
|
ray.init()
|
|
|
|
def tearDown(self):
|
|
ray.shutdown()
|
|
|
|
def test_feat_importance_cartpole(self):
|
|
config = CRRConfig().framework("torch")
|
|
runner = CRR(config, env="CartPole-v0")
|
|
policy = runner.workers.local_worker().get_policy()
|
|
sample_batch = synchronous_parallel_sample(worker_set=runner.workers)
|
|
|
|
for repeat in [1, 10]:
|
|
evaluator = FeatureImportance(policy=policy, gamma=0.0, repeat=repeat)
|
|
|
|
estimate = evaluator.estimate(sample_batch)
|
|
|
|
# check if the estimate is positive
|
|
assert all([val > 0 for val in estimate.values()])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|