浏览代码

refactor(wechat): 重构用户信息接口以使用安全工具类获取用户ID

- 移除重复的 userId 提取逻辑
- 引入 SecurityUtils 统一处理 currentUserId 获取
- 添加对 Long、Integer、String 类型的支持
- 使用 CustomException 处理无效用户ID情况
- 简化控制器方法中的用户认证逻辑
mcbaiyun 2 月之前
父节点
当前提交
f34568f87c

+ 8 - 26
src/main/java/work/baiyun/chronicdiseaseapp/controller/WeChatController.java

@@ -94,19 +94,10 @@ public class WeChatController {
     @Operation(summary = "获取用户信息", description = "根据 token 返回当前用户信息(支持 Authorization/X-Token/token)")
     @PostMapping(path = "/user_info", consumes = MediaType.APPLICATION_JSON_VALUE, produces = MediaType.APPLICATION_JSON_VALUE)
     public R<?> getUserInfo(@RequestBody(required = false) Map<String, String> body, HttpServletRequest request) {
-        // 使用拦截器放入的 currentUserId(AuthInterceptor 已验证 X-Token)
-        Object attr = request.getAttribute("currentUserId");
-        Long userId = null;
-
-        if (attr instanceof Long) {
-            userId = (Long) attr;
-        } else if (attr instanceof Integer) {
-            // 有时框架可能将数字解析为 Integer
-            userId = ((Integer) attr).longValue();
-        }
-
-        // 如果拦截器没有提供 userId,则401
-        if (userId == null) {
+        Long userId;
+        try {
+            userId = work.baiyun.chronicdiseaseapp.util.SecurityUtils.getCurrentUserId();
+        } catch (work.baiyun.chronicdiseaseapp.exception.CustomException e) {
             return R.fail(401, "No valid userId");
         }
 
@@ -133,19 +124,10 @@ public class WeChatController {
     @PostMapping(path = "/update_user_info", consumes = MediaType.APPLICATION_JSON_VALUE, produces = MediaType.APPLICATION_JSON_VALUE)
     public R<?> updateUserInfo(@RequestBody(required = false) work.baiyun.chronicdiseaseapp.model.vo.UpdateUserInfoRequest req,
                                HttpServletRequest request) {
-        // 使用拦截器放入的 currentUserId(AuthInterceptor 已验证 X-Token)
-        Object attr = request.getAttribute("currentUserId");
-        Long userId = null;
-
-        if (attr instanceof Long) {
-            userId = (Long) attr;
-        } else if (attr instanceof Integer) {
-            // 有时框架可能将数字解析为 Integer
-            userId = ((Integer) attr).longValue();
-        }
-
-        // 如果拦截器没有提供 userId,则401
-        if (userId == null) {
+        Long userId;
+        try {
+            userId = work.baiyun.chronicdiseaseapp.util.SecurityUtils.getCurrentUserId();
+        } catch (work.baiyun.chronicdiseaseapp.exception.CustomException e) {
             return R.fail(401, "No valid userId");
         }
 

+ 32 - 0
src/main/java/work/baiyun/chronicdiseaseapp/util/SecurityUtils.java

@@ -0,0 +1,32 @@
+package work.baiyun.chronicdiseaseapp.util;
+
+/**
+ * 安全工具类:统一从 request attribute 中读取 currentUserId
+ */
+public class SecurityUtils {
+
+    /**
+     * 从当前请求上下文读取拦截器设置的 currentUserId
+     * 支持 Long、Integer、String 三种类型;读取失败抛出 CustomException
+     */
+    public static Long getCurrentUserId() {
+        org.springframework.web.context.request.RequestAttributes attrs = org.springframework.web.context.request.RequestContextHolder.getRequestAttributes();
+        if (attrs == null || !(attrs instanceof org.springframework.web.context.request.ServletRequestAttributes)) {
+            throw new work.baiyun.chronicdiseaseapp.exception.CustomException("No valid userId");
+        }
+        jakarta.servlet.http.HttpServletRequest request = ((org.springframework.web.context.request.ServletRequestAttributes) attrs).getRequest();
+        Object attr = request.getAttribute("currentUserId");
+        Long userId = null;
+        if (attr instanceof Long) {
+            userId = (Long) attr;
+        } else if (attr instanceof Integer) {
+            userId = ((Integer) attr).longValue();
+        } else if (attr instanceof String) {
+            try { userId = Long.parseLong((String) attr); } catch (NumberFormatException ignored) {}
+        }
+        if (userId == null) {
+            throw new work.baiyun.chronicdiseaseapp.exception.CustomException("No valid userId");
+        }
+        return userId;
+    }
+}