-
Notifications
You must be signed in to change notification settings - Fork 5
/
policy.py
52 lines (48 loc) · 1.89 KB
/
policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import numpy as np
from pyglet.window import key
# individual agent policy
class Policy(object):
def __init__(self):
pass
def action(self, obs):
raise NotImplementedError()
# interactive policy based on keyboard input
# hard-coded to deal only with movement, not communication
class InteractivePolicy(Policy):
def __init__(self, env, agent_index):
super(InteractivePolicy, self).__init__()
self.env = env
# hard-coded keyboard events
self.move = [False for i in range(4)]
self.comm = [False for i in range(env.world.dim_c)]
# register keyboard events with this environment's window
env.viewers[agent_index].window.on_key_press = self.key_press
env.viewers[agent_index].window.on_key_release = self.key_release
def action(self, obs):
# ignore observation and just act based on keyboard events
if self.env.discrete_action_input:
u = 0
if self.move[0]: u = 1
if self.move[1]: u = 2
if self.move[2]: u = 4
if self.move[3]: u = 3
else:
u = np.zeros(5) # 5-d because of no-move action
if self.move[0]: u[1] += 1.0
if self.move[1]: u[2] += 1.0
if self.move[3]: u[3] += 1.0
if self.move[2]: u[4] += 1.0
if True not in self.move:
u[0] += 1.0
return np.concatenate([u, np.zeros(self.env.world.dim_c)])
# keyboard event callbacks
def key_press(self, k, mod):
if k==key.LEFT: self.move[0] = True
if k==key.RIGHT: self.move[1] = True
if k==key.UP: self.move[2] = True
if k==key.DOWN: self.move[3] = True
def key_release(self, k, mod):
if k==key.LEFT: self.move[0] = False
if k==key.RIGHT: self.move[1] = False
if k==key.UP: self.move[2] = False
if k==key.DOWN: self.move[3] = False