ray/rllib/offline/is_estimator.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

42 lines
1.3 KiB
Python

from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
OffPolicyEstimate
from ray.rllib.utils.annotations import override
class ImportanceSamplingEstimator(OffPolicyEstimator):
"""The step-wise IS estimator.
Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf"""
def __init__(self, policy, gamma):
OffPolicyEstimator.__init__(self, policy, gamma)
@override(OffPolicyEstimator)
def estimate(self, batch):
self.check_can_estimate_for(batch)
rewards, old_prob = batch["rewards"], batch["action_prob"]
new_prob = self.action_prob(batch)
# calculate importance ratios
p = []
for t in range(batch.count - 1):
if t == 0:
pt_prev = 1.0
else:
pt_prev = p[t - 1]
p.append(pt_prev * new_prob[t] / old_prob[t])
# calculate stepwise IS estimate
V_prev, V_step_IS = 0.0, 0.0
for t in range(batch.count - 1):
V_prev += rewards[t] * self.gamma**t
V_step_IS += p[t] * rewards[t] * self.gamma**t
estimation = OffPolicyEstimate(
"is", {
"V_prev": V_prev,
"V_step_IS": V_step_IS,
"V_gain_est": V_step_IS / max(1e-8, V_prev),
})
return estimation