RateLimiterFilter.java
package com.siddharthgawas.apigateway.ratelimiter;
import com.siddharthgawas.apigateway.ratelimiter.dto.RateLimitProps;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.NonNull;
import org.springframework.http.HttpStatus;
import org.springframework.security.web.context.RequestAttributeSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;
import java.io.IOException;
import java.util.function.Function;
/**
* Filter for rate limiting requests based on a specified strategy.
* <p>
* This filter checks if the rate limit for a request has been exceeded and
* responds with an appropriate HTTP status code if the limit is reached.
*/
public class RateLimiterFilter extends OncePerRequestFilter {
private final SecurityContextRepository securityContextRepository =
new RequestAttributeSecurityContextRepository();
private final RateLimitStrategy rateLimitStrategy;
private final RequestMatcher requestMatcher;
private final Function<HttpServletRequest, String> keyExtractor;
/**
* Constructs a RateLimiterFilter with the specified rate limit strategy and request matcher.
*
* @param rateLimitStrategy the strategy to use for rate limiting
* @param requestMatcher the matcher to determine which requests to filter
*/
public RateLimiterFilter(final RateLimitStrategy rateLimitStrategy, final RequestMatcher requestMatcher) {
this.rateLimitStrategy = rateLimitStrategy;
this.requestMatcher = requestMatcher;
this.keyExtractor = ServletRequest::getRemoteAddr;
}
/**
* Constructs a RateLimiterFilter with the specified rate limit strategy, request matcher,
* and key extractor function.
*
* @param rateLimitStrategy the strategy to use for rate limiting
* @param requestMatcher the matcher to determine which requests to filter
* @param keyExtractor function to extract the key from the request
*/
public RateLimiterFilter(final RateLimitStrategy rateLimitStrategy,
final RequestMatcher requestMatcher,
final Function<HttpServletRequest, String> keyExtractor) {
this.rateLimitStrategy = rateLimitStrategy;
this.requestMatcher = requestMatcher;
this.keyExtractor = keyExtractor;
}
/**
* Filters incoming requests to apply rate limiting.
* <p>
* This method checks if the request matches the specified request matcher and
* applies the rate limit strategy. If the rate limit is exceeded, it responds
* with a 429 Too Many Requests status; otherwise, it allows the request to proceed.
*
* @param request the HttpServletRequest to filter
* @param response the HttpServletResponse to write the response to
* @param filterChain the filter chain to continue processing the request
* @throws ServletException if an error occurs during filtering
* @throws IOException if an I/O error occurs
*/
@Override
protected void doFilterInternal(final HttpServletRequest request,
final HttpServletResponse response,
final FilterChain filterChain) throws ServletException, IOException {
final var requestPath = request.getRequestURI();
final var key = keyExtractor.apply(request);
if (!StringUtils.hasLength(key)) {
if (!response.isCommitted()) {
response.setStatus(HttpStatus.UNAUTHORIZED.value());
response.getWriter().write("Unauthorized access. Please provide valid credentials.");
}
return;
}
final var isQuotaExhausted = rateLimitStrategy.isQuotaExceeded(new RateLimitProps(key, requestPath));
if (isQuotaExhausted) {
if (!response.isCommitted()) {
response.setStatus(HttpStatus.TOO_MANY_REQUESTS.value());
response.getWriter().write("Rate limit exceeded. Please try again later.");
}
} else {
filterChain.doFilter(request, response);
}
}
/**
* Determines whether the filter should be applied to the request.
* <p>
* This method checks if the request matches the specified request matcher.
*
* @param request the HttpServletRequest to check
* @return true if the filter should not be applied, false otherwise
* @throws ServletException if an error occurs during filtering
*/
@Override
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
return !requestMatcher.matches(request);
}
}