diff --git a/requirements.txt b/requirements.txt index 1187fe9..4ff19b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ httpx omegaconf pandas tabpfn +password-strength # for testing respx \ No newline at end of file diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index aacb1aa..96892e0 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -218,13 +218,43 @@ def try_authenticate(self, access_token) -> bool: return is_authenticated + def validate_email(self, email: str) -> tuple[bool, str]: + """ + Send entered email to server that checks if it is valid and not already in use. + + Parameters + ---------- + email : str + + Returns + ------- + is_valid : bool + True if the email is valid. + message : str + The message returned from the server. + """ + response = self.httpx_client.post( + self.server_endpoints.validate_email.path, + params={"email": email} + ) + + self._validate_response(response, "validate_email", only_version_check=True) + if response.status_code == 200: + is_valid = True + message = "" + else: + is_valid = False + message = response.json()["detail"] + + return is_valid, message + def register( self, email: str, password: str, password_confirm: str, validation_link: str - ) -> (bool, str): + ) -> tuple[bool, str]: """ Register a new user with the provided credentials. @@ -245,7 +275,8 @@ def register( response = self.httpx_client.post( self.server_endpoints.register.path, - params={"email": email, "password": password, "password_confirm": password_confirm, "validation_link": validation_link} + params={"email": email, "password": password, "password_confirm": password_confirm, + "validation_link": validation_link} ) self._validate_response(response, "register", only_version_check=True) @@ -258,7 +289,7 @@ def register( return is_created, message - def login(self, email: str, password: str) -> str | None: + def login(self, email: str, password: str) -> tuple[str, str]: """ Login with the provided credentials and return the access token if successful. @@ -271,6 +302,8 @@ def login(self, email: str, password: str) -> str | None: ------- access_token : str | None The access token returned from the server. Return None if login fails. + message : str + The message returned from the server. """ access_token = None @@ -279,11 +312,14 @@ def login(self, email: str, password: str) -> str | None: data=common_utils.to_oauth_request_form(email, password) ) - self._validate_response(response, "login", only_version_check=False) + self._validate_response(response, "login", only_version_check=True) if response.status_code == 200: access_token = response.json()["access_token"] + message = "" + else: + message = response.json()["detail"] - return access_token + return access_token, message def get_password_policy(self) -> {}: """ @@ -302,6 +338,27 @@ def get_password_policy(self) -> {}: return response.json()["requirements"] + def add_user_information( + self, company: str | None, role: str | None, use_case: str | None, contact_via_email: bool + ): + """ + Send additional user information to the server. + """ + information = {"contact_via_email": contact_via_email} + if company: + information["company"] = company + if role: + information["role"] = role + if use_case: + information["use_case"] = use_case + + response = self.httpx_client.post( + self.server_endpoints.add_user_information.path, + json=information + ) + + self._validate_response(response, "add_user_information") + def retrieve_greeting_messages(self) -> list[str]: """ Retrieve greeting messages that are new for the user. diff --git a/tabpfn_client/prompt_agent.py b/tabpfn_client/prompt_agent.py index b88de66..432fd34 100644 --- a/tabpfn_client/prompt_agent.py +++ b/tabpfn_client/prompt_agent.py @@ -1,5 +1,6 @@ import textwrap import getpass +from password_strength import PasswordPolicy class PromptAgent: @@ -9,6 +10,20 @@ def indent(text: str): indent_str = " " * indent_factor return textwrap.indent(text, indent_str) + @staticmethod + def password_req_to_policy(password_req: list[str]): + """ + Small function that receives password requirements as a list of + strings like "Length(8)" and returns a corresponding + PasswordPolicy object. + """ + requirements = {} + for req in password_req: + word_part, number_part = req.split('(') + number = int(number_part[:-1]) + requirements[word_part.lower()] = number + return PasswordPolicy.from_names(**requirements) + @classmethod def prompt_welcome(cls): prompt = "\n".join([ @@ -23,6 +38,7 @@ def prompt_welcome(cls): @classmethod def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"): + # Choose between registration and login prompt = "\n".join([ "Please choose one of the following options:", "(1) Create a TabPFN account", @@ -30,16 +46,23 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"): "", "Please enter your choice: ", ]) + choice = cls._choice_with_retries(prompt, ["1", "2"]) - choice = input(cls.indent(prompt)) - + # Registration if choice == "1": - #validation_link = input(cls.indent("Please enter your secret code: ")) + # validation_link = input(cls.indent("Please enter your secret code: ")) validation_link = "tabpfn-2023" - # create account - email = input(cls.indent("Please enter your email: ")) + while True: + email = input(cls.indent("Please enter your email: ")) + # Send request to server to check if email is valid and not already taken. + is_valid, message = user_auth_handler.validate_email(email) + if is_valid: + break + else: + print(cls.indent(message + "\n")) password_req = user_auth_handler.get_password_policy() + password_policy = cls.password_req_to_policy(password_req) password_req_prompt = "\n".join([ "", "Password requirements (minimum):", @@ -47,26 +70,39 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"): "", "Please enter your password: ", ]) - - password = getpass.getpass(cls.indent(password_req_prompt)) - password_confirm = getpass.getpass(cls.indent("Please confirm your password: ")) - - user_auth_handler.set_token_by_registration(email, password, password_confirm, validation_link) - + while True: + password = getpass.getpass(cls.indent(password_req_prompt)) + password_req_prompt = "Please enter your password: " + if len(password_policy.test(password)) != 0: + print(cls.indent("Password requirements not satisfied.\n")) + continue + + password_confirm = getpass.getpass(cls.indent("Please confirm your password: ")) + if password == password_confirm: + break + else: + print(cls.indent("Entered password and confirmation password do not match, please try again.\n")) + + is_created, message = user_auth_handler.set_token_by_registration( + email, password, password_confirm, validation_link) + if not is_created: + raise RuntimeError("User registration failed: " + message + "\n") + cls.prompt_add_user_information(user_auth_handler) print(cls.indent("Account created successfully!") + "\n") + # Login elif choice == "2": # login to account - email = input(cls.indent("Please enter your email: ")) - password = getpass.getpass(cls.indent("Please enter your password: ")) - - user_auth_handler.set_token_by_login(email, password) - + while True: + email = input(cls.indent("Please enter your email: ")) + password = getpass.getpass(cls.indent("Please enter your password: ")) + + successful, message = user_auth_handler.set_token_by_login(email, password) + if successful: + break + print(cls.indent("Login failed: " + message) + "\n") print(cls.indent("Login successful!") + "\n") - else: - raise RuntimeError("Invalid choice") - @classmethod def prompt_terms_and_cond(cls) -> bool: t_and_c = "\n".join([ @@ -74,22 +110,23 @@ def prompt_terms_and_cond(cls) -> bool: "By using TabPFN, you agree to the following terms and conditions:", "Do you agree to the above terms and conditions? (y/n): ", ]) + choice = cls._choice_with_retries(t_and_c, ["y", "n"]) + return choice == "y" - choice = input(cls.indent(t_and_c)) - - # retry for 3 attempts until valid choice is made - is_valid_choice = False - for _ in range(3): - if choice.lower() not in ["y", "n"]: - choice = input(cls.indent("Invalid choice, please enter 'y' or 'n': ")) - else: - is_valid_choice = True - break + @classmethod + def prompt_add_user_information(cls, user_auth_handler: "UserAuthenticationClient"): + print(cls.indent("To help us tailor our support and services to your needs, we have a few optional questions. " + "Feel free to skip any question by leaving it blank.") + "\n") + company = input(cls.indent("Where do you work? ")) + role = input(cls.indent("What is your role? ")) + use_case = input(cls.indent("What do you want to use TabPFN for? ")) - if not is_valid_choice: - raise RuntimeError("Invalid choice") + choice_contact = cls._choice_with_retries( + "Can we reach out to you via email to support you? (y/n):", ["y", "n"] + ) + contact_via_email = True if choice_contact == "y" else False - return choice.lower() == "y" + user_auth_handler.add_user_information(company, role, use_case, contact_via_email) @classmethod def prompt_reusing_existing_token(cls): @@ -115,3 +152,22 @@ def prompt_confirm_password_for_user_account_deletion(cls) -> str: @classmethod def prompt_account_deleted(cls): print(cls.indent("Your account has been deleted.")) + + @classmethod + def _choice_with_retries(cls, prompt: str, choices: list) -> str: + """ + Prompt text and give user infinitely many attempts to select one of the possible choices. If valid choice + is selected, return choice in lowercase. + """ + assert all(c.lower() == c for c in choices), "Choices need to be lower case." + choice = input(cls.indent(prompt)) + + # retry until valid choice is made + while True: + if choice.lower() not in choices: + choices_str = ", ".join(f"'{item}'" for item in choices[:-1]) + f" or '{choices[-1]}'" + choice = input(cls.indent(f"Invalid choice, please enter {choices_str}: ")) + else: + break + + return choice.lower() diff --git a/tabpfn_client/server_config.yaml b/tabpfn_client/server_config.yaml index 68efae2..ee05c1e 100644 --- a/tabpfn_client/server_config.yaml +++ b/tabpfn_client/server_config.yaml @@ -18,6 +18,11 @@ endpoints: methods: [ "GET" ] description: "Password policy" + validate_email: + path: "/auth/validate_email/" + methods: [ "POST" ] + description: "Validate email" + register: path: "/auth/register/" methods: [ "POST" ] @@ -33,6 +38,11 @@ endpoints: methods: [ "GET" ] description: "Retrieve new greeting messages" + add_user_information: + path: "/add_user_information/" + methods: [ "POST" ] + description: "Add additional user information to database" + protected_root: path: "/protected/" methods: [ "GET" ] diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py index e3d9e47..3773d17 100644 --- a/tabpfn_client/service_wrapper.py +++ b/tabpfn_client/service_wrapper.py @@ -30,30 +30,32 @@ def set_token(self, access_token: str): self.CACHED_TOKEN_FILE.parent.mkdir(parents=True, exist_ok=True) self.CACHED_TOKEN_FILE.write_text(access_token) + def validate_email(self, email: str) -> tuple[bool, str]: + is_valid, message = self.service_client.validate_email(email) + return is_valid, message + def set_token_by_registration( self, email: str, password: str, password_confirm: str, validation_link: str - ) -> None: - if password != password_confirm: - raise ValueError("Password and password_confirm must be the same.") + ) -> tuple[bool, str]: is_created, message = self.service_client.register(email, password, password_confirm, validation_link) - if not is_created: - raise RuntimeError(f"Failed to register user: {message}") - - # login after registration - self.set_token_by_login(email, password) + if is_created: + # login after registration + self.set_token_by_login(email, password) + return is_created, message - def set_token_by_login(self, email: str, password: str) -> None: - access_token = self.service_client.login(email, password) + def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]: + access_token, message = self.service_client.login(email, password) if access_token is None: - raise RuntimeError("Failed to login, please check your email and password.") + return False, message self.set_token(access_token) + return True, message def try_reuse_existing_token(self) -> bool: if self.service_client.access_token is None: @@ -78,6 +80,11 @@ def try_reuse_existing_token(self) -> bool: def get_password_policy(self): return self.service_client.get_password_policy() + def add_user_information( + self, company: str | None, role: str | None, use_case: str | None, contact_via_email: bool + ): + self.service_client.add_user_information(company, role, use_case, contact_via_email) + def reset_cache(self): self._reset_token() diff --git a/tabpfn_client/tests/unit/test_client.py b/tabpfn_client/tests/unit/test_client.py index 0f21738..ab2a850 100644 --- a/tabpfn_client/tests/unit/test_client.py +++ b/tabpfn_client/tests/unit/test_client.py @@ -36,6 +36,17 @@ def test_try_connection_with_outdated_client_raises_runtime_error(self, mock_ser self.client.try_connection() self.assertTrue(str(cm.exception).startswith("Client version too old.")) + @with_mock_server() + def test_validate_email(self, mock_server): + mock_server.router.post(mock_server.endpoints.validate_email.path).respond(200, json={"message": "dummy_message"}) + self.assertTrue(self.client.validate_email("dummy_email")[0]) + + @with_mock_server() + def test_validate_email_invalid(self, mock_server): + mock_server.router.post(mock_server.endpoints.validate_email.path).respond(401, json={"detail": "dummy_message"}) + self.assertFalse(self.client.validate_email("dummy_email")[0]) + self.assertEqual("dummy_message", self.client.validate_email("dummy_email")[1]) + @with_mock_server() def test_register_user(self, mock_server): mock_server.router.post(mock_server.endpoints.register.path).respond(200, json={"message": "dummy_message"}) @@ -90,6 +101,18 @@ def test_predict_with_valid_train_set_and_test_set(self, mock_server): ) self.assertTrue(np.array_equal(pred, dummy_result["y_pred"])) + @with_mock_server() + def test_add_user_information(self, mock_server): + mock_server.router.post(mock_server.endpoints.add_user_information.path).respond(200) + self.assertIsNone(self.client.add_user_information( + "company", "dev", "", True)) + + @with_mock_server() + def test_add_user_information_raises_runtime_error(self, mock_server): + mock_server.router.post(mock_server.endpoints.add_user_information.path).respond(500) + with self.assertRaises(RuntimeError): + self.client.add_user_information("company", "dev", "", True) + def test_validate_response_no_error(self): response = Mock() response.status_code = 200 diff --git a/tabpfn_client/tests/unit/test_prompt_agent.py b/tabpfn_client/tests/unit/test_prompt_agent.py new file mode 100644 index 0000000..14b2b7b --- /dev/null +++ b/tabpfn_client/tests/unit/test_prompt_agent.py @@ -0,0 +1,54 @@ +import unittest +from unittest.mock import patch, MagicMock +import respx +from httpx import Response +from tabpfn_client.prompt_agent import PromptAgent +from tabpfn_client.tests.mock_tabpfn_server import with_mock_server +from tabpfn_client.service_wrapper import UserAuthenticationClient, ServiceClient + + +class TestPromptAgent(unittest.TestCase): + def test_password_req_to_policy(self): + password_req = ["Length(8)", "Uppercase(1)", "Numbers(1)", "Special(1)"] + password_policy = PromptAgent.password_req_to_policy(password_req) + requirements = [repr(req) for req in password_policy.test("")] + self.assertEqual(password_req, requirements) + + @with_mock_server() + @patch('getpass.getpass', side_effect=['Password123!', 'Password123!']) + @patch('builtins.input', side_effect=['1', 'user@example.com', 'test', 'test', 'test', 'y']) + def test_prompt_and_set_token_registration(self, mock_input, mock_getpass, mock_server): + mock_auth_client = MagicMock() + mock_auth_client.get_password_policy.return_value = ['Length(8)', 'Uppercase(1)', 'Numbers(1)', 'Special(1)'] + mock_auth_client.set_token_by_registration.return_value = (True, 'Registration successful') + mock_auth_client.validate_email.return_value = (True, '') + PromptAgent.prompt_and_set_token(user_auth_handler=mock_auth_client) + mock_auth_client.set_token_by_registration.assert_called_once() + + @patch('getpass.getpass', side_effect = ['password123']) + @patch('builtins.input', side_effect=['2', 'user@example.com']) + def test_prompt_and_set_token_login(self, mock_input, mock_getpass): + mock_auth_client = MagicMock() + mock_auth_client.set_token_by_login.return_value = (True, 'Login successful') + PromptAgent.prompt_and_set_token(user_auth_handler=mock_auth_client) + mock_auth_client.set_token_by_login.assert_called_once() + + @patch('builtins.input', return_value='y') + def test_prompt_terms_and_cond_returns_true(self, mock_input): + result = PromptAgent.prompt_terms_and_cond() + self.assertTrue(result) + + @patch('builtins.input', return_value='n') + def test_prompt_terms_and_cond_returns_false(self, mock_input): + result = PromptAgent.prompt_terms_and_cond() + self.assertFalse(result) + + @patch('builtins.input', return_value='1') + def test_choice_with_retries_valid_first_try(self, mock_input): + result = PromptAgent._choice_with_retries("Please enter your choice: ", ["1", "2"]) + self.assertEqual(result, '1') + + @patch('builtins.input', side_effect=['3', '3', '1']) + def test_choice_with_retries_valid_third_try(self, mock_input): + result = PromptAgent._choice_with_retries("Please enter your choice: ", ["1", "2"]) + self.assertEqual(result, '1') diff --git a/tabpfn_client/tests/unit/test_service_wrapper.py b/tabpfn_client/tests/unit/test_service_wrapper.py index 72ea1fa..66aa76a 100644 --- a/tabpfn_client/tests/unit/test_service_wrapper.py +++ b/tabpfn_client/tests/unit/test_service_wrapper.py @@ -29,8 +29,8 @@ def test_set_token_by_valid_login(self, mock_server): json={"access_token": dummy_token} ) - # assert no exception is raised - UserAuthenticationClient(ServiceClient()).set_token_by_login("dummy_email", "dummy_password") + self.assertTrue(UserAuthenticationClient(ServiceClient()).set_token_by_login( + "dummy_email", "dummy_password")[0]) # assert token is set self.assertEqual(dummy_token, ServiceClient().access_token) @@ -38,13 +38,11 @@ def test_set_token_by_valid_login(self, mock_server): @with_mock_server() def test_set_token_by_invalid_login(self, mock_server): # mock invalid login response - mock_server.router.post(mock_server.endpoints.login.path).respond(400) - - # assert exception is raised - self.assertRaises( - RuntimeError, - UserAuthenticationClient(ServiceClient()).set_token_by_login, - "dummy_email", "dummy_password" + mock_server.router.post(mock_server.endpoints.login.path).respond(401, json={ + "detail": "Incorrect email or password"}) + self.assertEqual( + (False, "Incorrect email or password"), + UserAuthenticationClient(ServiceClient()).set_token_by_login("dummy_email", "dummy_password") ) # assert token is not set @@ -87,9 +85,10 @@ def test_set_token_by_valid_registration(self, mock_server): json={"access_token": dummy_token} ) - # assert no exception is raised - UserAuthenticationClient(ServiceClient()).set_token_by_registration( - "dummy_email", "dummy_password", "dummy_password", "dummy_validation" + self.assertTrue( + UserAuthenticationClient(ServiceClient()).set_token_by_registration( + "dummy_email", "dummy_password", "dummy_password", "dummy_validation" + )[0] ) # assert token is set @@ -98,16 +97,13 @@ def test_set_token_by_valid_registration(self, mock_server): @with_mock_server() def test_set_token_by_invalid_registration(self, mock_server): # mock invalid registration response - mock_server.router.post(mock_server.endpoints.register.path).respond( - 400, - json={"detail": "doesn't matter"} - ) - - # assert exception is raised - self.assertRaises( - RuntimeError, - UserAuthenticationClient(ServiceClient()).set_token_by_registration, - "dummy_email", "dummy_password", "dummy_password", "dummy_validation" + mock_server.router.post(mock_server.endpoints.register.path).respond(401, json={ + "detail": "Password mismatch"}) + self.assertEqual( + (False, "Password mismatch"), + UserAuthenticationClient(ServiceClient()).set_token_by_registration( + "dummy_email", "dummy_password", "dummy_password", + "dummy_validation") ) # assert token is not set