From e5220d4bee3ae87d18784ea5eed942d6cccdfad9 Mon Sep 17 00:00:00 2001 From: Arkaprava De Date: Tue, 9 Apr 2024 18:56:22 +0000 Subject: [PATCH] Updated code to support ssh into sagemaker space apps - JupyterLab and CodeEditor 1. Added Dockerfile to build custom image 2. Updated script to ssh into sagemaker space apps --- sagemaker_ssh_helper/ide.py | 24 ++++++++++++- sagemaker_ssh_helper/manager.py | 13 +++++++ sagemaker_ssh_helper/sm-connect-ssh-proxy | 18 +++++----- sagemaker_ssh_helper/sm-helper-functions | 10 ++++-- sagemaker_ssh_helper/sm-init-ssm | 6 +++- sagemaker_ssh_helper/sm-local-ssh-ide | 29 +++++++++++++-- sagemaker_ssh_helper/sm-ssh-ide | 33 ++++++++++------- .../Dockerfile.codeeditor.internet_free | 35 +++++++++++++++++++ 8 files changed, 141 insertions(+), 27 deletions(-) create mode 100644 tests/byoi_studio/Dockerfile.codeeditor.internet_free diff --git a/sagemaker_ssh_helper/ide.py b/sagemaker_ssh_helper/ide.py index 08e57fe..eb782d8 100644 --- a/sagemaker_ssh_helper/ide.py +++ b/sagemaker_ssh_helper/ide.py @@ -46,9 +46,10 @@ def __init__(self, arn, version_arn) -> None: class SSHIDE: logger = logging.getLogger('sagemaker-ssh-helper:SSHIDE') - def __init__(self, domain_id: str, user: str, region_name: str = None): + def __init__(self, domain_id: str, user: str, region_name: str = None, space: str = None): self.user = user self.domain_id = domain_id + self.space = space self.current_region = region_name or boto3.session.Session().region_name self.client = boto3.client('sagemaker', region_name=self.current_region) self.ssh_log = SSHLog(region_name=self.current_region) @@ -201,6 +202,9 @@ def resolve_sagemaker_kernel_image_arn(self, image_name): def print_kernel_instance_id(self, app_name, timeout_in_sec, index: int = 0): print(self.get_kernel_instance_ids(app_name, timeout_in_sec)[index]) + def print_space_instance_id(self, app_name, timeout_in_sec, index: int = 0): + print(self.get_space_instance_ids(app_name, timeout_in_sec)[index]) + def get_kernel_instance_ids(self, app_name, timeout_in_sec): self.logger.info("Resolving IDE instance IDs through SSM tags") self.log_urls(app_name) @@ -218,6 +222,24 @@ def get_kernel_instance_ids(self, app_name, timeout_in_sec): result = SSMManager().get_studio_kgw_instance_ids(app_name, timeout_in_sec) return result + def get_space_instance_ids(self, app_name, timeout_in_sec): + self.logger.info("Resolving IDE instance IDs through SSM tags") + self.log_urls(app_name) + if self.domain_id and self.space: + result = SSMManager().get_studio_space_app_instance_ids(self.domain_id, self.space, app_name, timeout_in_sec) + elif self.space: + self.logger.warning(f"Domain ID is not set. Will attempt to connect to the latest " + f"active kernel gateway with the name {app_name} in the region {self.current_region} " + f"for space {self.space}") + result = SSMManager().get_studio_space_app_instance_ids("", self.space, app_name, + timeout_in_sec) + else: + self.logger.warning(f"Domain ID or space name are not set. Will attempt to connect to the latest " + f"active kernel gateway with the name {app_name} in the region {self.current_region}") + result = SSMManager().get_studio_app_instance_ids(app_name, timeout_in_sec) + return result + + def log_urls(self, app_name): self.logger.info(f"Remote logs are at {self.get_cloudwatch_url(app_name)}") if self.domain_id and self.user: diff --git a/sagemaker_ssh_helper/manager.py b/sagemaker_ssh_helper/manager.py index bf52be7..838e12c 100644 --- a/sagemaker_ssh_helper/manager.py +++ b/sagemaker_ssh_helper/manager.py @@ -113,10 +113,23 @@ def get_studio_user_kgw_instance_ids(self, domain_id, user_profile_name, kgw_nam return self.get_instance_ids('app', f"{kgw_name}", timeout_in_sec, arn_filter_regex=arn_filter) + def get_studio_space_app_instance_ids(self, domain_id, space_name, app_name, timeout_in_sec=0): + self.logger.info(f"Querying SSM instance IDs for SageMaker Studio space {app_name}") + if not domain_id: + arn_filter = f":app/.*/{space_name}/" + else: + arn_filter = f":app/{domain_id}/{space_name}/" + return self.get_instance_ids('app', f"{app_name}", timeout_in_sec, + arn_filter_regex=arn_filter) + def get_studio_kgw_instance_ids(self, kgw_name, timeout_in_sec=0): self.logger.info(f"Querying SSM instance IDs for SageMaker Studio kernel gateway {kgw_name}") return self.get_instance_ids('app', f"{kgw_name}", timeout_in_sec) + def get_studio_app_instance_ids(self, app_name, timeout_in_sec=0): + self.logger.info(f"Querying SSM instance IDs for SageMaker Studio space {app_name}") + return self.get_instance_ids('app', f"{app_name}", timeout_in_sec) + def get_notebook_instance_ids(self, instance_name, timeout_in_sec=0): self.logger.info(f"Querying SSM instance IDs for SageMaker notebook instance {instance_name}") return self.get_instance_ids('notebook-instance', f"{instance_name}", timeout_in_sec) diff --git a/sagemaker_ssh_helper/sm-connect-ssh-proxy b/sagemaker_ssh_helper/sm-connect-ssh-proxy index 32f634d..8b986e3 100644 --- a/sagemaker_ssh_helper/sm-connect-ssh-proxy +++ b/sagemaker_ssh_helper/sm-connect-ssh-proxy @@ -57,7 +57,7 @@ send_command=$(aws ssm send-command \ 'cat /etc/ssh/authorized_keys.d/* > /etc/ssh/authorized_keys', 'ls -la /etc/ssh/authorized_keys' ]" \ - --no-cli-pager --no-paginate \ + --no-paginate \ --output json) json_value_regexp='s/^[^"]*".*": \"\(.*\)\"[^"]*/\1/' @@ -75,7 +75,7 @@ for i in $(seq 1 15); do command_output=$(aws ssm get-command-invocation \ --instance-id "${INSTANCE_ID}" \ --command-id "${command_id}" \ - --no-cli-pager --no-paginate \ + --no-paginate \ --output json) command_output=$(echo "$command_output" | python -m json.tool) command_status=$(echo "$command_output" | grep '"Status":' | sed -e "$json_value_regexp") @@ -110,10 +110,10 @@ proxy_command="aws ssm start-session\ --document-name AWS-StartSSHSession\ --parameters portNumber=%p" -# shellcheck disable=SC2086 -ssh -4 -o User=root -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \ - -o ProxyCommand="$proxy_command" \ - -o ServerAliveInterval=15 -o ServerAliveCountMax=3 \ - -o PasswordAuthentication=no \ - -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ - $PORT_FWD_ARGS "$INSTANCE_ID" + # shellcheck disable=SC2086 + ssh -4 -o User=sagemaker-user -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \ + -o ProxyCommand="$proxy_command" \ + -o ServerAliveInterval=15 -o ServerAliveCountMax=3 \ + -o PasswordAuthentication=no \ + -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ + $PORT_FWD_ARGS "$INSTANCE_ID" diff --git a/sagemaker_ssh_helper/sm-helper-functions b/sagemaker_ssh_helper/sm-helper-functions index 7743e93..ddc194b 100644 --- a/sagemaker_ssh_helper/sm-helper-functions +++ b/sagemaker_ssh_helper/sm-helper-functions @@ -170,7 +170,13 @@ function _print_sm_domain_id() { # shellcheck disable=SC2001 function _print_sm_user_profile_name() { sm_resource_metadata_json=$(tr -d "\n" < /opt/ml/metadata/resource-metadata.json) - echo -n "$sm_resource_metadata_json" | sed -e 's/^.*"UserProfileName":\"\([^"]*\)\".*$/\1/' + echo -n "$sm_resource_metadata_json" | jq -r '.UserProfileName' +} + +# shellcheck disable=SC2001 +function _print_sm_space_name() { + sm_resource_metadata_json=$(tr -d "\n" < /opt/ml/metadata/resource-metadata.json) + echo -n "$sm_resource_metadata_json" | jq -r '.SpaceName' } function _print_sm_studio_python() { @@ -216,6 +222,6 @@ function _start_sshd() { if _is_centos; then /usr/sbin/sshd else - service ssh start || (echo "ERROR: Failed to start sshd service" && exit 255) + sudo service ssh start || (echo "ERROR: Failed to start sshd service" && exit 255) fi } diff --git a/sagemaker_ssh_helper/sm-init-ssm b/sagemaker_ssh_helper/sm-init-ssm index 74fc43a..8a0b5ca 100644 --- a/sagemaker_ssh_helper/sm-init-ssm +++ b/sagemaker_ssh_helper/sm-init-ssm @@ -53,4 +53,8 @@ response=$(aws ssm create-activation \ acode=$(echo $response | jq --raw-output '.ActivationCode') aid=$(echo $response | jq --raw-output '.ActivationId') -echo Yes | amazon-ssm-agent -register -id "$aid" -code "$acode" -region "$CURRENT_REGION" +if [[ -n $(_print_sm_user_profile_name) && $(_print_sm_user_profile_name) != "null" ]]; then + echo Yes | amazon-ssm-agent -register -id "$aid" -code "$acode" -region "$CURRENT_REGION" +else + echo Yes | sudo amazon-ssm-agent -register -id "$aid" -code "$acode" -region "$CURRENT_REGION" +fi diff --git a/sagemaker_ssh_helper/sm-local-ssh-ide b/sagemaker_ssh_helper/sm-local-ssh-ide index 8f24d74..0a08fbc 100644 --- a/sagemaker_ssh_helper/sm-local-ssh-ide +++ b/sagemaker_ssh_helper/sm-local-ssh-ide @@ -42,13 +42,29 @@ if [[ "$COMMAND" == "connect" ]]; then echo "sm-local-ssh-ide: WARNING: SageMaker Studio user profile name is not set."\ "Run 'sm-local-ssh-ide set-user-profile-name' to override." fi + SPACE_NAME="" + if [ -f ~/.sm-studio-space-name ]; then + SPACE_NAME="$(cat ~/.sm-studio-space-name)" + else + echo "sm-local-ssh-ide: WARNING: SageMaker Studio space name is not set."\ + "Run 'sm-local-ssh-ide set-space-name' to override." + fi - INSTANCE_ID=$(python < ~/.sm-studio-user-profile-name +elif [[ "$COMMAND" == "set-space-name" ]]; then + SPACE_NAME="$(echo "$2" | tr '[:upper:]' '[:lower:]')" + if [[ "$SPACE_NAME" == "" ]]; then + echo "sm-local-ssh-ide: ERROR: argument is expected" + exit 1 + fi + echo "sm-local-ssh-ide: Saving SageMaker Studio user profile name into ~/.sm-studio-space-name" + echo "$SPACE_NAME" > ~/.sm-studio-space-name + elif [[ "$COMMAND" == "run-command" ]]; then shift @@ -116,4 +141,4 @@ else echo "sm-local-ssh-ide connect $*" # shellcheck disable=SC2048 $0 connect $* -fi \ No newline at end of file +fi diff --git a/sagemaker_ssh_helper/sm-ssh-ide b/sagemaker_ssh_helper/sm-ssh-ide index 7e69744..0e08926 100644 --- a/sagemaker_ssh_helper/sm-ssh-ide +++ b/sagemaker_ssh_helper/sm-ssh-ide @@ -31,10 +31,10 @@ if [[ "$1" == "configure" ]]; then cat >/etc/profile.d/sm-ssh-ide.sh <