فهرست منبع

refactor(auth): 统一使用 Optional 类型注解并优化配置加载逻辑

- 将所有 dict | None 和 str | None 类型注解统一改为 Optional[dict] 和 Optional[str]
- 重构配置文件加载逻辑,支持运行时目录和项目默认配置
- 添加 get_app_dir() 函数用于确定应用运行目录
- 优化日志和保护列表文件路径生成逻辑
- 在 login 函数中添加 base_url 占位符检查和虚拟响应处理
- 添加 --debug 参数支持在 Windows 上分配控制台显示日志
- 修复 find_high_connections_for_src_ports 函数缩进问题
mcbaiyun 2 ماه پیش
والد
کامیت
e43cb3881e
9فایلهای تغییر یافته به همراه155 افزوده شده و 30 حذف شده
  1. 1 1
      acl_actions.py
  2. 1 1
      advanced_acl.py
  3. 3 3
      analysis_connections.py
  4. 104 17
      auth_session.py
  5. 1 0
      build_exe.bat
  6. 14 1
      log_util.py
  7. 17 6
      main.py
  8. 1 1
      protector.py
  9. 13 0
      protector_list.py

+ 1 - 1
acl_actions.py

@@ -31,7 +31,7 @@ def _make_cookies_from_sess_key(sess_key: str) -> dict:
     return {"sess_key": val}
 
 
-def get_acl_rules(payload: dict | None = None, timeout: int = 10) -> Tuple[requests.Response, Optional[dict]]:
+def get_acl_rules(payload: Optional[dict] = None, timeout: int = 10) -> Tuple[requests.Response, Optional[dict]]:
     base = get_base_url()
     url = f"{base}/Action/call"
     data = payload or DEFAULT_SHOW_PAYLOAD

+ 1 - 1
advanced_acl.py

@@ -96,7 +96,7 @@ def _collect_prefixed_rules() -> List[Dict]:
     return matched
 
 
-def add_ip(ip: str, comment: str | None = None) -> Dict:
+def add_ip(ip: str, comment: Optional[str] = None) -> Dict:
     """Add a single IP to the Test_* grouped rules.
 
     If `comment` is provided it will be used as the comment when creating a new rule.

+ 3 - 3
analysis_connections.py

@@ -111,7 +111,7 @@ if __name__ == "__main__":
 	main()
 
 
-def find_high_connection_pairs(ip: str, threshold: int | None = None, limit: str = "0,100000") -> dict:
+def find_high_connection_pairs(ip: str, threshold: Optional[int] = None, limit: str = "0,100000") -> dict:
 	"""Fetch connections for `ip` and return pairs (src_port, dst_addr) whose count > threshold.
 
 	Returns a dict with keys: threshold, total_reported, fetched, matches (list of dicts).
@@ -134,7 +134,7 @@ def find_high_connection_pairs(ip: str, threshold: int | None = None, limit: str
 		return result
 
 
-	def find_high_connections_for_src_ports(ip: str, src_ports, threshold: int | None = None, limit: str = "0,100000") -> dict:
+	def find_high_connections_for_src_ports(ip: str, src_ports, threshold: Optional[int] = None, limit: str = "0,100000") -> dict:
 		"""Analyze only the specified src_ports (single int or iterable) for given ip.
 
 		Returns dict: {"threshold": n, "total_reported": x, "fetched": y, "by_src_port": {port: [{dst_addr, count, samples}, ...]}}
@@ -221,7 +221,7 @@ def find_high_connection_pairs(ip: str, threshold: int | None = None, limit: str
 	return result
 
 
-def find_high_connections_for_src_ports(ip: str, src_ports, threshold: int | None = None, limit: str = "0,100000") -> dict:
+def find_high_connections_for_src_ports(ip: str, src_ports, threshold: Optional[int] = None, limit: str = "0,100000") -> dict:
 	"""Analyze only the specified src_ports (single int or iterable) for given ip.
 
 	Returns dict: {"threshold": n, "total_reported": x, "fetched": y, "by_src_port": {port: [{dst_addr, count, samples}, ...]}}

+ 104 - 17
auth_session.py

@@ -18,6 +18,7 @@ import json
 import re
 import sys
 from pathlib import Path
+import os
 from threading import Lock
 from typing import Optional, Tuple
 
@@ -26,8 +27,38 @@ import requests
 from log_util import get_logger, log_request
 
 
+def get_app_dir() -> Path:
+    """Return the directory where runtime files should be read/written.
+
+    Logic:
+    - If running as a PyInstaller one-file bundle, prefer the directory of the
+      executable (Path(sys.executable).parent).
+    - Otherwise prefer the current working directory (Path.cwd()).
+    - This makes writes (logs, config, protector_list) appear next to the exe
+      when distributed.
+    """
+    try:
+        if getattr(sys, "frozen", False):
+            return Path(sys.executable).resolve().parent
+    except Exception:
+        pass
+
+    try:
+        return Path.cwd()
+    except Exception:
+        return Path(__file__).resolve().parent
+
+
 ROOT = Path(__file__).resolve().parent
-CONFIG_PATH = ROOT / "config.json"
+CONFIG_PATH = get_app_dir() / "config.json"
+
+
+DEFAULT_PAYLOAD = {
+    "username": "xiaobai",
+    "passwd": "dc81b4427df07fd6b3ebcb05a7b34daf",
+    "pass": "c2FsdF8xMXhpYW9iYWku",
+    "remember_password": "",
+}
 
 
 # --- simple thread-safe in-memory session state ---
@@ -57,10 +88,54 @@ def clear_sess_key() -> None:
 
 # --- login/request helpers ---
 def load_config() -> dict:
-    if not CONFIG_PATH.exists():
-        raise FileNotFoundError(f"配置文件未找到: {CONFIG_PATH}")
-    with open(CONFIG_PATH, "r", encoding="utf-8") as f:
-        return json.load(f)
+    """Load configuration.
+
+    If the runtime config (CONFIG_PATH) does not exist, attempt to create it by
+    copying the project's default `config.json` (ROOT / 'config.json') or using
+    reasonable defaults. The created file will be written next to the exe or in
+    the current working directory so that packaged exe creates files in its
+    directory.
+    """
+    # If config exists in runtime location, load it.
+    if CONFIG_PATH.exists():
+        with open(CONFIG_PATH, "r", encoding="utf-8") as f:
+            return json.load(f)
+
+    # Otherwise try to obtain a project default from source tree (useful during
+    # development and when shipping defaults inside the package).
+    project_default = ROOT / "config.json"
+    default_cfg = {
+        "base_url": "http://ip:port/",
+        "conn_threshold": 20,
+        "scan_interval": 60,
+        "test_prefix": "Test_",
+        "rule_ip_limit": 1000,
+    }
+    if project_default.exists():
+        try:
+            with open(project_default, "r", encoding="utf-8") as f:
+                default_cfg = json.load(f)
+        except Exception:
+            # ignore and fall back to embedded defaults
+            pass
+
+    # Ensure runtime directory exists and write config atomically
+    try:
+        CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
+        tmp = CONFIG_PATH.with_suffix(".tmp")
+        with open(tmp, "w", encoding="utf-8") as f:
+            json.dump(default_cfg, f, ensure_ascii=False, indent=2)
+        try:
+            tmp.replace(CONFIG_PATH)
+        except Exception:
+            # fallback rename
+            tmp.rename(CONFIG_PATH)
+    except Exception as e:
+        # If we cannot write, surface a FileNotFoundError to preserve original
+        # callers' behavior.
+        raise FileNotFoundError(f"无法创建配置文件 {CONFIG_PATH}: {e}")
+
+    return default_cfg
 
 
 def get_base_url() -> str:
@@ -72,7 +147,6 @@ def get_base_url() -> str:
 def get_host() -> str:
     """Return host:port (netloc) parsed from base_url.
 
-    Example: 'http://8.210.76.176:65531/' -> '8.210.76.176:65531'
     """
     from urllib.parse import urlparse
 
@@ -81,16 +155,6 @@ def get_host() -> str:
         return ""
     parsed = urlparse(base)
     return parsed.netloc
-
-
-DEFAULT_PAYLOAD = {
-    "username": "xiaobai",
-    "passwd": "dc81b4427df07fd6b3ebcb05a7b34daf",
-    "pass": "c2FsdF8xMXhpYW9iYWku",
-    "remember_password": "",
-}
-
-
 def _extract_sess_cookie_from_response(resp: requests.Response) -> Optional[str]:
     """Try to extract sess_key cookie string from response.
 
@@ -109,13 +173,36 @@ def _extract_sess_cookie_from_response(resp: requests.Response) -> Optional[str]
     return None
 
 
-def login(payload: dict | None = None, timeout: int = 10) -> Tuple[requests.Response, Optional[str]]:
+def login(payload: Optional[dict] = None, timeout: int = 10) -> Tuple[requests.Response, Optional[str]]:
     """Send login POST. Returns (response, sess_cookie_or_none).
 
     The function will also set the in-memory sess_key if it can be extracted.
     """
     cfg = load_config()
     base = cfg.get("base_url", "").rstrip("/")
+    # Validate base URL to avoid passing placeholders like "http://ip:port/" to requests
+    from urllib.parse import urlparse
+
+    parsed = urlparse(base)
+    # If base_url is a placeholder (intentionally invalid), do not raise an
+    # exception here. Return a lightweight dummy response so the app can
+    # continue (GUI can prompt the user to fix config.json on first run).
+    if not base or not parsed.scheme or not parsed.netloc or "ip:port" in base:
+        logger = get_logger("login_post")
+        logger.warning(
+            "base_url appears to be a placeholder; skipping network login. GUI can be used to set a real base_url."
+        )
+
+        class _DummyResp:
+            def __init__(self):
+                self.status_code = 0
+                self.text = "base_url placeholder - no network request performed"
+
+            def json(self):
+                raise ValueError("No JSON available")
+
+        return _DummyResp(), None
+
     url = f"{base}/Action/login"
     data = payload or DEFAULT_PAYLOAD
     logger = get_logger("login_post")

+ 1 - 0
build_exe.bat

@@ -0,0 +1 @@
+pyinstaller --onefile --windowed --noconfirm --clean --add-data "config.json;." --name main main.py

+ 14 - 1
log_util.py

@@ -4,9 +4,22 @@ import json
 from datetime import datetime, timedelta
 import glob
 import os
+import sys
 
 
-LOG_DIR = Path(__file__).resolve().parent / "logs"
+def _get_app_dir() -> Path:
+    try:
+        if getattr(sys, "frozen", False):
+            return Path(sys.executable).resolve().parent
+    except Exception:
+        pass
+    try:
+        return Path.cwd()
+    except Exception:
+        return Path(__file__).resolve().parent
+
+
+LOG_DIR = _get_app_dir() / "logs"
 LOG_DIR.mkdir(parents=True, exist_ok=True)
 
 

+ 17 - 6
main.py

@@ -15,14 +15,25 @@ def main():
 	logger = get_logger("main")
 	logger.info("开始 main()")
 
-	# If --gui passed, launch GUI and start protector runner
-	if "--gui" in sys.argv:
-		# create protector runner but DO NOT start it yet.
-		# GUI will start the protector after initial countdown to avoid immediate runs.
+	# If --debug passed, allocate a console so logs and prints show up.
+	# Default behavior: start GUI. If user passes --nogui, run headless flow.
+	if "--debug" in sys.argv:
+		# allocate a fresh console on Windows
+		try:
+			import ctypes
+			ctypes.windll.kernel32.AllocConsole()
+			# reopen std streams to the new console
+			import sys as _sys
+			_sys.stdout = open("CONOUT$", "w", encoding="utf-8", buffering=1)
+			_sys.stderr = open("CONOUT$", "w", encoding="utf-8", buffering=1)
+		except Exception:
+			# best-effort: continue if console allocation fails
+			pass
+
+	# Default: launch GUI unless explicitly requested not to
+	if "--nogui" not in sys.argv:
 		pr = ProtectorRunner()
-		# launch GUI (blocking) and pass protector runner so GUI can control its start
 		gui.run_gui(protector_runner=pr)
-		# when GUI exits, stop protector
 		pr.stop()
 		return
 

+ 1 - 1
protector.py

@@ -140,7 +140,7 @@ class ProtectorRunner:
         except Exception:
             return self.interval
 
-    def get_next_run_time(self) -> float | None:
+    def get_next_run_time(self) -> Optional[float]:
         """Return the epoch timestamp of the next scheduled run, or None if unknown."""
         try:
             interval = self.get_interval()

+ 13 - 0
protector_list.py

@@ -13,8 +13,21 @@ from __future__ import annotations
 import json
 from pathlib import Path
 from typing import List, Optional, Dict
+import sys
 
 ROOT = Path(__file__).resolve().parent
+def _get_app_dir() -> Path:
+    try:
+        if getattr(sys, "frozen", False):
+            return Path(sys.executable).resolve().parent
+    except Exception:
+        pass
+    try:
+        return Path.cwd()
+    except Exception:
+        return Path(__file__).resolve().parent
+
+ROOT = _get_app_dir()
 LIST_PATH = ROOT / "protector_list.json"