-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcp-index_sat_2mols.py
161 lines (143 loc) · 5.63 KB
/
cp-index_sat_2mols.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Created on: 6 Feb 2024
# Author: Oleg Zaikin
# E-mail: [email protected]
#
# For a given order n, produce a CNF that encodes the search for 2 MOLS
# (a pair of orthogonal Latin squares) of order n. The SAT encoding is CP-index
# proposed in
# Noah Rubin, Curtis Bright, Brett Stevens, Kevin Cheung. Integer and
# Constraint Programming Revisited for Mutually Orthogonal Latin Squares
# // In AAAI 2022.
# For the mentioned paper, the CP-index SAT encoding was implemented in C++ in
# https://github.com/noahrubin333/CP-IP
#=============================================================================
import sys
script = "cp-index_sat_2mols.py"
version = "0.0.1"
# Return clauses that encode the AtMostOne constraint:
def at_most_one_clauses(vars : list):
assert(len(vars) > 1)
assert(len(set(vars)) == len(vars))
res_clauses = []
for i in range(len(vars)):
for j in range(i+1, len(vars)):
res_clauses.append([-vars[i], -vars[j]])
return res_clauses
# Get clauses encoding the ExactlyOne constraint for a set of variables:
def exactly_one_clauses(vars : list):
assert(len(vars) > 1)
assert(len(set(vars)) == len(vars))
res_clauses = at_most_one_clauses(vars)
res_clauses.append(vars) # AtLeastOne constraint
return res_clauses
# Latin square constraints:
def latin_square_clauses(T : list, is_diag : bool):
ls_order = len(T)
assert(ls_order > 0)
res_clauses = []
# Constraints on rows, columns, and values are obligatory:
for i in range(ls_order):
for j in range(ls_order):
# Each square's cell contains exactly one value 0..n-1:
res_clauses += exactly_one_clauses([T[i][j][k] for k in range(ls_order)])
# Each value occurs exactly once in each row:
res_clauses += exactly_one_clauses([T[i][k][j] for k in range(ls_order)])
# Each value occurs exactly once in each column:
res_clauses += exactly_one_clauses([T[k][i][j] for k in range(ls_order)])
# Constraints on main diagonal and antidiagonal are optional:
if is_diag:
# Main diagonal:
for i in range(ls_order):
res_clauses += exactly_one_clauses([T[k][k][i] for k in range(ls_order)])
# Main antidiagonal:
for i in range(ls_order):
res_clauses += exactly_one_clauses([T[ls_order-k-1][k][i] for k in range(ls_order)])
return res_clauses
# Orthogonality is encoded via element indexing constraint Z_i[X_ij] = Yij.
# This constraint for a certain value in a Z's cell is that
# if cell X[i][j] has value k and Y[i][j] has value l,
# then in i-th row of Z value l is in column k.
# In other words,
# (X[i][j][k] & Y[i][j][l]) -> Z[i][k][l]
# (Y[i][j][l] & Z[i][k][l]) -> X[i][j][k]
# (X[i][j][k] & Z[i][k][l]) -> Y[i][j][l]
# This is encoded via three clauses:
# [-Y[i][j][l], -X[i][j][k], Z[i][k][l]]
# [-Y[i][j][l], -Z[i][k][l], X[i][j][k]]
# [-Z[i][k][l], -X[i][j][k], Y[i][j][l]]
# The AllDifferent constraint for Z's columns encures the orhtogonality,
# yet the AllDifferent for Z's rows redundant but may help a solver.
def orthogonality_clauses(X : list, Y : list, Z : list):
res_clauses = []
ls_order = len(X)
assert(ls_order > 0 and ls_order == len(Y) and ls_order == len(Z))
for i in range(ls_order):
for j in range(ls_order):
for k in range(ls_order):
for l in range(ls_order):
res_clauses.append([-X[i][j][k], -Y[i][j][l], Z[i][k][l]])
res_clauses.append([-Y[i][j][l], -Z[i][k][l], X[i][j][k]])
res_clauses.append([-X[i][j][k], -Z[i][k][l], Y[i][j][l]])
return res_clauses
### Main function:
if len(sys.argv) < 3:
print('Usage : ls-order cnf-name [--diag]')
print(' ls-order : order of Latin squares')
print(' cnf-name : name of output CNF')
print(' --diag : if given, then both Latin sqaures are diagonal')
exit(1)
print('Script ' + script + ' of version ' + version + ' is running')
if sys.argv[1] == '-v':
exit(1)
# Parse input parameters:
ls_order = int(sys.argv[1])
cnf_name = sys.argv[2]
is_diag = False
if len(sys.argv) > 3 and '--diag' in sys.argv[3:]:
is_diag = True
print('ls_order : ' + str(ls_order))
print('cnf_name : ' + cnf_name)
print('is_diag : ' + str(is_diag))
vars_num = 0
# The first ls_order^3 variables in the CNF encode Latin square X:
X = [[[0 for k in range(ls_order)] for j in range(ls_order)] for i in range(ls_order)]
# Variables for Latin squares X:
for i in range(ls_order):
for j in range(ls_order):
for k in range(ls_order):
vars_num += 1
X[i][j][k] = vars_num
# Then next ls_order^3 variables encode Latin square Y:
Y = [[[0 for k in range(ls_order)] for j in range(ls_order)] for i in range(ls_order)]
for i in range(ls_order):
for j in range(ls_order):
for k in range(ls_order):
vars_num += 1
Y[i][j][k] = vars_num
# And ls_order^3 for Latin square Z that ensures orthogonality for X and Y:
Z = [[[0 for k in range(ls_order)] for j in range(ls_order)] for i in range(ls_order)]
for i in range(ls_order):
for j in range(ls_order):
for k in range(ls_order):
vars_num += 1
Z[i][j][k] = vars_num
clauses = []
# Latin square constraints for all Latin squares:
clauses += latin_square_clauses(X, is_diag)
clauses += latin_square_clauses(Y, is_diag)
# Diagonal constraints are not needed for Latin square Z:
clauses += latin_square_clauses(Z, False)
ls_clauses_num = len(clauses)
print(str(ls_clauses_num) + ' clauses encode Latin squares constraints')
# Orthogonality constraints:
clauses += orthogonality_clauses(X, Y, Z)
print(str(len(clauses) - ls_clauses_num) + ' clauses encode orthogonality constraints')
print(str(len(clauses)) + ' clauses in total')
print('Writing to file ' + cnf_name + ' ...')
with open(cnf_name, 'w') as ofile:
ofile.write('p cnf ' + str(vars_num) + ' ' + str(len(clauses)) + '\n')
for cla in clauses:
s = ''
for lit in cla:
s += str(lit) + ' '
ofile.write(s + '0\n')