mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Fix test_ope flakiness (#27676)
This commit is contained in:
parent
bc5d8d9176
commit
4607e788c1
8 changed files with 558 additions and 381 deletions
14
rllib/BUILD
14
rllib/BUILD
|
@ -1752,6 +1752,20 @@ py_test(
|
|||
data = ["tests/data/cartpole/small.json"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_dm_learning",
|
||||
tags = ["team:rllib", "offline"],
|
||||
size = "large",
|
||||
srcs = ["offline/estimators/tests/test_dm_learning.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_dr_learning",
|
||||
tags = ["team:rllib", "offline"],
|
||||
size = "large",
|
||||
srcs = ["offline/estimators/tests/test_dr_learning.py"],
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Policies
|
||||
# rllib/policy/
|
||||
|
|
4
rllib/examples/env/cliff_walking_wall_env.py
vendored
4
rllib/examples/env/cliff_walking_wall_env.py
vendored
|
@ -35,9 +35,11 @@ class CliffWalkingWallEnv(gym.Env):
|
|||
Each time step incurs -1 reward, except reaching the goal which gives +10 reward.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, seed=42) -> None:
|
||||
self.observation_space = spaces.Discrete(48)
|
||||
self.action_space = spaces.Discrete(4)
|
||||
self.observation_space.seed(seed)
|
||||
self.action_space.seed(seed)
|
||||
|
||||
def reset(self):
|
||||
self.position = 36
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import gym
|
||||
from typing import Dict, Union, List, Tuple, Optional
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.policy.policy import Policy, ViewRequirement
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.typing import AlgorithmConfigDict, TensorStructType, TensorType
|
||||
from typing import Dict, Union, List, Tuple, Optional
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
import numpy as np
|
||||
import gym
|
||||
from ray.rllib.utils.typing import AlgorithmConfigDict, TensorStructType, TensorType
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.debug import update_global_seed_if_necessary
|
||||
|
||||
|
||||
class CliffWalkingWallPolicy(Policy):
|
||||
|
@ -23,6 +25,7 @@ class CliffWalkingWallPolicy(Policy):
|
|||
action_space: gym.Space,
|
||||
config: AlgorithmConfigDict,
|
||||
):
|
||||
update_global_seed_if_necessary(seed=config.get("seed"))
|
||||
super().__init__(observation_space, action_space, config)
|
||||
|
||||
# Known optimal action dist for each of the 48 states and 4 actions
|
||||
|
|
|
@ -61,7 +61,7 @@ class FQETorchModel:
|
|||
|
||||
if model is None:
|
||||
model = {
|
||||
"fcnet_hiddens": [8, 8],
|
||||
"fcnet_hiddens": [32, 32, 32],
|
||||
"fcnet_activation": "relu",
|
||||
"vf_share_layers": True,
|
||||
}
|
||||
|
@ -75,6 +75,7 @@ class FQETorchModel:
|
|||
framework="torch",
|
||||
name="TorchQModel",
|
||||
).to(self.device)
|
||||
|
||||
self.target_q_model: TorchModelV2 = ModelCatalog.get_model_v2(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
|
@ -83,6 +84,7 @@ class FQETorchModel:
|
|||
framework="torch",
|
||||
name="TargetTorchQModel",
|
||||
).to(self.device)
|
||||
|
||||
self.n_iters = n_iters
|
||||
self.lr = lr
|
||||
self.min_loss_threshold = min_loss_threshold
|
||||
|
|
202
rllib/offline/estimators/tests/test_dm_learning.py
Normal file
202
rllib/offline/estimators/tests/test_dm_learning.py
Normal file
|
@ -0,0 +1,202 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.offline.estimators import DirectMethod
|
||||
from ray.rllib.offline.estimators.tests.utils import (
|
||||
get_cliff_walking_wall_policy_and_data,
|
||||
check_estimate,
|
||||
)
|
||||
|
||||
SEED = 0
|
||||
|
||||
|
||||
class TestDMLearning(unittest.TestCase):
|
||||
"""Learning tests for the DirectMethod estimator.
|
||||
|
||||
Generates three GridWorldWallPolicy policies and batches with epsilon = 0.2, 0.5,
|
||||
and 0.8 respectively using `get_cliff_walking_wall_policy_and_data`.
|
||||
|
||||
Tests that the estimators converge on all eight combinations of evaluation policy
|
||||
and behavior batch using `check_estimates`, except random policy-expert batch.
|
||||
|
||||
Note: We do not test OPE with the "random" policy (epsilon=0.8)
|
||||
and "expert" (epsilon=0.2) batch because of the large policy-data mismatch. The
|
||||
expert batch is unlikely to contain the longer trajectories that would be observed
|
||||
under the random policy, thus the OPE estimate is flaky and inaccurate.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init()
|
||||
# Epsilon-greedy exploration values
|
||||
random_eps = 0.8
|
||||
mixed_eps = 0.5
|
||||
expert_eps = 0.2
|
||||
num_episodes = 64
|
||||
cls.gamma = 0.99
|
||||
|
||||
# Config settings for FQE model
|
||||
cls.q_model_config = {
|
||||
"n_iters": 800,
|
||||
"minibatch_size": 64,
|
||||
"polyak_coef": 1.0,
|
||||
"model": {
|
||||
"fcnet_hiddens": [32, 32, 32],
|
||||
"activation": "relu",
|
||||
},
|
||||
"lr": 1e-3,
|
||||
}
|
||||
|
||||
(
|
||||
cls.random_policy,
|
||||
cls.random_batch,
|
||||
cls.random_reward,
|
||||
cls.random_std,
|
||||
) = get_cliff_walking_wall_policy_and_data(
|
||||
num_episodes, cls.gamma, random_eps, seed=SEED
|
||||
)
|
||||
print(
|
||||
f"Collected random batch of {cls.random_batch.count} steps "
|
||||
f"with return {cls.random_reward} stddev {cls.random_std}"
|
||||
)
|
||||
|
||||
(
|
||||
cls.mixed_policy,
|
||||
cls.mixed_batch,
|
||||
cls.mixed_reward,
|
||||
cls.mixed_std,
|
||||
) = get_cliff_walking_wall_policy_and_data(
|
||||
num_episodes, cls.gamma, mixed_eps, seed=SEED
|
||||
)
|
||||
print(
|
||||
f"Collected mixed batch of {cls.mixed_batch.count} steps "
|
||||
f"with return {cls.mixed_reward} stddev {cls.mixed_std}"
|
||||
)
|
||||
|
||||
(
|
||||
cls.expert_policy,
|
||||
cls.expert_batch,
|
||||
cls.expert_reward,
|
||||
cls.expert_std,
|
||||
) = get_cliff_walking_wall_policy_and_data(
|
||||
num_episodes, cls.gamma, expert_eps, seed=SEED
|
||||
)
|
||||
print(
|
||||
f"Collected expert batch of {cls.expert_batch.count} steps "
|
||||
f"with return {cls.expert_reward} stddev {cls.expert_std}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def test_dm_random_policy_random_data(self):
|
||||
print("Test DirectMethod on random policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.random_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.random_reward,
|
||||
std_ret=self.random_std,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
def test_dm_random_policy_mixed_data(self):
|
||||
print("Test DirectMethod on random policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.random_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.random_reward,
|
||||
std_ret=self.random_std,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
def test_dm_mixed_policy_random_data(self):
|
||||
print("Test DirectMethod on mixed policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
def test_dm_mixed_policy_mixed_data(self):
|
||||
print("Test DirectMethod on mixed policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
def test_dm_mixed_policy_expert_data(self):
|
||||
print("Test DirectMethod on mixed policy on expert dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.expert_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
def test_dm_expert_policy_random_data(self):
|
||||
print("Test DirectMethod on expert policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
def test_dm_expert_policy_mixed_data(self):
|
||||
print("Test DirectMethod on expert policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
def test_dm_expert_policy_expert_data(self):
|
||||
print("Test DirectMethod on expert policy on expert dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.expert_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
202
rllib/offline/estimators/tests/test_dr_learning.py
Normal file
202
rllib/offline/estimators/tests/test_dr_learning.py
Normal file
|
@ -0,0 +1,202 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.offline.estimators import DoublyRobust
|
||||
from ray.rllib.offline.estimators.tests.utils import (
|
||||
get_cliff_walking_wall_policy_and_data,
|
||||
check_estimate,
|
||||
)
|
||||
|
||||
SEED = 0
|
||||
|
||||
|
||||
class TestDRLearning(unittest.TestCase):
|
||||
"""Learning tests for the DoublyRobust estimator.
|
||||
|
||||
Generates three GridWorldWallPolicy policies and batches with epsilon = 0.2, 0.5,
|
||||
and 0.8 respectively using `get_cliff_walking_wall_policy_and_data`.
|
||||
|
||||
Tests that the estimators converge on all eight combinations of evaluation policy
|
||||
and behavior batch using `check_estimates`, except random policy-expert batch.
|
||||
|
||||
Note: We do not test OPE with the "random" policy (epsilon=0.8)
|
||||
and "expert" (epsilon=0.2) batch because of the large policy-data mismatch. The
|
||||
expert batch is unlikely to contain the longer trajectories that would be observed
|
||||
under the random policy, thus the OPE estimate is flaky and inaccurate.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init()
|
||||
# Epsilon-greedy exploration values
|
||||
random_eps = 0.8
|
||||
mixed_eps = 0.5
|
||||
expert_eps = 0.2
|
||||
num_episodes = 64
|
||||
cls.gamma = 0.99
|
||||
|
||||
# Config settings for FQE model
|
||||
cls.q_model_config = {
|
||||
"n_iters": 800,
|
||||
"minibatch_size": 64,
|
||||
"polyak_coef": 1.0,
|
||||
"model": {
|
||||
"fcnet_hiddens": [32, 32, 32],
|
||||
"activation": "relu",
|
||||
},
|
||||
"lr": 1e-3,
|
||||
}
|
||||
|
||||
(
|
||||
cls.random_policy,
|
||||
cls.random_batch,
|
||||
cls.random_reward,
|
||||
cls.random_std,
|
||||
) = get_cliff_walking_wall_policy_and_data(
|
||||
num_episodes, cls.gamma, random_eps, seed=SEED
|
||||
)
|
||||
print(
|
||||
f"Collected random batch of {cls.random_batch.count} steps "
|
||||
f"with return {cls.random_reward} stddev {cls.random_std}"
|
||||
)
|
||||
|
||||
(
|
||||
cls.mixed_policy,
|
||||
cls.mixed_batch,
|
||||
cls.mixed_reward,
|
||||
cls.mixed_std,
|
||||
) = get_cliff_walking_wall_policy_and_data(
|
||||
num_episodes, cls.gamma, mixed_eps, seed=SEED
|
||||
)
|
||||
print(
|
||||
f"Collected mixed batch of {cls.mixed_batch.count} steps "
|
||||
f"with return {cls.mixed_reward} stddev {cls.mixed_std}"
|
||||
)
|
||||
|
||||
(
|
||||
cls.expert_policy,
|
||||
cls.expert_batch,
|
||||
cls.expert_reward,
|
||||
cls.expert_std,
|
||||
) = get_cliff_walking_wall_policy_and_data(
|
||||
num_episodes, cls.gamma, expert_eps, seed=SEED
|
||||
)
|
||||
print(
|
||||
f"Collected expert batch of {cls.expert_batch.count} steps "
|
||||
f"with return {cls.expert_reward} stddev {cls.expert_std}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def test_dr_random_policy_random_data(self):
|
||||
print("Test DoublyRobust on random policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.random_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.random_reward,
|
||||
std_ret=self.random_std,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
def test_dr_random_policy_mixed_data(self):
|
||||
print("Test DoublyRobust on random policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.random_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.random_reward,
|
||||
std_ret=self.random_std,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
def test_dr_mixed_policy_random_data(self):
|
||||
print("Test DoublyRobust on mixed policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
def test_dr_mixed_policy_mixed_data(self):
|
||||
print("Test DoublyRobust on mixed policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
def test_dr_mixed_policy_expert_data(self):
|
||||
print("Test DoublyRobust on mixed policy on expert dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.expert_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
def test_dr_expert_policy_random_data(self):
|
||||
print("Test DoublyRobust on expert policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
def test_dr_expert_policy_mixed_data(self):
|
||||
print("Test DoublyRobust on expert policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
def test_dr_expert_policy_expert_data(self):
|
||||
print("Test DoublyRobust on expert policy on expert dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.expert_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -2,16 +2,12 @@ import copy
|
|||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Type, Union, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
from ray.data import read_json
|
||||
from ray.rllib.algorithms import AlgorithmConfig
|
||||
from ray.rllib.algorithms.dqn import DQNConfig
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.examples.env.cliff_walking_wall_env import CliffWalkingWallEnv
|
||||
from ray.rllib.examples.policy.cliff_walking_wall_policy import CliffWalkingWallPolicy
|
||||
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
|
||||
from ray.rllib.offline.dataset_reader import DatasetReader
|
||||
from ray.rllib.offline.estimators import (
|
||||
DirectMethod,
|
||||
|
@ -20,7 +16,6 @@ from ray.rllib.offline.estimators import (
|
|||
WeightedImportanceSampling,
|
||||
)
|
||||
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, concat_samples
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
|
@ -264,376 +259,8 @@ class TestFQE(unittest.TestCase):
|
|||
check(estimates, q_values, decimals=1)
|
||||
|
||||
|
||||
def get_cliff_walking_wall_policy_and_data(
|
||||
num_episodes: int,
|
||||
gamma: float,
|
||||
epsilon: float,
|
||||
) -> Tuple[Policy, SampleBatch, float, float]:
|
||||
"""Collect a cliff_walking_wall policy and data with epsilon-greedy exploration.
|
||||
|
||||
Args:
|
||||
num_episodes: Minimum number of episodes to collect
|
||||
gamma: discount factor
|
||||
epsilon: epsilon-greedy exploration value
|
||||
|
||||
Returns:
|
||||
A Tuple consisting of:
|
||||
- A CliffWalkingWallPolicy with exploration parameter epsilon
|
||||
- A SampleBatch of at least `num_episodes` CliffWalkingWall episodes
|
||||
collected using epsilon-greedy exploration
|
||||
- The mean of the discounted return over the collected episodes
|
||||
- The stddev of the discounted return over the collected episodes
|
||||
|
||||
"""
|
||||
config = (
|
||||
AlgorithmConfig()
|
||||
.rollouts(batch_mode="complete_episodes")
|
||||
.environment(disable_env_checking=True)
|
||||
.experimental(_disable_preprocessor_api=True)
|
||||
)
|
||||
config = config.to_dict()
|
||||
config["epsilon"] = epsilon
|
||||
|
||||
env = CliffWalkingWallEnv()
|
||||
policy = CliffWalkingWallPolicy(
|
||||
env.observation_space, env.action_space, {"epsilon": epsilon}
|
||||
)
|
||||
workers = WorkerSet(
|
||||
env_creator=lambda env_config: CliffWalkingWallEnv(),
|
||||
policy_class=CliffWalkingWallPolicy,
|
||||
trainer_config=config,
|
||||
num_workers=4,
|
||||
)
|
||||
ep_ret = []
|
||||
batches = []
|
||||
n_eps = 0
|
||||
while n_eps < num_episodes:
|
||||
batch = synchronous_parallel_sample(worker_set=workers)
|
||||
for episode in batch.split_by_episode():
|
||||
ret = 0
|
||||
for r in episode[SampleBatch.REWARDS][::-1]:
|
||||
ret = r + gamma * ret
|
||||
ep_ret.append(ret)
|
||||
n_eps += 1
|
||||
batches.append(batch)
|
||||
workers.stop()
|
||||
return policy, concat_samples(batches), np.mean(ep_ret), np.std(ep_ret)
|
||||
|
||||
|
||||
def check_estimate(
|
||||
*,
|
||||
estimator_cls: Type[Union[DirectMethod, DoublyRobust]],
|
||||
gamma: float,
|
||||
q_model_config: Dict,
|
||||
policy: Policy,
|
||||
batch: SampleBatch,
|
||||
mean_ret: float,
|
||||
std_ret: float,
|
||||
) -> None:
|
||||
"""Compute off-policy estimates and compare them to the true discounted return.
|
||||
|
||||
Args:
|
||||
estimator_cls: Off-Policy Estimator class to be used
|
||||
gamma: discount factor
|
||||
q_model_config: Optional config settings for the estimator's Q-model
|
||||
policy: The target policy we compute estimates for
|
||||
batch: The behavior data we use for off-policy estimation
|
||||
mean_ret: The mean discounted episode return over the batch
|
||||
std_ret: The standard deviation corresponding to mean_ret
|
||||
|
||||
Raises:
|
||||
AssertionError if the estimated mean episode return computed by
|
||||
the off-policy estimator does not fall within one standard deviation of
|
||||
the values specified above i.e. [mean_ret - std_ret, mean_ret + std_ret]
|
||||
"""
|
||||
estimator = estimator_cls(
|
||||
policy=policy,
|
||||
gamma=gamma,
|
||||
q_model_config=q_model_config,
|
||||
)
|
||||
loss = estimator.train(batch)["loss"]
|
||||
estimates = estimator.estimate(batch)
|
||||
est_mean = estimates["v_target"]
|
||||
est_std = estimates["v_target_std"]
|
||||
print(f"{est_mean:.2f}, {est_std:.2f}, {mean_ret:.2f}, {std_ret:.2f}, {loss:.2f}")
|
||||
# Assert that the two mean +- stddev intervals overlap
|
||||
assert mean_ret - std_ret <= est_mean <= mean_ret + std_ret, (
|
||||
f"DirectMethod estimate {est_mean:.2f} with stddev "
|
||||
f"{est_std:.2f} does not converge to true discounted return "
|
||||
f"{mean_ret:.2f} with stddev {std_ret:.2f}!"
|
||||
)
|
||||
|
||||
|
||||
class TestOPELearning(unittest.TestCase):
|
||||
"""Learning tests for the DirectMethod and DoublyRobust estimators.
|
||||
|
||||
Generates three GridWorldWallPolicy policies and batches with epsilon = 0.2, 0.5,
|
||||
and 0.8 respectively using `get_cliff_walking_wall_policy_and_data`.
|
||||
|
||||
Tests that the estimators converge on all eight combinations of evaluation policy
|
||||
and behavior batch using `check_estimates`, except random policy-expert batch.
|
||||
|
||||
Note: We do not test OPE with the "random" policy (epsilon=0.8)
|
||||
and "expert" (epsilon=0.2) batch because of the large policy-data mismatch. The
|
||||
expert batch is unlikely to contain the longer trajectories that would be observed
|
||||
under the random policy, thus the OPE estimate is flaky and inaccurate.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init()
|
||||
# Epsilon-greedy exploration values
|
||||
random_eps = 0.8
|
||||
mixed_eps = 0.5
|
||||
expert_eps = 0.2
|
||||
num_episodes = 64
|
||||
cls.gamma = 0.99
|
||||
|
||||
# Config settings for FQE model
|
||||
cls.q_model_config = {
|
||||
"n_iters": 800,
|
||||
"minibatch_size": 64,
|
||||
"polyak_coef": 1.0,
|
||||
"model": {
|
||||
"fcnet_hiddens": [],
|
||||
"activation": "linear",
|
||||
},
|
||||
"lr": 0.01,
|
||||
}
|
||||
|
||||
(
|
||||
cls.random_policy,
|
||||
cls.random_batch,
|
||||
cls.random_reward,
|
||||
cls.random_std,
|
||||
) = get_cliff_walking_wall_policy_and_data(num_episodes, cls.gamma, random_eps)
|
||||
print(
|
||||
f"Collected random batch of {cls.random_batch.count} steps "
|
||||
f"with return {cls.random_reward} stddev {cls.random_std}"
|
||||
)
|
||||
|
||||
(
|
||||
cls.mixed_policy,
|
||||
cls.mixed_batch,
|
||||
cls.mixed_reward,
|
||||
cls.mixed_std,
|
||||
) = get_cliff_walking_wall_policy_and_data(num_episodes, cls.gamma, mixed_eps)
|
||||
print(
|
||||
f"Collected mixed batch of {cls.mixed_batch.count} steps "
|
||||
f"with return {cls.mixed_reward} stddev {cls.mixed_std}"
|
||||
)
|
||||
|
||||
(
|
||||
cls.expert_policy,
|
||||
cls.expert_batch,
|
||||
cls.expert_reward,
|
||||
cls.expert_std,
|
||||
) = get_cliff_walking_wall_policy_and_data(num_episodes, cls.gamma, expert_eps)
|
||||
print(
|
||||
f"Collected expert batch of {cls.expert_batch.count} steps "
|
||||
f"with return {cls.expert_reward} stddev {cls.expert_std}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def test_dm_random_policy_random_data(self):
|
||||
print("Test DirectMethod on random policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.random_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.random_reward,
|
||||
std_ret=self.random_std,
|
||||
)
|
||||
|
||||
def test_dm_random_policy_mixed_data(self):
|
||||
print("Test DirectMethod on random policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.random_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.random_reward,
|
||||
std_ret=self.random_std,
|
||||
)
|
||||
|
||||
def test_dm_mixed_policy_random_data(self):
|
||||
print("Test DirectMethod on mixed policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
)
|
||||
|
||||
def test_dm_mixed_policy_mixed_data(self):
|
||||
print("Test DirectMethod on mixed policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
)
|
||||
|
||||
def test_dm_mixed_policy_expert_data(self):
|
||||
print("Test DirectMethod on mixed policy on expert dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.expert_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
)
|
||||
|
||||
def test_dm_expert_policy_random_data(self):
|
||||
print("Test DirectMethod on expert policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
)
|
||||
|
||||
def test_dm_expert_policy_mixed_data(self):
|
||||
print("Test DirectMethod on expert policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
)
|
||||
|
||||
def test_dm_expert_policy_expert_data(self):
|
||||
print("Test DirectMethod on expert policy on expert dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.expert_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
)
|
||||
|
||||
def test_dr_random_policy_random_data(self):
|
||||
print("Test DoublyRobust on random policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.random_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.random_reward,
|
||||
std_ret=self.random_std,
|
||||
)
|
||||
|
||||
def test_dr_random_policy_mixed_data(self):
|
||||
print("Test DoublyRobust on random policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.random_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.random_reward,
|
||||
std_ret=self.random_std,
|
||||
)
|
||||
|
||||
def test_dr_mixed_policy_random_data(self):
|
||||
print("Test DoublyRobust on mixed policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
)
|
||||
|
||||
def test_dr_mixed_policy_mixed_data(self):
|
||||
print("Test DoublyRobust on mixed policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
)
|
||||
|
||||
def test_dr_mixed_policy_expert_data(self):
|
||||
print("Test DoublyRobust on mixed policy on expert dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.expert_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
)
|
||||
|
||||
def test_dr_expert_policy_random_data(self):
|
||||
print("Test DoublyRobust on expert policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
)
|
||||
|
||||
def test_dr_expert_policy_mixed_data(self):
|
||||
print("Test DoublyRobust on expert policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
)
|
||||
|
||||
def test_dr_expert_policy_expert_data(self):
|
||||
print("Test DoublyRobust on expert policy on expert dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.expert_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
125
rllib/offline/estimators/tests/utils.py
Normal file
125
rllib/offline/estimators/tests/utils.py
Normal file
|
@ -0,0 +1,125 @@
|
|||
from typing import Type, Union, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
from ray.rllib.algorithms import AlgorithmConfig
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.examples.env.cliff_walking_wall_env import CliffWalkingWallEnv
|
||||
from ray.rllib.examples.policy.cliff_walking_wall_policy import CliffWalkingWallPolicy
|
||||
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
|
||||
from ray.rllib.offline.estimators import (
|
||||
DirectMethod,
|
||||
DoublyRobust,
|
||||
)
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, concat_samples
|
||||
from ray.rllib.utils.debug import update_global_seed_if_necessary
|
||||
|
||||
|
||||
def get_cliff_walking_wall_policy_and_data(
|
||||
num_episodes: int, gamma: float, epsilon: float, seed: int
|
||||
) -> Tuple[Policy, SampleBatch, float, float]:
|
||||
"""Collect a cliff_walking_wall policy and data with epsilon-greedy exploration.
|
||||
|
||||
Args:
|
||||
num_episodes: Minimum number of episodes to collect
|
||||
gamma: discount factor
|
||||
epsilon: epsilon-greedy exploration value
|
||||
|
||||
Returns:
|
||||
A Tuple consisting of:
|
||||
- A CliffWalkingWallPolicy with exploration parameter epsilon
|
||||
- A SampleBatch of at least `num_episodes` CliffWalkingWall episodes
|
||||
collected using epsilon-greedy exploration
|
||||
- The mean of the discounted return over the collected episodes
|
||||
- The stddev of the discounted return over the collected episodes
|
||||
|
||||
"""
|
||||
|
||||
config = (
|
||||
AlgorithmConfig()
|
||||
.debugging(seed=seed)
|
||||
.rollouts(batch_mode="complete_episodes")
|
||||
.environment(disable_env_checking=True)
|
||||
.experimental(_disable_preprocessor_api=True)
|
||||
)
|
||||
config = config.to_dict()
|
||||
config["epsilon"] = epsilon
|
||||
|
||||
env = CliffWalkingWallEnv(seed=seed)
|
||||
policy = CliffWalkingWallPolicy(
|
||||
env.observation_space, env.action_space, {"epsilon": epsilon, "seed": seed}
|
||||
)
|
||||
workers = WorkerSet(
|
||||
env_creator=lambda env_config: CliffWalkingWallEnv(),
|
||||
policy_class=CliffWalkingWallPolicy,
|
||||
trainer_config=config,
|
||||
num_workers=4,
|
||||
)
|
||||
ep_ret = []
|
||||
batches = []
|
||||
n_eps = 0
|
||||
while n_eps < num_episodes:
|
||||
batch = synchronous_parallel_sample(worker_set=workers)
|
||||
for episode in batch.split_by_episode():
|
||||
ret = 0
|
||||
for r in episode[SampleBatch.REWARDS][::-1]:
|
||||
ret = r + gamma * ret
|
||||
ep_ret.append(ret)
|
||||
n_eps += 1
|
||||
batches.append(batch)
|
||||
workers.stop()
|
||||
|
||||
return policy, concat_samples(batches), np.mean(ep_ret), np.std(ep_ret)
|
||||
|
||||
|
||||
def check_estimate(
|
||||
*,
|
||||
estimator_cls: Type[Union[DirectMethod, DoublyRobust]],
|
||||
gamma: float,
|
||||
q_model_config: Dict,
|
||||
policy: Policy,
|
||||
batch: SampleBatch,
|
||||
mean_ret: float,
|
||||
std_ret: float,
|
||||
seed: int,
|
||||
) -> None:
|
||||
"""Compute off-policy estimates and compare them to the true discounted return.
|
||||
|
||||
Args:
|
||||
estimator_cls: Off-Policy Estimator class to be used
|
||||
gamma: discount factor
|
||||
q_model_config: Optional config settings for the estimator's Q-model
|
||||
policy: The target policy we compute estimates for
|
||||
batch: The behavior data we use for off-policy estimation
|
||||
mean_ret: The mean discounted episode return over the batch
|
||||
std_ret: The standard deviation corresponding to mean_ret
|
||||
|
||||
Raises:
|
||||
AssertionError if the estimated mean episode return computed by
|
||||
the off-policy estimator does not fall within one standard deviation of
|
||||
the values specified above i.e. [mean_ret - std_ret, mean_ret + std_ret]
|
||||
"""
|
||||
# only torch is supported for now
|
||||
update_global_seed_if_necessary(framework="torch", seed=seed)
|
||||
estimator = estimator_cls(
|
||||
policy=policy,
|
||||
gamma=gamma,
|
||||
q_model_config=q_model_config,
|
||||
)
|
||||
loss = estimator.train(batch)["loss"]
|
||||
estimates = estimator.estimate(batch)
|
||||
est_mean = estimates["v_target"]
|
||||
est_std = estimates["v_target_std"]
|
||||
print(
|
||||
f"est_mean={est_mean:.2f}, "
|
||||
f"est_std={est_std:.2f}, "
|
||||
f"target_mean={mean_ret:.2f}, "
|
||||
f"target_std={std_ret:.2f}, "
|
||||
f"loss={loss:.2f}"
|
||||
)
|
||||
# Assert that the two mean +- stddev intervals overlap
|
||||
assert mean_ret - std_ret <= est_mean <= mean_ret + std_ret, (
|
||||
f"OPE estimate {est_mean:.2f} with stddev "
|
||||
f"{est_std:.2f} does not converge to true discounted return "
|
||||
f"{mean_ret:.2f} with stddev {std_ret:.2f}!"
|
||||
)
|
Loading…
Add table
Reference in a new issue