diff --git a/app/internal/model_factory.py b/app/internal/model_factory.py index 4c8821ab99c1e430b0823ef4b9c1864e40427be2..cb91d11486239936ca71819f49d42d525ca7aaf1 100644 --- a/app/internal/model_factory.py +++ b/app/internal/model_factory.py @@ -34,7 +34,8 @@ def discover_models(): return models def get_model_instance(model_id): - return get_vips_models().get(model_id)() + model_class = get_vips_models().get(model_id) + return model_class() if model_class is not None else None def get_vips_model_packages(): global _packages diff --git a/app/main.py b/app/main.py index be6ad2e17a966de494790c06b55e036a81399398..7c8f073d5cb893308270b87a662cb3cf80783204 100644 --- a/app/main.py +++ b/app/main.py @@ -1,4 +1,4 @@ -from fastapi import FastAPI +from fastapi import FastAPI, Response, status from fastapi.responses import PlainTextResponse from app.internal.model_factory import * from vipscore_common.vips_model import VIPSModel @@ -49,12 +49,17 @@ async def print_model_list(language: str) -> str: ####### Model running endpoints ####### @app.post("/models/run/", name="Run a model") -async def run_model_from_config_only(model_configuration:ModelConfiguration): - return _run_model(model_configuration.model_id, model_configuration) +async def run_model_from_config_only(model_configuration:ModelConfiguration, response:Response): + result = _run_model(model_configuration.model_id, model_configuration) + return _result_or_404(result, model_id=model_id, response=response) @app.post("/models/{model_id}/run/", name="Run a model") -async def run_model_from_config_only(model_id, model_configuration:ModelConfiguration): - return _run_model(model_id, model_configuration) +async def run_model_from_config_only(model_id, model_configuration:ModelConfiguration, response:Response): + result = _run_model(model_id, model_configuration) + return _result_or_404(result, model_id=model_id, response=response) + + + ####### Helper functions ####### @@ -66,8 +71,10 @@ def _run_model(model_id:str, model_configuration:ModelConfiguration): Return a list of Result objects """ requested_model = get_model_instance(model_id) - requested_model.set_configuration(model_configuration) - return requested_model.get_result() + if requested_model is not None: + requested_model.set_configuration(model_configuration) + return requested_model.get_result() + return None def _print_model_list(language:str): @@ -91,4 +98,16 @@ def _print_model_list_json(language:str): "modelId" : model_id, "modelName" : get_model_instance(model_id).get_model_name() }) - return model_list \ No newline at end of file + return model_list + + + +def _result_or_404(result, model_id:str, response:Response): + """ + DRY + """ + if result is not None: + return result + else: + response.status_code = status.HTTP_404_NOT_FOUND + return {"ERROR":"VIPS model with id=%s could not be found." % model_id} \ No newline at end of file