diff --git a/homeassistant/components/unifi/.translations/en.json b/homeassistant/components/unifi/.translations/en.json index f1f96b3c363075..9ac01e514bf46a 100644 --- a/homeassistant/components/unifi/.translations/en.json +++ b/homeassistant/components/unifi/.translations/en.json @@ -6,7 +6,8 @@ }, "error": { "faulty_credentials": "Bad user credentials", - "service_unavailable": "No service available" + "service_unavailable": "No service available", + "unknown_client_mac": "No client available on that MAC address" }, "step": { "user": { @@ -34,15 +35,26 @@ "track_wired_clients": "Include wired network clients" }, "description": "Configure device tracking", - "title": "UniFi options" + "title": "UniFi options 1/3" + }, + "client_control": { + "data": { + "block_client": "Network access controlled clients", + "new_client": "Add new client (MAC) for network access control" + }, + "description": "Configure client controls\n\nCreate switches for serial numbers you want to control network access for.", + "title": "UniFi options 2/3" }, "statistics_sensors": { "data": { "allow_bandwidth_sensors": "Bandwidth usage sensors for network clients" }, "description": "Configure statistics sensors", - "title": "UniFi options" + "title": "UniFi options 3/3" } + }, + "error": { + "unknown_client_mac": "No client available in UniFi on that MAC address" } } } \ No newline at end of file diff --git a/homeassistant/components/unifi/config_flow.py b/homeassistant/components/unifi/config_flow.py index 36fa7489e81500..e0bb1c3bb9fcf3 100644 --- a/homeassistant/components/unifi/config_flow.py +++ b/homeassistant/components/unifi/config_flow.py @@ -16,6 +16,7 @@ from .const import ( CONF_ALLOW_BANDWIDTH_SENSORS, + CONF_BLOCK_CLIENT, CONF_CONTROLLER, CONF_DETECTION_TIME, CONF_SITE_ID, @@ -30,6 +31,7 @@ from .controller import get_controller from .errors import AlreadyConfigured, AuthenticationRequired, CannotConnect +CONF_NEW_CLIENT = "new_client" DEFAULT_PORT = 8443 DEFAULT_SITE_ID = "default" DEFAULT_VERIFY_SSL = False @@ -171,61 +173,117 @@ def __init__(self, config_entry): """Initialize UniFi options flow.""" self.config_entry = config_entry self.options = dict(config_entry.options) + self.controller = None async def async_step_init(self, user_input=None): """Manage the UniFi options.""" + self.controller = get_controller_from_config_entry(self.hass, self.config_entry) + self.options[CONF_BLOCK_CLIENT] = self.controller.option_block_clients return await self.async_step_device_tracker() async def async_step_device_tracker(self, user_input=None): """Manage the device tracker options.""" if user_input is not None: self.options.update(user_input) - return await self.async_step_statistics_sensors() + return await self.async_step_client_control() - controller = get_controller_from_config_entry(self.hass, self.config_entry) - - ssid_filter = {wlan: wlan for wlan in controller.api.wlans} + ssid_filter = {wlan: wlan for wlan in self.controller.api.wlans} return self.async_show_form( step_id="device_tracker", data_schema=vol.Schema( { vol.Optional( - CONF_TRACK_CLIENTS, default=controller.option_track_clients, + CONF_TRACK_CLIENTS, + default=self.controller.option_track_clients, ): bool, vol.Optional( CONF_TRACK_WIRED_CLIENTS, - default=controller.option_track_wired_clients, + default=self.controller.option_track_wired_clients, ): bool, vol.Optional( - CONF_TRACK_DEVICES, default=controller.option_track_devices, + CONF_TRACK_DEVICES, + default=self.controller.option_track_devices, ): bool, vol.Optional( - CONF_SSID_FILTER, default=controller.option_ssid_filter + CONF_SSID_FILTER, default=self.controller.option_ssid_filter ): cv.multi_select(ssid_filter), vol.Optional( CONF_DETECTION_TIME, - default=int(controller.option_detection_time.total_seconds()), + default=int( + self.controller.option_detection_time.total_seconds() + ), ): int, } ), ) + async def async_step_client_control(self, user_input=None): + """Manage configuration of network access controlled clients.""" + errors = {} + + if user_input is not None: + new_client = user_input.pop(CONF_NEW_CLIENT, None) + self.options.update(user_input) + + if new_client: + if ( + new_client in self.controller.api.clients + or new_client in self.controller.api.clients_all + ): + self.options[CONF_BLOCK_CLIENT].append(new_client) + + else: + errors["base"] = "unknown_client_mac" + + else: + return await self.async_step_statistics_sensors() + + clients_to_block = {} + + for mac in self.options[CONF_BLOCK_CLIENT]: + + name = None + + for clients in [ + self.controller.api.clients, + self.controller.api.clients_all, + ]: + if mac in clients: + name = f"{clients[mac].name or clients[mac].hostname} ({mac})" + break + + if not name: + name = mac + + clients_to_block[mac] = name + + return self.async_show_form( + step_id="client_control", + data_schema=vol.Schema( + { + vol.Optional( + CONF_BLOCK_CLIENT, default=self.options[CONF_BLOCK_CLIENT] + ): cv.multi_select(clients_to_block), + vol.Optional(CONF_NEW_CLIENT): str, + } + ), + errors=errors, + ) + async def async_step_statistics_sensors(self, user_input=None): """Manage the statistics sensors options.""" if user_input is not None: self.options.update(user_input) return await self._update_options() - controller = get_controller_from_config_entry(self.hass, self.config_entry) - return self.async_show_form( step_id="statistics_sensors", data_schema=vol.Schema( { vol.Optional( CONF_ALLOW_BANDWIDTH_SENSORS, - default=controller.option_allow_bandwidth_sensors, + default=self.controller.option_allow_bandwidth_sensors, ): bool } ), diff --git a/homeassistant/components/unifi/const.py b/homeassistant/components/unifi/const.py index d82b7b49d4515d..341364063f26d9 100644 --- a/homeassistant/components/unifi/const.py +++ b/homeassistant/components/unifi/const.py @@ -25,11 +25,9 @@ CONF_DONT_TRACK_WIRED_CLIENTS = "dont_track_wired_clients" DEFAULT_ALLOW_BANDWIDTH_SENSORS = False -DEFAULT_BLOCK_CLIENTS = [] DEFAULT_TRACK_CLIENTS = True DEFAULT_TRACK_DEVICES = True DEFAULT_TRACK_WIRED_CLIENTS = True DEFAULT_DETECTION_TIME = 300 -DEFAULT_SSID_FILTER = [] ATTR_MANUFACTURER = "Ubiquiti Networks" diff --git a/homeassistant/components/unifi/controller.py b/homeassistant/components/unifi/controller.py index b7cd8e8b6a13b4..7da36131058726 100644 --- a/homeassistant/components/unifi/controller.py +++ b/homeassistant/components/unifi/controller.py @@ -31,9 +31,7 @@ CONF_TRACK_WIRED_CLIENTS, CONTROLLER_ID, DEFAULT_ALLOW_BANDWIDTH_SENSORS, - DEFAULT_BLOCK_CLIENTS, DEFAULT_DETECTION_TIME, - DEFAULT_SSID_FILTER, DEFAULT_TRACK_CLIENTS, DEFAULT_TRACK_DEVICES, DEFAULT_TRACK_WIRED_CLIENTS, @@ -99,7 +97,7 @@ def option_allow_bandwidth_sensors(self): @property def option_block_clients(self): """Config entry option with list of clients to control network access.""" - return self.config_entry.options.get(CONF_BLOCK_CLIENT, DEFAULT_BLOCK_CLIENTS) + return self.config_entry.options.get(CONF_BLOCK_CLIENT, []) @property def option_track_clients(self): @@ -130,7 +128,7 @@ def option_detection_time(self): @property def option_ssid_filter(self): """Config entry option listing what SSIDs are being used to track clients.""" - return self.config_entry.options.get(CONF_SSID_FILTER, DEFAULT_SSID_FILTER) + return self.config_entry.options.get(CONF_SSID_FILTER, []) @property def mac(self): diff --git a/homeassistant/components/unifi/strings.json b/homeassistant/components/unifi/strings.json index e652b60ee32a6b..58728225de7df9 100644 --- a/homeassistant/components/unifi/strings.json +++ b/homeassistant/components/unifi/strings.json @@ -16,7 +16,8 @@ }, "error": { "faulty_credentials": "Bad user credentials", - "service_unavailable": "No service available" + "service_unavailable": "No service available", + "unknown_client_mac": "No client available on that MAC address" }, "abort": { "already_configured": "Controller site is already configured", @@ -37,15 +38,26 @@ "track_wired_clients": "Include wired network clients" }, "description": "Configure device tracking", - "title": "UniFi options" + "title": "UniFi options 1/3" + }, + "client_control": { + "data": { + "block_client": "Network access controlled clients", + "new_client": "Add new client for network access control" + }, + "description": "Configure client controls\n\nCreate switches for serial numbers you want to control network access for.", + "title": "UniFi options 2/3" }, "statistics_sensors": { "data": { "allow_bandwidth_sensors": "Bandwidth usage sensors for network clients" }, "description": "Configure statistics sensors", - "title": "UniFi options" + "title": "UniFi options 3/3" } } + }, + "error": { + "unknown_client_mac": "No client available in UniFi on that MAC address" } } \ No newline at end of file diff --git a/homeassistant/components/unifi/switch.py b/homeassistant/components/unifi/switch.py index 941f4f8ab84d8f..84e85188ededaf 100644 --- a/homeassistant/components/unifi/switch.py +++ b/homeassistant/components/unifi/switch.py @@ -4,7 +4,6 @@ from homeassistant.components.switch import SwitchDevice from homeassistant.components.unifi.config_flow import get_controller_from_config_entry from homeassistant.core import callback -from homeassistant.helpers import entity_registry from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.restore_state import RestoreEntity @@ -30,10 +29,12 @@ async def async_setup_entry(hass, config_entry, async_add_entities): switches = {} switches_off = [] - registry = await entity_registry.async_get_registry(hass) + option_block_clients = controller.option_block_clients + + entity_registry = await hass.helpers.entity_registry.async_get_registry() # Restore clients that is not a part of active clients list. - for entity in registry.entities.values(): + for entity in entity_registry.entities.values(): if ( entity.config_entry_id == config_entry.entry_id @@ -61,6 +62,43 @@ def update_controller(): async_dispatcher_connect(hass, controller.signal_update, update_controller) ) + @callback + def options_updated(): + """Manage entities affected by config entry options.""" + nonlocal option_block_clients + + update = set() + remove = set() + + if option_block_clients != controller.option_block_clients: + option_block_clients = controller.option_block_clients + + for block_client_id, entity in switches.items(): + if not isinstance(entity, UniFiBlockClientSwitch): + continue + + if entity.client.mac in option_block_clients: + update.add(block_client_id) + else: + remove.add(block_client_id) + + for block_client_id in remove: + entity = switches.pop(block_client_id) + + if entity_registry.async_is_registered(entity.entity_id): + entity_registry.async_remove(entity.entity_id) + + hass.async_create_task(entity.async_remove()) + + if len(update) != len(option_block_clients): + update_controller() + + controller.listeners.append( + async_dispatcher_connect( + hass, controller.signal_options_update, options_updated + ) + ) + update_controller() switches_off.clear() @@ -74,15 +112,21 @@ def add_entities(controller, async_add_entities, switches, switches_off): # block client for client_id in controller.option_block_clients: + client = None block_client_id = f"block-{client_id}" if block_client_id in switches: continue - if client_id not in controller.api.clients_all: + if client_id in controller.api.clients: + client = controller.api.clients[client_id] + + elif client_id in controller.api.clients_all: + client = controller.api.clients_all[client_id] + + if not client: continue - client = controller.api.clients_all[client_id] switches[block_client_id] = UniFiBlockClientSwitch(client, controller) new_switches.append(switches[block_client_id]) diff --git a/tests/components/unifi/test_config_flow.py b/tests/components/unifi/test_config_flow.py index 64d1ab9775e66b..9a280ffe9e69b8 100644 --- a/tests/components/unifi/test_config_flow.py +++ b/tests/components/unifi/test_config_flow.py @@ -5,7 +5,18 @@ from homeassistant import data_entry_flow from homeassistant.components import unifi from homeassistant.components.unifi import config_flow -from homeassistant.components.unifi.const import CONF_CONTROLLER, CONF_SITE_ID +from homeassistant.components.unifi.config_flow import CONF_NEW_CLIENT +from homeassistant.components.unifi.const import ( + CONF_ALLOW_BANDWIDTH_SENSORS, + CONF_BLOCK_CLIENT, + CONF_CONTROLLER, + CONF_DETECTION_TIME, + CONF_SITE_ID, + CONF_SSID_FILTER, + CONF_TRACK_CLIENTS, + CONF_TRACK_DEVICES, + CONF_TRACK_WIRED_CLIENTS, +) from homeassistant.const import ( CONF_HOST, CONF_PASSWORD, @@ -18,6 +29,8 @@ from tests.common import MockConfigEntry +CLIENTS = [{"mac": "00:00:00:00:00:01"}] + WLANS = [{"name": "SSID 1"}, {"name": "SSID 2"}] @@ -28,7 +41,7 @@ async def test_flow_works(hass, aioclient_mock, mock_discovery): config_flow.DOMAIN, context={"source": "user"} ) - assert result["type"] == "form" + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["step_id"] == "user" assert result["data_schema"]({CONF_USERNAME: "", CONF_PASSWORD: ""}) == { CONF_HOST: "unifi", @@ -64,7 +77,7 @@ async def test_flow_works(hass, aioclient_mock, mock_discovery): }, ) - assert result["type"] == "create_entry" + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result["title"] == "Site name" assert result["data"] == { CONF_CONTROLLER: { @@ -84,7 +97,7 @@ async def test_flow_works_multiple_sites(hass, aioclient_mock): config_flow.DOMAIN, context={"source": "user"} ) - assert result["type"] == "form" + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["step_id"] == "user" aioclient_mock.post( @@ -116,7 +129,7 @@ async def test_flow_works_multiple_sites(hass, aioclient_mock): }, ) - assert result["type"] == "form" + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["step_id"] == "site" assert result["data_schema"]({"site": "site name"}) assert result["data_schema"]({"site": "site2 name"}) @@ -133,7 +146,7 @@ async def test_flow_fails_site_already_configured(hass, aioclient_mock): config_flow.DOMAIN, context={"source": "user"} ) - assert result["type"] == "form" + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["step_id"] == "user" aioclient_mock.post( @@ -162,7 +175,7 @@ async def test_flow_fails_site_already_configured(hass, aioclient_mock): }, ) - assert result["type"] == "abort" + assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT async def test_flow_fails_user_credentials_faulty(hass, aioclient_mock): @@ -171,7 +184,7 @@ async def test_flow_fails_user_credentials_faulty(hass, aioclient_mock): config_flow.DOMAIN, context={"source": "user"} ) - assert result["type"] == "form" + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["step_id"] == "user" with patch("aiounifi.Controller.login", side_effect=aiounifi.errors.Unauthorized): @@ -186,7 +199,7 @@ async def test_flow_fails_user_credentials_faulty(hass, aioclient_mock): }, ) - assert result["type"] == "form" + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["errors"] == {"base": "faulty_credentials"} @@ -196,7 +209,7 @@ async def test_flow_fails_controller_unavailable(hass, aioclient_mock): config_flow.DOMAIN, context={"source": "user"} ) - assert result["type"] == "form" + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["step_id"] == "user" with patch("aiounifi.Controller.login", side_effect=aiounifi.errors.RequestError): @@ -211,7 +224,7 @@ async def test_flow_fails_controller_unavailable(hass, aioclient_mock): }, ) - assert result["type"] == "form" + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["errors"] == {"base": "service_unavailable"} @@ -221,7 +234,7 @@ async def test_flow_fails_unknown_problem(hass, aioclient_mock): config_flow.DOMAIN, context={"source": "user"} ) - assert result["type"] == "form" + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["step_id"] == "user" with patch("aiounifi.Controller.login", side_effect=Exception): @@ -236,12 +249,14 @@ async def test_flow_fails_unknown_problem(hass, aioclient_mock): }, ) - assert result["type"] == "abort" + assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT async def test_option_flow(hass): """Test config flow options.""" - controller = await setup_unifi_integration(hass, wlans_response=WLANS) + controller = await setup_unifi_integration( + hass, clients_response=CLIENTS, wlans_response=WLANS + ) result = await hass.config_entries.options.async_init( controller.config_entry.entry_id @@ -253,27 +268,64 @@ async def test_option_flow(hass): result = await hass.config_entries.options.async_configure( result["flow_id"], user_input={ - config_flow.CONF_TRACK_CLIENTS: False, - config_flow.CONF_TRACK_WIRED_CLIENTS: False, - config_flow.CONF_TRACK_DEVICES: False, - config_flow.CONF_SSID_FILTER: ["SSID 1"], - config_flow.CONF_DETECTION_TIME: 100, + CONF_TRACK_CLIENTS: False, + CONF_TRACK_WIRED_CLIENTS: False, + CONF_TRACK_DEVICES: False, + CONF_SSID_FILTER: ["SSID 1"], + CONF_DETECTION_TIME: 100, + }, + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "client_control" + + clients_to_block = hass.config_entries.options._progress[result["flow_id"]].options[ + CONF_BLOCK_CLIENT + ] + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={ + CONF_BLOCK_CLIENT: clients_to_block, + CONF_NEW_CLIENT: "00:00:00:00:00:01", }, ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "client_control" + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={ + CONF_BLOCK_CLIENT: clients_to_block, + CONF_NEW_CLIENT: "00:00:00:00:00:02", + }, + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "client_control" + assert result["errors"] == {"base": "unknown_client_mac"} + + clients_to_block = hass.config_entries.options._progress[result["flow_id"]].options[ + CONF_BLOCK_CLIENT + ] + result = await hass.config_entries.options.async_configure( + result["flow_id"], user_input={CONF_BLOCK_CLIENT: clients_to_block}, + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["step_id"] == "statistics_sensors" result = await hass.config_entries.options.async_configure( - result["flow_id"], user_input={config_flow.CONF_ALLOW_BANDWIDTH_SENSORS: True} + result["flow_id"], user_input={CONF_ALLOW_BANDWIDTH_SENSORS: True} ) assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result["data"] == { - config_flow.CONF_TRACK_CLIENTS: False, - config_flow.CONF_TRACK_WIRED_CLIENTS: False, - config_flow.CONF_TRACK_DEVICES: False, - config_flow.CONF_DETECTION_TIME: 100, - config_flow.CONF_SSID_FILTER: ["SSID 1"], - config_flow.CONF_ALLOW_BANDWIDTH_SENSORS: True, + CONF_TRACK_CLIENTS: False, + CONF_TRACK_WIRED_CLIENTS: False, + CONF_TRACK_DEVICES: False, + CONF_DETECTION_TIME: 100, + CONF_SSID_FILTER: ["SSID 1"], + CONF_BLOCK_CLIENT: ["00:00:00:00:00:01"], + CONF_ALLOW_BANDWIDTH_SENSORS: True, } diff --git a/tests/components/unifi/test_controller.py b/tests/components/unifi/test_controller.py index daec8cddf5dec6..8bf2225d1f1286 100644 --- a/tests/components/unifi/test_controller.py +++ b/tests/components/unifi/test_controller.py @@ -166,7 +166,7 @@ async def test_controller_setup(hass): controller.option_allow_bandwidth_sensors == unifi.const.DEFAULT_ALLOW_BANDWIDTH_SENSORS ) - assert controller.option_block_clients == unifi.const.DEFAULT_BLOCK_CLIENTS + assert isinstance(controller.option_block_clients, list) assert controller.option_track_clients == unifi.const.DEFAULT_TRACK_CLIENTS assert controller.option_track_devices == unifi.const.DEFAULT_TRACK_DEVICES assert ( @@ -175,7 +175,7 @@ async def test_controller_setup(hass): assert controller.option_detection_time == timedelta( seconds=unifi.const.DEFAULT_DETECTION_TIME ) - assert controller.option_ssid_filter == unifi.const.DEFAULT_SSID_FILTER + assert isinstance(controller.option_ssid_filter, list) assert controller.mac is None @@ -235,7 +235,7 @@ async def test_reset_after_successful_setup(hass): """Calling reset when the entry has been setup.""" controller = await setup_unifi_integration(hass) - assert len(controller.listeners) == 5 + assert len(controller.listeners) == 6 result = await controller.async_reset() await hass.async_block_till_done() diff --git a/tests/components/unifi/test_switch.py b/tests/components/unifi/test_switch.py index a2b609078deeb2..bc30161b77f34c 100644 --- a/tests/components/unifi/test_switch.py +++ b/tests/components/unifi/test_switch.py @@ -4,6 +4,11 @@ from homeassistant import config_entries from homeassistant.components import unifi import homeassistant.components.switch as switch +from homeassistant.components.unifi.const import ( + CONF_BLOCK_CLIENT, + CONF_TRACK_CLIENTS, + CONF_TRACK_DEVICES, +) from homeassistant.helpers import entity_registry from homeassistant.setup import async_setup_component @@ -200,11 +205,7 @@ async def test_platform_manually_configured(hass): async def test_no_clients(hass): """Test the update_clients function when no clients are found.""" controller = await setup_unifi_integration( - hass, - options={ - unifi.const.CONF_TRACK_CLIENTS: False, - unifi.const.CONF_TRACK_DEVICES: False, - }, + hass, options={CONF_TRACK_CLIENTS: False, CONF_TRACK_DEVICES: False}, ) assert len(controller.mock_requests) == 4 @@ -215,10 +216,7 @@ async def test_controller_not_client(hass): """Test that the controller doesn't become a switch.""" controller = await setup_unifi_integration( hass, - options={ - unifi.const.CONF_TRACK_CLIENTS: False, - unifi.const.CONF_TRACK_DEVICES: False, - }, + options={CONF_TRACK_CLIENTS: False, CONF_TRACK_DEVICES: False}, clients_response=[CONTROLLER_HOST], devices_response=[DEVICE_1], ) @@ -235,10 +233,7 @@ async def test_not_admin(hass): sites["Site name"]["role"] = "not admin" controller = await setup_unifi_integration( hass, - options={ - unifi.const.CONF_TRACK_CLIENTS: False, - unifi.const.CONF_TRACK_DEVICES: False, - }, + options={CONF_TRACK_CLIENTS: False, CONF_TRACK_DEVICES: False}, sites=sites, clients_response=[CLIENT_1], devices_response=[DEVICE_1], @@ -253,9 +248,9 @@ async def test_switches(hass): controller = await setup_unifi_integration( hass, options={ - unifi.CONF_BLOCK_CLIENT: [BLOCKED["mac"], UNBLOCKED["mac"]], - unifi.const.CONF_TRACK_CLIENTS: False, - unifi.const.CONF_TRACK_DEVICES: False, + CONF_BLOCK_CLIENT: [BLOCKED["mac"], UNBLOCKED["mac"]], + CONF_TRACK_CLIENTS: False, + CONF_TRACK_DEVICES: False, }, clients_response=[CLIENT_1, CLIENT_4], devices_response=[DEVICE_1], @@ -284,59 +279,100 @@ async def test_switches(hass): assert unblocked is not None assert unblocked.state == "on" + await hass.services.async_call( + "switch", "turn_off", {"entity_id": "switch.block_client_1"}, blocking=True + ) + assert len(controller.mock_requests) == 5 + assert controller.mock_requests[4] == { + "json": {"mac": "00:00:00:00:01:01", "cmd": "block-sta"}, + "method": "post", + "path": "s/{site}/cmd/stamgr/", + } + + await hass.services.async_call( + "switch", "turn_on", {"entity_id": "switch.block_client_1"}, blocking=True + ) + assert len(controller.mock_requests) == 6 + assert controller.mock_requests[5] == { + "json": {"mac": "00:00:00:00:01:01", "cmd": "unblock-sta"}, + "method": "post", + "path": "s/{site}/cmd/stamgr/", + } + async def test_new_client_discovered_on_block_control(hass): """Test if 2nd update has a new client.""" controller = await setup_unifi_integration( hass, options={ - unifi.CONF_BLOCK_CLIENT: [BLOCKED["mac"]], - unifi.const.CONF_TRACK_CLIENTS: False, - unifi.const.CONF_TRACK_DEVICES: False, + CONF_BLOCK_CLIENT: [BLOCKED["mac"]], + CONF_TRACK_CLIENTS: False, + CONF_TRACK_DEVICES: False, }, - clients_all_response=[BLOCKED], ) assert len(controller.mock_requests) == 4 - assert len(hass.states.async_all()) == 2 + assert len(hass.states.async_all()) == 1 + + blocked = hass.states.get("switch.block_client_1") + assert blocked is None controller.api.websocket._data = { "meta": {"message": "sta:sync"}, "data": [BLOCKED], } controller.api.session_handler("data") + await hass.async_block_till_done() - # Calling a service will trigger the updates to run - await hass.services.async_call( - "switch", "turn_off", {"entity_id": "switch.block_client_1"}, blocking=True + assert len(hass.states.async_all()) == 2 + blocked = hass.states.get("switch.block_client_1") + assert blocked is not None + + +async def test_option_block_clients(hass): + """Test the changes to option reflects accordingly.""" + controller = await setup_unifi_integration( + hass, + options={CONF_BLOCK_CLIENT: [BLOCKED["mac"]]}, + clients_all_response=[BLOCKED, UNBLOCKED], ) - assert len(controller.mock_requests) == 5 assert len(hass.states.async_all()) == 2 - assert controller.mock_requests[4] == { - "json": {"mac": "00:00:00:00:01:01", "cmd": "block-sta"}, - "method": "post", - "path": "s/{site}/cmd/stamgr/", - } - await hass.services.async_call( - "switch", "turn_on", {"entity_id": "switch.block_client_1"}, blocking=True + # Add a second switch + hass.config_entries.async_update_entry( + controller.config_entry, + options={CONF_BLOCK_CLIENT: [BLOCKED["mac"], UNBLOCKED["mac"]]}, ) - assert len(controller.mock_requests) == 6 - assert controller.mock_requests[5] == { - "json": {"mac": "00:00:00:00:01:01", "cmd": "unblock-sta"}, - "method": "post", - "path": "s/{site}/cmd/stamgr/", - } + await hass.async_block_till_done() + assert len(hass.states.async_all()) == 3 + + # Remove the second switch again + hass.config_entries.async_update_entry( + controller.config_entry, options={CONF_BLOCK_CLIENT: [BLOCKED["mac"]]}, + ) + await hass.async_block_till_done() + assert len(hass.states.async_all()) == 2 + + # Enable one and remove another one + hass.config_entries.async_update_entry( + controller.config_entry, options={CONF_BLOCK_CLIENT: [UNBLOCKED["mac"]]}, + ) + await hass.async_block_till_done() + assert len(hass.states.async_all()) == 2 + + # Remove one + hass.config_entries.async_update_entry( + controller.config_entry, options={CONF_BLOCK_CLIENT: []}, + ) + await hass.async_block_till_done() + assert len(hass.states.async_all()) == 1 async def test_new_client_discovered_on_poe_control(hass): """Test if 2nd update has a new client.""" controller = await setup_unifi_integration( hass, - options={ - unifi.const.CONF_TRACK_CLIENTS: False, - unifi.const.CONF_TRACK_DEVICES: False, - }, + options={CONF_TRACK_CLIENTS: False, CONF_TRACK_DEVICES: False}, clients_response=[CLIENT_1], devices_response=[DEVICE_1], ) @@ -435,9 +471,9 @@ async def test_restoring_client(hass): controller = await setup_unifi_integration( hass, options={ - unifi.CONF_BLOCK_CLIENT: ["random mac"], - unifi.const.CONF_TRACK_CLIENTS: False, - unifi.const.CONF_TRACK_DEVICES: False, + CONF_BLOCK_CLIENT: ["random mac"], + CONF_TRACK_CLIENTS: False, + CONF_TRACK_DEVICES: False, }, clients_response=[CLIENT_2], devices_response=[DEVICE_1],