-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #57 from worldcoin/yichen/fragile-bit-masking
Yichen/fragile bit masking
- Loading branch information
Showing
22 changed files
with
136 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
95 changes: 65 additions & 30 deletions
95
src/iris/nodes/iris_response_refinement/fragile_bits_refinement.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,75 +1,110 @@ | ||
from typing import Literal, Tuple | ||
from enum import Enum | ||
from typing import List, Tuple | ||
|
||
import numpy as np | ||
from pydantic import confloat | ||
|
||
from iris.callbacks.callback_interface import Callback | ||
from iris.io.class_configs import Algorithm | ||
from iris.io.dataclasses import IrisFilterResponse | ||
|
||
|
||
class FragileType(str, Enum): | ||
"""Makes wrapper for params.""" | ||
|
||
cartesian = "cartesian" | ||
polar = "polar" | ||
|
||
|
||
class FragileBitRefinement(Algorithm): | ||
"""Refining mask by masking out fragile bits. | ||
"""Calculate fragile bits for mask. | ||
Algorithm: | ||
Thresholding by the given parameter value_threshold at each bit, set the corresponding mask response to 0 if iris response is below the threshold. | ||
Thresholding by the given parameter value_threshold at each bit, set the corresponding mask response to 0 if iris response is outside the thresholds. | ||
""" | ||
|
||
class Parameters(Algorithm.Parameters): | ||
"""RegularProbeSchema parameters.""" | ||
|
||
value_threshold: Tuple[confloat(ge=0), confloat(ge=0)] | ||
fragile_type: Literal["cartesian", "polar"] | ||
value_threshold: Tuple[confloat(ge=0), confloat(ge=0), confloat(ge=0)] | ||
fragile_type: FragileType | ||
|
||
__parameters_type__ = Parameters | ||
|
||
def __init__( | ||
self, | ||
value_threshold: Tuple[confloat(ge=0), confloat(ge=0)], | ||
fragile_type: Literal["cartesian", "polar"] = "polar", | ||
value_threshold: Tuple[confloat(ge=0), confloat(ge=0), confloat(ge=0)], | ||
fragile_type: FragileType = FragileType.polar, | ||
callbacks: List[Callback] = [], | ||
) -> None: | ||
"""Create Fragile Bit Refinement object. | ||
Args: | ||
value_threshold (Tuple[confloat(ge=0), confloat(ge=0)]): Thresholding iris response values. | ||
fragile_type (Literal["cartesian", "polar"], optional): The Fragile bits can be either | ||
value_threshold (Tuple[confloat(ge=0), confloat(ge=0), confloat(ge=0)]): Threshold at which response | ||
is strong enough such that the bit is valued as 1 in the mask. | ||
fragile_type (FragileType, optional): The Fragile bits can be either | ||
calculated in cartesian or polar coordinates. In the first, the values | ||
of value_threshold denote to x and y axis, in the case of polar coordinates, | ||
the values denote to radius and angle. Defaults to "polar". | ||
the values denote to radius and angle. Defaults to FragileType.polar. | ||
""" | ||
super().__init__(value_threshold=value_threshold, fragile_type=fragile_type) | ||
super().__init__(value_threshold=value_threshold, fragile_type=fragile_type, callbacks=callbacks) | ||
|
||
def run(self, iris_filter_response: IrisFilterResponse) -> IrisFilterResponse: | ||
def run(self, response: IrisFilterResponse) -> IrisFilterResponse: | ||
"""Generate refined IrisFilterResponse. | ||
Args: | ||
iris_filter_response (IrisFilterResponse): Filter bank response. | ||
response (IrisFilterResponse): Filter bank response | ||
Returns: | ||
IrisFilterResponse: Filter bank response. | ||
""" | ||
fragile_masks = [] | ||
for iris_response, iris_mask in zip(iris_filter_response.iris_responses, iris_filter_response.mask_responses): | ||
if self.params.fragile_type == "cartesian": | ||
mask_value_real = np.abs(np.real(iris_response)) >= self.params.value_threshold[0] | ||
mask_value_imaginary = np.abs(np.imag(iris_response)) >= self.params.value_threshold[1] | ||
mask_value = mask_value_real * mask_value_imaginary | ||
|
||
if self.params.fragile_type == "polar": | ||
for iris_response, iris_mask in zip(response.iris_responses, response.mask_responses): | ||
if self.params.fragile_type == FragileType.cartesian: | ||
mask_value_real = ( | ||
np.logical_and( | ||
np.abs(iris_response.real) >= self.params.value_threshold[0], | ||
np.abs(iris_response.real) <= self.params.value_threshold[1], | ||
) | ||
* iris_mask.real | ||
) | ||
mask_value_imag = ( | ||
np.logical_and( | ||
np.abs(iris_response.imag) >= self.params.value_threshold[0], | ||
np.abs(iris_response.imag) <= self.params.value_threshold[2], | ||
) | ||
* iris_mask.imag | ||
) | ||
|
||
if self.params.fragile_type == FragileType.polar: | ||
# transform from cartesian to polar system | ||
# radius | ||
iris_response_r = np.abs(iris_response) | ||
# angle | ||
iris_response_phi = np.angle(iris_response) | ||
|
||
mask_value_r = iris_response_r >= self.params.value_threshold[0] | ||
|
||
cos_mask = np.abs(np.cos(iris_response_phi)) <= np.abs(np.cos(self.params.value_threshold[1])) | ||
sine_mask = np.abs(np.sin(iris_response_phi)) <= np.abs(np.cos(self.params.value_threshold[1])) | ||
mask_value_phi = cos_mask * sine_mask | ||
mask_value = mask_value_r * mask_value_phi | ||
|
||
mask_value = mask_value * iris_mask | ||
# min radius | ||
mask_value_r = np.logical_and( | ||
iris_response_r >= self.params.value_threshold[0], iris_response_r <= self.params.value_threshold[1] | ||
) | ||
# min angle away from the coordinate lines | ||
|
||
# cosine requirement: makes sure that angle is different enough from x-axis | ||
cos_mask = np.abs(np.cos(iris_response_phi)) <= np.abs(np.cos(self.params.value_threshold[2])) | ||
# sine requirement: makes sure that angle is different enough from y-axis | ||
sine_mask = np.abs(np.sin(iris_response_phi)) <= np.abs(np.cos(self.params.value_threshold[2])) | ||
# combine | ||
mask_value_real = mask_value_r * sine_mask * iris_mask.real | ||
# combine with radius | ||
mask_value_imag = mask_value_r * cos_mask * iris_mask.imag | ||
|
||
# combine with mask for response | ||
mask_value = mask_value_real + 1j * mask_value_imag | ||
fragile_masks.append(mask_value) | ||
|
||
return IrisFilterResponse( | ||
iris_responses=iris_filter_response.iris_responses, | ||
iris_responses=response.iris_responses, | ||
mask_responses=fragile_masks, | ||
iris_code_version=iris_filter_response.iris_code_version, | ||
iris_code_version=response.iris_code_version, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file modified
BIN
-30 Bytes
(100%)
tests/e2e_tests/nodes/encoder/mocks/iris_encoder/iris_response.pickle
Binary file not shown.
Binary file modified
BIN
+30 Bytes
(100%)
tests/e2e_tests/nodes/iris_response/mocks/conv_filter_bank/e2e_expected_result.pickle
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file modified
BIN
+7.49 KB
(120%)
...es/iris_response_refinement/mocks/fragile_bits/artificial_iris_responses_cartesian.pickle
Binary file not shown.
Binary file modified
BIN
+4.52 KB
(140%)
.../nodes/iris_response_refinement/mocks/fragile_bits/artificial_iris_responses_polar.pickle
Binary file not shown.
Binary file modified
BIN
-32 Bytes
(100%)
...inement/mocks/fragile_bits/artificial_mask_responses_cartesian_expected_refinement.pickle
Binary file not shown.
Binary file modified
BIN
+4.21 KB
(630%)
..._refinement/mocks/fragile_bits/artificial_mask_responses_polar_expected_refinement.pickle
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file modified
BIN
-127 KB
(98%)
tests/e2e_tests/orchestration/mocks/expected_iris_pipeline_debug_output.pickle
Binary file not shown.
Binary file modified
BIN
-320 KB
(1.9%)
tests/e2e_tests/orchestration/mocks/expected_iris_pipeline_orb_output.pickle
Binary file not shown.
Binary file modified
BIN
-320 KB
(9.4%)
tests/e2e_tests/orchestration/mocks/expected_iris_pipeline_simple_output.pickle
Binary file not shown.
Binary file modified
BIN
-123 KB
(98%)
tests/e2e_tests/orchestration/mocks/mock_iris_pipeline_call_trace.pickle
Binary file not shown.
Binary file modified
BIN
+1.23 MB
(120%)
tests/e2e_tests/pipelines/mocks/outputs/expected_iris_debug_pipeline_output.pickle
Binary file not shown.
Binary file modified
BIN
+0 Bytes
(100%)
tests/e2e_tests/pipelines/mocks/outputs/expected_iris_orb_pipeline_output.pickle
Binary file not shown.
21 changes: 12 additions & 9 deletions
21
tests/unit_tests/nodes/iris_response_refinement/test_fragile_bits_refinement.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,25 @@ | ||
from typing import Literal, Tuple | ||
from typing import Tuple | ||
|
||
import pytest | ||
from pydantic import ValidationError, confloat | ||
|
||
from iris.nodes.iris_response_refinement.fragile_bits_refinement import FragileBitRefinement | ||
from iris.nodes.iris_response_refinement.fragile_bits_refinement import FragileBitRefinement, FragileType | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"value_threshold,fragile_type", | ||
# given | ||
"value_threshold, fragile_type", | ||
[ | ||
pytest.param([-0.6, -0.3], "cartesian"), | ||
pytest.param([-0.2, -0.5], "polar"), | ||
pytest.param([0, 0], "elliptical"), | ||
pytest.param([-0.6, 0, -0.3], FragileType.cartesian), | ||
pytest.param([-0.2, -0.5, 1], FragileType.polar), | ||
pytest.param([0, 0, 1], "elliptical"), | ||
], | ||
ids=["error_threshold_cartesian", "error_threshold_polar", "error_fragile_type"], | ||
ids=["error_theshold_cartesian", "error_theshold_polar", "error_fragile_type"], | ||
) | ||
def test_iris_encoder_threshold_raises_an_exception( | ||
value_threshold: Tuple[confloat(ge=0), confloat(ge=0)], fragile_type: Literal["cartesian", "polar"] | ||
value_threshold: Tuple[confloat(ge=0), confloat(ge=0), confloat(ge=0)], | ||
fragile_type: FragileType, | ||
) -> None: | ||
with pytest.raises(ValidationError): | ||
# when | ||
with pytest.raises((ValidationError)): | ||
_ = FragileBitRefinement(value_threshold, fragile_type) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters