Accelerate ML Model Serving with FastAPI and Redis Caching

Redis, an open-source, in-memory data structure store, is an excellent choice for caching in machine learning applications. Its speed, durability, and support for various data structures make it ideal for handling the high-throughput demands of real-time inference tasks.
In this tutorial, we will explore the importance of Redis caching in machine learning workflows. We will demonstrate how to build a robust machine learning application using FastAPI and Redis. The tutorial will cover the installation of Redis on Windows, running it locally, and integrating it into the machine learning project. Finally, we will test the application by sending both duplicate and unique requests to verify that the Redis caching system is functioning correctly.
Why Use Redis Caching in Machine Learning?
In today's fast-paced digital landscape, users expect instant results from machine learning applications. For instance, consider an e-commerce platform that uses a recommendation model to suggest products to users. By implementing Redis for caching repeated requests, the platform can dramatically reduce response times.
When a user requests product recommendations, the system first checks if the request has been cached. If it has, the cached response is returned in microseconds, providing a seamless experience. If not, the model processes the request, generates the recommendations, and stores the result in Redis for future requests. This approach not only enhances user satisfaction but also optimizes server resources, allowing the model to handle more requests efficiently.
Building the Phishing Email Classification App with Redis
In this project, we will build a phishing email classification app. The process involves loading and processing a dataset from Kaggle, training a machine learning model on the processed data, evaluating its performance, saving the trained model, and finally building a FastAPI application with Redis integration.
1. Setting Up
- Download the Phishing Email Detection dataset from Kaggle and place it into the
data/
directory. - To get started, you need to install Redis. Run the following command in your terminal to install the Redis Python client:
pip install redis
- If you are on Windows and do not have Windows Subsystem for Linux (WSL) installed, follow Microsoft's guide to enable WSL and install a Linux distribution (e.g., Ubuntu) from the Microsoft Store.
- Once WSL is set up, open your WSL terminal and execute the following commands to install Redis:
sudo apt update
sudo apt install redis-server
sudo service redis-server start
You should see a confirmation message indicating that redis-server
has started successfully.
2. Model Training
The training script loads the dataset, processes the data, trains the model, and saves it locally.
import joblib
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.modelselection import traintest_split
from sklearn.pipeline import Pipeline
def main():
# Load dataset
df = pd.readcsv("data/PhishingEmail.csv") # adjust the path as necessary
# Assume dataset has columns "text" and "label"
X = df["Email Text"].fillna("")
y = df["Email Type"]
# Split the dataset into training and testing sets
Xtrain, Xtest, ytrain, ytest = traintestsplit(
X, y, testsize=0.2, randomstate=42
)
# Create a pipeline with TF-IDF and Logistic Regression
pipeline = Pipeline(
[
("tfidf", TfidfVectorizer(stop_words="english")),
("clf", LogisticRegression(solver="liblinear")),
]
)
# Train the model
pipeline.fit(Xtrain, ytrain)
# Save the trained model to a file
joblib.dump(pipeline, "phishing_model.pkl")
print("Model trained and saved as phishing_model.pkl")
if name == "main":
main()
python train.py
Model trained and saved as phishing_model.pkl
3. Model Evaluation
The evaluation script loads the dataset and the saved model file to perform model evaluations.
import pandas as pd
from sklearn.metrics import classificationreport, accuracyscore
from sklearn.modelselection import traintest_split
import joblib
def main():
# Load dataset
df = pd.readcsv("data/PhishingEmail.csv") # adjust the path as necessary
# Assume dataset has columns "text" and "label"
X = df["Email Text"].fillna("")
y = df["Email Type"]
# Split the dataset
Xtrain, Xtest, ytrain, ytest = traintestsplit(
X, y, testsize=0.2, randomstate=42
)
# Load the trained model
model = joblib.load("phishing_model.pkl")
# Make predictions on the test set
ypred = model.predict(Xtest)
# Evaluate the model
print("Accuracy: ", accuracyscore(ytest, y_pred))
print("Classification Report:")
print(classificationreport(ytest, y_pred))
if name == "main":
main()
The results are nearly perfect, and the F1 score is also excellent.
python validate.py
Accuracy: 0.9723860589812332
Classification Report:
precision recall f1-score support
Phishing Email 0.96 0.97 0.96 1457
Safe Email 0.98 0.97 0.98 2273
accuracy 0.97 3730
macro avg 0.97 0.97 0.97 3730
weighted avg 0.97 0.97 0.97 3730
4. Model Serving with Redis
To serve the model, we will use FastAPI to create a REST API and integrate Redis for caching predictions.
import asyncio
import json
import joblib
from fastapi import FastAPI
from pydantic import BaseModel
import redis.asyncio as redis
Create an asynchronous Redis client (make sure Redis is running on localhost:6379)
redisclient = redis.Redis(host="localhost", port=6379, db=0, decoderesponses=True)
Load the trained model (synchronously)
model = joblib.load("phishing_model.pkl")
app = FastAPI()
Define the request and response data models
class PredictionRequest(BaseModel):
text: str
class PredictionResponse(BaseModel):
prediction: str
probability: float
@app.post("/predict", response_model=PredictionResponse)
async def predict_email(data: PredictionRequest):
# Use the email text as a cache key
cache_key = f"prediction:{data.text}"
cached = await redisclient.get(cachekey)
if cached:
return json.loads(cached)
# Run model inference in a thread to avoid blocking the event loop
pred = await asyncio.to_thread(model.predict, [data.text])
prob = await asyncio.tothread(lambda: model.predictproba([data.text])[0].max())
result = {"prediction": str(pred[0]), "probability": float(prob)}
# Cache the result for 1 hour (3600 seconds)
await redisclient.setex(cachekey, 3600, json.dumps(result))
return result
if name == "main":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
python serve.py
INFO: Started server process [17640]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
You can check the REST API documentation by going to the URL http://localhost:8000/docs

How Redis Caching Works in Machine Learning Applications
Here is a step-by-step explanation of how Redis caching operates in our machine learning application, along with a diagram to illustrate the process:

- The client submits input data to request a prediction from the machine learning model.
- A unique identifier is generated based on the input data to check if the prediction already exists.
1. If a cached prediction is found, it is retrieved and returned in a JSON response.
2. If no cached prediction is found, the input data is passed to the machine learning model to generate a new prediction.
- The newly generated prediction is stored in the Redis cache for future use.
- The final result is returned to the client in JSON format.
Testing the Phishing Email Classification App
After building our phishing email classification application, it's time to test its functionality. In this section, we will evaluate the app by sending multiple email texts using the cURL
command and analyzing the responses. Additionally, we will verify the Redis database to ensure that the caching system is working as expected.
Testing the API using CURL Command
To test the API, we will send five requests to the /predict
endpoint. Among these, three requests will contain unique email texts, while the other two will be duplicates of previously sent emails. This will allow us to verify both the prediction accuracy and the caching mechanism.
echo "\n===== Testing API Endpoint with 5 Requests =====\n"
First unique email
echo "\n----- Request 1 (First unique email) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"
}'
Second unique email
echo "\n\n----- Request 2 (Second unique email) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "urgent action required: your account has been compromised, click here to reset your password immediately"
}'
First duplicate (same as first email)
echo "\n\n----- Request 3 (Duplicate of first email - should be cached) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"
}'
Third unique email
echo "\n\n----- Request 4 (Third unique email) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "congratulations you have won a free iphone, click here to claim your prize now before it expires"
}'
Second duplicate (same as second email)
echo "\n\n----- Request 5 (Duplicate of second email - should be cached) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "urgent action required: your account has been compromised, click here to reset your password immediately"
}'
echo "\n\n===== Test Complete =====\n"
echo "Now run 'python check_redis.py' to verify the Redis cache entries"
When you run the above script, the API should return predictions for each email. For duplicate requests, the response should be retrieved from the Redis cache, ensuring faster response times.
sh test.sh
\n===== Testing API Endpoint with 5 Requests =====\n
\n----- Request 1 (First unique email) -----
{"prediction":"Safe Email","probability":0.7791625553383463}\n\n----- Request 2 (Second unique email) -----
{"prediction":"Phishing Email","probability":0.8895319031315131}\n\n----- Request 3 (Duplicate of first email - should be cached) -----
{"prediction":"Safe Email","probability":0.7791625553383463}\n\n----- Request 4 (Third unique email) -----
{"prediction":"Phishing Email","probability":0.9169092144856761}\n\n----- Request 5 (Duplicate of second email - should be cached) -----
{"prediction":"Phishing Email","probability":0.8895319031315131}\n\n===== Test Complete =====\n
Now run 'python check_redis.py' to verify the Redis cache entries
Verify the Redis Cache
To confirm that the caching system is working correctly, we will use a Python script check_redis.py
to inspect the Redis database. This script retrieves cached predictions and displays them in a tabular format.
import redis
import json
from tabulate import tabulate
def main():
# Connect to Redis (ensure Redis is running on localhost:6379)
redisclient = redis.Redis(host="localhost", port=6379, db=0, decoderesponses=True)
# Retrieve all keys that start with "prediction:"
keys = redis_client.keys("prediction:*")
total_entries = len(keys)
print(f"Total number of cached prediction entries: {total_entries}\n")
table_data = []
# Process only the first 5 entries
for key in keys[:5]:
# Remove the 'prediction:' prefix to get the original email text
email_text = key.replace("prediction:", "", 1)
# Retrieve the cached value
value = redis_client.get(key)
try:
data = json.loads(value)
except json.JSONDecodeError:
data = {}
prediction = data.get("prediction", "N/A")
# Display only the first 7 words of the email text
words = email_text.split()
truncated_text = " ".join(words[:7]) + ("..." if len(words) > 7 else "")
tabledata.append([truncatedtext, prediction])
# Print table using tabulate (only two columns now)
headers = ["Email Text (First 7 Words)", "Prediction"]
print(tabulate(table_data, headers=headers, tablefmt="pretty"))
if name == "main":
main()
When you run the check_redis.py
script, it will display the number of cache entries and the cached predictions in a table format.
python check_redis.py
Total number of cached prediction entries: 3
+--------------------------------------------------+----------------+
| Email Text (First 7 Words) | Prediction |
+--------------------------------------------------+----------------+
| congratulations you have won a free iphone,... | Phishing Email |
| urgent action required: your account has been... | Phishing Email |
| todays floor meeting you may get a... | Safe Email |
+--------------------------------------------------+----------------+
Final Thoughts
By testing the phishing email classification app with multiple requests, we successfully demonstrated that the API can accurately identify phishing emails while efficiently caching duplicate requests using Redis. This caching mechanism significantly enhances performance by reducing redundant computations for repeated inputs, which is especially beneficial in real-world applications where APIs handle high volumes of traffic.
Although this was a relatively simple machine learning model, the benefits of caching become even more pronounced when working with larger and more complex models, such as image recognition. For instance, if you were deploying a large-scale image classification model, caching predictions for frequently processed inputs could save substantial computational resources and drastically improve response times.