TokenBucketRateLimitStrategy.java

package com.siddharthgawas.apigateway.ratelimiter.impl;

import com.siddharthgawas.apigateway.ratelimiter.RateLimitStrategy;
import com.siddharthgawas.apigateway.ratelimiter.dto.RateLimitProps;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;

import java.util.List;

/**
 * TokenBucketRateLimitStrategy implements a token bucket algorithm for rate limiting.
 * <p>
 * This strategy uses Redis to manage the token count and refill time, allowing for
 * a specified number of requests per minute.
 */
@Slf4j
public class TokenBucketRateLimitStrategy implements RateLimitStrategy {

    private static final String LUA_SCRIPT = """
            local tokens = tonumber(redis.call('get', KEYS[1]))
            local lastRefill = tonumber(redis.call('get', KEYS[2]))
            local currentTime = tonumber(ARGV[1])
            if not tokens or not lastRefill then
                tokens = ARGV[2]
                lastRefill = currentTime
                redis.call('set', KEYS[1], tokens, 'EX', ARGV[3])
                redis.call('set', KEYS[2], lastRefill, 'EX', ARGV[3])
            end
            local isQuotaExceeded = tonumber(tokens) <= 0
            if not isQuotaExceeded then
                redis.call('decr', KEYS[1])
                return 0
            end
            return 1
            """;

    private static final int WINDOW_SIZE = 60; // seconds

    private final RedisTemplate<String, Object> redisTemplate;

    private final Long maxTokenPerMinute;


    public TokenBucketRateLimitStrategy(RedisTemplate<String, Object> redisTemplate, Long maxTokenPerMinute) {
        this.redisTemplate = redisTemplate;
        this.maxTokenPerMinute = maxTokenPerMinute;
    }

    /**
     * Checks if the quota is exceeded for the given rate limit properties.
     *
     * @param rateLimitProps The properties containing the key and request path for rate limiting.
     * @return true if the quota is exceeded, false otherwise.
     */
    @Override
    public Boolean isQuotaExceeded(final RateLimitProps rateLimitProps) {
        final String key = rateLimitProps.getKey() + ":" + rateLimitProps.getRequestPath();
        final String tokenCountKey = key + ":tokens";
        final String lastRefillKey = key + ":lastRefill";
        final long currentTime = System.currentTimeMillis() / 1000; // Current time in seconds
        Long result = redisTemplate.execute(new DefaultRedisScript<>(LUA_SCRIPT, Long.class),
                List.of(tokenCountKey, lastRefillKey),
                currentTime,
                maxTokenPerMinute,
                WINDOW_SIZE);
        return result == 1L;
    }

}