Accelerate ML Model Serving with FastAPI and Redis Caching

Accelerate ML Model Serving with FastAPI and Redis Caching

Redis, an open-source, in-memory data structure store, is an excellent choice for caching in machine learning applications. Its speed, durability, and support for various data structures make it ideal for handling the high-throughput demands of real-time inference tasks.

In this tutorial, we will explore the importance of Redis caching in machine learning workflows. We will demonstrate how to build a robust machine learning application using FastAPI and Redis. The tutorial will cover the installation of Redis on Windows, running it locally, and integrating it into the machine learning project. Finally, we will test the application by sending both duplicate and unique requests to verify that the Redis caching system is functioning correctly.

Why Use Redis Caching in Machine Learning?

In today's fast-paced digital landscape, users expect instant results from machine learning applications. For instance, consider an e-commerce platform that uses a recommendation model to suggest products to users. By implementing Redis for caching repeated requests, the platform can dramatically reduce response times.

When a user requests product recommendations, the system first checks if the request has been cached. If it has, the cached response is returned in microseconds, providing a seamless experience. If not, the model processes the request, generates the recommendations, and stores the result in Redis for future requests. This approach not only enhances user satisfaction but also optimizes server resources, allowing the model to handle more requests efficiently.

Building the Phishing Email Classification App with Redis

In this project, we will build a phishing email classification app. The process involves loading and processing a dataset from Kaggle, training a machine learning model on the processed data, evaluating its performance, saving the trained model, and finally building a FastAPI application with Redis integration.

1. Setting Up

  • Download the Phishing Email Detection dataset from Kaggle and place it into the data/ directory.
  • To get started, you need to install Redis. Run the following command in your terminal to install the Redis Python client:

pip install redis

  • If you are on Windows and do not have Windows Subsystem for Linux (WSL) installed, follow Microsoft's guide to enable WSL and install a Linux distribution (e.g., Ubuntu) from the Microsoft Store.
  • Once WSL is set up, open your WSL terminal and execute the following commands to install Redis:

sudo apt update

sudo apt install redis-server

  • To start the Redis server, run:
  • sudo service redis-server start

    You should see a confirmation message indicating that redis-server has started successfully.

    2. Model Training

    The training script loads the dataset, processes the data, trains the model, and saves it locally.

    import joblib

    import pandas as pd

    from sklearn.feature_extraction.text import TfidfVectorizer

    from sklearn.linear_model import LogisticRegression

    from sklearn.modelselection import traintest_split

    from sklearn.pipeline import Pipeline

    def main():

    # Load dataset

    df = pd.readcsv("data/PhishingEmail.csv") # adjust the path as necessary

    # Assume dataset has columns "text" and "label"

    X = df["Email Text"].fillna("")

    y = df["Email Type"]

    # Split the dataset into training and testing sets

    Xtrain, Xtest, ytrain, ytest = traintestsplit(

    X, y, testsize=0.2, randomstate=42

    )

    # Create a pipeline with TF-IDF and Logistic Regression

    pipeline = Pipeline(

    [

    ("tfidf", TfidfVectorizer(stop_words="english")),

    ("clf", LogisticRegression(solver="liblinear")),

    ]

    )

    # Train the model

    pipeline.fit(Xtrain, ytrain)

    # Save the trained model to a file

    joblib.dump(pipeline, "phishing_model.pkl")

    print("Model trained and saved as phishing_model.pkl")

    if name == "main":

    main()

    python train.py

    Model trained and saved as phishing_model.pkl

    3. Model Evaluation

    The evaluation script loads the dataset and the saved model file to perform model evaluations.

    import pandas as pd

    from sklearn.metrics import classificationreport, accuracyscore

    from sklearn.modelselection import traintest_split

    import joblib

    def main():

    # Load dataset

    df = pd.readcsv("data/PhishingEmail.csv") # adjust the path as necessary

    # Assume dataset has columns "text" and "label"

    X = df["Email Text"].fillna("")

    y = df["Email Type"]

    # Split the dataset

    Xtrain, Xtest, ytrain, ytest = traintestsplit(

    X, y, testsize=0.2, randomstate=42

    )

    # Load the trained model

    model = joblib.load("phishing_model.pkl")

    # Make predictions on the test set

    ypred = model.predict(Xtest)

    # Evaluate the model

    print("Accuracy: ", accuracyscore(ytest, y_pred))

    print("Classification Report:")

    print(classificationreport(ytest, y_pred))

    if name == "main":

    main()

    The results are nearly perfect, and the F1 score is also excellent.

    python validate.py

    Accuracy: 0.9723860589812332

    Classification Report:

    precision recall f1-score support

    Phishing Email 0.96 0.97 0.96 1457

    Safe Email 0.98 0.97 0.98 2273

    accuracy 0.97 3730

    macro avg 0.97 0.97 0.97 3730

    weighted avg 0.97 0.97 0.97 3730

    4. Model Serving with Redis

    To serve the model, we will use FastAPI to create a REST API and integrate Redis for caching predictions.

    import asyncio

    import json

    import joblib

    from fastapi import FastAPI

    from pydantic import BaseModel

    import redis.asyncio as redis

    Create an asynchronous Redis client (make sure Redis is running on localhost:6379)

    redisclient = redis.Redis(host="localhost", port=6379, db=0, decoderesponses=True)

    Load the trained model (synchronously)

    model = joblib.load("phishing_model.pkl")

    app = FastAPI()

    Define the request and response data models

    class PredictionRequest(BaseModel):

    text: str

    class PredictionResponse(BaseModel):

    prediction: str

    probability: float

    @app.post("/predict", response_model=PredictionResponse)

    async def predict_email(data: PredictionRequest):

    # Use the email text as a cache key

    cache_key = f"prediction:{data.text}"

    cached = await redisclient.get(cachekey)

    if cached:

    return json.loads(cached)

    # Run model inference in a thread to avoid blocking the event loop

    pred = await asyncio.to_thread(model.predict, [data.text])

    prob = await asyncio.tothread(lambda: model.predictproba([data.text])[0].max())

    result = {"prediction": str(pred[0]), "probability": float(prob)}

    # Cache the result for 1 hour (3600 seconds)

    await redisclient.setex(cachekey, 3600, json.dumps(result))

    return result

    if name == "main":

    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)

    python serve.py

    INFO: Started server process [17640]

    INFO: Waiting for application startup.

    INFO: Application startup complete.

    INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)

    You can check the REST API documentation by going to the URL http://localhost:8000/docs

    Accelerate ML Model Serving with FastAPI and Redis Caching

    How Redis Caching Works in Machine Learning Applications

    Here is a step-by-step explanation of how Redis caching operates in our machine learning application, along with a diagram to illustrate the process:

    Accelerate ML Model Serving with FastAPI and Redis Caching

    • The client submits input data to request a prediction from the machine learning model.
    • A unique identifier is generated based on the input data to check if the prediction already exists.
  • The system queries the Redis cache using the generated key to search for a previously stored prediction.
  • 1. If a cached prediction is found, it is retrieved and returned in a JSON response.

    2. If no cached prediction is found, the input data is passed to the machine learning model to generate a new prediction.

    • The newly generated prediction is stored in the Redis cache for future use.
    • The final result is returned to the client in JSON format.

    Testing the Phishing Email Classification App

    After building our phishing email classification application, it's time to test its functionality. In this section, we will evaluate the app by sending multiple email texts using the cURL command and analyzing the responses. Additionally, we will verify the Redis database to ensure that the caching system is working as expected.

    Testing the API using CURL Command

    To test the API, we will send five requests to the /predict endpoint. Among these, three requests will contain unique email texts, while the other two will be duplicates of previously sent emails. This will allow us to verify both the prediction accuracy and the caching mechanism.

    echo "\n===== Testing API Endpoint with 5 Requests =====\n"

    First unique email

    echo "\n----- Request 1 (First unique email) -----"

    curl -X 'POST' \

    'http://localhost:8000/predict' \

    -H 'accept: application/json' \

    -H 'Content-Type: application/json' \

    -d '{

    "text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"

    }'

    Second unique email

    echo "\n\n----- Request 2 (Second unique email) -----"

    curl -X 'POST' \

    'http://localhost:8000/predict' \

    -H 'accept: application/json' \

    -H 'Content-Type: application/json' \

    -d '{

    "text": "urgent action required: your account has been compromised, click here to reset your password immediately"

    }'

    First duplicate (same as first email)

    echo "\n\n----- Request 3 (Duplicate of first email - should be cached) -----"

    curl -X 'POST' \

    'http://localhost:8000/predict' \

    -H 'accept: application/json' \

    -H 'Content-Type: application/json' \

    -d '{

    "text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"

    }'

    Third unique email

    echo "\n\n----- Request 4 (Third unique email) -----"

    curl -X 'POST' \

    'http://localhost:8000/predict' \

    -H 'accept: application/json' \

    -H 'Content-Type: application/json' \

    -d '{

    "text": "congratulations you have won a free iphone, click here to claim your prize now before it expires"

    }'

    Second duplicate (same as second email)

    echo "\n\n----- Request 5 (Duplicate of second email - should be cached) -----"

    curl -X 'POST' \

    'http://localhost:8000/predict' \

    -H 'accept: application/json' \

    -H 'Content-Type: application/json' \

    -d '{

    "text": "urgent action required: your account has been compromised, click here to reset your password immediately"

    }'

    echo "\n\n===== Test Complete =====\n"

    echo "Now run 'python check_redis.py' to verify the Redis cache entries"

    When you run the above script, the API should return predictions for each email. For duplicate requests, the response should be retrieved from the Redis cache, ensuring faster response times.

    sh test.sh

    \n===== Testing API Endpoint with 5 Requests =====\n

    \n----- Request 1 (First unique email) -----

    {"prediction":"Safe Email","probability":0.7791625553383463}\n\n----- Request 2 (Second unique email) -----

    {"prediction":"Phishing Email","probability":0.8895319031315131}\n\n----- Request 3 (Duplicate of first email - should be cached) -----

    {"prediction":"Safe Email","probability":0.7791625553383463}\n\n----- Request 4 (Third unique email) -----

    {"prediction":"Phishing Email","probability":0.9169092144856761}\n\n----- Request 5 (Duplicate of second email - should be cached) -----

    {"prediction":"Phishing Email","probability":0.8895319031315131}\n\n===== Test Complete =====\n

    Now run 'python check_redis.py' to verify the Redis cache entries

    Verify the Redis Cache

    To confirm that the caching system is working correctly, we will use a Python script check_redis.py to inspect the Redis database. This script retrieves cached predictions and displays them in a tabular format.

    import redis

    import json

    from tabulate import tabulate

    def main():

    # Connect to Redis (ensure Redis is running on localhost:6379)

    redisclient = redis.Redis(host="localhost", port=6379, db=0, decoderesponses=True)

    # Retrieve all keys that start with "prediction:"

    keys = redis_client.keys("prediction:*")

    total_entries = len(keys)

    print(f"Total number of cached prediction entries: {total_entries}\n")

    table_data = []

    # Process only the first 5 entries

    for key in keys[:5]:

    # Remove the 'prediction:' prefix to get the original email text

    email_text = key.replace("prediction:", "", 1)

    # Retrieve the cached value

    value = redis_client.get(key)

    try:

    data = json.loads(value)

    except json.JSONDecodeError:

    data = {}

    prediction = data.get("prediction", "N/A")

    # Display only the first 7 words of the email text

    words = email_text.split()

    truncated_text = " ".join(words[:7]) + ("..." if len(words) > 7 else "")

    tabledata.append([truncatedtext, prediction])

    # Print table using tabulate (only two columns now)

    headers = ["Email Text (First 7 Words)", "Prediction"]

    print(tabulate(table_data, headers=headers, tablefmt="pretty"))

    if name == "main":

    main()

    When you run the check_redis.py script, it will display the number of cache entries and the cached predictions in a table format.

    python check_redis.py

    Total number of cached prediction entries: 3

    +--------------------------------------------------+----------------+

    | Email Text (First 7 Words) | Prediction |

    +--------------------------------------------------+----------------+

    | congratulations you have won a free iphone,... | Phishing Email |

    | urgent action required: your account has been... | Phishing Email |

    | todays floor meeting you may get a... | Safe Email |

    +--------------------------------------------------+----------------+

    Final Thoughts

    By testing the phishing email classification app with multiple requests, we successfully demonstrated that the API can accurately identify phishing emails while efficiently caching duplicate requests using Redis. This caching mechanism significantly enhances performance by reducing redundant computations for repeated inputs, which is especially beneficial in real-world applications where APIs handle high volumes of traffic.

    Although this was a relatively simple machine learning model, the benefits of caching become even more pronounced when working with larger and more complex models, such as image recognition. For instance, if you were deploying a large-scale image classification model, caching predictions for frequently processed inputs could save substantial computational resources and drastically improve response times.