Q&A 4 How do you serve saved models as prediction endpoints using FastAPI?
4.1 Explanation
Once you’ve saved your trained models, the next step is to create an API that loads those models and makes them available for real-time prediction. FastAPI is a lightweight, high-performance framework that’s ideal for this.
In this Q&A, we define a FastAPI app that:
- Loads all
.joblibmodels from themodels/folder - Defines a prediction route
/predict/{model_name} - Accepts JSON input using a
pydanticschema - Returns a prediction as a JSON response
4.2 Python Code (Define FastAPI App)
# scripts/model_api.py
import os
import joblib
import pandas as pd
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
# Load models dynamically
MODEL_DIR = "models"
models = {}
for fname in os.listdir(MODEL_DIR):
if fname.endswith(".joblib"):
model_name = fname.replace(".joblib", "")
model_path = os.path.join(MODEL_DIR, fname)
models[model_name] = joblib.load(model_path)
# Create FastAPI app
app = FastAPI()
# Define input schema
class InputData(BaseModel):
Pclass: int
Sex: int
Age: float
Fare: float
Embarked: int
# Define output schema
class PredictionOutput(BaseModel):
model: str
prediction: int
# Route to list available models
@app.get("/models")
def list_models():
return {"available_models": list(models.keys())}
# Route to predict using any loaded model
@app.post("/predict/{model_name}", response_model=PredictionOutput)
def predict(model_name: str, input_data: InputData):
if model_name not in models:
raise HTTPException(status_code=404, detail="Model not found.")
input_df = pd.DataFrame([input_data.dict()])
model = models[model_name]
try:
prediction = model.predict(input_df)[0]
return PredictionOutput(model=model_name, prediction=int(prediction))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))