Skip to content

Commit

Permalink
haar matrix, cleanest tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fmndantas committed Jan 30, 2020
1 parent b6f9f66 commit 76b2ba3
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 30 deletions.
8 changes: 4 additions & 4 deletions mcw2_4.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-1
10
111
-1.9
1
1
1
-1
9 changes: 5 additions & 4 deletions wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ def _allocate_mcw(self, message_length):
if self.mcw is None:
self.mcw = np.zeros((self.m, message_length + self.mg - self.m))

def _displace_mcw(self, pad):
self.mcw = np.roll(self.mcw, pad)
def check_mcw_existence(self):
if self.mcw is None:
raise ValueError("MCW is not known")

def _get_encoded_model(self):
self.check_mcw_existence()
return np.zeros((self.mcw.shape[1],))

def mcw_from_coefficients(self, file, message_length):
Expand All @@ -55,8 +57,7 @@ def encode(self, message):
return encoded

def _get_decoded_model(self):
if self.mcw is None:
raise ValueError("MCW is not known")
self.check_mcw_existence()
return np.zeros((self.mcw.shape[1] - self.mg + self.m,))

def decode(self, encoded):
Expand Down
49 changes: 27 additions & 22 deletions wavelet_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ def test_vstack(self):
ans = np.array([[0, 1], [2, 3]])
self.assertTrue(np.array_equal(ret, ans))

def test_elementwise_multiplication(self):
arr1 = np.array([10, -10, 1, -1, 1, -1])
arr2 = np.array([[1, 1, 0, 0, 0, 0], [1, -1, 0, 0, 0, 0]])
ans = np.array([[10, -10, 0, 0, 0, 0], [10, 10, 0, 0, 0, 0]])
self.assertTrue(np.array_equal(ans, arr1 * arr2))


class WaveletShortTests(unittest.TestCase):
def setUp(self):
Expand All @@ -25,8 +31,8 @@ def test_mg(self):

def test_get_raw_4_lines(self):
l0, l1 = list(self.w._get_raw_lines(self.file))
self.assertTrue(np.array_equal(l0, np.array([-1, 10])))
self.assertTrue(np.array_equal(l1, np.array([111, -1.9])))
self.assertTrue(np.array_equal(l0, np.array([1, 1])))
self.assertTrue(np.array_equal(l1, np.array([1, -1])))

def test_allocate_matrix(self):
self.assertEqual(self.w.A, None)
Expand All @@ -36,7 +42,7 @@ def test_allocate_matrix(self):

def test_set_a_coefficients(self):
self.w = Wavelet(m=2, g=1)
ans = np.array([[-1, 10], [111, -1.9]])
ans = np.array([[1, 1], [1, -1]])
self.w._set_a_coefficients(self.file)
self.assertTrue(np.array_equal(ans, self.w.A))

Expand All @@ -54,33 +60,26 @@ def test_mcw_from_coefficients(self):
self.assertEqual(message_length + self.w.mg - self.w.m, mcw_columns)
self.assertTrue(np.array_equal(self.w.mcw[:self.w.m, :self.w.mg], self.w.A))

def test_mcw_displacement(self):
message_length = 6
self.w.mcw_from_coefficients(self.file, message_length)
original_mcw = np.copy(self.w.mcw)
# Displace by one m
self.w._displace_mcw(1)
ans = np.array([
[0, -1, 10, 0, 0, 0],
[0, 111, -1.9, 0, 0, 0],
])
self.assertTrue(np.array_equal(ans, self.w.mcw))

self.w._displace_mcw(-1)
self.assertTrue(np.array_equal(original_mcw, self.w.mcw))

def test_get_encoded_model(self):
message_length = 4
self.w.mcw_from_coefficients(self.file, message_length)
encoded_output = self.w._get_encoded_model()
self.assertTrue(np.array_equal(np.zeros(4, ), encoded_output))
self.assertTrue(np.array_equal(np.zeros(message_length, ), encoded_output))

def test_encoding(self):
self.w.mcw_from_coefficients(self.file, np.size(self.message))
message = np.array([-1, -1, -1, -1, -1, -1])
self.w.mcw_from_coefficients(self.file, np.size(message))

ans = np.matmul(self.message, np.array([[-1, 10, 0, 0], [111, -1.9, 0, 0], [0, 0, -1, 10], [0, 0, 111, -1.9]]))
ans = np.matmul(message, np.array([
[1, 1, 0, 0, 0, 0],
[1, -1, 0, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 1, -1, 0, 0],
[0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 1, -1]
]))

encoded_output = self.w.encode(self.message)
encoded_output = self.w.encode(message)
self.assertTrue(np.array_equal(ans, encoded_output))

def test_get_decoded_model(self):
Expand Down Expand Up @@ -146,7 +145,13 @@ def test_huge_encoding_and_decoding(self):
message = np.array([np.random.choice([-1, 1]) for _ in range(50000)])
self.w.mcw_from_coefficients(self.file, np.size(message))
encoded = self.w.encode(message)
encoded_elements = set(encoded)
ans_elements = [-self.w.mg + k for k in range(2, 2 * self.w.mg + 1)]
for el in encoded_elements:
self.assertTrue(el in ans_elements)
decoded = self.w.decode(encoded)
self.assertEqual(len(set(np.abs(decoded))), 1)
self.assertEqual(set(np.abs(decoded)), {self.w.mg})
decoded[decoded >= 0] = 1
decoded[decoded < 0] = -1
self.assertTrue(np.array_equal(message, decoded))
Expand Down

0 comments on commit 76b2ba3

Please sign in to comment.