Просмотр исходного кода

feat(acl): 实现高级 ACL IP 管理接口

- 新增 get_all_test_ips 接口用于获取所有测试 IP 列表
- 实现 add_ip 接口支持向规则中添加单个 IP 地址
- 实现 del_ip 接口支持从规则中删除指定 IP 地址
- 添加线程锁机制确保并发操作安全
- 支持按备注前缀分组管理 ACL 规则
- 提供集成测试用例演示完整使用流程
- 优化规则匹配与排序逻辑提升查找效率
mcbaiyun 2 месяцев назад
Родитель
Сommit
8132cbf883
2 измененных файлов с 268 добавлено и 1 удалено
  1. 217 0
      advanced_acl.py
  2. 51 1
      main.py

+ 217 - 0
advanced_acl.py

@@ -0,0 +1,217 @@
+"""高级 ACL 操作:按备注前缀管理分组的 dst_addr 集合。
+
+实现 get_all_test_ips, add_ip, del_ip 三个高级接口,依赖低级接口
+在 `acl_actions` 中提供的函数,以及 `auth_session.load_config`。
+
+实现注意点:
+- 使用线程锁串行化写操作以降低竞态;写回失败则返回错误信息。
+- 不对 IP 段展开,直接以字符串方式处理 dst_addr 的逗号分隔项。
+"""
+from __future__ import annotations
+
+import threading
+from typing import List, Dict, Tuple, Optional, Set
+
+from log_util import get_logger
+from auth_session import load_config
+import acl_actions
+
+logger = get_logger("advanced_acl")
+
+# 本地锁:序列化对 ACL 的修改请求以降低竞态
+_lock = threading.Lock()
+
+
+def _parse_dst_addr(dst: str) -> List[str]:
+    if not dst:
+        return []
+    parts = [p.strip() for p in dst.split(",") if p.strip()]
+    return parts
+
+
+def _join_dst_addr(parts: List[str]) -> str:
+    return ",".join(parts)
+
+
+def _load_settings() -> Tuple[str, int]:
+    cfg = load_config()
+    prefix = cfg.get("test_prefix", "Test_")
+    limit = int(cfg.get("rule_ip_limit", 1000))
+    return prefix, limit
+
+
+def get_all_test_ips() -> List[str]:
+    """返回所有以 test_prefix 前缀命名规则的合并去重 IP 列表(字符串形式)。"""
+    prefix, _ = _load_settings()
+    resp, data = acl_actions.get_acl_rules()
+    ips: Set[str] = set()
+
+    if not data:
+        return []
+
+    # Data.data is expected to be a list of rules; support a couple of shapes
+    rules = data.get("Data", {}).get("data") if isinstance(data.get("Data"), dict) else data.get("Data")
+    if rules is None:
+        # try common structure
+        rules = data.get("data") or []
+
+    for rule in rules:
+        comment = rule.get("comment", "")
+        if not comment or not comment.startswith(prefix):
+            continue
+        dst = rule.get("dst_addr", "")
+        for ip in _parse_dst_addr(dst):
+            ips.add(ip)
+
+    return sorted(ips)
+
+
+def _collect_prefixed_rules() -> List[Dict]:
+    """Return list of rule dicts that have comment starting with prefix, sorted by numeric suffix."""
+    prefix, _ = _load_settings()
+    resp, data = acl_actions.get_acl_rules()
+    if not data:
+        return []
+
+    rules = data.get("Data", {}).get("data") if isinstance(data.get("Data"), dict) else data.get("Data")
+    if rules is None:
+        rules = data.get("data") or []
+
+    matched = []
+    for rule in rules:
+        comment = rule.get("comment", "")
+        if comment and comment.startswith(prefix):
+            matched.append(rule)
+
+    # sort by numeric suffix if present, else lexicographically
+    def _key(r: Dict):
+        c = r.get("comment", "")
+        s = c[len(prefix):]
+        try:
+            return int(s)
+        except Exception:
+            return float("inf")
+
+    matched.sort(key=_key)
+    return matched
+
+
+def add_ip(ip: str) -> Dict:
+    """Add a single IP to the Test_* grouped rules.
+
+    返回 dict: {"added": bool, "rule": str|None, "row_id": int|None, "message": str}
+    """
+    prefix, limit = _load_settings()
+    # 快速存在性检查
+    current = set(get_all_test_ips())
+    if ip in current:
+        return {"added": False, "message": "already exists", "rule": None, "row_id": None}
+
+    # 串行化修改操作
+    with _lock:
+        # 重新获取并解析规则,避免竞态
+        rules = _collect_prefixed_rules()
+
+        # 尝试放入已有规则
+        for rule in rules:
+            rid = rule.get("id") or rule.get("ID") or rule.get("RowId")
+            comment = rule.get("comment", "")
+            dst = rule.get("dst_addr", "")
+            parts = _parse_dst_addr(dst)
+            if len(parts) + 1 <= limit:
+                # 将 IP 附加并编辑
+                if ip in parts:
+                    return {"added": False, "message": "already exists after recheck", "rule": comment, "row_id": rid}
+                parts.append(ip)
+                new_dst = _join_dst_addr(parts)
+                try:
+                    resp, data = acl_actions.edit_acl_rule(rule_id=rid, dst_addr=new_dst, comment=comment)
+                except Exception as e:
+                    logger.exception("edit_acl_rule 调用失败")
+                    return {"added": False, "message": f"edit failed: {e}", "rule": comment, "row_id": rid}
+
+                # 检查响应是否成功(保守判断:HTTP 200 且返回非空 JSON)
+                if resp.status_code == 200:
+                    return {"added": True, "rule": comment, "row_id": (data.get("RowId") if isinstance(data, dict) else None), "message": "ok"}
+                else:
+                    return {"added": False, "message": f"edit returned {resp.status_code}", "rule": comment, "row_id": rid}
+
+        # 如果没有可放的规则,则创建新的规则
+        # 确定新规则编号
+        if not rules:
+            new_idx = 1
+        else:
+            # 尝试解析最大编号
+            last = rules[-1].get("comment", "")
+            try:
+                last_idx = int(last[len(prefix):])
+                new_idx = last_idx + 1
+            except Exception:
+                new_idx = len(rules) + 1
+
+        new_comment = f"{prefix}{new_idx}"
+        try:
+            resp, data = acl_actions.add_acl_rule(dst_addr=ip, comment=new_comment)
+        except Exception as e:
+            logger.exception("add_acl_rule 调用失败")
+            return {"added": False, "message": f"add failed: {e}", "rule": new_comment, "row_id": None}
+
+        if resp.status_code == 200:
+            row_id = data.get("RowId") if isinstance(data, dict) else None
+            return {"added": True, "rule": new_comment, "row_id": row_id, "message": "ok"}
+        else:
+            return {"added": False, "message": f"add returned {resp.status_code}", "rule": new_comment, "row_id": None}
+
+
+def del_ip(ip: str) -> Dict:
+    """Delete IP from all Test_* rules.
+
+    返回 dict: {"deleted": bool, "affected": [ {"rule_id": id, "comment": str, "action": "edited"|"deleted"} ], "message": str}
+    """
+    prefix, _ = _load_settings()
+    affected = []
+
+    with _lock:
+        rules = _collect_prefixed_rules()
+        found = False
+        for rule in rules:
+            rid = rule.get("id") or rule.get("ID") or rule.get("RowId")
+            comment = rule.get("comment", "")
+            dst = rule.get("dst_addr", "")
+            parts = _parse_dst_addr(dst)
+            if ip in parts:
+                found = True
+                parts = [p for p in parts if p != ip]
+                if parts:
+                    new_dst = _join_dst_addr(parts)
+                    try:
+                        resp, data = acl_actions.edit_acl_rule(rule_id=rid, dst_addr=new_dst, comment=comment)
+                    except Exception as e:
+                        logger.exception("edit_acl_rule 调用失败")
+                        affected.append({"rule_id": rid, "comment": comment, "action": "error", "message": str(e)})
+                        continue
+
+                    if resp.status_code == 200:
+                        affected.append({"rule_id": rid, "comment": comment, "action": "edited"})
+                    else:
+                        affected.append({"rule_id": rid, "comment": comment, "action": "error", "message": f"edit returned {resp.status_code}"})
+                else:
+                    # 删除该规则
+                    try:
+                        resp, data = acl_actions.del_acl_rule(rule_id=rid)
+                    except Exception as e:
+                        logger.exception("del_acl_rule 调用失败")
+                        affected.append({"rule_id": rid, "comment": comment, "action": "error", "message": str(e)})
+                        continue
+
+                    if resp.status_code == 200:
+                        affected.append({"rule_id": rid, "comment": comment, "action": "deleted"})
+                    else:
+                        affected.append({"rule_id": rid, "comment": comment, "action": "error", "message": f"del returned {resp.status_code}"})
+
+        if not found:
+            return {"deleted": False, "affected": [], "message": "not found"}
+        return {"deleted": True, "affected": affected, "message": "ok"}
+
+
+__all__ = ["get_all_test_ips", "add_ip", "del_ip"]

+ 51 - 1
main.py

@@ -3,6 +3,7 @@
 from auth_session import login, set_sess_key
 from log_util import get_logger
 from acl_actions import get_acl_rules, add_acl_rule, edit_acl_rule, del_acl_rule
+from advanced_acl import get_all_test_ips, add_ip, del_ip
 
 
 def main():
@@ -25,7 +26,56 @@ def main():
 		set_sess_key(sess_cookie)
 		print(f"sess_key: {sess_cookie}")
 
-		# # 测试ACL接口
+		# # 集成测试:演示高级接口的使用流程
+		# test_ip = "1.5.6.7"
+
+		# # 1) 列出当前集合
+		# try:
+		# 	ips = get_all_test_ips()
+		# 	logger.info(f"当前 Test_* IP 数量: {len(ips)}")
+		# 	print(f"当前 Test_* IP 数量: {len(ips)}")
+		# except Exception as e:
+		# 	logger.exception(f"获取 Test IP 列表失败: {e}")
+		# 	print(f"获取 Test IP 列表失败: {e}")
+
+		# # 2) 添加测试 IP
+		# try:
+		# 	res = add_ip(test_ip)
+		# 	logger.info(f"add_ip({test_ip}) -> {res}")
+		# 	print(f"add_ip -> {res}")
+		# except Exception as e:
+		# 	logger.exception(f"添加测试 IP 失败: {e}")
+		# 	print(f"添加测试 IP 失败: {e}")
+
+		# # 3) 再次列出以验证添加
+		# try:
+		# 	ips_after = get_all_test_ips()
+		# 	logger.info(f"添加后 Test_* IP 数量: {len(ips_after)}")
+		# 	print(f"添加后 Test_* IP 数量: {len(ips_after)}")
+		# except Exception as e:
+		# 	logger.exception(f"获取添加后列表失败: {e}")
+		# 	print(f"获取添加后列表失败: {e}")
+
+		# # 4) 删除测试 IP
+		# try:
+		# 	res_del = del_ip(test_ip)
+		# 	logger.info(f"del_ip({test_ip}) -> {res_del}")
+		# 	print(f"del_ip -> {res_del}")
+		# except Exception as e:
+		# 	logger.exception(f"删除测试 IP 失败: {e}")
+		# 	print(f"删除测试 IP 失败: {e}")
+
+		# # 5) 最终验证
+		# try:
+		# 	ips_final = get_all_test_ips()
+		# 	logger.info(f"最终 Test_* IP 数量: {len(ips_final)}")
+		# 	print(f"最终 Test_* IP 数量: {len(ips_final)}")
+		# except Exception as e:
+		# 	logger.exception(f"获取最终列表失败: {e}")
+		# 	print(f"获取最终列表失败: {e}")
+
+		# 旧的低级测试保留(按需)
+		# 测试ACL接口
 		# try:
 		# 	acl_resp, acl_data = get_acl_rules()
 		# 	logger.info(f"已调用: get_acl_rules() | 状态: {acl_resp.status_code}")