Skip to content

Commit

Permalink
Merge pull request #57 from worldcoin/yichen/fragile-bit-masking
Browse files Browse the repository at this point in the history
Yichen/fragile bit masking
  • Loading branch information
ycbiometrics authored Dec 23, 2024
2 parents 8f4a692 + 64dba88 commit 37ad778
Show file tree
Hide file tree
Showing 22 changed files with 136 additions and 68 deletions.
3 changes: 1 addition & 2 deletions src/iris/nodes/encoder/iris_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,9 @@ def run(self, response: IrisFilterResponse) -> IrisTemplate:
mask_codes: List[np.ndarray] = []

for iris_response, mask_response in zip(response.iris_responses, response.mask_responses):
mask_code = mask_response >= self.params.mask_threshold

iris_code = np.stack([iris_response.real > 0, iris_response.imag > 0], axis=-1)
mask_code = np.stack([mask_code, mask_code], axis=-1)
mask_code = np.stack([mask_response.real >= self.params.mask_threshold, mask_response.imag >= self.params.mask_threshold], axis=-1)

iris_codes.append(iris_code)
mask_codes.append(mask_code)
Expand Down
35 changes: 18 additions & 17 deletions src/iris/nodes/iris_response/conv_filter_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ def _convolve(
p_rows = k_rows // 2
p_cols = k_cols // 2
iris_response = np.zeros((probe_schema.params.n_rows, probe_schema.params.n_cols), dtype=np.complex64)
mask_response = np.zeros((probe_schema.params.n_rows, probe_schema.params.n_cols))
mask_response = np.zeros((probe_schema.params.n_rows, probe_schema.params.n_cols), dtype=np.complex64)

padded_iris = polar_img_padding(normalization_output.normalized_image, 0, p_cols)
padded_mask = polar_img_padding(normalization_output.normalized_mask, 0, p_cols)
padded_iris = polar_img_padding(normalization_output.normalized_image, p_rows, p_cols)
padded_mask = polar_img_padding(normalization_output.normalized_mask, p_rows, p_cols)

for i in range(probe_schema.params.n_rows):
for j in range(probe_schema.params.n_cols):
Expand All @@ -147,20 +147,21 @@ def _convolve(
c_probe = min(round(probe_schema.phis[pos] * i_cols), i_cols - 1)

# Get patch from image centered at [i,j] probed pixel position.
rtop = max(0, r_probe - p_rows)
rbot = min(r_probe + p_rows + 1, i_rows - 1)
iris_patch = padded_iris[rtop:rbot, c_probe : c_probe + k_cols]
mask_patch = padded_mask[rtop:rbot, c_probe : c_probe + k_cols]

# Perform convolution at [i,j] probed pixel position.
ktop = p_rows - iris_patch.shape[0] // 2
iris_response[i][j] = (
(iris_patch * img_filter.kernel_values[ktop : ktop + iris_patch.shape[0], :]).sum()
/ iris_patch.shape[0]
/ k_cols
)
mask_response[i][j] = (
0 if iris_response[i][j] == 0 else (mask_patch.sum() / iris_patch.shape[0] / k_cols)
iris_patch = padded_iris[r_probe : r_probe + k_rows, c_probe : c_probe + k_cols]
mask_patch = padded_mask[r_probe : r_probe + k_rows, c_probe : c_probe + k_cols]

# Compute normalization term by excluding zero-padded pixels
non_padded_k_rows = (
k_rows
if np.logical_and(r_probe > p_rows, r_probe <= i_rows - p_rows)
else (k_rows - max(p_rows - r_probe, r_probe + p_rows - i_rows))
)
# Perform convolution at [i,j] probed pixel position.
iris_response[i][j] = (iris_patch * img_filter.kernel_values).sum() / non_padded_k_rows / k_cols
mask_response[i][j] = 0 if iris_response[i][j] == 0 else (mask_patch.sum() / non_padded_k_rows / k_cols)

iris_response.real = iris_response.real / img_filter.kernel_norm.real
iris_response.imag = iris_response.imag / img_filter.kernel_norm.imag
mask_response.imag = mask_response.real

return iris_response, mask_response
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ def __init__(self, **kwargs: Any) -> None:
"""Init function."""
super().__init__(**kwargs)
self.__kernel_values = self.compute_kernel_values()
self.__kernel_norm = np.linalg.norm(self.__kernel_values.real, ord="fro") + np.linalg.norm(self.__kernel_values.imag, ord="fro")*1j

@property
def kernel_norm(self) -> float:
"""Get kernel norm.
Returns:
float: Filter kernel norm.
"""
return self.__kernel_norm

@property
def kernel_values(self) -> np.ndarray:
Expand Down
95 changes: 65 additions & 30 deletions src/iris/nodes/iris_response_refinement/fragile_bits_refinement.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,110 @@
from typing import Literal, Tuple
from enum import Enum
from typing import List, Tuple

import numpy as np
from pydantic import confloat

from iris.callbacks.callback_interface import Callback
from iris.io.class_configs import Algorithm
from iris.io.dataclasses import IrisFilterResponse


class FragileType(str, Enum):
"""Makes wrapper for params."""

cartesian = "cartesian"
polar = "polar"


class FragileBitRefinement(Algorithm):
"""Refining mask by masking out fragile bits.
"""Calculate fragile bits for mask.
Algorithm:
Thresholding by the given parameter value_threshold at each bit, set the corresponding mask response to 0 if iris response is below the threshold.
Thresholding by the given parameter value_threshold at each bit, set the corresponding mask response to 0 if iris response is outside the thresholds.
"""

class Parameters(Algorithm.Parameters):
"""RegularProbeSchema parameters."""

value_threshold: Tuple[confloat(ge=0), confloat(ge=0)]
fragile_type: Literal["cartesian", "polar"]
value_threshold: Tuple[confloat(ge=0), confloat(ge=0), confloat(ge=0)]
fragile_type: FragileType

__parameters_type__ = Parameters

def __init__(
self,
value_threshold: Tuple[confloat(ge=0), confloat(ge=0)],
fragile_type: Literal["cartesian", "polar"] = "polar",
value_threshold: Tuple[confloat(ge=0), confloat(ge=0), confloat(ge=0)],
fragile_type: FragileType = FragileType.polar,
callbacks: List[Callback] = [],
) -> None:
"""Create Fragile Bit Refinement object.
Args:
value_threshold (Tuple[confloat(ge=0), confloat(ge=0)]): Thresholding iris response values.
fragile_type (Literal["cartesian", "polar"], optional): The Fragile bits can be either
value_threshold (Tuple[confloat(ge=0), confloat(ge=0), confloat(ge=0)]): Threshold at which response
is strong enough such that the bit is valued as 1 in the mask.
fragile_type (FragileType, optional): The Fragile bits can be either
calculated in cartesian or polar coordinates. In the first, the values
of value_threshold denote to x and y axis, in the case of polar coordinates,
the values denote to radius and angle. Defaults to "polar".
the values denote to radius and angle. Defaults to FragileType.polar.
"""
super().__init__(value_threshold=value_threshold, fragile_type=fragile_type)
super().__init__(value_threshold=value_threshold, fragile_type=fragile_type, callbacks=callbacks)

def run(self, iris_filter_response: IrisFilterResponse) -> IrisFilterResponse:
def run(self, response: IrisFilterResponse) -> IrisFilterResponse:
"""Generate refined IrisFilterResponse.
Args:
iris_filter_response (IrisFilterResponse): Filter bank response.
response (IrisFilterResponse): Filter bank response
Returns:
IrisFilterResponse: Filter bank response.
"""
fragile_masks = []
for iris_response, iris_mask in zip(iris_filter_response.iris_responses, iris_filter_response.mask_responses):
if self.params.fragile_type == "cartesian":
mask_value_real = np.abs(np.real(iris_response)) >= self.params.value_threshold[0]
mask_value_imaginary = np.abs(np.imag(iris_response)) >= self.params.value_threshold[1]
mask_value = mask_value_real * mask_value_imaginary

if self.params.fragile_type == "polar":
for iris_response, iris_mask in zip(response.iris_responses, response.mask_responses):
if self.params.fragile_type == FragileType.cartesian:
mask_value_real = (
np.logical_and(
np.abs(iris_response.real) >= self.params.value_threshold[0],
np.abs(iris_response.real) <= self.params.value_threshold[1],
)
* iris_mask.real
)
mask_value_imag = (
np.logical_and(
np.abs(iris_response.imag) >= self.params.value_threshold[0],
np.abs(iris_response.imag) <= self.params.value_threshold[2],
)
* iris_mask.imag
)

if self.params.fragile_type == FragileType.polar:
# transform from cartesian to polar system
# radius
iris_response_r = np.abs(iris_response)
# angle
iris_response_phi = np.angle(iris_response)

mask_value_r = iris_response_r >= self.params.value_threshold[0]

cos_mask = np.abs(np.cos(iris_response_phi)) <= np.abs(np.cos(self.params.value_threshold[1]))
sine_mask = np.abs(np.sin(iris_response_phi)) <= np.abs(np.cos(self.params.value_threshold[1]))
mask_value_phi = cos_mask * sine_mask
mask_value = mask_value_r * mask_value_phi

mask_value = mask_value * iris_mask
# min radius
mask_value_r = np.logical_and(
iris_response_r >= self.params.value_threshold[0], iris_response_r <= self.params.value_threshold[1]
)
# min angle away from the coordinate lines

# cosine requirement: makes sure that angle is different enough from x-axis
cos_mask = np.abs(np.cos(iris_response_phi)) <= np.abs(np.cos(self.params.value_threshold[2]))
# sine requirement: makes sure that angle is different enough from y-axis
sine_mask = np.abs(np.sin(iris_response_phi)) <= np.abs(np.cos(self.params.value_threshold[2]))
# combine
mask_value_real = mask_value_r * sine_mask * iris_mask.real
# combine with radius
mask_value_imag = mask_value_r * cos_mask * iris_mask.imag

# combine with mask for response
mask_value = mask_value_real + 1j * mask_value_imag
fragile_masks.append(mask_value)

return IrisFilterResponse(
iris_responses=iris_filter_response.iris_responses,
iris_responses=response.iris_responses,
mask_responses=fragile_masks,
iris_code_version=iris_filter_response.iris_code_version,
iris_code_version=response.iris_code_version,
)
2 changes: 2 additions & 0 deletions src/iris/orchestration/output_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def build_simple_debugging_output(call_trace: PipelineCallTraceStorage) -> Dict[
extrapolated_polygons = call_trace["geometry_estimation"]
normalized_iris = call_trace["normalization"]
iris_response = call_trace["filter_bank"]
iris_response_refined = call_trace["iris_response_refinement"]

return {
"iris_template": iris_template,
Expand All @@ -82,6 +83,7 @@ def build_simple_debugging_output(call_trace: PipelineCallTraceStorage) -> Dict[
"extrapolated_polygons": __safe_serialize(extrapolated_polygons),
"normalized_iris": __safe_serialize(normalized_iris),
"iris_response": __safe_serialize(iris_response),
"iris_response_refined": __safe_serialize(iris_response_refined),
"error": error,
}

Expand Down
13 changes: 12 additions & 1 deletion src/iris/pipelines/confs/pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,24 @@ pipeline:
source_node: normalization
callbacks:

- name: iris_response_refinement
algorithm:
class_name: iris.nodes.iris_response_refinement.fragile_bits_refinement.FragileBitRefinement
params:
value_threshold: [0.0001, 0.275, 0.08726646259971647]
fragile_type: "polar"
inputs:
- name: response
source_node: filter_bank
callbacks:

- name: encoder
algorithm:
class_name: iris.IrisEncoder
params: {}
inputs:
- name: response
source_node: filter_bank
source_node: iris_response_refinement
callbacks:
- class_name: iris.nodes.validators.object_validators.IsMaskTooSmallValidator
params:
Expand Down
Binary file not shown.
Binary file not shown.
15 changes: 10 additions & 5 deletions tests/e2e_tests/nodes/iris_response/test_e2e_conv_filter_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,14 @@ def test_convfilterbank_constructor(

for i_iris_response, i_mask_response in zip(filter_responses.iris_responses, filter_responses.mask_responses):
assert i_iris_response.shape == i_mask_response.shape
assert np.iscomplexobj(i_iris_response) and not np.iscomplexobj(i_mask_response)
assert np.max(i_mask_response) <= 1 and np.min(i_mask_response) >= 0
assert np.max(i_mask_response) > np.min(i_mask_response)
assert np.max(i_iris_response.real) > np.min(i_iris_response.real)
assert np.max(i_iris_response.imag) > np.min(i_iris_response.imag)
assert np.iscomplexobj(i_iris_response)
assert np.iscomplexobj(i_mask_response)
assert i_mask_response.real.max() <= 1
assert i_mask_response.real.min() >= 0
assert i_mask_response.imag.max() <= 1
assert i_mask_response.imag.min() >= 0
assert i_mask_response.real.max() > i_mask_response.real.min()
assert i_mask_response.imag.max() > i_mask_response.imag.min()
assert i_iris_response.real.max() > i_iris_response.real.min()
assert i_iris_response.imag.max() > i_iris_response.imag.min()
assert filter_responses.iris_code_version == "v0.1"
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import numpy as np
import pytest
from pydantic import confloat

from iris.nodes.iris_response_refinement.fragile_bits_refinement import FragileBitRefinement
from iris.nodes.iris_response_refinement.fragile_bits_refinement import FragileBitRefinement, FragileType


def load_mock_pickle(name: str) -> Any:
Expand All @@ -19,13 +20,13 @@ def load_mock_pickle(name: str) -> Any:
@pytest.mark.parametrize(
"value_threshold,fragile_type",
[
pytest.param([0.5, 0.5], "cartesian"),
pytest.param([0.49, np.pi / 8], "polar"),
pytest.param([0, 0.5, 0.5], FragileType.cartesian),
pytest.param([0, 0.49, np.pi / 8], FragileType.polar),
],
ids=["cartesian", "polar"],
)
def test_fragile_bits_dummy_responses(
value_threshold: Tuple[float, float], fragile_type: Literal["cartesian", "polar"]
value_threshold: Tuple[confloat(ge=0), confloat(ge=0), confloat(ge=0)], fragile_type: FragileType
) -> None:
iris_filter_response = load_mock_pickle(f"artificial_iris_responses_{fragile_type}")

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
from typing import Literal, Tuple
from typing import Tuple

import pytest
from pydantic import ValidationError, confloat

from iris.nodes.iris_response_refinement.fragile_bits_refinement import FragileBitRefinement
from iris.nodes.iris_response_refinement.fragile_bits_refinement import FragileBitRefinement, FragileType


@pytest.mark.parametrize(
"value_threshold,fragile_type",
# given
"value_threshold, fragile_type",
[
pytest.param([-0.6, -0.3], "cartesian"),
pytest.param([-0.2, -0.5], "polar"),
pytest.param([0, 0], "elliptical"),
pytest.param([-0.6, 0, -0.3], FragileType.cartesian),
pytest.param([-0.2, -0.5, 1], FragileType.polar),
pytest.param([0, 0, 1], "elliptical"),
],
ids=["error_threshold_cartesian", "error_threshold_polar", "error_fragile_type"],
ids=["error_theshold_cartesian", "error_theshold_polar", "error_fragile_type"],
)
def test_iris_encoder_threshold_raises_an_exception(
value_threshold: Tuple[confloat(ge=0), confloat(ge=0)], fragile_type: Literal["cartesian", "polar"]
value_threshold: Tuple[confloat(ge=0), confloat(ge=0), confloat(ge=0)],
fragile_type: FragileType,
) -> None:
with pytest.raises(ValidationError):
# when
with pytest.raises((ValidationError)):
_ = FragileBitRefinement(value_threshold, fragile_type)
1 change: 1 addition & 0 deletions tests/unit_tests/pipelines/test_iris_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def test_error_manager(input: str, env: Environment, expectation, request: Fixtu
"extrapolated_polygons",
"normalized_iris",
"iris_response",
"iris_response_refined",
"landmarks",
"iris_template",
"status",
Expand Down

0 comments on commit 37ad778

Please sign in to comment.