Просмотр исходного кода

feat(protector): 实现后台防护与GUI配置管理

- 新增 ProtectorRunner 后台服务,定时检查并阻断异常连接
- 添加 protector_list.json 配置文件管理防护规则
- 实现 tkinter GUI 界面用于编辑配置和管理防护列表
- 增加 find_high_connections_for_src_ports 方法支持指定源端口分析
- 支持通过 --gui 参数启动图形界面和后台防护服务
- advanced_acl.add_ip 方法新增 comment 参数支持自定义注释
- 修复 monitor_lanip 接口调用示例代码的缩进问题
- 添加 test.bat 测试脚本用于快速验证连接
mcbaiyun 2 месяцев назад
Родитель
Сommit
e62d4e6942
8 измененных файлов с 472 добавлено и 14 удалено
  1. 6 2
      advanced_acl.py
  2. 58 0
      analysis_connections.py
  3. 202 0
      gui.py
  4. 26 12
      main.py
  5. 98 0
      protector.py
  6. 8 0
      protector_list.json
  7. 73 0
      protector_list.py
  8. 1 0
      test.bat

+ 6 - 2
advanced_acl.py

@@ -96,9 +96,12 @@ def _collect_prefixed_rules() -> List[Dict]:
     return matched
     return matched
 
 
 
 
-def add_ip(ip: str) -> Dict:
+def add_ip(ip: str, comment: str | None = None) -> Dict:
     """Add a single IP to the Test_* grouped rules.
     """Add a single IP to the Test_* grouped rules.
 
 
+    If `comment` is provided it will be used as the comment when creating a new rule.
+    When adding to an existing Test_* rule, the existing rule's comment is preserved.
+
     返回 dict: {"added": bool, "rule": str|None, "row_id": int|None, "message": str}
     返回 dict: {"added": bool, "rule": str|None, "row_id": int|None, "message": str}
     """
     """
     prefix, limit = _load_settings()
     prefix, limit = _load_settings()
@@ -149,7 +152,8 @@ def add_ip(ip: str) -> Dict:
             except Exception:
             except Exception:
                 new_idx = len(rules) + 1
                 new_idx = len(rules) + 1
 
 
-        new_comment = f"{prefix}{new_idx}"
+        # new_comment: by default use prefix-based Test_n, but allow caller to override by passing comment
+        new_comment = comment if comment else f"{prefix}{new_idx}"
         try:
         try:
             resp, data = acl_actions.add_acl_rule(dst_addr=ip, comment=new_comment)
             resp, data = acl_actions.add_acl_rule(dst_addr=ip, comment=new_comment)
         except Exception as e:
         except Exception as e:

+ 58 - 0
analysis_connections.py

@@ -220,3 +220,61 @@ def find_high_connection_pairs(ip: str, threshold: int | None = None, limit: str
 	result["matches"].sort(key=lambda x: x["count"], reverse=True)
 	result["matches"].sort(key=lambda x: x["count"], reverse=True)
 	return result
 	return result
 
 
+
+def find_high_connections_for_src_ports(ip: str, src_ports, threshold: int | None = None, limit: str = "0,100000") -> dict:
+	"""Analyze only the specified src_ports (single int or iterable) for given ip.
+
+	Returns dict: {"threshold": n, "total_reported": x, "fetched": y, "by_src_port": {port: [{dst_addr, count, samples}, ...]}}
+	"""
+	from collections import Counter
+
+	# normalize src_ports to a set of ints/strings
+	if isinstance(src_ports, (int, str)):
+		ports_set = {int(src_ports)}
+	else:
+		ports_set = {int(p) for p in src_ports}
+
+	cfg = load_config()
+	cfg_threshold = cfg.get("conn_threshold", 20)
+	if threshold is None:
+		threshold = int(cfg_threshold)
+
+	resp, data = monitor_lanip(ip, limit=limit)
+	result = {"threshold": threshold, "total_reported": None, "fetched": 0, "by_src_port": {}}
+
+	if not data:
+		return result
+
+	total_reported = data.get("Data", {}).get("conn_num")
+	conns = data.get("Data", {}).get("conn", []) or []
+	result["total_reported"] = total_reported
+	result["fetched"] = len(conns)
+
+	# group by src_port then count dst_addr
+	grouped: dict[int, Counter] = {}
+	samples: dict[tuple, list] = {}
+	for entry in conns:
+		try:
+			src_port = int(entry.get("src_port"))
+		except Exception:
+			continue
+		if src_port not in ports_set:
+			continue
+		dst = entry.get("dst_addr")
+		key = (src_port, dst)
+		grouped.setdefault(src_port, Counter())[dst] += 1
+		samples.setdefault(key, []).append(entry)
+
+	for port in ports_set:
+		counter = grouped.get(port, Counter())
+		matches = []
+		for dst, cnt in counter.items():
+			if cnt > threshold:
+				sample_list = samples.get((port, dst), [])[:5]
+				matches.append({"dst_addr": dst, "count": cnt, "samples": sample_list})
+		# sort
+		matches.sort(key=lambda x: x["count"], reverse=True)
+		result["by_src_port"][port] = matches
+
+	return result
+

+ 202 - 0
gui.py

@@ -0,0 +1,202 @@
+"""A simple tkinter GUI to edit config.json and manage protector list.
+
+Features:
+- Edit base_url, conn_threshold, scan_interval (saved to config.json)
+- Manage protector_list entries (add/edit/remove) stored in protector_list.json
+- On saving base_url, attempt login via auth_session.login() and store sess_key
+"""
+from __future__ import annotations
+
+import json
+import threading
+import tkinter as tk
+from tkinter import messagebox
+from pathlib import Path
+
+from auth_session import load_config, get_base_url, login, set_sess_key
+import protector_list
+
+ROOT = Path(__file__).resolve().parent
+CONFIG_PATH = ROOT / "config.json"
+
+
+def load_cfg():
+    return load_config()
+
+
+def save_cfg(cfg: dict):
+    with open(CONFIG_PATH, "w", encoding="utf-8") as f:
+        json.dump(cfg, f, ensure_ascii=False, indent=2)
+
+
+class ProtectorGUI(tk.Tk):
+    def __init__(self):
+        super().__init__()
+        self.title("Protector GUI")
+        self.geometry("700x500")
+
+        self.cfg = load_cfg()
+
+        # Config frame
+        cf = tk.LabelFrame(self, text="Config")
+        cf.pack(fill="x", padx=8, pady=6)
+
+        tk.Label(cf, text="base_url:").grid(row=0, column=0, sticky="w")
+        self.base_entry = tk.Entry(cf, width=60)
+        self.base_entry.grid(row=0, column=1, padx=4, pady=2)
+
+        tk.Label(cf, text="conn_threshold:").grid(row=1, column=0, sticky="w")
+        self.th_entry = tk.Entry(cf, width=10)
+        self.th_entry.grid(row=1, column=1, sticky="w", padx=4, pady=2)
+
+        tk.Label(cf, text="scan_interval (s):").grid(row=2, column=0, sticky="w")
+        self.interval_entry = tk.Entry(cf, width=10)
+        self.interval_entry.grid(row=2, column=1, sticky="w", padx=4, pady=2)
+
+        tk.Button(cf, text="Save Config", command=self.save_config).grid(row=3, column=1, sticky="w", pady=6)
+
+        # Protector list frame
+        lf = tk.LabelFrame(self, text="Protector List")
+        lf.pack(fill="both", expand=True, padx=8, pady=6)
+
+        self.listbox = tk.Listbox(lf)
+        self.listbox.pack(side="left", fill="both", expand=True, padx=4, pady=4)
+        self.listbox.bind("<<ListboxSelect>>", self.on_select)
+
+        right = tk.Frame(lf)
+        right.pack(side="right", fill="y", padx=4)
+
+        tk.Label(right, text="target_ip:").pack(anchor="w")
+        self.ip_entry = tk.Entry(right)
+        self.ip_entry.pack(fill="x")
+        tk.Label(right, text="src_port:").pack(anchor="w")
+        self.port_entry = tk.Entry(right)
+        self.port_entry.pack(fill="x")
+        tk.Label(right, text="threshold (optional):").pack(anchor="w")
+        self.entry_threshold = tk.Entry(right)
+        self.entry_threshold.pack(fill="x")
+
+        tk.Button(right, text="Add", command=self.add_entry).pack(fill="x", pady=4)
+        tk.Button(right, text="Update", command=self.update_entry).pack(fill="x", pady=4)
+        tk.Button(right, text="Remove", command=self.remove_entry).pack(fill="x", pady=4)
+
+        self.load_values()
+
+    def load_values(self):
+        self.cfg = load_cfg()
+        self.base_entry.delete(0, tk.END)
+        self.base_entry.insert(0, self.cfg.get("base_url", ""))
+        self.th_entry.delete(0, tk.END)
+        self.th_entry.insert(0, str(self.cfg.get("conn_threshold", "")))
+        self.interval_entry.delete(0, tk.END)
+        self.interval_entry.insert(0, str(self.cfg.get("scan_interval", 60)))
+
+        self.reload_listbox()
+
+    def reload_listbox(self):
+        self.listbox.delete(0, tk.END)
+        items = protector_list.load_list()
+        for i in items:
+            self.listbox.insert(tk.END, f"{i['id']}: {i['target_ip']}:{i['src_port']} (th={i.get('threshold')})")
+
+    def save_config(self):
+        try:
+            self.cfg["base_url"] = self.base_entry.get().strip()
+            self.cfg["conn_threshold"] = int(self.th_entry.get().strip())
+            self.cfg["scan_interval"] = int(self.interval_entry.get().strip())
+        except Exception as e:
+            messagebox.showerror("错误", f"配置输入无效: {e}")
+            return
+
+        save_cfg(self.cfg)
+
+        # try login on a background thread
+        def _login():
+            try:
+                resp, sess_cookie = login()
+                if sess_cookie:
+                    set_sess_key(sess_cookie)
+                messagebox.showinfo("登录", f"登录完成, 状态: {resp.status_code}")
+            except Exception as e:
+                messagebox.showerror("登录失败", str(e))
+
+        threading.Thread(target=_login, daemon=True).start()
+
+    def on_select(self, evt):
+        sel = self.listbox.curselection()
+        if not sel:
+            return
+        idx = sel[0]
+        items = protector_list.load_list()
+        if idx >= len(items):
+            return
+        item = items[idx]
+        self.ip_entry.delete(0, tk.END)
+        self.ip_entry.insert(0, item.get("target_ip", ""))
+        self.port_entry.delete(0, tk.END)
+        self.port_entry.insert(0, str(item.get("src_port", "")))
+        self.entry_threshold.delete(0, tk.END)
+        self.entry_threshold.insert(0, str(item.get("threshold", "")))
+
+    def add_entry(self):
+        ip = self.ip_entry.get().strip()
+        port = self.port_entry.get().strip()
+        th = self.entry_threshold.get().strip() or None
+        if not ip or not port:
+            messagebox.showerror("错误", "ip 与 port 为必填")
+            return
+        try:
+            entry = protector_list.add_entry(ip, int(port), int(th) if th else None)
+            self.reload_listbox()
+            messagebox.showinfo("添加", f"已添加: {entry}")
+        except Exception as e:
+            messagebox.showerror("错误", str(e))
+
+    def update_entry(self):
+        sel = self.listbox.curselection()
+        if not sel:
+            messagebox.showerror("错误", "先选择一条")
+            return
+        idx = sel[0]
+        items = protector_list.load_list()
+        if idx >= len(items):
+            return
+        eid = items[idx]["id"]
+        ip = self.ip_entry.get().strip()
+        port = self.port_entry.get().strip()
+        th = self.entry_threshold.get().strip() or None
+        try:
+            updated = protector_list.update_entry(eid, target_ip=ip, src_port=int(port), threshold=(int(th) if th else None))
+            if updated is None:
+                messagebox.showerror("错误", "更新失败")
+            else:
+                self.reload_listbox()
+                messagebox.showinfo("更新", f"已更新: {updated}")
+        except Exception as e:
+            messagebox.showerror("错误", str(e))
+
+    def remove_entry(self):
+        sel = self.listbox.curselection()
+        if not sel:
+            messagebox.showerror("错误", "先选择一条")
+            return
+        idx = sel[0]
+        items = protector_list.load_list()
+        if idx >= len(items):
+            return
+        eid = items[idx]["id"]
+        ok = protector_list.remove_entry(eid)
+        if ok:
+            self.reload_listbox()
+            messagebox.showinfo("删除", "已删除")
+        else:
+            messagebox.showerror("错误", "删除失败")
+
+
+def run_gui():
+    app = ProtectorGUI()
+    app.mainloop()
+
+
+if __name__ == "__main__":
+    run_gui()

+ 26 - 12
main.py

@@ -1,17 +1,31 @@
 """Entry point to call various API actions."""
 """Entry point to call various API actions."""
 
 
+import sys
 from auth_session import login, set_sess_key
 from auth_session import login, set_sess_key
 from log_util import get_logger
 from log_util import get_logger
 from acl_actions import get_acl_rules, add_acl_rule, edit_acl_rule, del_acl_rule
 from acl_actions import get_acl_rules, add_acl_rule, edit_acl_rule, del_acl_rule
 import json
 import json
 from advanced_acl import get_all_test_ips, add_ip, del_ip
 from advanced_acl import get_all_test_ips, add_ip, del_ip
 from analysis_connections import monitor_lanip
 from analysis_connections import monitor_lanip
+import gui
+from protector import ProtectorRunner
 
 
 
 
 def main():
 def main():
 	logger = get_logger("main")
 	logger = get_logger("main")
 	logger.info("开始 main()")
 	logger.info("开始 main()")
 
 
+	# If --gui passed, launch GUI and start protector runner
+	if "--gui" in sys.argv:
+		# start background protector
+		pr = ProtectorRunner()
+		pr.start()
+		# launch GUI (blocking)
+		gui.run_gui()
+		# when GUI exits, stop protector
+		pr.stop()
+		return
+
 	try:
 	try:
 		resp, sess_cookie = login()
 		resp, sess_cookie = login()
 	except FileNotFoundError as e:
 	except FileNotFoundError as e:
@@ -134,18 +148,18 @@ def main():
 		# 	logger.exception(f"测试删除ACL规则失败: {e}")
 		# 	logger.exception(f"测试删除ACL规则失败: {e}")
 		# 	print(f"测试删除ACL规则失败: {e}")
 		# 	print(f"测试删除ACL规则失败: {e}")
 
 
-		# # 6) 演示 monitor_lanip 接口
-		# try:
-		# 	conn_ip = "10.8.7.2"
-		# 	resp_conn, data_conn = monitor_lanip(conn_ip)
-		# 	logger.info(f"已调用: monitor_lanip({conn_ip}) | 状态: {resp_conn.status_code}")
-		# 	if data_conn:
-		# 		print(json.dumps(data_conn, ensure_ascii=False, indent=2))
-		# 	else:
-		# 		print(resp_conn.text)
-		# except Exception as e:
-		# 	logger.exception(f"测试 monitor_lanip 失败: {e}")
-		# 	print(f"测试 monitor_lanip 失败: {e}")
+		# 6) 演示 monitor_lanip 接口
+		try:
+			conn_ip = "10.8.7.2"
+			resp_conn, data_conn = monitor_lanip(conn_ip)
+			logger.info(f"已调用: monitor_lanip({conn_ip}) | 状态: {resp_conn.status_code}")
+			if data_conn:
+				print(json.dumps(data_conn, ensure_ascii=False, indent=2))
+			else:
+				print(resp_conn.text)
+		except Exception as e:
+			logger.exception(f"测试 monitor_lanip 失败: {e}")
+			print(f"测试 monitor_lanip 失败: {e}")
 	else:
 	else:
 		logger.warning("未在响应中找到 sess_key")
 		logger.warning("未在响应中找到 sess_key")
 		print("未在响应中找到 sess_key")
 		print("未在响应中找到 sess_key")

+ 98 - 0
protector.py

@@ -0,0 +1,98 @@
+"""Background protector: periodically checks protector_list entries and blocks offending dst_addr via advanced_acl.add_ip
+"""
+from __future__ import annotations
+
+import threading
+import time
+from typing import Optional
+
+import advanced_acl
+from auth_session import load_config
+import protector_list
+from analysis_connections import find_high_connections_for_src_ports
+from log_util import get_logger
+
+logger = get_logger("protector")
+
+
+class ProtectorRunner:
+    def __init__(self):
+        cfg = load_config()
+        self.interval = int(cfg.get("scan_interval", 60))
+        self._stop = threading.Event()
+
+    def run_once(self):
+        items = protector_list.load_list()
+        for entry in items:
+            target = entry.get("target_ip")
+            src_port = int(entry.get("src_port"))
+            threshold = entry.get("threshold")
+            try:
+                res = find_high_connections_for_src_ports(target, src_port, threshold=threshold)
+            except Exception as e:
+                logger.exception(f"检查失败 {target}:{src_port} - {e}")
+                continue
+
+            matches = res.get("by_src_port", {}).get(int(src_port), [])
+            if not matches:
+                logger.info(f"{target}:{src_port} - 无异常连接")
+                continue
+
+            for m in matches:
+                dst = m.get("dst_addr")
+                # Call advanced_acl.add_ip to block
+                try:
+                    # Do not pass custom comment here so add_ip will place the IP into Test_* groups
+                    # (comment was causing new rules to be named AutoBlock_..., not grouped).
+                    result = advanced_acl.add_ip(dst)
+                    # Support both legacy (resp, data) tuple and new dict result
+                    if isinstance(result, dict):
+                        added = result.get("added")
+                        rule = result.get("rule")
+                        row_id = result.get("row_id")
+                        msg = result.get("message")
+                        logger.info(f"已尝试阻断 {dst}, added={added}, rule={rule}, row_id={row_id}, msg={msg}")
+                    else:
+                        # assume (resp, data)
+                        resp, data = result
+                        logger.info(f"已尝试阻断 {dst}, 状态: {resp.status_code}, 返回: {data}")
+                except Exception as e:
+                    logger.exception(f"阻断失败 {dst}: {e}")
+
+    def start(self):
+        self._stop.clear()
+        def _loop():
+            while not self._stop.is_set():
+                try:
+                    self.run_once()
+                except Exception:
+                    logger.exception("运行一次保护检查失败")
+                # reload interval in case config changed
+                try:
+                    cfg = load_config()
+                    self.interval = int(cfg.get("scan_interval", self.interval))
+                except Exception:
+                    pass
+                time.sleep(self.interval)
+
+        t = threading.Thread(target=_loop, daemon=True)
+        t.start()
+        return t
+
+    def stop(self):
+        self._stop.set()
+
+
+def run_protector_blocking_once():
+    pr = ProtectorRunner()
+    pr.run_once()
+
+
+if __name__ == "__main__":
+    r = ProtectorRunner()
+    r.start()
+    try:
+        while True:
+            time.sleep(1)
+    except KeyboardInterrupt:
+        r.stop()

+ 8 - 0
protector_list.json

@@ -0,0 +1,8 @@
+[
+  {
+    "id": 1,
+    "target_ip": "10.8.7.2",
+    "src_port": 443,
+    "threshold": 20
+  }
+]

+ 73 - 0
protector_list.py

@@ -0,0 +1,73 @@
+"""Simple manager for the protector list stored in protector_list.json.
+
+Each entry is a dict with keys:
+- id: int (unique)
+- target_ip: str
+- src_port: int
+- threshold: int | None
+
+This module provides simple load/save/add/remove/update helpers.
+"""
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import List, Optional, Dict
+
+ROOT = Path(__file__).resolve().parent
+LIST_PATH = ROOT / "protector_list.json"
+
+
+def load_list() -> List[Dict]:
+    if not LIST_PATH.exists():
+        return []
+    with open(LIST_PATH, "r", encoding="utf-8") as f:
+        return json.load(f)
+
+
+def save_list(items: List[Dict]) -> None:
+    with open(LIST_PATH, "w", encoding="utf-8") as f:
+        json.dump(items, f, ensure_ascii=False, indent=2)
+
+
+def _next_id(items: List[Dict]) -> int:
+    if not items:
+        return 1
+    return max(int(i.get("id", 0)) for i in items) + 1
+
+
+def add_entry(target_ip: str, src_port: int, threshold: Optional[int] = None) -> Dict:
+    items = load_list()
+    entry = {"id": _next_id(items), "target_ip": target_ip, "src_port": int(src_port), "threshold": (int(threshold) if threshold is not None else None)}
+    items.append(entry)
+    save_list(items)
+    return entry
+
+
+def remove_entry(entry_id: int) -> bool:
+    items = load_list()
+    new_items = [i for i in items if int(i.get("id", -1)) != int(entry_id)]
+    if len(new_items) == len(items):
+        return False
+    save_list(new_items)
+    return True
+
+
+def update_entry(entry_id: int, **kwargs) -> Optional[Dict]:
+    items = load_list()
+    found = None
+    for i in items:
+        if int(i.get("id", -1)) == int(entry_id):
+            i.update({k: v for k, v in kwargs.items() if v is not None})
+            found = i
+            break
+    if found is None:
+        return None
+    save_list(items)
+    return found
+
+
+if __name__ == "__main__":
+    # quick demo
+    print("Current protector list:")
+    print(load_list())

+ 1 - 0
test.bat

@@ -0,0 +1 @@
+tcping -w 0.5 -i 0.5 -t toobee.top 443