diff --git a/mcw2_4.txt b/mcw2_4.txt index f6da52a..3744bbc 100644 --- a/mcw2_4.txt +++ b/mcw2_4.txt @@ -1,4 +1,4 @@ --1 -10 -111 --1.9 \ No newline at end of file +1 +1 +1 +-1 \ No newline at end of file diff --git a/wavelet.py b/wavelet.py index 3c9db19..21f6385 100644 --- a/wavelet.py +++ b/wavelet.py @@ -35,10 +35,12 @@ def _allocate_mcw(self, message_length): if self.mcw is None: self.mcw = np.zeros((self.m, message_length + self.mg - self.m)) - def _displace_mcw(self, pad): - self.mcw = np.roll(self.mcw, pad) + def check_mcw_existence(self): + if self.mcw is None: + raise ValueError("MCW is not known") def _get_encoded_model(self): + self.check_mcw_existence() return np.zeros((self.mcw.shape[1],)) def mcw_from_coefficients(self, file, message_length): @@ -55,8 +57,7 @@ def encode(self, message): return encoded def _get_decoded_model(self): - if self.mcw is None: - raise ValueError("MCW is not known") + self.check_mcw_existence() return np.zeros((self.mcw.shape[1] - self.mg + self.m,)) def decode(self, encoded): diff --git a/wavelet_tests.py b/wavelet_tests.py index 1b52251..76f8183 100644 --- a/wavelet_tests.py +++ b/wavelet_tests.py @@ -13,6 +13,12 @@ def test_vstack(self): ans = np.array([[0, 1], [2, 3]]) self.assertTrue(np.array_equal(ret, ans)) + def test_elementwise_multiplication(self): + arr1 = np.array([10, -10, 1, -1, 1, -1]) + arr2 = np.array([[1, 1, 0, 0, 0, 0], [1, -1, 0, 0, 0, 0]]) + ans = np.array([[10, -10, 0, 0, 0, 0], [10, 10, 0, 0, 0, 0]]) + self.assertTrue(np.array_equal(ans, arr1 * arr2)) + class WaveletShortTests(unittest.TestCase): def setUp(self): @@ -25,8 +31,8 @@ def test_mg(self): def test_get_raw_4_lines(self): l0, l1 = list(self.w._get_raw_lines(self.file)) - self.assertTrue(np.array_equal(l0, np.array([-1, 10]))) - self.assertTrue(np.array_equal(l1, np.array([111, -1.9]))) + self.assertTrue(np.array_equal(l0, np.array([1, 1]))) + self.assertTrue(np.array_equal(l1, np.array([1, -1]))) def test_allocate_matrix(self): self.assertEqual(self.w.A, None) @@ -36,7 +42,7 @@ def test_allocate_matrix(self): def test_set_a_coefficients(self): self.w = Wavelet(m=2, g=1) - ans = np.array([[-1, 10], [111, -1.9]]) + ans = np.array([[1, 1], [1, -1]]) self.w._set_a_coefficients(self.file) self.assertTrue(np.array_equal(ans, self.w.A)) @@ -54,33 +60,26 @@ def test_mcw_from_coefficients(self): self.assertEqual(message_length + self.w.mg - self.w.m, mcw_columns) self.assertTrue(np.array_equal(self.w.mcw[:self.w.m, :self.w.mg], self.w.A)) - def test_mcw_displacement(self): - message_length = 6 - self.w.mcw_from_coefficients(self.file, message_length) - original_mcw = np.copy(self.w.mcw) - # Displace by one m - self.w._displace_mcw(1) - ans = np.array([ - [0, -1, 10, 0, 0, 0], - [0, 111, -1.9, 0, 0, 0], - ]) - self.assertTrue(np.array_equal(ans, self.w.mcw)) - - self.w._displace_mcw(-1) - self.assertTrue(np.array_equal(original_mcw, self.w.mcw)) - def test_get_encoded_model(self): message_length = 4 self.w.mcw_from_coefficients(self.file, message_length) encoded_output = self.w._get_encoded_model() - self.assertTrue(np.array_equal(np.zeros(4, ), encoded_output)) + self.assertTrue(np.array_equal(np.zeros(message_length, ), encoded_output)) def test_encoding(self): - self.w.mcw_from_coefficients(self.file, np.size(self.message)) + message = np.array([-1, -1, -1, -1, -1, -1]) + self.w.mcw_from_coefficients(self.file, np.size(message)) - ans = np.matmul(self.message, np.array([[-1, 10, 0, 0], [111, -1.9, 0, 0], [0, 0, -1, 10], [0, 0, 111, -1.9]])) + ans = np.matmul(message, np.array([ + [1, 1, 0, 0, 0, 0], + [1, -1, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0], + [0, 0, 1, -1, 0, 0], + [0, 0, 0, 0, 1, 1], + [0, 0, 0, 0, 1, -1] + ])) - encoded_output = self.w.encode(self.message) + encoded_output = self.w.encode(message) self.assertTrue(np.array_equal(ans, encoded_output)) def test_get_decoded_model(self): @@ -146,7 +145,13 @@ def test_huge_encoding_and_decoding(self): message = np.array([np.random.choice([-1, 1]) for _ in range(50000)]) self.w.mcw_from_coefficients(self.file, np.size(message)) encoded = self.w.encode(message) + encoded_elements = set(encoded) + ans_elements = [-self.w.mg + k for k in range(2, 2 * self.w.mg + 1)] + for el in encoded_elements: + self.assertTrue(el in ans_elements) decoded = self.w.decode(encoded) + self.assertEqual(len(set(np.abs(decoded))), 1) + self.assertEqual(set(np.abs(decoded)), {self.w.mg}) decoded[decoded >= 0] = 1 decoded[decoded < 0] = -1 self.assertTrue(np.array_equal(message, decoded))