Skip to content

Commit

Permalink
Update user registration (#17)
Browse files Browse the repository at this point in the history
* Allow retry during registration and login

* Ask user additional questions during registration

* Fix errors and improve prompt agent

* Add tests for adding additional user information

* Simplify code of prompt agent and update tests

* Add unit tests for prompt agent

* Make small changes in test_prompt_agent.py

* Add seperate checks and retry loops during registration

* Only send additional user information that are non-empty
  • Loading branch information
davidotte authored Mar 31, 2024
1 parent 87fa78f commit 7487032
Show file tree
Hide file tree
Showing 8 changed files with 274 additions and 70 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ httpx
omegaconf
pandas
tabpfn
password-strength

# for testing
respx
67 changes: 62 additions & 5 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,43 @@ def try_authenticate(self, access_token) -> bool:

return is_authenticated

def validate_email(self, email: str) -> tuple[bool, str]:
"""
Send entered email to server that checks if it is valid and not already in use.
Parameters
----------
email : str
Returns
-------
is_valid : bool
True if the email is valid.
message : str
The message returned from the server.
"""
response = self.httpx_client.post(
self.server_endpoints.validate_email.path,
params={"email": email}
)

self._validate_response(response, "validate_email", only_version_check=True)
if response.status_code == 200:
is_valid = True
message = ""
else:
is_valid = False
message = response.json()["detail"]

return is_valid, message

def register(
self,
email: str,
password: str,
password_confirm: str,
validation_link: str
) -> (bool, str):
) -> tuple[bool, str]:
"""
Register a new user with the provided credentials.
Expand All @@ -245,7 +275,8 @@ def register(

response = self.httpx_client.post(
self.server_endpoints.register.path,
params={"email": email, "password": password, "password_confirm": password_confirm, "validation_link": validation_link}
params={"email": email, "password": password, "password_confirm": password_confirm,
"validation_link": validation_link}
)

self._validate_response(response, "register", only_version_check=True)
Expand All @@ -258,7 +289,7 @@ def register(

return is_created, message

def login(self, email: str, password: str) -> str | None:
def login(self, email: str, password: str) -> tuple[str, str]:
"""
Login with the provided credentials and return the access token if successful.
Expand All @@ -271,6 +302,8 @@ def login(self, email: str, password: str) -> str | None:
-------
access_token : str | None
The access token returned from the server. Return None if login fails.
message : str
The message returned from the server.
"""

access_token = None
Expand All @@ -279,11 +312,14 @@ def login(self, email: str, password: str) -> str | None:
data=common_utils.to_oauth_request_form(email, password)
)

self._validate_response(response, "login", only_version_check=False)
self._validate_response(response, "login", only_version_check=True)
if response.status_code == 200:
access_token = response.json()["access_token"]
message = ""
else:
message = response.json()["detail"]

return access_token
return access_token, message

def get_password_policy(self) -> {}:
"""
Expand All @@ -302,6 +338,27 @@ def get_password_policy(self) -> {}:

return response.json()["requirements"]

def add_user_information(
self, company: str | None, role: str | None, use_case: str | None, contact_via_email: bool
):
"""
Send additional user information to the server.
"""
information = {"contact_via_email": contact_via_email}
if company:
information["company"] = company
if role:
information["role"] = role
if use_case:
information["use_case"] = use_case

response = self.httpx_client.post(
self.server_endpoints.add_user_information.path,
json=information
)

self._validate_response(response, "add_user_information")

def retrieve_greeting_messages(self) -> list[str]:
"""
Retrieve greeting messages that are new for the user.
Expand Down
120 changes: 88 additions & 32 deletions tabpfn_client/prompt_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import textwrap
import getpass
from password_strength import PasswordPolicy


class PromptAgent:
Expand All @@ -9,6 +10,20 @@ def indent(text: str):
indent_str = " " * indent_factor
return textwrap.indent(text, indent_str)

@staticmethod
def password_req_to_policy(password_req: list[str]):
"""
Small function that receives password requirements as a list of
strings like "Length(8)" and returns a corresponding
PasswordPolicy object.
"""
requirements = {}
for req in password_req:
word_part, number_part = req.split('(')
number = int(number_part[:-1])
requirements[word_part.lower()] = number
return PasswordPolicy.from_names(**requirements)

@classmethod
def prompt_welcome(cls):
prompt = "\n".join([
Expand All @@ -23,73 +38,95 @@ def prompt_welcome(cls):

@classmethod
def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"):
# Choose between registration and login
prompt = "\n".join([
"Please choose one of the following options:",
"(1) Create a TabPFN account",
"(2) Login to your TabPFN account",
"",
"Please enter your choice: ",
])
choice = cls._choice_with_retries(prompt, ["1", "2"])

choice = input(cls.indent(prompt))

# Registration
if choice == "1":
#validation_link = input(cls.indent("Please enter your secret code: "))
# validation_link = input(cls.indent("Please enter your secret code: "))
validation_link = "tabpfn-2023"
# create account
email = input(cls.indent("Please enter your email: "))
while True:
email = input(cls.indent("Please enter your email: "))
# Send request to server to check if email is valid and not already taken.
is_valid, message = user_auth_handler.validate_email(email)
if is_valid:
break
else:
print(cls.indent(message + "\n"))

password_req = user_auth_handler.get_password_policy()
password_policy = cls.password_req_to_policy(password_req)
password_req_prompt = "\n".join([
"",
"Password requirements (minimum):",
"\n".join([f". {req}" for req in password_req]),
"",
"Please enter your password: ",
])

password = getpass.getpass(cls.indent(password_req_prompt))
password_confirm = getpass.getpass(cls.indent("Please confirm your password: "))

user_auth_handler.set_token_by_registration(email, password, password_confirm, validation_link)

while True:
password = getpass.getpass(cls.indent(password_req_prompt))
password_req_prompt = "Please enter your password: "
if len(password_policy.test(password)) != 0:
print(cls.indent("Password requirements not satisfied.\n"))
continue

password_confirm = getpass.getpass(cls.indent("Please confirm your password: "))
if password == password_confirm:
break
else:
print(cls.indent("Entered password and confirmation password do not match, please try again.\n"))

is_created, message = user_auth_handler.set_token_by_registration(
email, password, password_confirm, validation_link)
if not is_created:
raise RuntimeError("User registration failed: " + message + "\n")
cls.prompt_add_user_information(user_auth_handler)
print(cls.indent("Account created successfully!") + "\n")

# Login
elif choice == "2":
# login to account
email = input(cls.indent("Please enter your email: "))
password = getpass.getpass(cls.indent("Please enter your password: "))

user_auth_handler.set_token_by_login(email, password)

while True:
email = input(cls.indent("Please enter your email: "))
password = getpass.getpass(cls.indent("Please enter your password: "))

successful, message = user_auth_handler.set_token_by_login(email, password)
if successful:
break
print(cls.indent("Login failed: " + message) + "\n")
print(cls.indent("Login successful!") + "\n")

else:
raise RuntimeError("Invalid choice")

@classmethod
def prompt_terms_and_cond(cls) -> bool:
t_and_c = "\n".join([
"Please refer to our terms and conditions at: https://www.priorlabs.ai/terms-eu-en "
"By using TabPFN, you agree to the following terms and conditions:",
"Do you agree to the above terms and conditions? (y/n): ",
])
choice = cls._choice_with_retries(t_and_c, ["y", "n"])
return choice == "y"

choice = input(cls.indent(t_and_c))

# retry for 3 attempts until valid choice is made
is_valid_choice = False
for _ in range(3):
if choice.lower() not in ["y", "n"]:
choice = input(cls.indent("Invalid choice, please enter 'y' or 'n': "))
else:
is_valid_choice = True
break
@classmethod
def prompt_add_user_information(cls, user_auth_handler: "UserAuthenticationClient"):
print(cls.indent("To help us tailor our support and services to your needs, we have a few optional questions. "
"Feel free to skip any question by leaving it blank.") + "\n")
company = input(cls.indent("Where do you work? "))
role = input(cls.indent("What is your role? "))
use_case = input(cls.indent("What do you want to use TabPFN for? "))

if not is_valid_choice:
raise RuntimeError("Invalid choice")
choice_contact = cls._choice_with_retries(
"Can we reach out to you via email to support you? (y/n):", ["y", "n"]
)
contact_via_email = True if choice_contact == "y" else False

return choice.lower() == "y"
user_auth_handler.add_user_information(company, role, use_case, contact_via_email)

@classmethod
def prompt_reusing_existing_token(cls):
Expand All @@ -115,3 +152,22 @@ def prompt_confirm_password_for_user_account_deletion(cls) -> str:
@classmethod
def prompt_account_deleted(cls):
print(cls.indent("Your account has been deleted."))

@classmethod
def _choice_with_retries(cls, prompt: str, choices: list) -> str:
"""
Prompt text and give user infinitely many attempts to select one of the possible choices. If valid choice
is selected, return choice in lowercase.
"""
assert all(c.lower() == c for c in choices), "Choices need to be lower case."
choice = input(cls.indent(prompt))

# retry until valid choice is made
while True:
if choice.lower() not in choices:
choices_str = ", ".join(f"'{item}'" for item in choices[:-1]) + f" or '{choices[-1]}'"
choice = input(cls.indent(f"Invalid choice, please enter {choices_str}: "))
else:
break

return choice.lower()
10 changes: 10 additions & 0 deletions tabpfn_client/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ endpoints:
methods: [ "GET" ]
description: "Password policy"

validate_email:
path: "/auth/validate_email/"
methods: [ "POST" ]
description: "Validate email"

register:
path: "/auth/register/"
methods: [ "POST" ]
Expand All @@ -33,6 +38,11 @@ endpoints:
methods: [ "GET" ]
description: "Retrieve new greeting messages"

add_user_information:
path: "/add_user_information/"
methods: [ "POST" ]
description: "Add additional user information to database"

protected_root:
path: "/protected/"
methods: [ "GET" ]
Expand Down
29 changes: 18 additions & 11 deletions tabpfn_client/service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,32 @@ def set_token(self, access_token: str):
self.CACHED_TOKEN_FILE.parent.mkdir(parents=True, exist_ok=True)
self.CACHED_TOKEN_FILE.write_text(access_token)

def validate_email(self, email: str) -> tuple[bool, str]:
is_valid, message = self.service_client.validate_email(email)
return is_valid, message

def set_token_by_registration(
self,
email: str,
password: str,
password_confirm: str,
validation_link: str
) -> None:
if password != password_confirm:
raise ValueError("Password and password_confirm must be the same.")
) -> tuple[bool, str]:

is_created, message = self.service_client.register(email, password, password_confirm, validation_link)
if not is_created:
raise RuntimeError(f"Failed to register user: {message}")

# login after registration
self.set_token_by_login(email, password)
if is_created:
# login after registration
self.set_token_by_login(email, password)
return is_created, message

def set_token_by_login(self, email: str, password: str) -> None:
access_token = self.service_client.login(email, password)
def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]:
access_token, message = self.service_client.login(email, password)

if access_token is None:
raise RuntimeError("Failed to login, please check your email and password.")
return False, message

self.set_token(access_token)
return True, message

def try_reuse_existing_token(self) -> bool:
if self.service_client.access_token is None:
Expand All @@ -78,6 +80,11 @@ def try_reuse_existing_token(self) -> bool:
def get_password_policy(self):
return self.service_client.get_password_policy()

def add_user_information(
self, company: str | None, role: str | None, use_case: str | None, contact_via_email: bool
):
self.service_client.add_user_information(company, role, use_case, contact_via_email)

def reset_cache(self):
self._reset_token()

Expand Down
Loading

0 comments on commit 7487032

Please sign in to comment.