mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
53 lines
1.6 KiB
Python
53 lines
1.6 KiB
Python
from typing import Any, List
|
|
|
|
import numpy as np
|
|
|
|
from ray.rllib.connectors.connector import (
|
|
AgentConnector,
|
|
ConnectorContext,
|
|
register_connector,
|
|
)
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.typing import AgentConnectorDataType
|
|
from ray.util.annotations import PublicAPI
|
|
|
|
|
|
@PublicAPI(stability="alpha")
|
|
class ClipRewardAgentConnector(AgentConnector):
|
|
def __init__(self, ctx: ConnectorContext, sign=False, limit=None):
|
|
super().__init__(ctx)
|
|
assert (
|
|
not sign or not limit
|
|
), "should not enable both sign and limit reward clipping."
|
|
self.sign = sign
|
|
self.limit = limit
|
|
|
|
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
|
|
d = ac_data.data
|
|
assert (
|
|
type(d) == dict
|
|
), "Single agent data must be of type Dict[str, TensorStructType]"
|
|
|
|
assert SampleBatch.REWARDS in d, "input data does not have reward column."
|
|
if self.sign:
|
|
d[SampleBatch.REWARDS] = np.sign(d[SampleBatch.REWARDS])
|
|
elif self.limit:
|
|
d[SampleBatch.REWARDS] = np.clip(
|
|
d[SampleBatch.REWARDS],
|
|
a_min=-self.limit,
|
|
a_max=self.limit,
|
|
)
|
|
return ac_data
|
|
|
|
def to_state(self):
|
|
return ClipRewardAgentConnector.__name__, {
|
|
"sign": self.sign,
|
|
"limit": self.limit,
|
|
}
|
|
|
|
@staticmethod
|
|
def from_state(ctx: ConnectorContext, params: List[Any]):
|
|
return ClipRewardAgentConnector(ctx, **params)
|
|
|
|
|
|
register_connector(ClipRewardAgentConnector.__name__, ClipRewardAgentConnector)
|