using System; using System.Net; using System.Text.Json; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Haoliang.Core.Services; using Haoliang.Models.Common; namespace Haoliang.Api.Middleware { public class ExceptionMiddleware { private readonly RequestDelegate _next; private readonly ILoggingService _loggingService; public ExceptionMiddleware(RequestDelegate next, ILoggingService loggingService) { _next = next; _loggingService = loggingService; } public async Task Invoke(HttpContext context) { try { await _next(context); } catch (Exception ex) { await HandleExceptionAsync(context, ex); } } private async Task HandleExceptionAsync(HttpContext context, Exception exception) { context.Response.ContentType = "application/json"; var response = new ApiResponse { Timestamp = DateTime.Now, Success = false }; switch (exception) { case UnauthorizedAccessException _: context.Response.StatusCode = (int)HttpStatusCode.Unauthorized; response.Message = "Unauthorized access"; response.ErrorCode = 401; await _loggingService.LogWarningAsync($"Unauthorized access: {exception.Message}"); break; case ForbiddenException _: context.Response.StatusCode = (int)HttpStatusCode.Forbidden; response.Message = "Access forbidden"; response.ErrorCode = 403; await _loggingService.LogWarningAsync($"Access forbidden: {exception.Message}"); break; case NotFoundException _: context.Response.StatusCode = (int)HttpStatusCode.NotFound; response.Message = "Resource not found"; response.ErrorCode = 404; await _loggingService.LogWarningAsync($"Resource not found: {exception.Message}"); break; case BadRequestException _: context.Response.StatusCode = (int)HttpStatusCode.BadRequest; response.Message = "Bad request"; response.ErrorCode = 400; await _loggingService.LogWarningAsync($"Bad request: {exception.Message}"); break; case ValidationException _: context.Response.StatusCode = (int)HttpStatusCode.BadRequest; response.Message = "Validation failed"; response.ErrorCode = 400; response.Data = ((ValidationException)exception).Errors; await _loggingService.LogWarningAsync($"Validation failed: {exception.Message}"); break; case ConflictException _: context.Response.StatusCode = (int)HttpStatusCode.Conflict; response.Message = "Resource conflict"; response.ErrorCode = 409; await _loggingService.LogWarningAsync($"Resource conflict: {exception.Message}"); break; default: context.Response.StatusCode = (int)HttpStatusCode.InternalServerError; response.Message = "An unexpected error occurred"; response.ErrorCode = 500; response.Data = new { Detail = exception.Message, StackTrace = exception.StackTrace }; await _loggingService.LogErrorAsync($"Unhandled exception: {exception.Message}", exception); break; } var jsonResponse = JsonSerializer.Serialize(response); await context.Response.WriteAsync(jsonResponse); } } // Custom exception classes public class ForbiddenException : Exception { public ForbiddenException(string message) : base(message) { } } public class NotFoundException : Exception { public NotFoundException(string message) : base(message) { } } public class BadRequestException : Exception { public BadRequestException(string message) : base(message) { } } public class ValidationException : Exception { public object Errors { get; set; } public ValidationException(string message, object errors = null) : base(message) { Errors = errors; } } public class ConflictException : Exception { public ConflictException(string message) : base(message) { } } }