Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated code to support ssh into sagemaker space apps - JupyterLab and CodeEditor #57

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion sagemaker_ssh_helper/ide.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ def __init__(self, arn, version_arn) -> None:
class SSHIDE:
logger = logging.getLogger('sagemaker-ssh-helper:SSHIDE')

def __init__(self, domain_id: str, user: str, region_name: str = None):
def __init__(self, domain_id: str, user: str, region_name: str = None, space: str = None):
self.user = user
self.domain_id = domain_id
self.space = space
self.current_region = region_name or boto3.session.Session().region_name
self.client = boto3.client('sagemaker', region_name=self.current_region)
self.ssh_log = SSHLog(region_name=self.current_region)
Expand Down Expand Up @@ -201,6 +202,9 @@ def resolve_sagemaker_kernel_image_arn(self, image_name):
def print_kernel_instance_id(self, app_name, timeout_in_sec, index: int = 0):
print(self.get_kernel_instance_ids(app_name, timeout_in_sec)[index])

def print_space_instance_id(self, app_name, timeout_in_sec, index: int = 0):
print(self.get_space_instance_ids(app_name, timeout_in_sec)[index])

def get_kernel_instance_ids(self, app_name, timeout_in_sec):
self.logger.info("Resolving IDE instance IDs through SSM tags")
self.log_urls(app_name)
Expand All @@ -218,6 +222,24 @@ def get_kernel_instance_ids(self, app_name, timeout_in_sec):
result = SSMManager().get_studio_kgw_instance_ids(app_name, timeout_in_sec)
return result

def get_space_instance_ids(self, app_name, timeout_in_sec):
self.logger.info("Resolving IDE instance IDs through SSM tags")
self.log_urls(app_name)
if self.domain_id and self.space:
result = SSMManager().get_studio_space_app_instance_ids(self.domain_id, self.space, app_name, timeout_in_sec)
elif self.space:
self.logger.warning(f"Domain ID is not set. Will attempt to connect to the latest "
f"active kernel gateway with the name {app_name} in the region {self.current_region} "
f"for space {self.space}")
result = SSMManager().get_studio_space_app_instance_ids("", self.space, app_name,
timeout_in_sec)
else:
self.logger.warning(f"Domain ID or space name are not set. Will attempt to connect to the latest "
f"active kernel gateway with the name {app_name} in the region {self.current_region}")
result = SSMManager().get_studio_app_instance_ids(app_name, timeout_in_sec)
return result


def log_urls(self, app_name):
self.logger.info(f"Remote logs are at {self.get_cloudwatch_url(app_name)}")
if self.domain_id and self.user:
Expand Down
13 changes: 13 additions & 0 deletions sagemaker_ssh_helper/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,23 @@ def get_studio_user_kgw_instance_ids(self, domain_id, user_profile_name, kgw_nam
return self.get_instance_ids('app', f"{kgw_name}", timeout_in_sec,
arn_filter_regex=arn_filter)

def get_studio_space_app_instance_ids(self, domain_id, space_name, app_name, timeout_in_sec=0):
self.logger.info(f"Querying SSM instance IDs for SageMaker Studio space {app_name}")
if not domain_id:
arn_filter = f":app/.*/{space_name}/"
else:
arn_filter = f":app/{domain_id}/{space_name}/"
return self.get_instance_ids('app', f"{app_name}", timeout_in_sec,
arn_filter_regex=arn_filter)

def get_studio_kgw_instance_ids(self, kgw_name, timeout_in_sec=0):
self.logger.info(f"Querying SSM instance IDs for SageMaker Studio kernel gateway {kgw_name}")
return self.get_instance_ids('app', f"{kgw_name}", timeout_in_sec)

def get_studio_app_instance_ids(self, app_name, timeout_in_sec=0):
self.logger.info(f"Querying SSM instance IDs for SageMaker Studio space {app_name}")
return self.get_instance_ids('app', f"{app_name}", timeout_in_sec)

def get_notebook_instance_ids(self, instance_name, timeout_in_sec=0):
self.logger.info(f"Querying SSM instance IDs for SageMaker notebook instance {instance_name}")
return self.get_instance_ids('notebook-instance', f"{instance_name}", timeout_in_sec)
Expand Down
18 changes: 9 additions & 9 deletions sagemaker_ssh_helper/sm-connect-ssh-proxy
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ send_command=$(aws ssm send-command \
'cat /etc/ssh/authorized_keys.d/* > /etc/ssh/authorized_keys',
'ls -la /etc/ssh/authorized_keys'
]" \
--no-cli-pager --no-paginate \
--no-paginate \
--output json)

json_value_regexp='s/^[^"]*".*": \"\(.*\)\"[^"]*/\1/'
Expand All @@ -75,7 +75,7 @@ for i in $(seq 1 15); do
command_output=$(aws ssm get-command-invocation \
--instance-id "${INSTANCE_ID}" \
--command-id "${command_id}" \
--no-cli-pager --no-paginate \
--no-paginate \
--output json)
command_output=$(echo "$command_output" | python -m json.tool)
command_status=$(echo "$command_output" | grep '"Status":' | sed -e "$json_value_regexp")
Expand Down Expand Up @@ -110,10 +110,10 @@ proxy_command="aws ssm start-session\
--document-name AWS-StartSSHSession\
--parameters portNumber=%p"

# shellcheck disable=SC2086
ssh -4 -o User=root -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \
-o ProxyCommand="$proxy_command" \
-o ServerAliveInterval=15 -o ServerAliveCountMax=3 \
-o PasswordAuthentication=no \
-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
$PORT_FWD_ARGS "$INSTANCE_ID"
# shellcheck disable=SC2086
ssh -4 -o User=sagemaker-user -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \
-o ProxyCommand="$proxy_command" \
-o ServerAliveInterval=15 -o ServerAliveCountMax=3 \
-o PasswordAuthentication=no \
-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
$PORT_FWD_ARGS "$INSTANCE_ID"
10 changes: 8 additions & 2 deletions sagemaker_ssh_helper/sm-helper-functions
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,13 @@ function _print_sm_domain_id() {
# shellcheck disable=SC2001
function _print_sm_user_profile_name() {
sm_resource_metadata_json=$(tr -d "\n" < /opt/ml/metadata/resource-metadata.json)
echo -n "$sm_resource_metadata_json" | sed -e 's/^.*"UserProfileName":\"\([^"]*\)\".*$/\1/'
echo -n "$sm_resource_metadata_json" | jq -r '.UserProfileName'
}

# shellcheck disable=SC2001
function _print_sm_space_name() {
sm_resource_metadata_json=$(tr -d "\n" < /opt/ml/metadata/resource-metadata.json)
echo -n "$sm_resource_metadata_json" | jq -r '.SpaceName'
}

function _print_sm_studio_python() {
Expand Down Expand Up @@ -216,6 +222,6 @@ function _start_sshd() {
if _is_centos; then
/usr/sbin/sshd
else
service ssh start || (echo "ERROR: Failed to start sshd service" && exit 255)
sudo service ssh start || (echo "ERROR: Failed to start sshd service" && exit 255)
fi
}
6 changes: 5 additions & 1 deletion sagemaker_ssh_helper/sm-init-ssm
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,8 @@ response=$(aws ssm create-activation \
acode=$(echo $response | jq --raw-output '.ActivationCode')
aid=$(echo $response | jq --raw-output '.ActivationId')

echo Yes | amazon-ssm-agent -register -id "$aid" -code "$acode" -region "$CURRENT_REGION"
if [[ -n $(_print_sm_user_profile_name) && $(_print_sm_user_profile_name) != "null" ]]; then
echo Yes | amazon-ssm-agent -register -id "$aid" -code "$acode" -region "$CURRENT_REGION"
else
echo Yes | sudo amazon-ssm-agent -register -id "$aid" -code "$acode" -region "$CURRENT_REGION"
fi
29 changes: 27 additions & 2 deletions sagemaker_ssh_helper/sm-local-ssh-ide
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,29 @@ if [[ "$COMMAND" == "connect" ]]; then
echo "sm-local-ssh-ide: WARNING: SageMaker Studio user profile name is not set."\
"Run 'sm-local-ssh-ide set-user-profile-name' to override."
fi
SPACE_NAME=""
if [ -f ~/.sm-studio-space-name ]; then
SPACE_NAME="$(cat ~/.sm-studio-space-name)"
else
echo "sm-local-ssh-ide: WARNING: SageMaker Studio space name is not set."\
"Run 'sm-local-ssh-ide set-space-name' to override."
fi

INSTANCE_ID=$(python <<EOF
if [ -z "$SPACE_NAME" ]; then
INSTANCE_ID=$(python <<EOF
import sagemaker; from sagemaker_ssh_helper.ide import SSHIDE;
import logging; logging.basicConfig(level=logging.INFO);
SSHIDE("$DOMAIN_ID", "$USER_PROFILE_NAME").print_kernel_instance_id("$SM_STUDIO_KGW_NAME", timeout_in_sec=300)
EOF
)
else
INSTANCE_ID=$(python <<EOF
import sagemaker; from sagemaker_ssh_helper.ide import SSHIDE;
import logging; logging.basicConfig(level=logging.INFO);
SSHIDE("$DOMAIN_ID", "$USER_PROFILE_NAME", space="$SPACE_NAME").print_space_instance_id("$SM_STUDIO_KGW_NAME", timeout_in_sec=300)
EOF
)
fi

if [[ "$OPTIONS" == "--ssh-only" ]]; then
echo "sm-local-ssh-ide: Connecting only SSH to local port 10022 (got the flag --ssh-only)"
Expand Down Expand Up @@ -100,6 +116,15 @@ elif [[ "$COMMAND" == "set-user-profile-name" ]]; then
echo "sm-local-ssh-ide: Saving SageMaker Studio user profile name into ~/.sm-studio-user-profile-name"
echo "$USER_PROFILE_NAME" > ~/.sm-studio-user-profile-name

elif [[ "$COMMAND" == "set-space-name" ]]; then
SPACE_NAME="$(echo "$2" | tr '[:upper:]' '[:lower:]')"
if [[ "$SPACE_NAME" == "" ]]; then
echo "sm-local-ssh-ide: ERROR: <space-name> argument is expected"
exit 1
fi
echo "sm-local-ssh-ide: Saving SageMaker Studio user profile name into ~/.sm-studio-space-name"
echo "$SPACE_NAME" > ~/.sm-studio-space-name

elif [[ "$COMMAND" == "run-command" ]]; then

shift
Expand All @@ -116,4 +141,4 @@ else
echo "sm-local-ssh-ide connect $*"
# shellcheck disable=SC2048
$0 connect $*
fi
fi
33 changes: 21 additions & 12 deletions sagemaker_ssh_helper/sm-ssh-ide
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ if [[ "$1" == "configure" ]]; then
cat >/etc/profile.d/sm-ssh-ide.sh <<EOF
export XAUTHORITY="/tmp/.Xauthority-\$USER"
export ICEAUTHORITY="/tmp/.ICEauthority-\$USER"
touch "/tmp/.Xauthority-\$USER"
touch "/tmp/.ICEauthority-\$USER"
chmod 600 "/tmp/.Xauthority-\$USER"
chmod 600 "/tmp/.ICEauthority-\$USER"
sudo touch "/tmp/.Xauthority-\$USER"
sudo touch "/tmp/.ICEauthority-\$USER"
sudo chmod 600 "/tmp/.Xauthority-\$USER"
sudo chmod 600 "/tmp/.ICEauthority-\$USER"
EOF
source /etc/profile.d/sm-ssh-ide.sh

Expand Down Expand Up @@ -151,12 +151,21 @@ elif [[ "$1" == "init-ssm" ]]; then
LOCAL_USER_ID="$(cat ~/.sm-ssh-owner)"
echo "sm-ssh-ide: Will use local user ID: $LOCAL_USER_ID"

user_profile_json=$(aws sagemaker describe-user-profile \
--domain-id "$(_print_sm_domain_id)" \
--user-profile-name "$(_print_sm_user_profile_name)" \
--output json \
| tr -d "\n")
execution_role=$(echo "$user_profile_json" | grep "ExecutionRole" \
if [[ -n $(_print_sm_user_profile_name) && $(_print_sm_user_profile_name) != "null" ]]; then
json=$(aws sagemaker describe-user-profile \
--domain-id "$(_print_sm_domain_id)" \
--user-profile-name "$(_print_sm_user_profile_name)" \
--output json \
| tr -d "\n")
else
json=$(aws sagemaker describe-space \
--domain-id "$(_print_sm_domain_id)" \
--space-name "$(_print_sm_space_name)" \
--output json \
| tr -d "\n")
fi

execution_role=$(echo "$json" | grep "ExecutionRole" \
| sed -e 's/^.*"ExecutionRole": \"\([^"]*\)\".*$/\1/')

SSH_SSM_ROLE=$(echo "$execution_role" | sed -e 's/^.*:role\/\(.*\)/\1/')
Expand Down Expand Up @@ -192,7 +201,7 @@ elif [[ "$1" == "start" ]]; then
touch /tmp/.ssh-ide-local-lock

echo "sm-ssh-ide: Saving env variables for remote SSH interpreter"
sm-save-env
sudo sm-save-env

echo "sm-ssh-ide: Generating UTF-8 locales"
_locale_gen
Expand Down Expand Up @@ -244,7 +253,7 @@ elif [[ "$1" == "start" ]]; then
elif [[ "$1" == "ssm-agent" ]]; then

echo "sm-ssh-ide: Starting SSM agent"
/usr/bin/amazon-ssm-agent
sudo /usr/bin/amazon-ssm-agent

elif [[ "$1" == "status" ]]; then

Expand Down
35 changes: 35 additions & 0 deletions tests/byoi_studio/Dockerfile.codeeditor.internet_free
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
FROM public.ecr.aws/sagemaker/sagemaker-distribution:1.6-cpu@sha256:d5148872a9e35b62054fbd82991541592b0ea5edb7b343e579a2daf3b50c2f6b

USER root
# Install SageMaker SSH Helper for the Internet-free setup
ARG SAGEMAKER_SSH_HELPER_DIR="/opt/sagemaker-ssh-helper"
RUN mkdir -p $SAGEMAKER_SSH_HELPER_DIR

# See tests/test_ide.py::test_studio_internet_free_mode

# Log the kernel specs
# The kernel name needs to match SageMaker Image config
# RUN jupyter-kernelspec list

RUN pip3 uninstall -y -q awscli

# Install official release (for users):
#RUN \
# pip3 install --no-cache-dir sagemaker-ssh-helper

# Install dev release from source (for developers):
COPY ./ $SAGEMAKER_SSH_HELPER_DIR/src/
RUN \
pip3 --no-cache-dir install wheel && \
pip3 --no-cache-dir install $SAGEMAKER_SSH_HELPER_DIR/src/

# Pre-configure the container with packages, which should be installed from Internet
# Consider adding `--ssh-only` flag and commenting the first RUN command, if you don't plan to connect
# to the VNC server or to the Jupyter notebook
# RUN apt-get update -y && apt-get upgrade -y

RUN sm-ssh-ide configure --ssh-only

USER $MAMBA_USER
WORKDIR "/home/${NB_USER}"
ENTRYPOINT ["entrypoint-code-editor"]