Accelerating Machine Learning Predictions with FastAPI and Redis Caching
Overcoming Latency Issues in ML Model Serving
What is FastAPI?
What is Redis?
Why Combine FastAPI and Redis?
Implementation Steps
- Loading a Pre-trained Model
- Creating a FastAPI Endpoint for Predictions
- Adding Redis Caching for Predictions
- Testing and Measuring Performance Gains
Conclusion
Speeding Up Machine Learning Predictions with FastAPI and Redis Caching
Ever waited too long for a machine learning model to return predictions? We’ve all been there. Large and complex models can be painfully slow for real-time serving, while users expect instant feedback. This is where latency becomes a crucial concern. A significant contributor to this issue is redundant computation; when the same input triggers the model’s slow process repeatedly, it can lead to substantial delays. But fear not! In this post, I’ll guide you on how to fix that by building a FastAPI-based ML service with Redis caching, allowing you to return repeated predictions in milliseconds.
What is FastAPI?
FastAPI is a modern high-performance web framework for building APIs with Python. It leverages Python’s type hints for data validation and auto-generates interactive API documentation (thanks to Swagger UI and ReDoc). Built on top of Starlette and Pydantic, FastAPI supports asynchronous programming, rivaling Node.js and Go in performance. Its design promotes rapid development of robust, production-ready APIs, making it an ideal choice for deploying machine learning models as scalable RESTful services.
What is Redis?
Redis (Remote Dictionary Server) is an open-source, in-memory data structure store that functions as a database, cache, and message broker. By storing data in memory, Redis provides ultra-low latency for read and write operations, which is perfect for caching frequent or computationally intensive tasks like ML model predictions. It supports various data structures and features such as key expiration (TTL) for efficient cache management.
Why Combine FastAPI and Redis?
Integrating FastAPI with Redis creates a responsive and efficient system. FastAPI serves as a quick and reliable interface for handling API requests, while Redis acts as a caching layer for storing results from previous computations. When the same input is received again, results can be retrieved instantly from Redis, avoiding recomputation. This approach minimizes latency, decreases computational load, and bolsters the scalability of your application. In distributed environments, Redis serves as a centralized cache accessible by multiple FastAPI instances, making it suitable for production-grade ML deployments.
Now, let’s walk through the implementation of a FastAPI application that serves ML model predictions and utilizes Redis caching.
Step 1: Loading a Pre-trained Model
Before we dive into the FastAPI implementation, let’s ensure we have a pre-trained machine learning model. Most models are trained offline, saved to disk, and subsequently loaded into a serving application. In this example, we’ll train a simple scikit-learn classifier on the famous Iris flower dataset and save it using Joblib.
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import joblib
# Load dataset and train a model (Iris classification)
X, y = load_iris(return_X_y=True)
model = RandomForestClassifier().fit(X, y)
# Save the trained model to disk
joblib.dump(model, "model.joblib")
After training, we’ll load the model to ensure it’s ready for serving:
model = joblib.load("model.joblib")
print("Model loaded and ready to serve predictions.")
Step 2: Creating a FastAPI Prediction Endpoint
Now that we have the model, let’s set up the API. FastAPI simplifies the process of defining an endpoint and mapping request parameters to Python function arguments.
from fastapi import FastAPI
import joblib
app = FastAPI()
model = joblib.load("model.joblib") # Load model at startup
@app.get("/predict")
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float):
""" Predict the Iris flower species from input measurements. """
features = [[sepal_length, sepal_width, petal_length, petal_width]]
prediction = model.predict(features)[0] # Returns the predicted class
return {"prediction": str(prediction)}
Run the FastAPI app using uvicorn:
uvicorn main:app --reload
Step 3: Adding Redis Caching for Predictions
To cache the model output, we will introduce Redis. Ensure that the Redis server is running (you can run it via Docker or install it locally). We will use the Python redis library to communicate with the server.
First, make sure to install the Redis library:
pip install redis
Now, let’s enhance our FastAPI app to integrate Redis caching logic.
import redis # Import Redis library
# Connect to a local Redis server
cache = redis.Redis(host="localhost", port=6379, db=0)
@app.get("/predict")
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float):
""" Predict the species, with caching to speed up repeated predictions. """
cache_key = f"{sepal_length}:{sepal_width}:{petal_length}:{petal_width}"
cached_val = cache.get(cache_key)
if cached_val:
return {"prediction": cached_val.decode("utf-8")}
features = [[sepal_length, sepal_width, petal_length, petal_width]]
prediction = model.predict(features)[0]
cache.set(cache_key, str(prediction))
return {"prediction": str(prediction)}
Step 4: Testing and Measuring Performance Gains
With our FastAPI app now connected to Redis, let’s test the latency improvements we’ve achieved through caching.
import requests, time
params = {
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
# First request (expected cache miss)
start = time.time()
response1 = requests.get("http://localhost:8000/predict", params=params)
elapsed1 = time.time() - start
print("First response:", response1.json(), f"(Time: {elapsed1:.4f} seconds)")
# Second request (expected cache hit)
start = time.time()
response2 = requests.get("http://localhost:8000/predict", params=params)
elapsed2 = time.time() - start
print("Second response:", response2.json(), f"(Time: {elapsed2:.6f} seconds)")
You should observe significantly reduced response times on repeated requests, showcasing caching’s performance benefits.
Comparison
Let’s summarize the gains:
Without caching: Every request, including identical ones, hits the model, resulting in potentially multiplied latency.
With caching: The first request incurs the full computational cost, while subsequent identical requests only require a quick Redis lookup. The resulting speedup is substantial, especially for expensive models.
Conclusion
In this post, we’ve highlighted how to efficiently serve machine learning predictions using FastAPI and Redis. FastAPI provides a sleek API layer, while Redis significantly reduces latency and computational load for repeated predictions. By implementing caching, we’ve not only enhanced responsiveness but also improved the scalability of our application.
Happy coding, and may your machine learning deployments be swift and efficient!
Hi, I am Janvi, a passionate data science enthusiast currently working at Analytics Vidhya. My journey into the world of data began with a deep curiosity about how we can extract meaningful insights from complex datasets. Feel free to connect!