Source code for egta.savesched
"""Module for a scheduler that saves all profile data"""
import numpy as np
from gameanalysis import rsgame
from gameanalysis import paygame
from egta import profsched
class _SaveScheduler(profsched._Scheduler): # pylint: disable=protected-access
"""A scheduler that saves all of the payoff data for output later
Parameters
----------
game : BaseGame
The base game of the scheduler.
sched : Scheduler
The base scheduler to save payoffs from
"""
def __init__(self, sched):
super().__init__(sched.role_names, sched.strat_names, sched.num_role_players)
self._sched = sched
self._game = paygame.samplegame_copy(rsgame.empty_copy(self))
self._profiles = []
self._payoffs = []
async def sample_payoffs(self, profile):
payoff = await self._sched.sample_payoffs(profile)
self._profiles.append(profile)
self._payoffs.append(payoff)
return payoff
def get_game(self):
"""Get the game with the observed data"""
if self._profiles:
new_profs = np.concatenate(
[self._game.flat_profiles(), np.stack(self._profiles)]
)
new_pays = np.concatenate(
[self._game.flat_payoffs(), np.stack(self._payoffs)]
)
self._profiles.clear()
self._payoffs.clear()
self._game = paygame.samplegame_replace_flat(
self._game, new_profs, new_pays
)
return self._game
def __str__(self):
return str(self._sched)
[docs]def savesched(sched):
"""Create a save scheduler"""
return _SaveScheduler(sched)