forked from NTT123/a0-jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconnect_four_game.py
164 lines (135 loc) · 5.08 KB
/
connect_four_game.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""Connect-Four game mechanics"""
from typing import Tuple
import chex
import jax.numpy as jnp
import numpy as np
import pax
from env import Enviroment
from utils import select_tree
class Connect4WinChecker(pax.Module):
"""Check who won the game
We use a conv2d for scanning the whole board.
(We can do better by locating the recent move.)
Filters to scan for winning patterns:
1 0 0 0 1 1 1 1 1 0 0 0
1 0 0 0 0 0 0 0 0 1 0 0
1 0 0 0 0 0 0 0 0 0 1 0
1 0 0 0 0 0 0 0 0 0 0 1
0 0 0 1 0 0 0 0 0 0 0 1
0 0 0 1 0 0 0 0 0 0 1 0
0 0 0 1 0 0 0 0 0 1 0 0
0 0 0 1 1 1 1 1 1 0 0 0
"""
def __init__(self):
super().__init__()
conv = pax.Conv2D(1, 6, 4, padding="VALID")
weight = np.zeros((4, 4, 1, 6), dtype=np.float32)
weight[0, :, :, 0] = 1
weight[:, 0, :, 1] = 1
weight[-1, :, :, 2] = 1
weight[:, -1, :, 3] = 1
for i in range(4):
weight[i, i, :, 4] = 1
weight[i, 3 - i, :, 5] = 1
assert weight.shape == conv.weight.shape
self.conv = conv.replace(weight=weight)
def __call__(self, board):
board = board[None, :, :, None].astype(jnp.float32)
x = self.conv(board)
m = jnp.max(jnp.abs(x))
m1 = jnp.where(m == jnp.max(x), 1, -1)
return jnp.where(m == 4, m1, 0)
class Connect4Game(Enviroment):
"""Connect-Four game environment"""
board: chex.Array
who_play: chex.Array
terminated: chex.Array
winner: chex.Array
num_cols: int
num_rows: int
def __init__(self, num_cols: int = 7, num_rows: int = 6):
super().__init__()
self.winner_checker = Connect4WinChecker()
self.board = jnp.zeros((num_rows, num_cols), dtype=jnp.int32)
self.who_play = jnp.array(1, dtype=jnp.int32)
self.col_counts = jnp.zeros((num_cols,), dtype=jnp.int32)
self.terminated = jnp.array(0, dtype=jnp.bool_)
self.winner = jnp.array(0, dtype=jnp.int32)
self.num_cols = num_cols
self.num_rows = num_rows
self.reset()
def num_actions(self):
return self.num_cols
def invalid_actions(self) -> chex.Array:
return self.col_counts >= self.num_rows
def reset(self):
self.board = jnp.zeros((self.num_rows, self.num_cols), dtype=jnp.int32)
self.who_play = jnp.array(1, dtype=jnp.int32)
self.col_counts = jnp.zeros((self.num_cols,), dtype=jnp.int32)
self.terminated = jnp.array(0, dtype=jnp.bool_)
self.winner = jnp.array(0, dtype=jnp.int32)
@pax.pure
def step(self, action: chex.Array) -> Tuple["Connect4Game", chex.Array]:
"""One step of the game.
An invalid move will terminate the game with reward -1.
"""
row_idx = self.col_counts[action]
invalid_move = row_idx >= self.num_rows
board_ = self.board.at[row_idx, action].set(self.who_play)
self.board = select_tree(self.terminated, self.board, board_)
self.winner = self.winner_checker(self.board)
reward = self.winner * self.who_play
# increase column counter
self.col_counts = self.col_counts.at[action].set(self.col_counts[action] + 1)
self.who_play = -self.who_play
count = jnp.sum(self.col_counts)
self.terminated = jnp.logical_or(self.terminated, reward != 0)
self.terminated = jnp.logical_or(
self.terminated, count >= self.num_cols * self.num_rows
)
self.terminated = jnp.logical_or(self.terminated, invalid_move)
reward = jnp.where(invalid_move, -1.0, reward)
return self, reward
def render(self) -> None:
"""Render the game on screen."""
for col in range(self.num_cols):
print(col, end=" ")
print()
for row in reversed(range(self.num_rows)):
for col in range(self.num_cols):
if self.board[row, col].item() == 1:
print("X", end=" ")
elif self.board[row, col].item() == -1:
print("O", end=" ")
else:
print(".", end=" ")
print()
print()
def observation(self) -> chex.Array:
return self.board
def canonical_observation(self) -> chex.Array:
return self.board * self.who_play
def is_terminated(self):
return self.terminated
def max_num_steps(self) -> int:
return self.num_cols * self.num_rows
def symmetries(self, state, action_weights):
out = [(state, action_weights)]
out.append((np.flip(state, axis=1), np.flip(action_weights)))
return out
if __name__ == "__main__":
game = Connect4Game()
game.render()
game, reward = game.step(6)
game, reward = game.step(1)
game, reward = game.step(1)
game, reward = game.step(2)
game, reward = game.step(6)
game, reward = game.step(2)
game, reward = game.step(2)
game, reward = game.step(4)
game, reward = game.step(6)
game, reward = game.step(5)
game, reward = game.step(6)
game.render()
print("Reward", reward)