From 82a28c9c4ce41733cc83d1cdf9f9d2e6e9d1e63b Mon Sep 17 00:00:00 2001 From: Luca Capra Date: Mon, 5 Jun 2023 12:13:15 +0200 Subject: [PATCH] add BASE_URL support --- app/webservice.py | 78 +++++++++++++++++++++++++++++++---------------- 1 file changed, 51 insertions(+), 27 deletions(-) diff --git a/app/webservice.py b/app/webservice.py index cf333da..8597a24 100644 --- a/app/webservice.py +++ b/app/webservice.py @@ -5,7 +5,7 @@ import numpy as np import ffmpeg -from fastapi import FastAPI, File, UploadFile, Query, applications +from fastapi import FastAPI, File, UploadFile, Query, applications, APIRouter from fastapi.responses import StreamingResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.openapi.docs import get_swagger_ui_html @@ -17,8 +17,10 @@ else: from .openai_whisper.core import transcribe, language_detection -SAMPLE_RATE=16000 -LANGUAGE_CODES=sorted(list(tokenizer.LANGUAGES.keys())) +SAMPLE_RATE = 16000 +LANGUAGE_CODES = sorted(list(tokenizer.LANGUAGES.keys())) + +BASE_URL = os.getenv("BASE_URL", "") projectMetadata = importlib.metadata.metadata('whisper-asr-webservice') app = FastAPI( @@ -32,56 +34,77 @@ license_info={ "name": "MIT License", "url": projectMetadata['License'] - } + }, + docs_url=f"{BASE_URL}/docs", + openapi_url=f"{BASE_URL}/openapi.json" ) +router = APIRouter(prefix=BASE_URL) + + assets_path = os.getcwd() + "/swagger-ui-assets" if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/swagger-ui-bundle.js"): - app.mount("/assets", StaticFiles(directory=assets_path), name="static") + app.mount(f"{BASE_URL}/assets", + StaticFiles(directory=assets_path), name="static") + def swagger_monkey_patch(*args, **kwargs): + kwargs["openapi_url"] = f"{BASE_URL}/openapi.json" return get_swagger_ui_html( *args, **kwargs, swagger_favicon_url="", - swagger_css_url="/assets/swagger-ui.css", - swagger_js_url="/assets/swagger-ui-bundle.js", + swagger_css_url=f"{BASE_URL}/assets/swagger-ui.css", + swagger_js_url=f"{BASE_URL}/assets/swagger-ui-bundle.js", ) applications.get_swagger_ui_html = swagger_monkey_patch -@app.get("/", response_class=RedirectResponse, include_in_schema=False) + +@router.get("/", response_class=RedirectResponse, include_in_schema=False) async def index(): - return "/docs" + return f"{BASE_URL}/docs" -@app.post("/asr", tags=["Endpoints"]) + +@router.post("/asr", tags=["Endpoints"]) def asr( - task : Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]), + task: Union[str, None] = Query(default="transcribe", enum=[ + "transcribe", "translate"]), language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES), initial_prompt: Union[str, None] = Query(default=None), audio_file: UploadFile = File(...), - encode : bool = Query(default=True, description="Encode audio first through ffmpeg"), - output : Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]), - word_timestamps : bool = Query( - default=False, - description="World level timestamps", + encode: bool = Query( + default=True, description="Encode audio first through ffmpeg"), + output: Union[str, None] = Query( + default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]), + word_timestamps: bool = Query( + default=False, + description="World level timestamps", include_in_schema=(True if ASR_ENGINE == "faster_whisper" else False) ) ): - result = transcribe(load_audio(audio_file.file, encode), task, language, initial_prompt, word_timestamps, output) + result = transcribe(load_audio(audio_file.file, encode), + task, language, initial_prompt, word_timestamps, output) return StreamingResponse( - result, - media_type="text/plain", + result, + media_type="text/plain", headers={ - 'Asr-Engine': ASR_ENGINE, - 'Content-Disposition': f'attachment; filename="{audio_file.filename}.{output}"' - }) + 'Asr-Engine': ASR_ENGINE, + 'Content-Disposition': f'attachment; filename="{audio_file.filename}.{output}"' + }) + -@app.post("/detect-language", tags=["Endpoints"]) +@router.post("/detect-language", tags=["Endpoints"]) def detect_language( audio_file: UploadFile = File(...), - encode : bool = Query(default=True, description="Encode audio first through ffmpeg") + encode: bool = Query( + default=True, description="Encode audio first through ffmpeg") ): - detected_lang_code = language_detection(load_audio(audio_file.file, encode)) - return { "detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code" : detected_lang_code } + detected_lang_code = language_detection( + load_audio(audio_file.file, encode)) + return {"detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code": detected_lang_code} + + +app.include_router(router) + def load_audio(file: BinaryIO, encode=True, sr: int = SAMPLE_RATE): """ @@ -109,7 +132,8 @@ def load_audio(file: BinaryIO, encode=True, sr: int = SAMPLE_RATE): .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True, input=file.read()) ) except ffmpeg.Error as e: - raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + raise RuntimeError( + f"Failed to load audio: {e.stderr.decode()}") from e else: out = file.read()