Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved battle_against API #657

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
4 changes: 2 additions & 2 deletions src/poke_env/player/gymnasium_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@
while self._keep_challenging:
opponent = self._get_opponent()
if isinstance(opponent, Player):
await self.agent.battle_against(opponent, 1)
await self.agent.battle_against(opponent, n_battles=1)

Check warning on line 468 in src/poke_env/player/gymnasium_api.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/gymnasium_api.py#L468

Added line #L468 was not covered by tests
else:
await self.agent.send_challenges(opponent, 1)
if callback and self.current_battle is not None:
Expand All @@ -474,7 +474,7 @@
for _ in range(n_challenges):
opponent = self._get_opponent()
if isinstance(opponent, Player):
await self.agent.battle_against(opponent, 1)
await self.agent.battle_against(opponent, n_battles=1)

Check warning on line 477 in src/poke_env/player/gymnasium_api.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/gymnasium_api.py#L477

Added line #L477 was not covered by tests
else:
await self.agent.send_challenges(opponent, 1)
if callback and self.current_battle is not None:
Expand Down
35 changes: 20 additions & 15 deletions src/poke_env/player/player.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""This module defines a base class for players.
"""

from __future__ import annotations

import asyncio
import random
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -699,28 +701,31 @@ async def _ladder(self, n_games: int):
perf_counter() - start_time,
)

async def battle_against(self, opponent: "Player", n_battles: int = 1):
"""Make the player play n_battles against opponent.
async def battle_against(self, *opponents: Player, n_battles: int = 1):
"""Make the player play n_battles against the given opponents.

This function is a wrapper around send_challenges and accept challenges.
This function is a wrapper around send_challenges and accept_challenges.

:param opponent: The opponent to play against.
:type opponent: Player
:param opponents: The opponents to play against.
:type opponents: Player
:param n_battles: The number of games to play. Defaults to 1.
:type n_battles: int
"""
await handle_threaded_coroutines(self._battle_against(opponent, n_battles))

async def _battle_against(self, opponent: "Player", n_battles: int):
await asyncio.gather(
self.send_challenges(
to_id_str(opponent.username),
n_battles,
to_wait=opponent.ps_client.logged_in,
),
opponent.accept_challenges(to_id_str(self.username), n_battles),
await handle_threaded_coroutines(
self._battle_against(*opponents, n_battles=n_battles)
)

async def _battle_against(self, *opponents: Player, n_battles: int):
for opponent in opponents:
await asyncio.gather(
self.send_challenges(
to_id_str(opponent.username),
n_battles,
to_wait=opponent.ps_client.logged_in,
),
opponent.accept_challenges(to_id_str(self.username), n_battles),
)

async def send_challenges(
self, opponent: str, n_challenges: int, to_wait: Optional[Event] = None
):
Expand Down
33 changes: 10 additions & 23 deletions src/poke_env/player/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Dict, List, Optional, Tuple

from poke_env.concurrency import POKE_LOOP
from poke_env.data import to_id_str
from poke_env.player.baselines import MaxBasePowerPlayer, SimpleHeuristicsPlayer
from poke_env.player.player import Player
from poke_env.player.random_player import RandomPlayer
Expand All @@ -31,29 +30,17 @@
players: List[Player], n_challenges: int
) -> Dict[str, Dict[str, Optional[float]]]:
results: Dict[str, Dict[str, Optional[float]]] = {
p_1.username: {p_2.username: None for p_2 in players} for p_1 in players
p1.username: {p2.username: None for p2 in players} for p1 in players
}
for i, p_1 in enumerate(players):
for j, p_2 in enumerate(players):
for i, p1 in enumerate(players):
for j, p2 in enumerate(players):
if j <= i:
continue
await asyncio.gather(
p_1.send_challenges(
opponent=to_id_str(p_2.username),
n_challenges=n_challenges,
to_wait=p_2.ps_client.logged_in,
),
p_2.accept_challenges(
opponent=to_id_str(p_1.username),
n_challenges=n_challenges,
packed_team=p_2.next_team,
),
)
results[p_1.username][p_2.username] = p_1.win_rate
results[p_2.username][p_1.username] = p_2.win_rate

p_1.reset_battles()
p_2.reset_battles()
await p1.battle_against(p2, n_battles=n_challenges)
results[p1.username][p2.username] = p1.win_rate
results[p2.username][p1.username] = p2.win_rate
p1.reset_battles()
p2.reset_battles()
return results


Expand Down Expand Up @@ -170,7 +157,7 @@
baselines = [p(max_concurrent_battles=n_battles) for p in _EVALUATION_RATINGS] # type: ignore

for p in baselines:
await p.battle_against(player, n_placement_battles)
await p.battle_against(player, n_battles=n_placement_battles)

Check warning on line 160 in src/poke_env/player/utils.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/utils.py#L160

Added line #L160 was not covered by tests

# Select the best opponent for evaluation
best_opp = min(
Expand All @@ -179,7 +166,7 @@

# Performing the main evaluation
remaining_battles = n_battles - len(_EVALUATION_RATINGS) * n_placement_battles
await best_opp.battle_against(player, remaining_battles)
await best_opp.battle_against(player, n_battles=remaining_battles)

Check warning on line 169 in src/poke_env/player/utils.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/utils.py#L169

Added line #L169 was not covered by tests

return _estimate_strength_from_results(
best_opp.n_finished_battles,
Expand Down
24 changes: 10 additions & 14 deletions unit_tests/player/test_player_misc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from collections import namedtuple
from unittest.mock import MagicMock, patch

import pytest

from poke_env import AccountConfiguration
from poke_env.environment import AbstractBattle, Battle, DoubleBattle, Move, PokemonType
from poke_env.player import BattleOrder, Player, RandomPlayer, cross_evaluate
from poke_env.stats import _raw_hp, _raw_stat
Expand All @@ -13,7 +13,13 @@ def choose_move(self, battle: AbstractBattle) -> BattleOrder:
return self.choose_random_move(battle)


class FixedWinRatePlayer:
class FixedWinRatePlayer(Player):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def choose_move(self, battle: AbstractBattle) -> BattleOrder:
return self.choose_random_move(battle)

async def accept_challenges(self, *args, **kwargs):
pass

Expand All @@ -27,14 +33,6 @@ def reset_battles(self):
def win_rate(self):
return 0.5

@property
def next_team(self):
return None

@property
def ps_client(self):
return namedtuple("PSClient", "logged_in")(logged_in=None)


def test_player_default_order():
assert SimplePlayer().choose_default_move().message == "/choose default"
Expand Down Expand Up @@ -208,11 +206,9 @@ async def test_basic_challenge_handling():

@pytest.mark.asyncio
async def test_cross_evaluate():
p1 = FixedWinRatePlayer()
p2 = FixedWinRatePlayer()
p1 = FixedWinRatePlayer(account_configuration=AccountConfiguration("p1", None))
p2 = FixedWinRatePlayer(account_configuration=AccountConfiguration("p2", None))

p1.username = "p1"
p2.username = "p2"
cross_evaluation = await cross_evaluate([p1, p2], 10)
assert cross_evaluation == {
"p1": {"p1": None, "p2": 0.5},
Expand Down
Loading