Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved battle_against API #657

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
4 changes: 2 additions & 2 deletions src/poke_env/player/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@
while self._keep_challenging:
opponent = self._get_opponent()
if isinstance(opponent, Player):
await self.agent.battle_against(opponent, 1)
await self.agent.battle_against(opponent, n_battles=1)

Check warning on line 461 in src/poke_env/player/openai_api.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/openai_api.py#L461

Added line #L461 was not covered by tests
else:
await self.agent.send_challenges(opponent, 1)
if callback and self.current_battle is not None:
Expand All @@ -467,7 +467,7 @@
for _ in range(n_challenges):
opponent = self._get_opponent()
if isinstance(opponent, Player):
await self.agent.battle_against(opponent, 1)
await self.agent.battle_against(opponent, n_battles=1)

Check warning on line 470 in src/poke_env/player/openai_api.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/openai_api.py#L470

Added line #L470 was not covered by tests
else:
await self.agent.send_challenges(opponent, 1)
if callback and self.current_battle is not None:
Expand Down
41 changes: 29 additions & 12 deletions src/poke_env/player/player.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""This module defines a base class for players.
"""

from __future__ import annotations

import asyncio
import random
from abc import ABC, abstractmethod
from asyncio import Condition, Event, Queue, Semaphore
from logging import Logger
from time import perf_counter
from typing import Any, Awaitable, Dict, List, Optional, Union
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union

import orjson

Expand Down Expand Up @@ -692,7 +694,9 @@ async def _ladder(self, n_games: int):
perf_counter() - start_time,
)

async def battle_against(self, opponent: "Player", n_battles: int = 1):
async def battle_against(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the docstring?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this make n_battles a required named argument? If so i'd prefer opponent to be either a player or a list of players, and keep n_battles from requiring to be named.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated docstring: 411a469
No need to worry, n_battles is still an optional argument - I'm pretty sure the only requirement is that all the opponents are listed before n_battles is specified, and you need to explicitly say "n_battles=..." in order to specify the integer from the Player objects.

self, *opponents: Player, n_battles: int = 1
) -> Dict[str, Tuple[float, float]]:
"""Make the player play n_battles against opponent.

This function is a wrapper around send_challenges and accept challenges.
Expand All @@ -702,17 +706,30 @@ async def battle_against(self, opponent: "Player", n_battles: int = 1):
:param n_battles: The number of games to play. Defaults to 1.
:type n_battles: int
"""
await handle_threaded_coroutines(self._battle_against(opponent, n_battles))

async def _battle_against(self, opponent: "Player", n_battles: int):
await asyncio.gather(
self.send_challenges(
to_id_str(opponent.username),
n_battles,
to_wait=opponent.ps_client.logged_in,
),
opponent.accept_challenges(to_id_str(self.username), n_battles),
result = await handle_threaded_coroutines(
self._battle_against(*opponents, n_battles=n_battles)
)
return result

async def _battle_against(
self, *opponents: Player, n_battles: int
) -> Dict[str, Tuple[float, float]]:
results: Dict[str, Tuple[float, float]] = {}
for opponent in opponents:
await asyncio.gather(
self.send_challenges(
to_id_str(opponent.username),
n_battles,
to_wait=opponent.ps_client.logged_in,
),
opponent.accept_challenges(
to_id_str(self.username), n_battles, opponent.next_team
),
)
results[opponent.username] = (self.win_rate, opponent.win_rate)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should bot be part of this method - reporting results is unrelated to starting battles. Similarly, resetting the stored battles after the battles are done would be counterintuitive and make for a bad API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Here you go: f17c415

self.reset_battles()
opponent.reset_battles()
return results

async def send_challenges(
self, opponent: str, n_challenges: int, to_wait: Optional[Event] = None
Expand Down
34 changes: 9 additions & 25 deletions src/poke_env/player/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Dict, List, Optional, Tuple

from poke_env.concurrency import POKE_LOOP
from poke_env.data import to_id_str
from poke_env.player.baselines import MaxBasePowerPlayer, SimpleHeuristicsPlayer
from poke_env.player.player import Player
from poke_env.player.random_player import RandomPlayer
Expand All @@ -31,29 +30,14 @@
players: List[Player], n_challenges: int
) -> Dict[str, Dict[str, Optional[float]]]:
results: Dict[str, Dict[str, Optional[float]]] = {
p_1.username: {p_2.username: None for p_2 in players} for p_1 in players
p1.username: {p2.username: None for p2 in players} for p1 in players
}
for i, p_1 in enumerate(players):
for j, p_2 in enumerate(players):
if j <= i:
continue
await asyncio.gather(
p_1.send_challenges(
opponent=to_id_str(p_2.username),
n_challenges=n_challenges,
to_wait=p_2.ps_client.logged_in,
),
p_2.accept_challenges(
opponent=to_id_str(p_1.username),
n_challenges=n_challenges,
packed_team=p_2.next_team,
),
)
results[p_1.username][p_2.username] = p_1.win_rate
results[p_2.username][p_1.username] = p_2.win_rate

p_1.reset_battles()
p_2.reset_battles()
for i, p1 in enumerate(players):
results[p1.username][p1.username] = None
r = await p1.battle_against(*players[i + 1 :], n_battles=n_challenges)
for p2, (win_rate, lose_rate) in r.items():
results[p1.username][p2] = win_rate
results[p2][p1.username] = lose_rate
return results


Expand Down Expand Up @@ -170,7 +154,7 @@
baselines = [p(max_concurrent_battles=n_battles) for p in _EVALUATION_RATINGS] # type: ignore

for p in baselines:
await p.battle_against(player, n_placement_battles)
await p.battle_against(player, n_battles=n_placement_battles)

Check warning on line 157 in src/poke_env/player/utils.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/utils.py#L157

Added line #L157 was not covered by tests

# Select the best opponent for evaluation
best_opp = min(
Expand All @@ -179,7 +163,7 @@

# Performing the main evaluation
remaining_battles = n_battles - len(_EVALUATION_RATINGS) * n_placement_battles
await best_opp.battle_against(player, remaining_battles)
await best_opp.battle_against(player, n_battles=remaining_battles)

Check warning on line 166 in src/poke_env/player/utils.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/utils.py#L166

Added line #L166 was not covered by tests

return _estimate_strength_from_results(
best_opp.n_finished_battles,
Expand Down
20 changes: 10 additions & 10 deletions unit_tests/player/test_player_misc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from collections import namedtuple
from unittest.mock import MagicMock, patch

import pytest

from poke_env import AccountConfiguration
from poke_env.environment import AbstractBattle, Battle, DoubleBattle, Move, PokemonType
from poke_env.player import BattleOrder, Player, RandomPlayer, cross_evaluate
from poke_env.stats import _raw_hp, _raw_stat
Expand All @@ -13,7 +13,13 @@ def choose_move(self, battle: AbstractBattle) -> BattleOrder:
return self.choose_random_move(battle)


class FixedWinRatePlayer:
class FixedWinRatePlayer(Player):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def choose_move(self, battle: AbstractBattle) -> BattleOrder:
return self.choose_random_move(battle)

async def accept_challenges(self, *args, **kwargs):
pass

Expand All @@ -31,10 +37,6 @@ def win_rate(self):
def next_team(self):
return None

@property
def ps_client(self):
return namedtuple("PSClient", "logged_in")(logged_in=None)


def test_player_default_order():
assert SimplePlayer().choose_default_move().message == "/choose default"
Expand Down Expand Up @@ -208,11 +210,9 @@ async def test_basic_challenge_handling():

@pytest.mark.asyncio
async def test_cross_evaluate():
p1 = FixedWinRatePlayer()
p2 = FixedWinRatePlayer()
p1 = FixedWinRatePlayer(account_configuration=AccountConfiguration("p1", None))
p2 = FixedWinRatePlayer(account_configuration=AccountConfiguration("p2", None))

p1.username = "p1"
p2.username = "p2"
cross_evaluation = await cross_evaluate([p1, p2], 10)
assert cross_evaluation == {
"p1": {"p1": None, "p2": 0.5},
Expand Down
Loading