import wandb import time from torch.utils.tensorboard import SummaryWriter total_rounds = {"Free": 0, "Go": 0, "Attack": 0} win_rounds = {"Free": 0, "Go": 0, "Attack": 0} # class for wandb recording class WandbRecorder: def __init__(self, game_name: str, game_type: str, run_name: str, _args) -> None: # init wandb self.game_name = game_name self.game_type = game_type self._args = _args self.run_name = run_name if self._args.wandb_track: wandb.init( project=self.game_name, entity=self._args.wandb_entity, sync_tensorboard=True, config=vars(self._args), name=self.run_name, monitor_gym=True, save_code=True, ) self.writer = SummaryWriter(f"runs/{self.run_name}") self.writer.add_text( "hyperparameters", "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(self._args).items()])), ) def add_target_scalar( self, target_name, thisT, v_loss, dis_pg_loss, con_pg_loss, loss, entropy_loss, target_reward_mean, target_steps, ): # fmt:off self.writer.add_scalar( f"Target{target_name}/value_loss", v_loss.item(), target_steps[thisT] ) self.writer.add_scalar( f"Target{target_name}/dis_policy_loss", dis_pg_loss.item(), target_steps[thisT] ) self.writer.add_scalar( f"Target{target_name}/con_policy_loss", con_pg_loss.item(), target_steps[thisT] ) self.writer.add_scalar( f"Target{target_name}/total_loss", loss.item(), target_steps[thisT] ) self.writer.add_scalar( f"Target{target_name}/entropy_loss", entropy_loss.item(), target_steps[thisT] ) self.writer.add_scalar( f"Target{target_name}/Reward", target_reward_mean, target_steps[thisT] ) self.writer.add_scalar( f"Target{target_name}/WinRatio", win_rounds[target_name] / total_rounds[target_name], target_steps[thisT], ) # fmt:on def add_global_scalar( self, total_reward_mean, learning_rate, total_steps, ): self.writer.add_scalar("GlobalCharts/TotalRewardMean", total_reward_mean, total_steps) self.writer.add_scalar("GlobalCharts/learning_rate", learning_rate, total_steps) def add_win_ratio(self, target_name, target_steps): self.writer.add_scalar( f"Target{target_name}/WinRatio", win_rounds[target_name] / total_rounds[target_name], target_steps, )