Source code for egta.savesched
"""Module for a scheduler that saves all profile data"""
import numpy as np
from gameanalysis import rsgame
from gameanalysis import paygame
from egta import profsched
class _SaveScheduler(profsched._Scheduler):  # pylint: disable=protected-access
    """A scheduler that saves all of the payoff data for output later
    Parameters
    ----------
    game : BaseGame
        The base game of the scheduler.
    sched : Scheduler
        The base scheduler to save payoffs from
    """
    def __init__(self, sched):
        super().__init__(sched.role_names, sched.strat_names, sched.num_role_players)
        self._sched = sched
        self._game = paygame.samplegame_copy(rsgame.empty_copy(self))
        self._profiles = []
        self._payoffs = []
    async def sample_payoffs(self, profile):
        payoff = await self._sched.sample_payoffs(profile)
        self._profiles.append(profile)
        self._payoffs.append(payoff)
        return payoff
    def get_game(self):
        """Get the game with the observed data"""
        if self._profiles:
            new_profs = np.concatenate(
                [self._game.flat_profiles(), np.stack(self._profiles)]
            )
            new_pays = np.concatenate(
                [self._game.flat_payoffs(), np.stack(self._payoffs)]
            )
            self._profiles.clear()
            self._payoffs.clear()
            self._game = paygame.samplegame_replace_flat(
                self._game, new_profs, new_pays
            )
        return self._game
    def __str__(self):
        return str(self._sched)
[docs]def savesched(sched):
    """Create a save scheduler"""
    return _SaveScheduler(sched)