Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove extension limitation in files #1559

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 3 additions & 20 deletions gateway/api/services/file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class WorkingDir(Enum):
PROVIDER_STORAGE = 2


SUPPORTED_FILE_EXTENSIONS = [".tar", ".h5"]

logger = logging.getLogger("gateway")


Expand All @@ -45,21 +43,6 @@ class FileStorage: # pylint: disable=too-few-public-methods
provider_name (str | None): name of the provider in caseis needed to build the path
"""

@staticmethod
def is_valid_extension(file_name: str) -> bool:
"""
This method verifies if the extension of the file is valid.

Args:
file_name (str): file name to verify

Returns:
bool: True or False if it is valid or not
"""
return any(
file_name.endswith(extension) for extension in SUPPORTED_FILE_EXTENSIONS
)

def __init__(
self,
username: str,
Expand Down Expand Up @@ -121,8 +104,8 @@ def __get_provider_path(self, function_title: str, provider_name: str) -> str:
def get_files(self) -> list[str]:
"""
This method returns a list of file names following the next rules:
- Only files with supported extensions are listed
- It returns only files from a user or a provider file storage
- Directories are excluded

Returns:
list[str]: list of file names
Expand All @@ -138,8 +121,8 @@ def get_files(self) -> list[str]:

return [
os.path.basename(path)
for extension in SUPPORTED_FILE_EXTENSIONS
for path in glob.glob(f"{self.file_path}/*{extension}")
for path in glob.glob(f"{self.file_path}/*")
if os.path.isfile(path)
]

def get_file(self, file_name: str) -> Optional[Tuple[FileWrapper, str, int]]:
Expand Down
24 changes: 3 additions & 21 deletions gateway/api/views/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from rest_framework.decorators import action
from rest_framework.response import Response

from api.services.file_storage import SUPPORTED_FILE_EXTENSIONS, FileStorage, WorkingDir
from api.services.file_storage import FileStorage, WorkingDir
from api.utils import sanitize_file_name, sanitize_name
from api.models import Provider, Program

Expand Down Expand Up @@ -309,15 +309,6 @@ def download(self, request):
status=status.HTTP_400_BAD_REQUEST,
)

if not FileStorage.is_valid_extension(requested_file_name):
extensions = ", ".join(SUPPORTED_FILE_EXTENSIONS)
return Response(
{
"message": f"File name needs to have a valid extension: {extensions}"
},
status=status.HTTP_400_BAD_REQUEST,
)

function = self.get_function(
user=request.user,
function_title=function_title,
Expand Down Expand Up @@ -381,15 +372,6 @@ def provider_download(self, request):
status=status.HTTP_400_BAD_REQUEST,
)

if not FileStorage.is_valid_extension(requested_file_name):
extensions = ", ".join(SUPPORTED_FILE_EXTENSIONS)
return Response(
{
"message": f"File name needs to have a valid extension: {extensions}"
},
status=status.HTTP_400_BAD_REQUEST,
)

if not self.user_has_provider_access(request.user, provider_name):
return Response(
{"message": f"Provider {provider_name} doesn't exist."},
Expand Down Expand Up @@ -431,7 +413,7 @@ def provider_download(self, request):
return response

@action(methods=["DELETE"], detail=False)
def delete(self, request): # pylint: disable=invalid-name
def delete(self, request):
"""Deletes file uploaded or produced by the programs,"""
# default response for file not found, overwritten if file is found
tracer = trace.get_tracer("gateway.tracer")
Expand Down Expand Up @@ -484,7 +466,7 @@ def delete(self, request): # pylint: disable=invalid-name
)

@action(methods=["DELETE"], detail=False, url_path="provider/delete")
def provider_delete(self, request): # pylint: disable=invalid-name
def provider_delete(self, request):
"""Deletes file uploaded or produced by the programs,"""
# default response for file not found, overwritten if file is found
tracer = trace.get_tracer("gateway.tracer")
Expand Down
Loading