mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Add OPE Learning Tests (#27154)
This commit is contained in:
parent
6dc3dbdd37
commit
5b6a58ed28
9 changed files with 780 additions and 143 deletions
|
@ -1722,10 +1722,10 @@ py_test(
|
|||
|
||||
py_test(
|
||||
name = "test_ope",
|
||||
tags = ["team:rllib", "offline", "torch_only", "gpu"],
|
||||
size = "large",
|
||||
tags = ["team:rllib", "offline"],
|
||||
size = "medium",
|
||||
srcs = ["offline/estimators/tests/test_ope.py"],
|
||||
data = ["tests/data/cartpole/large.json"],
|
||||
data = ["tests/data/cartpole/small.json"],
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
|
@ -3437,6 +3437,7 @@ py_test_module_list(
|
|||
"tests/test_vector_env.py",
|
||||
"env/tests/test_multi_agent_env.py",
|
||||
"env/wrappers/tests/test_kaggle_wrapper.py",
|
||||
"examples/env/tests/test_cliff_walking_wall_env.py",
|
||||
"examples/env/tests/test_coin_game_non_vectorized_env.py",
|
||||
"examples/env/tests/test_coin_game_vectorized_env.py",
|
||||
"examples/env/tests/test_matrix_sequential_social_dilemma.py",
|
||||
|
|
69
rllib/examples/env/cliff_walking_wall_env.py
vendored
Normal file
69
rllib/examples/env/cliff_walking_wall_env.py
vendored
Normal file
|
@ -0,0 +1,69 @@
|
|||
import gym
|
||||
from gym import spaces
|
||||
|
||||
ACTION_UP = 0
|
||||
ACTION_RIGHT = 1
|
||||
ACTION_DOWN = 2
|
||||
ACTION_LEFT = 3
|
||||
|
||||
|
||||
class CliffWalkingWallEnv(gym.Env):
|
||||
"""Modified version of the CliffWalking environment from OpenAI Gym
|
||||
with walls instead of a cliff.
|
||||
|
||||
### Description
|
||||
The board is a 4x12 matrix, with (using NumPy matrix indexing):
|
||||
- [3, 0] or obs==36 as the start at bottom-left
|
||||
- [3, 11] or obs==47 as the goal at bottom-right
|
||||
- [3, 1..10] or obs==37...46 as the cliff at bottom-center
|
||||
|
||||
An episode terminates when the agent reaches the goal.
|
||||
|
||||
### Actions
|
||||
There are 4 discrete deterministic actions:
|
||||
- 0: move up
|
||||
- 1: move right
|
||||
- 2: move down
|
||||
- 3: move left
|
||||
You can also use the constants ACTION_UP, ACTION_RIGHT, ... defined above.
|
||||
|
||||
### Observations
|
||||
There are 3x12 + 2 possible states, not including the walls. If an action
|
||||
would move an agent into one of the walls, it simply stays in the same position.
|
||||
|
||||
### Reward
|
||||
Each time step incurs -1 reward, except reaching the goal which gives +10 reward.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.observation_space = spaces.Discrete(48)
|
||||
self.action_space = spaces.Discrete(4)
|
||||
|
||||
def reset(self):
|
||||
self.position = 36
|
||||
return self.position
|
||||
|
||||
def step(self, action):
|
||||
x = self.position // 12
|
||||
y = self.position % 12
|
||||
# UP
|
||||
if action == ACTION_UP:
|
||||
x = max(x - 1, 0)
|
||||
# RIGHT
|
||||
elif action == ACTION_RIGHT:
|
||||
if self.position != 36:
|
||||
y = min(y + 1, 11)
|
||||
# DOWN
|
||||
elif action == ACTION_DOWN:
|
||||
if self.position < 25 or self.position > 34:
|
||||
x = min(x + 1, 3)
|
||||
# LEFT
|
||||
elif action == ACTION_LEFT:
|
||||
if self.position != 47:
|
||||
y = max(y - 1, 0)
|
||||
else:
|
||||
raise ValueError(f"action {action} not in {self.action_space}")
|
||||
self.position = x * 12 + y
|
||||
done = self.position == 47
|
||||
reward = -1 if not done else 10
|
||||
return self.position, reward, done, {}
|
61
rllib/examples/env/tests/test_cliff_walking_wall_env.py
vendored
Normal file
61
rllib/examples/env/tests/test_cliff_walking_wall_env.py
vendored
Normal file
|
@ -0,0 +1,61 @@
|
|||
from ray.rllib.examples.env.cliff_walking_wall_env import (
|
||||
CliffWalkingWallEnv,
|
||||
ACTION_UP,
|
||||
ACTION_RIGHT,
|
||||
ACTION_DOWN,
|
||||
ACTION_LEFT,
|
||||
)
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestCliffWalkingWallEnv(unittest.TestCase):
|
||||
def test_env(self):
|
||||
env = CliffWalkingWallEnv()
|
||||
obs = env.reset()
|
||||
# Starting position
|
||||
self.assertEqual(obs, 36)
|
||||
# Left, Right, and Down are no-ops
|
||||
obs, _, _, _ = env.step(ACTION_LEFT)
|
||||
self.assertEqual(obs, 36)
|
||||
obs, _, _, _ = env.step(ACTION_DOWN)
|
||||
self.assertEqual(obs, 36)
|
||||
obs, _, _, _ = env.step(ACTION_RIGHT)
|
||||
self.assertEqual(obs, 36)
|
||||
|
||||
# Up and Down returns to starting position
|
||||
obs, _, _, _ = env.step(ACTION_UP)
|
||||
self.assertEqual(obs, 24)
|
||||
obs, _, _, _ = env.step(ACTION_DOWN)
|
||||
self.assertEqual(obs, 36)
|
||||
obs, _, _, _ = env.step(ACTION_DOWN)
|
||||
self.assertEqual(obs, 36)
|
||||
|
||||
# Going down at the wall is a no-op
|
||||
env.step(ACTION_UP)
|
||||
obs, _, _, _ = env.step(ACTION_RIGHT)
|
||||
self.assertEqual(obs, 25)
|
||||
obs, _, _, _ = env.step(ACTION_DOWN)
|
||||
self.assertEqual(obs, 25)
|
||||
|
||||
# Move all the way to the right wall
|
||||
for _ in range(10):
|
||||
env.step(ACTION_RIGHT)
|
||||
obs, rew, done, _ = env.step(ACTION_RIGHT)
|
||||
self.assertEqual(obs, 35)
|
||||
self.assertEqual(rew, -1)
|
||||
self.assertEqual(done, False)
|
||||
|
||||
# Move to goal
|
||||
obs, rew, done, _ = env.step(ACTION_DOWN)
|
||||
self.assertEqual(obs, 47)
|
||||
self.assertEqual(rew, 10)
|
||||
self.assertEqual(done, True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
96
rllib/examples/policy/cliff_walking_wall_policy.py
Normal file
96
rllib/examples/policy/cliff_walking_wall_policy.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
from ray.rllib.policy.policy import Policy, ViewRequirement
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.typing import AlgorithmConfigDict, TensorStructType, TensorType
|
||||
from typing import Dict, Union, List, Tuple, Optional
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
import numpy as np
|
||||
import gym
|
||||
|
||||
|
||||
class CliffWalkingWallPolicy(Policy):
|
||||
"""Optimal RLlib policy for the CliffWalkingWallEnv environment, defined in
|
||||
ray/rllib/examples/env/cliff_walking_wall_env.py, with epsilon-greedy exploration.
|
||||
|
||||
The policy takes a random action with probability epsilon, specified
|
||||
by `config["epsilon"]`, and the optimal action with probability 1 - epsilon.
|
||||
"""
|
||||
|
||||
@override(Policy)
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.Space,
|
||||
action_space: gym.Space,
|
||||
config: AlgorithmConfigDict,
|
||||
):
|
||||
super().__init__(observation_space, action_space, config)
|
||||
|
||||
# Known optimal action dist for each of the 48 states and 4 actions
|
||||
self.action_dist = np.zeros((48, 4), dtype=float)
|
||||
# Starting state: go up
|
||||
self.action_dist[36] = (1, 0, 0, 0)
|
||||
# Cliff + Goal: never actually used, set to random
|
||||
self.action_dist[37:] = (0.25, 0.25, 0.25, 0.25)
|
||||
# Row 2; always go right
|
||||
self.action_dist[24:36] = (0, 1, 0, 0)
|
||||
# Row 0 and Row 1; go down or go right
|
||||
self.action_dist[0:24] = (0, 0.5, 0.5, 0)
|
||||
# Col 11; always go down, supercedes previous values
|
||||
self.action_dist[[11, 23, 35]] = (0, 0, 1, 0)
|
||||
assert np.allclose(self.action_dist.sum(-1), 1)
|
||||
|
||||
# Epsilon-Greedy action selection
|
||||
epsilon = config.get("epsilon", 0.0)
|
||||
self.action_dist = self.action_dist * (1 - epsilon) + epsilon / 4
|
||||
assert np.allclose(self.action_dist.sum(-1), 1)
|
||||
|
||||
# Attributes required for RLlib; note that while CliffWalkingWallPolicy
|
||||
# inherits from Policy, it actually implements TorchPolicyV2.
|
||||
self.view_requirements[SampleBatch.ACTION_PROB] = ViewRequirement()
|
||||
self.device = "cpu"
|
||||
self.model = None
|
||||
self.dist_class = TorchCategorical
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions(
|
||||
self,
|
||||
obs_batch: Union[List[TensorStructType], TensorStructType],
|
||||
state_batches: Optional[List[TensorType]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
obs = np.array(obs_batch, dtype=int)
|
||||
action_probs = self.action_dist[obs]
|
||||
actions = np.zeros(len(obs), dtype=int)
|
||||
for i in range(len(obs)):
|
||||
actions[i] = np.random.choice(4, p=action_probs[i])
|
||||
return (
|
||||
actions,
|
||||
[],
|
||||
{SampleBatch.ACTION_PROB: action_probs[np.arange(len(obs)), actions]},
|
||||
)
|
||||
|
||||
@override(Policy)
|
||||
def compute_log_likelihoods(
|
||||
self,
|
||||
actions: Union[List[TensorType], TensorType],
|
||||
obs_batch: Union[List[TensorType], TensorType],
|
||||
**kwargs,
|
||||
) -> TensorType:
|
||||
obs = np.array(obs_batch, dtype=int)
|
||||
actions = np.array(actions, dtype=int)
|
||||
# Compute action probs for all possible actions
|
||||
action_probs = self.action_dist[obs]
|
||||
# Take the action_probs corresponding to the specified actions
|
||||
action_probs = action_probs[np.arange(len(obs)), actions]
|
||||
# Ignore RuntimeWarning thrown by np.log(0) if action_probs is 0
|
||||
with np.errstate(divide="ignore"):
|
||||
return np.log(action_probs)
|
||||
|
||||
def action_distribution_fn(
|
||||
self, model, obs_batch: TensorStructType, **kwargs
|
||||
) -> Tuple[TensorType, type, List[TensorType]]:
|
||||
obs = np.array(obs_batch[SampleBatch.OBS], dtype=int)
|
||||
action_probs = self.action_dist[obs]
|
||||
# Ignore RuntimeWarning thrown by np.log(0) if action_probs is 0
|
||||
with np.errstate(divide="ignore"):
|
||||
return np.log(action_probs), TorchCategorical, None
|
|
@ -4,13 +4,10 @@ from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
|
|||
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
import numpy as np
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
|
@ -51,9 +48,6 @@ class DirectMethod(OffPolicyEstimator):
|
|||
TODO (Rohan138): Unify this with RLModule API.
|
||||
"""
|
||||
|
||||
assert (
|
||||
policy.config["framework"] == "torch"
|
||||
), "DirectMethod estimator only works with torch!"
|
||||
super().__init__(policy, gamma)
|
||||
|
||||
q_model_config = q_model_config or {}
|
||||
|
|
|
@ -2,7 +2,6 @@ import logging
|
|||
from typing import Dict, Any, Optional
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
import numpy as np
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
|
@ -11,8 +10,6 @@ from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict
|
|||
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
|
||||
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
|
|
|
@ -28,10 +28,10 @@ class FQETorchModel:
|
|||
model: ModelConfigDict = None,
|
||||
n_iters: int = 1,
|
||||
lr: float = 1e-3,
|
||||
delta: float = 1e-4,
|
||||
min_loss_threshold: float = 1e-4,
|
||||
clip_grad_norm: float = 100.0,
|
||||
minibatch_size: int = None,
|
||||
tau: float = 1.0,
|
||||
polyak_coef: float = 1.0,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
|
@ -45,11 +45,11 @@ class FQETorchModel:
|
|||
},
|
||||
n_iters: Number of gradient steps to run on batch, defaults to 1
|
||||
lr: Learning rate for Adam optimizer
|
||||
delta: Early stopping threshold if the mean loss < delta
|
||||
min_loss_threshold: Early stopping if mean loss < min_loss_threshold
|
||||
clip_grad_norm: Clip loss gradients to this maximum value
|
||||
minibatch_size: Minibatch size for training Q-function;
|
||||
if None, train on the whole batch
|
||||
tau: Polyak averaging factor for target Q-function
|
||||
polyak_coef: Polyak averaging factor for target Q-function
|
||||
"""
|
||||
self.policy = policy
|
||||
assert isinstance(
|
||||
|
@ -85,14 +85,14 @@ class FQETorchModel:
|
|||
).to(self.device)
|
||||
self.n_iters = n_iters
|
||||
self.lr = lr
|
||||
self.delta = delta
|
||||
self.min_loss_threshold = min_loss_threshold
|
||||
self.clip_grad_norm = clip_grad_norm
|
||||
self.minibatch_size = minibatch_size
|
||||
self.tau = tau
|
||||
self.polyak_coef = polyak_coef
|
||||
self.optimizer = torch.optim.Adam(self.q_model.variables(), self.lr)
|
||||
initializer = get_initializer("xavier_uniform", framework="torch")
|
||||
# Hard update target
|
||||
self.update_target(tau=1.0)
|
||||
self.update_target(polyak_coef=1.0)
|
||||
|
||||
def f(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
|
@ -158,7 +158,7 @@ class FQETorchModel:
|
|||
minibatch_losses.append(loss.item())
|
||||
iter_loss = sum(minibatch_losses) / len(minibatch_losses)
|
||||
losses.append(iter_loss)
|
||||
if iter_loss < self.delta:
|
||||
if iter_loss < self.min_loss_threshold:
|
||||
break
|
||||
self.update_target()
|
||||
return losses
|
||||
|
@ -182,16 +182,16 @@ class FQETorchModel:
|
|||
v_values = torch.sum(q_values * action_probs, axis=-1)
|
||||
return v_values
|
||||
|
||||
def update_target(self, tau=None):
|
||||
def update_target(self, polyak_coef=None):
|
||||
# Update_target will be called periodically to copy Q network to
|
||||
# target Q network, using (soft) tau-synching.
|
||||
tau = tau or self.tau
|
||||
# target Q network, using (soft) polyak_coef-synching.
|
||||
polyak_coef = polyak_coef or self.polyak_coef
|
||||
model_state_dict = self.q_model.state_dict()
|
||||
# Support partial (soft) synching.
|
||||
# If tau == 1.0: Full sync from Q-model to target Q-model.
|
||||
# If polyak_coef == 1.0: Full sync from Q-model to target Q-model.
|
||||
target_state_dict = self.target_q_model.state_dict()
|
||||
model_state_dict = {
|
||||
k: tau * model_state_dict[k] + (1 - tau) * v
|
||||
k: polyak_coef * model_state_dict[k] + (1 - polyak_coef) * v
|
||||
for k, v in target_state_dict.items()
|
||||
}
|
||||
|
||||
|
|
|
@ -1,193 +1,210 @@
|
|||
import copy
|
||||
import os
|
||||
import unittest
|
||||
import ray
|
||||
from pathlib import Path
|
||||
from typing import Type, Union, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
from ray.data import read_json
|
||||
from ray.rllib.algorithms import AlgorithmConfig
|
||||
from ray.rllib.algorithms.dqn import DQNConfig
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.examples.env.cliff_walking_wall_env import CliffWalkingWallEnv
|
||||
from ray.rllib.examples.policy.cliff_walking_wall_policy import CliffWalkingWallPolicy
|
||||
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
|
||||
from ray.rllib.offline.dataset_reader import DatasetReader
|
||||
from ray.rllib.offline.estimators import (
|
||||
ImportanceSampling,
|
||||
WeightedImportanceSampling,
|
||||
DirectMethod,
|
||||
DoublyRobust,
|
||||
ImportanceSampling,
|
||||
WeightedImportanceSampling,
|
||||
)
|
||||
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
from ray.rllib.policy.sample_batch import concat_samples
|
||||
from ray.rllib.utils.test_utils import check
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, concat_samples
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from pathlib import Path
|
||||
import os
|
||||
import copy
|
||||
import numpy as np
|
||||
import gym
|
||||
import torch
|
||||
from ray.rllib.utils.test_utils import check
|
||||
|
||||
import ray
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
class TestOPE(unittest.TestCase):
|
||||
"""Compilation tests for using OPE both standalone and in an RLlib Algorithm"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init()
|
||||
rllib_dir = Path(__file__).parent.parent.parent.parent
|
||||
train_data = os.path.join(rllib_dir, "tests/data/cartpole/large.json")
|
||||
eval_data = train_data
|
||||
train_data = os.path.join(rllib_dir, "tests/data/cartpole/small.json")
|
||||
|
||||
env_name = "CartPole-v0"
|
||||
cls.gamma = 0.99
|
||||
n_episodes = 40
|
||||
cls.q_model_config = {"n_iters": 600}
|
||||
n_episodes = 3
|
||||
cls.q_model_config = {"n_iters": 160}
|
||||
|
||||
config = (
|
||||
DQNConfig()
|
||||
.environment(env=env_name)
|
||||
.training(gamma=cls.gamma)
|
||||
.rollouts(num_rollout_workers=3, batch_mode="complete_episodes")
|
||||
.rollouts(batch_mode="complete_episodes")
|
||||
.framework("torch")
|
||||
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", 0)))
|
||||
.offline_data(input_=train_data)
|
||||
.offline_data(
|
||||
input_="dataset", input_config={"format": "json", "paths": train_data}
|
||||
)
|
||||
.evaluation(
|
||||
evaluation_interval=None,
|
||||
evaluation_interval=1,
|
||||
evaluation_duration=n_episodes,
|
||||
evaluation_num_workers=1,
|
||||
evaluation_duration_unit="episodes",
|
||||
evaluation_config={"input": eval_data},
|
||||
off_policy_estimation_methods={
|
||||
"is": {"type": ImportanceSampling},
|
||||
"wis": {"type": WeightedImportanceSampling},
|
||||
"dm_fqe": {
|
||||
"type": DirectMethod,
|
||||
"q_model_config": {"type": FQETorchModel},
|
||||
},
|
||||
"dr_fqe": {
|
||||
"type": DoublyRobust,
|
||||
"q_model_config": {"type": FQETorchModel},
|
||||
},
|
||||
"dm_fqe": {"type": DirectMethod},
|
||||
"dr_fqe": {"type": DoublyRobust},
|
||||
},
|
||||
)
|
||||
)
|
||||
cls.algo = config.build()
|
||||
|
||||
# Train DQN for evaluation policy
|
||||
for _ in range(n_episodes):
|
||||
cls.algo.train()
|
||||
|
||||
# Read n_episodes of data, assuming that one line is one episode
|
||||
reader = JsonReader(eval_data)
|
||||
cls.batch = reader.next()
|
||||
for _ in range(n_episodes - 1):
|
||||
cls.batch = concat_samples([cls.batch, reader.next()])
|
||||
reader = DatasetReader(read_json(train_data))
|
||||
batches = [reader.next() for _ in range(n_episodes)]
|
||||
cls.batch = concat_samples(batches)
|
||||
cls.n_episodes = len(cls.batch.split_by_episode())
|
||||
print("Episodes:", cls.n_episodes, "Steps:", cls.batch.count)
|
||||
|
||||
cls.mean_ret = {}
|
||||
cls.std_ret = {}
|
||||
cls.losses = {}
|
||||
|
||||
# Simulate Monte-Carlo rollouts
|
||||
mc_ret = []
|
||||
env = gym.make(env_name)
|
||||
for _ in range(n_episodes):
|
||||
obs = env.reset()
|
||||
done = False
|
||||
rewards = []
|
||||
while not done:
|
||||
act = cls.algo.compute_single_action(obs)
|
||||
obs, reward, done, _ = env.step(act)
|
||||
rewards.append(reward)
|
||||
ret = 0
|
||||
for r in reversed(rewards):
|
||||
ret = r + cls.gamma * ret
|
||||
mc_ret.append(ret)
|
||||
|
||||
cls.mean_ret["simulation"] = np.mean(mc_ret)
|
||||
cls.std_ret["simulation"] = np.std(mc_ret)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
print("Standalone OPE results")
|
||||
print("Mean:")
|
||||
print(*list(cls.mean_ret.items()), sep="\n")
|
||||
print("Stddev:")
|
||||
print(*list(cls.std_ret.items()), sep="\n")
|
||||
print("Losses:")
|
||||
print(*list(cls.losses.items()), sep="\n")
|
||||
ray.shutdown()
|
||||
|
||||
def test_is(self):
|
||||
name = "is"
|
||||
def test_ope_standalone(self):
|
||||
# Test all OPE methods standalone
|
||||
estimator_outputs = {
|
||||
"v_behavior",
|
||||
"v_behavior_std",
|
||||
"v_target",
|
||||
"v_target_std",
|
||||
"v_gain",
|
||||
"v_gain_std",
|
||||
}
|
||||
estimator = ImportanceSampling(
|
||||
policy=self.algo.get_policy(),
|
||||
gamma=self.gamma,
|
||||
)
|
||||
estimates = estimator.estimate(self.batch)
|
||||
self.mean_ret[name] = estimates["v_target"]
|
||||
self.std_ret[name] = estimates["v_target_std"]
|
||||
self.assertEqual(estimates.keys(), estimator_outputs)
|
||||
|
||||
def test_wis(self):
|
||||
name = "wis"
|
||||
estimator = WeightedImportanceSampling(
|
||||
policy=self.algo.get_policy(),
|
||||
gamma=self.gamma,
|
||||
)
|
||||
estimates = estimator.estimate(self.batch)
|
||||
self.mean_ret[name] = estimates["v_target"]
|
||||
self.std_ret[name] = estimates["v_target_std"]
|
||||
self.assertEqual(estimates.keys(), estimator_outputs)
|
||||
|
||||
def test_dm_fqe(self):
|
||||
name = "dm_fqe"
|
||||
estimator = DirectMethod(
|
||||
policy=self.algo.get_policy(),
|
||||
gamma=self.gamma,
|
||||
q_model_config={"type": FQETorchModel, **self.q_model_config},
|
||||
q_model_config=self.q_model_config,
|
||||
)
|
||||
self.losses[name] = estimator.train(self.batch)
|
||||
losses = estimator.train(self.batch)
|
||||
assert losses, "DM estimator did not return mean loss"
|
||||
estimates = estimator.estimate(self.batch)
|
||||
self.mean_ret[name] = estimates["v_target"]
|
||||
self.std_ret[name] = estimates["v_target_std"]
|
||||
self.assertEqual(estimates.keys(), estimator_outputs)
|
||||
|
||||
def test_dr_fqe(self):
|
||||
name = "dr_fqe"
|
||||
estimator = DoublyRobust(
|
||||
policy=self.algo.get_policy(),
|
||||
gamma=self.gamma,
|
||||
q_model_config={"type": FQETorchModel, **self.q_model_config},
|
||||
q_model_config=self.q_model_config,
|
||||
)
|
||||
self.losses[name] = estimator.train(self.batch)
|
||||
losses = estimator.train(self.batch)
|
||||
assert losses, "DM estimator did not return mean loss"
|
||||
estimates = estimator.estimate(self.batch)
|
||||
self.mean_ret[name] = estimates["v_target"]
|
||||
self.std_ret[name] = estimates["v_target_std"]
|
||||
self.assertEqual(estimates.keys(), estimator_outputs)
|
||||
|
||||
def test_ope_in_algo(self):
|
||||
# Test OPE in DQN, during training as well as by calling evaluate()
|
||||
results = self.algo.train()
|
||||
ope_results = results["evaluation"]["off_policy_estimator"]
|
||||
# Check that key exists AND is not {}
|
||||
self.assertEqual(set(ope_results.keys()), {"is", "wis", "dm_fqe", "dr_fqe"})
|
||||
|
||||
# Check algo.evaluate() manually as well
|
||||
results = self.algo.evaluate()
|
||||
print("OPE in Algorithm results")
|
||||
estimates = results["evaluation"]["off_policy_estimator"]
|
||||
mean_est = {k: v["v_target"] for k, v in estimates.items()}
|
||||
std_est = {k: v["v_target_std"] for k, v in estimates.items()}
|
||||
ope_results = results["evaluation"]["off_policy_estimator"]
|
||||
self.assertEqual(set(ope_results.keys()), {"is", "wis", "dm_fqe", "dr_fqe"})
|
||||
|
||||
print("Mean:")
|
||||
print(*list(mean_est.items()), sep="\n")
|
||||
print("Stddev:")
|
||||
print(*list(std_est.items()), sep="\n")
|
||||
print("\n\n\n")
|
||||
|
||||
def test_fqe_model(self):
|
||||
# Test FQETorchModel for:
|
||||
# (1) Check that it does not modify the underlying batch during training
|
||||
# (2) Check that the stoppign criteria from FQE are working correctly
|
||||
# (3) Check that using fqe._compute_action_probs equals brute force
|
||||
# iterating over all actions with policy.compute_log_likelihoods
|
||||
class TestFQE(unittest.TestCase):
|
||||
"""Compilation and learning tests for the Fitted-Q Evaluation model"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init()
|
||||
env = CliffWalkingWallEnv()
|
||||
cls.policy = CliffWalkingWallPolicy(
|
||||
observation_space=env.observation_space,
|
||||
action_space=env.action_space,
|
||||
config={},
|
||||
)
|
||||
cls.gamma = 0.99
|
||||
# Collect single episode under optimal policy
|
||||
obs_batch = []
|
||||
new_obs = []
|
||||
actions = []
|
||||
action_prob = []
|
||||
rewards = []
|
||||
dones = []
|
||||
obs = env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
obs_batch.append(obs)
|
||||
act, _, extra = cls.policy.compute_single_action(obs)
|
||||
actions.append(act)
|
||||
action_prob.append(extra["action_prob"])
|
||||
obs, rew, done, _ = env.step(act)
|
||||
new_obs.append(obs)
|
||||
rewards.append(rew)
|
||||
dones.append(done)
|
||||
cls.batch = SampleBatch(
|
||||
obs=obs_batch,
|
||||
actions=actions,
|
||||
action_prob=action_prob,
|
||||
rewards=rewards,
|
||||
dones=dones,
|
||||
new_obs=new_obs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_fqe_compilation_and_stopping(self):
|
||||
"""Compilation tests for FQETorchModel.
|
||||
|
||||
(1) Check that it does not modify the underlying batch during training
|
||||
(2) Check that the stopping criteria from FQE are working correctly
|
||||
(3) Check that using fqe._compute_action_probs equals brute force
|
||||
iterating over all actions with policy.compute_log_likelihoods
|
||||
"""
|
||||
fqe = FQETorchModel(
|
||||
policy=self.algo.get_policy(),
|
||||
policy=self.policy,
|
||||
gamma=self.gamma,
|
||||
**self.q_model_config,
|
||||
)
|
||||
tmp_batch = copy.deepcopy(self.batch)
|
||||
losses = fqe.train(self.batch)
|
||||
|
||||
# Make sure FQETorchModel.train() does not modify self.batch
|
||||
# Make sure FQETorchModel.train() does not modify the batch
|
||||
check(tmp_batch, self.batch)
|
||||
|
||||
# Make sure FQE stopping criteria are respected
|
||||
assert (
|
||||
len(losses) == fqe.n_iters or losses[-1] < fqe.delta
|
||||
), f"FQE.train() terminated early in {len(losses)} steps with final loss"
|
||||
f"{losses[-1]} for n_iters: {fqe.n_iters} and delta: {fqe.delta}"
|
||||
assert len(losses) == fqe.n_iters or losses[-1] < fqe.min_loss_threshold, (
|
||||
f"FQE.train() terminated early in {len(losses)} steps with final loss"
|
||||
f"{losses[-1]} for n_iters: {fqe.n_iters} and "
|
||||
f"min_loss_threshold: {fqe.min_loss_threshold}"
|
||||
)
|
||||
|
||||
# Test fqe._compute_action_probs against "brute force" method
|
||||
# of computing log_prob for each possible action individually
|
||||
|
@ -199,22 +216,424 @@ class TestOPE(unittest.TestCase):
|
|||
tmp_probs = []
|
||||
for act in range(fqe.policy.action_space.n):
|
||||
tmp_actions = np.zeros_like(self.batch["actions"]) + act
|
||||
log_probs = fqe.policy.compute_log_likelihoods(
|
||||
log_probs = self.policy.compute_log_likelihoods(
|
||||
actions=tmp_actions,
|
||||
obs_batch=self.batch["obs"],
|
||||
)
|
||||
tmp_probs.append(torch.exp(log_probs))
|
||||
tmp_probs = torch.stack(tmp_probs).transpose(0, 1)
|
||||
tmp_probs = convert_to_numpy(tmp_probs)
|
||||
tmp_probs.append(np.exp(log_probs))
|
||||
tmp_probs = np.stack(tmp_probs).T
|
||||
check(action_probs, tmp_probs, decimals=3)
|
||||
|
||||
def test_multiple_inputs(self):
|
||||
# TODO (Rohan138): Test with multiple input files
|
||||
pass
|
||||
def test_fqe_optimal_convergence(self):
|
||||
"""Test that FQE converges to the true Q-values for an optimal trajectory
|
||||
|
||||
self.batch is deterministic since it is collected under a CliffWalkingWallPolicy
|
||||
with epsilon = 0.0; check that FQE converges to the true Q-values for self.batch
|
||||
"""
|
||||
|
||||
# If self.batch["rewards"] =
|
||||
# [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 10],
|
||||
# and gamma = 0.99, the discounted returns i.e. optimal Q-values are as follows:
|
||||
|
||||
q_values = np.zeros(len(self.batch["rewards"]), dtype=float)
|
||||
q_values[-1] = self.batch["rewards"][-1]
|
||||
for t in range(len(self.batch["rewards"]) - 2, -1, -1):
|
||||
q_values[t] = self.batch["rewards"][t] + self.gamma * q_values[t + 1]
|
||||
|
||||
print(q_values)
|
||||
|
||||
q_model_config = {
|
||||
"polyak_coef": 1.0,
|
||||
"model": {
|
||||
"fcnet_hiddens": [],
|
||||
"activation": "linear",
|
||||
},
|
||||
"lr": 0.01,
|
||||
"n_iters": 5000,
|
||||
}
|
||||
|
||||
fqe = FQETorchModel(
|
||||
policy=self.policy,
|
||||
gamma=self.gamma,
|
||||
**q_model_config,
|
||||
)
|
||||
losses = fqe.train(self.batch)
|
||||
print(losses[-10:])
|
||||
estimates = fqe.estimate_v(self.batch)
|
||||
print(estimates)
|
||||
check(estimates, q_values, decimals=1)
|
||||
|
||||
|
||||
def get_cliff_walking_wall_policy_and_data(
|
||||
num_episodes: int,
|
||||
gamma: float,
|
||||
epsilon: float,
|
||||
) -> Tuple[Policy, SampleBatch, float, float]:
|
||||
"""Collect a cliff_walking_wall policy and data with epsilon-greedy exploration.
|
||||
|
||||
Args:
|
||||
num_episodes: Minimum number of episodes to collect
|
||||
gamma: discount factor
|
||||
epsilon: epsilon-greedy exploration value
|
||||
|
||||
Returns:
|
||||
A Tuple consisting of:
|
||||
- A CliffWalkingWallPolicy with exploration parameter epsilon
|
||||
- A SampleBatch of at least `num_episodes` CliffWalkingWall episodes
|
||||
collected using epsilon-greedy exploration
|
||||
- The mean of the discounted return over the collected episodes
|
||||
- The stddev of the discounted return over the collected episodes
|
||||
|
||||
"""
|
||||
config = (
|
||||
AlgorithmConfig()
|
||||
.rollouts(batch_mode="complete_episodes")
|
||||
.environment(disable_env_checking=True)
|
||||
.experimental(_disable_preprocessor_api=True)
|
||||
)
|
||||
config = config.to_dict()
|
||||
config["epsilon"] = epsilon
|
||||
|
||||
env = CliffWalkingWallEnv()
|
||||
policy = CliffWalkingWallPolicy(
|
||||
env.observation_space, env.action_space, {"epsilon": epsilon}
|
||||
)
|
||||
workers = WorkerSet(
|
||||
env_creator=lambda env_config: CliffWalkingWallEnv(),
|
||||
policy_class=CliffWalkingWallPolicy,
|
||||
trainer_config=config,
|
||||
num_workers=4,
|
||||
)
|
||||
ep_ret = []
|
||||
batches = []
|
||||
n_eps = 0
|
||||
while n_eps < num_episodes:
|
||||
batch = synchronous_parallel_sample(worker_set=workers)
|
||||
for episode in batch.split_by_episode():
|
||||
ret = 0
|
||||
for r in episode[SampleBatch.REWARDS][::-1]:
|
||||
ret = r + gamma * ret
|
||||
ep_ret.append(ret)
|
||||
n_eps += 1
|
||||
batches.append(batch)
|
||||
workers.stop()
|
||||
return policy, concat_samples(batches), np.mean(ep_ret), np.std(ep_ret)
|
||||
|
||||
|
||||
def check_estimate(
|
||||
*,
|
||||
estimator_cls: Type[Union[DirectMethod, DoublyRobust]],
|
||||
gamma: float,
|
||||
q_model_config: Dict,
|
||||
policy: Policy,
|
||||
batch: SampleBatch,
|
||||
mean_ret: float,
|
||||
std_ret: float,
|
||||
) -> None:
|
||||
"""Compute off-policy estimates and compare them to the true discounted return.
|
||||
|
||||
Args:
|
||||
estimator_cls: Off-Policy Estimator class to be used
|
||||
gamma: discount factor
|
||||
q_model_config: Optional config settings for the estimator's Q-model
|
||||
policy: The target policy we compute estimates for
|
||||
batch: The behavior data we use for off-policy estimation
|
||||
mean_ret: The mean discounted episode return over the batch
|
||||
std_ret: The standard deviation corresponding to mean_ret
|
||||
|
||||
Raises:
|
||||
AssertionError if the estimated mean episode return computed by
|
||||
the off-policy estimator does not fall within one standard deviation of
|
||||
the values specified above i.e. [mean_ret - std_ret, mean_ret + std_ret]
|
||||
"""
|
||||
estimator = estimator_cls(
|
||||
policy=policy,
|
||||
gamma=gamma,
|
||||
q_model_config=q_model_config,
|
||||
)
|
||||
loss = estimator.train(batch)["loss"]
|
||||
estimates = estimator.estimate(batch)
|
||||
est_mean = estimates["v_target"]
|
||||
est_std = estimates["v_target_std"]
|
||||
print(f"{est_mean:.2f}, {est_std:.2f}, {mean_ret:.2f}, {std_ret:.2f}, {loss:.2f}")
|
||||
# Assert that the two mean +- stddev intervals overlap
|
||||
assert mean_ret - std_ret <= est_mean <= mean_ret + std_ret, (
|
||||
f"DirectMethod estimate {est_mean:.2f} with stddev "
|
||||
f"{est_std:.2f} does not converge to true discounted return "
|
||||
f"{mean_ret:.2f} with stddev {std_ret:.2f}!"
|
||||
)
|
||||
|
||||
|
||||
class TestOPELearning(unittest.TestCase):
|
||||
"""Learning tests for the DirectMethod and DoublyRobust estimators.
|
||||
|
||||
Generates three GridWorldWallPolicy policies and batches with epsilon = 0.2, 0.5,
|
||||
and 0.8 respectively using `get_cliff_walking_wall_policy_and_data`.
|
||||
|
||||
Tests that the estimators converge on all eight combinations of evaluation policy
|
||||
and behavior batch using `check_estimates`, except random policy-expert batch.
|
||||
|
||||
Note: We do not test OPE with the "random" policy (epsilon=0.8)
|
||||
and "expert" (epsilon=0.2) batch because of the large policy-data mismatch. The
|
||||
expert batch is unlikely to contain the longer trajectories that would be observed
|
||||
under the random policy, thus the OPE estimate is flaky and inaccurate.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init()
|
||||
# Epsilon-greedy exploration values
|
||||
random_eps = 0.8
|
||||
mixed_eps = 0.5
|
||||
expert_eps = 0.2
|
||||
num_episodes = 64
|
||||
cls.gamma = 0.99
|
||||
|
||||
# Config settings for FQE model
|
||||
cls.q_model_config = {
|
||||
"n_iters": 800,
|
||||
"minibatch_size": 64,
|
||||
"polyak_coef": 1.0,
|
||||
"model": {
|
||||
"fcnet_hiddens": [],
|
||||
"activation": "linear",
|
||||
},
|
||||
"lr": 0.01,
|
||||
}
|
||||
|
||||
(
|
||||
cls.random_policy,
|
||||
cls.random_batch,
|
||||
cls.random_reward,
|
||||
cls.random_std,
|
||||
) = get_cliff_walking_wall_policy_and_data(num_episodes, cls.gamma, random_eps)
|
||||
print(
|
||||
f"Collected random batch of {cls.random_batch.count} steps "
|
||||
f"with return {cls.random_reward} stddev {cls.random_std}"
|
||||
)
|
||||
|
||||
(
|
||||
cls.mixed_policy,
|
||||
cls.mixed_batch,
|
||||
cls.mixed_reward,
|
||||
cls.mixed_std,
|
||||
) = get_cliff_walking_wall_policy_and_data(num_episodes, cls.gamma, mixed_eps)
|
||||
print(
|
||||
f"Collected mixed batch of {cls.mixed_batch.count} steps "
|
||||
f"with return {cls.mixed_reward} stddev {cls.mixed_std}"
|
||||
)
|
||||
|
||||
(
|
||||
cls.expert_policy,
|
||||
cls.expert_batch,
|
||||
cls.expert_reward,
|
||||
cls.expert_std,
|
||||
) = get_cliff_walking_wall_policy_and_data(num_episodes, cls.gamma, expert_eps)
|
||||
print(
|
||||
f"Collected expert batch of {cls.expert_batch.count} steps "
|
||||
f"with return {cls.expert_reward} stddev {cls.expert_std}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def test_dm_random_policy_random_data(self):
|
||||
print("Test DirectMethod on random policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.random_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.random_reward,
|
||||
std_ret=self.random_std,
|
||||
)
|
||||
|
||||
def test_dm_random_policy_mixed_data(self):
|
||||
print("Test DirectMethod on random policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.random_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.random_reward,
|
||||
std_ret=self.random_std,
|
||||
)
|
||||
|
||||
def test_dm_mixed_policy_random_data(self):
|
||||
print("Test DirectMethod on mixed policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
)
|
||||
|
||||
def test_dm_mixed_policy_mixed_data(self):
|
||||
print("Test DirectMethod on mixed policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
)
|
||||
|
||||
def test_dm_mixed_policy_expert_data(self):
|
||||
print("Test DirectMethod on mixed policy on expert dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.expert_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
)
|
||||
|
||||
def test_dm_expert_policy_random_data(self):
|
||||
print("Test DirectMethod on expert policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
)
|
||||
|
||||
def test_dm_expert_policy_mixed_data(self):
|
||||
print("Test DirectMethod on expert policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
)
|
||||
|
||||
def test_dm_expert_policy_expert_data(self):
|
||||
print("Test DirectMethod on expert policy on expert dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DirectMethod,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.expert_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
)
|
||||
|
||||
def test_dr_random_policy_random_data(self):
|
||||
print("Test DoublyRobust on random policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.random_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.random_reward,
|
||||
std_ret=self.random_std,
|
||||
)
|
||||
|
||||
def test_dr_random_policy_mixed_data(self):
|
||||
print("Test DoublyRobust on random policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.random_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.random_reward,
|
||||
std_ret=self.random_std,
|
||||
)
|
||||
|
||||
def test_dr_mixed_policy_random_data(self):
|
||||
print("Test DoublyRobust on mixed policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
)
|
||||
|
||||
def test_dr_mixed_policy_mixed_data(self):
|
||||
print("Test DoublyRobust on mixed policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
)
|
||||
|
||||
def test_dr_mixed_policy_expert_data(self):
|
||||
print("Test DoublyRobust on mixed policy on expert dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.mixed_policy,
|
||||
batch=self.expert_batch,
|
||||
mean_ret=self.mixed_reward,
|
||||
std_ret=self.mixed_std,
|
||||
)
|
||||
|
||||
def test_dr_expert_policy_random_data(self):
|
||||
print("Test DoublyRobust on expert policy on random dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.random_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
)
|
||||
|
||||
def test_dr_expert_policy_mixed_data(self):
|
||||
print("Test DoublyRobust on expert policy on mixed dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.mixed_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
)
|
||||
|
||||
def test_dr_expert_policy_expert_data(self):
|
||||
print("Test DoublyRobust on expert policy on expert dataset")
|
||||
check_estimate(
|
||||
estimator_cls=DoublyRobust,
|
||||
gamma=self.gamma,
|
||||
q_model_config=self.q_model_config,
|
||||
policy=self.expert_policy,
|
||||
batch=self.expert_batch,
|
||||
mean_ret=self.expert_reward,
|
||||
std_ret=self.expert_std,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -240,6 +240,6 @@ def compute_log_likelihoods_from_input_dict(
|
|||
state_batches=[batch[k] for k in state_keys],
|
||||
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
|
||||
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
|
||||
actions_normalized=policy.config["actions_in_input_normalized"],
|
||||
actions_normalized=policy.config.get("actions_in_input_normalized", False),
|
||||
)
|
||||
return log_likelihoods
|
||||
|
|
Loading…
Add table
Reference in a new issue