You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
178 lines
6.5 KiB
C#
178 lines
6.5 KiB
C#
using System;
|
|
using System.Collections.Concurrent;
|
|
using System.Linq;
|
|
using System.Threading.Tasks;
|
|
using Microsoft.AspNetCore.Http;
|
|
using Microsoft.Extensions.Logging;
|
|
using Microsoft.Extensions.Options;
|
|
using Haoliang.Core.Services;
|
|
using Haoliang.Models.Common;
|
|
|
|
namespace Haoliang.Api.Middleware
|
|
{
|
|
public class RateLimitMiddleware
|
|
{
|
|
private readonly RequestDelegate _next;
|
|
private readonly ILogger<RateLimitMiddleware> _logger;
|
|
private readonly ILoggingService _loggingService;
|
|
private readonly RateLimitSettings _settings;
|
|
private readonly ConcurrentDictionary<string, DateTime> _requestTimestamps = new ConcurrentDictionary<string, DateTime>();
|
|
private readonly ConcurrentDictionary<string, int> _requestCounts = new ConcurrentDictionary<string, int>();
|
|
|
|
public RateLimitMiddleware(
|
|
RequestDelegate next,
|
|
ILogger<RateLimitMiddleware> logger,
|
|
ILoggingService loggingService,
|
|
IOptions<RateLimitSettings> settings)
|
|
{
|
|
_next = next;
|
|
_logger = logger;
|
|
_loggingService = loggingService;
|
|
_settings = settings.Value;
|
|
}
|
|
|
|
public async Task Invoke(HttpContext context)
|
|
{
|
|
var clientId = GetClientId(context);
|
|
var endpoint = GetEndpoint(context);
|
|
|
|
if (IsRateLimited(clientId, endpoint))
|
|
{
|
|
await LogRateLimitExceeded(clientId, endpoint);
|
|
context.Response.StatusCode = 429;
|
|
context.Response.Headers["X-RateLimit-Limit"] = _settings.MaxRequests.ToString();
|
|
context.Response.Headers["X-RateLimit-Remaining"] = "0";
|
|
context.Response.Headers["X-RateLimit-Reset"] = GetResetTime().ToString();
|
|
|
|
var response = new ApiResponse<object>
|
|
{
|
|
Success = false,
|
|
Message = "Rate limit exceeded. Please try again later.",
|
|
ErrorCode = 429,
|
|
Timestamp = DateTime.Now
|
|
};
|
|
|
|
await context.Response.WriteAsync(System.Text.Json.JsonSerializer.Serialize(response));
|
|
return;
|
|
}
|
|
|
|
await _next(context);
|
|
|
|
// Update rate limit headers
|
|
context.Response.Headers["X-RateLimit-Limit"] = _settings.MaxRequests.ToString();
|
|
context.Response.Headers["X-RateLimit-Remaining"] = GetRemainingRequests(clientId, endpoint).ToString();
|
|
context.Response.Headers["X-RateLimit-Reset"] = GetResetTime().ToString();
|
|
}
|
|
|
|
private bool IsRateLimited(string clientId, string endpoint)
|
|
{
|
|
var key = $"{clientId}:{endpoint}";
|
|
var now = DateTime.UtcNow;
|
|
var windowStart = now.AddSeconds(-_settings.TimeWindow);
|
|
|
|
// Clean old entries
|
|
CleanupOldEntries(windowStart);
|
|
|
|
// Check if we need to reset the count
|
|
if (_requestCounts.TryGetValue(key, out var count))
|
|
{
|
|
if (count >= _settings.MaxRequests)
|
|
{
|
|
// Check if the time window has reset
|
|
if (_requestTimestamps.TryGetValue(key, out var timestamp) && timestamp < windowStart)
|
|
{
|
|
_requestCounts[key] = 1;
|
|
_requestTimestamps[key] = now;
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
}
|
|
|
|
// Increment count
|
|
_requestCounts.AddOrUpdate(key, 1, (_, _) => count + 1);
|
|
_requestTimestamps.AddOrUpdate(key, now, (_, _) => now);
|
|
|
|
return false;
|
|
}
|
|
|
|
private void CleanupOldEntries(DateTime windowStart)
|
|
{
|
|
var oldKeys = _requestTimestamps
|
|
.Where(kvp => kvp.Value < windowStart)
|
|
.Select(kvp => kvp.Key)
|
|
.ToList();
|
|
|
|
foreach (var key in oldKeys)
|
|
{
|
|
_requestTimestamps.TryRemove(key, out _);
|
|
_requestCounts.TryRemove(key, out _);
|
|
}
|
|
}
|
|
|
|
private int GetRemainingRequests(string clientId, string endpoint)
|
|
{
|
|
var key = $"{clientId}:{endpoint}";
|
|
var maxRequests = _settings.MaxRequests;
|
|
var currentCount = _requestCounts.TryGetValue(key, out var count) ? count : 0;
|
|
return Math.Max(0, maxRequests - currentCount);
|
|
}
|
|
|
|
private DateTime GetResetTime()
|
|
{
|
|
return DateTime.UtcNow.AddSeconds(_settings.TimeWindow);
|
|
}
|
|
|
|
private async Task LogRateLimitExceeded(string clientId, string endpoint)
|
|
{
|
|
var logData = new
|
|
{
|
|
ClientId = clientId,
|
|
Endpoint = endpoint,
|
|
Timestamp = DateTime.Now,
|
|
MaxRequests = _settings.MaxRequests,
|
|
TimeWindow = _settings.TimeWindow
|
|
};
|
|
|
|
await _loggingService.LogWarningAsync($"Rate limit exceeded: {JsonSerializer.Serialize(logData)}");
|
|
_logger.LogWarning("Rate limit exceeded for client {ClientId} on endpoint {Endpoint}", clientId, endpoint);
|
|
}
|
|
|
|
private string GetClientId(HttpContext context)
|
|
{
|
|
// Try to get client ID from various sources
|
|
if (context.User?.Identity?.IsAuthenticated == true)
|
|
{
|
|
return context.User.Identity.Name ?? "authenticated";
|
|
}
|
|
|
|
// Use IP address as fallback
|
|
return context.Connection.RemoteIpAddress?.ToString() ?? "unknown";
|
|
}
|
|
|
|
private string GetEndpoint(HttpContext context)
|
|
{
|
|
return context.Request.Path.ToString();
|
|
}
|
|
}
|
|
|
|
public class RateLimitSettings
|
|
{
|
|
public int MaxRequests { get; set; } = 100;
|
|
public int TimeWindow { get; set; } = 60; // in seconds
|
|
public bool EnableRateLimiting { get; set; } = true;
|
|
public bool ExcludeHealthChecks { get; set; } = true;
|
|
public string[] ExcludedPaths { get; set; } = Array.Empty<string>();
|
|
}
|
|
|
|
public static class RateLimitMiddlewareExtensions
|
|
{
|
|
public static IApplicationBuilder UseRateLimiting(this IApplicationBuilder builder, Action<RateLimitSettings> configure = null)
|
|
{
|
|
var settings = new RateLimitSettings();
|
|
configure?.Invoke(settings);
|
|
|
|
return builder.UseMiddleware<RateLimitMiddleware>(Options.Create(settings));
|
|
}
|
|
}
|
|
} |