ray/rllib/offline/estimators/doubly_robust.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

77 lines
3 KiB
Python
Raw Normal View History

from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimate
from ray.rllib.offline.estimators.direct_method import DirectMethod, train_test_split
from ray.rllib.utils.annotations import ExperimentalAPI, override
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.numpy import convert_to_numpy
import numpy as np
@ExperimentalAPI
class DoublyRobust(DirectMethod):
"""The Doubly Robust (DR) estimator.
DR estimator described in https://arxiv.org/pdf/1511.03722.pdf"""
@override(DirectMethod)
def estimate(
self,
batch: SampleBatchType,
) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)
estimates = []
# Split data into train and test batches
for train_episodes, test_episodes in train_test_split(
batch,
self.train_test_split_val,
self.k,
):
# Train Q-function
if train_episodes:
# Reinitialize model
self.model.reset()
train_batch = SampleBatch.concat_samples(train_episodes)
losses = self.train(train_batch)
self.losses.append(losses)
# Calculate doubly robust OPE estimates
for episode in test_episodes:
rewards, old_prob = episode["rewards"], episode["action_prob"]
new_prob = np.exp(self.action_log_likelihood(episode))
v_old = 0.0
v_new = 0.0
q_values = self.model.estimate_q(
episode[SampleBatch.OBS], episode[SampleBatch.ACTIONS]
)
q_values = convert_to_numpy(q_values)
all_actions = np.zeros([episode.count, self.policy.action_space.n])
all_actions[:] = np.arange(self.policy.action_space.n)
# Two transposes required for torch.distributions to work
tmp_episode = episode.copy()
tmp_episode[SampleBatch.ACTIONS] = all_actions.T
action_probs = np.exp(self.action_log_likelihood(tmp_episode)).T
v_values = self.model.estimate_v(episode[SampleBatch.OBS], action_probs)
v_values = convert_to_numpy(v_values)
for t in reversed(range(episode.count)):
v_old = rewards[t] + self.gamma * v_old
v_new = v_values[t] + (new_prob[t] / old_prob[t]) * (
rewards[t] + self.gamma * v_new - q_values[t]
)
v_new = v_new.item()
estimates.append(
OffPolicyEstimate(
self.name,
{
"v_old": v_old,
"v_new": v_new,
"v_gain": v_new / max(1e-8, v_old),
},
)
)
return estimates