You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

179 lines
6.5 KiB
C#

using System;
using System.Collections.Concurrent;
using System.Linq;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Haoliang.Core.Services;
using Haoliang.Models.Common;
namespace Haoliang.Api.Middleware
{
public class RateLimitMiddleware
{
private readonly RequestDelegate _next;
private readonly ILogger<RateLimitMiddleware> _logger;
private readonly ILoggingService _loggingService;
private readonly RateLimitSettings _settings;
private readonly ConcurrentDictionary<string, DateTime> _requestTimestamps = new ConcurrentDictionary<string, DateTime>();
private readonly ConcurrentDictionary<string, int> _requestCounts = new ConcurrentDictionary<string, int>();
public RateLimitMiddleware(
RequestDelegate next,
ILogger<RateLimitMiddleware> logger,
ILoggingService loggingService,
IOptions<RateLimitSettings> settings)
{
_next = next;
_logger = logger;
_loggingService = loggingService;
_settings = settings.Value;
}
public async Task Invoke(HttpContext context)
{
var clientId = GetClientId(context);
var endpoint = GetEndpoint(context);
if (IsRateLimited(clientId, endpoint))
{
await LogRateLimitExceeded(clientId, endpoint);
context.Response.StatusCode = 429;
context.Response.Headers["X-RateLimit-Limit"] = _settings.MaxRequests.ToString();
context.Response.Headers["X-RateLimit-Remaining"] = "0";
context.Response.Headers["X-RateLimit-Reset"] = GetResetTime().ToString();
var response = new ApiResponse<object>
{
Success = false,
Message = "Rate limit exceeded. Please try again later.",
ErrorCode = 429,
Timestamp = DateTime.Now
};
await context.Response.WriteAsync(System.Text.Json.JsonSerializer.Serialize(response));
return;
}
await _next(context);
// Update rate limit headers
context.Response.Headers["X-RateLimit-Limit"] = _settings.MaxRequests.ToString();
context.Response.Headers["X-RateLimit-Remaining"] = GetRemainingRequests(clientId, endpoint).ToString();
context.Response.Headers["X-RateLimit-Reset"] = GetResetTime().ToString();
}
private bool IsRateLimited(string clientId, string endpoint)
{
var key = $"{clientId}:{endpoint}";
var now = DateTime.UtcNow;
var windowStart = now.AddSeconds(-_settings.TimeWindow);
// Clean old entries
CleanupOldEntries(windowStart);
// Check if we need to reset the count
if (_requestCounts.TryGetValue(key, out var count))
{
if (count >= _settings.MaxRequests)
{
// Check if the time window has reset
if (_requestTimestamps.TryGetValue(key, out var timestamp) && timestamp < windowStart)
{
_requestCounts[key] = 1;
_requestTimestamps[key] = now;
return false;
}
return true;
}
}
// Increment count
_requestCounts.AddOrUpdate(key, 1, (_, _) => count + 1);
_requestTimestamps.AddOrUpdate(key, now, (_, _) => now);
return false;
}
private void CleanupOldEntries(DateTime windowStart)
{
var oldKeys = _requestTimestamps
.Where(kvp => kvp.Value < windowStart)
.Select(kvp => kvp.Key)
.ToList();
foreach (var key in oldKeys)
{
_requestTimestamps.TryRemove(key, out _);
_requestCounts.TryRemove(key, out _);
}
}
private int GetRemainingRequests(string clientId, string endpoint)
{
var key = $"{clientId}:{endpoint}";
var maxRequests = _settings.MaxRequests;
var currentCount = _requestCounts.TryGetValue(key, out var count) ? count : 0;
return Math.Max(0, maxRequests - currentCount);
}
private DateTime GetResetTime()
{
return DateTime.UtcNow.AddSeconds(_settings.TimeWindow);
}
private async Task LogRateLimitExceeded(string clientId, string endpoint)
{
var logData = new
{
ClientId = clientId,
Endpoint = endpoint,
Timestamp = DateTime.Now,
MaxRequests = _settings.MaxRequests,
TimeWindow = _settings.TimeWindow
};
await _loggingService.LogWarningAsync($"Rate limit exceeded: {JsonSerializer.Serialize(logData)}");
_logger.LogWarning("Rate limit exceeded for client {ClientId} on endpoint {Endpoint}", clientId, endpoint);
}
private string GetClientId(HttpContext context)
{
// Try to get client ID from various sources
if (context.User?.Identity?.IsAuthenticated == true)
{
return context.User.Identity.Name ?? "authenticated";
}
// Use IP address as fallback
return context.Connection.RemoteIpAddress?.ToString() ?? "unknown";
}
private string GetEndpoint(HttpContext context)
{
return context.Request.Path.ToString();
}
}
public class RateLimitSettings
{
public int MaxRequests { get; set; } = 100;
public int TimeWindow { get; set; } = 60; // in seconds
public bool EnableRateLimiting { get; set; } = true;
public bool ExcludeHealthChecks { get; set; } = true;
public string[] ExcludedPaths { get; set; } = Array.Empty<string>();
}
public static class RateLimitMiddlewareExtensions
{
public static IApplicationBuilder UseRateLimiting(this IApplicationBuilder builder, Action<RateLimitSettings> configure = null)
{
var settings = new RateLimitSettings();
configure?.Invoke(settings);
return builder.UseMiddleware<RateLimitMiddleware>(Options.Create(settings));
}
}
}