mirror of
https://github.com/vale981/ray
synced 2025-03-09 04:46:38 -04:00
46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
![]() |
import logging
|
||
|
|
||
|
from ray.rllib.agents.trainer import with_common_config
|
||
|
from ray.rllib.agents.trainer_template import build_trainer
|
||
|
from ray.rllib.contrib.bandits.agents.policy import BanditPolicy
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
# yapf: disable
|
||
|
# __sphinx_doc_begin__
|
||
|
TS_CONFIG = with_common_config({
|
||
|
# No remote workers by default.
|
||
|
"num_workers": 0,
|
||
|
"use_pytorch": True,
|
||
|
|
||
|
# Do online learning one step at a time.
|
||
|
"rollout_fragment_length": 1,
|
||
|
"train_batch_size": 1,
|
||
|
|
||
|
# Bandits cant afford to do one timestep per iteration as it is extremely
|
||
|
# slow because of metrics collection overhead. This setting means that the
|
||
|
# agent will be trained for 100 times in one iteration of Rllib
|
||
|
"timesteps_per_iteration": 100,
|
||
|
|
||
|
"exploration_config": {
|
||
|
"type": "ray.rllib.contrib.bandits.exploration.ThompsonSampling"
|
||
|
}
|
||
|
})
|
||
|
# __sphinx_doc_end__
|
||
|
# yapf: enable
|
||
|
|
||
|
|
||
|
def get_stats(trainer):
|
||
|
env_metrics = trainer.collect_metrics()
|
||
|
stats = trainer.optimizer.stats()
|
||
|
# Uncomment if regret at each time step is needed
|
||
|
# stats.update({"all_regrets": trainer.get_policy().regrets})
|
||
|
return dict(env_metrics, **stats)
|
||
|
|
||
|
|
||
|
LinTSTrainer = build_trainer(
|
||
|
name="LinTS",
|
||
|
default_config=TS_CONFIG,
|
||
|
default_policy=BanditPolicy,
|
||
|
collect_metrics_fn=get_stats)
|