protector.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. """Background protector: periodically checks protector_list entries and blocks offending dst_addr via advanced_acl.add_ip
  2. """
  3. from __future__ import annotations
  4. import threading
  5. import time
  6. from typing import Optional
  7. import advanced_acl
  8. from auth_session import load_config
  9. import protector_list
  10. from analysis_connections import find_high_connections_for_src_ports
  11. from log_util import get_logger
  12. logger = get_logger("protector")
  13. class ProtectorRunner:
  14. def __init__(self):
  15. cfg = load_config()
  16. self.interval = int(cfg.get("scan_interval", 60))
  17. self._stop = threading.Event()
  18. self.last_run_time: Optional[float] = None
  19. def run_once(self):
  20. # record start time to help external coordinators avoid duplicate runs
  21. try:
  22. self.last_run_time = time.time()
  23. except Exception:
  24. pass
  25. items = protector_list.load_list()
  26. for entry in items:
  27. target = entry.get("target_ip")
  28. src_port = int(entry.get("src_port"))
  29. threshold = entry.get("threshold")
  30. try:
  31. res = find_high_connections_for_src_ports(target, src_port, threshold=threshold)
  32. except Exception as e:
  33. logger.exception(f"检查失败 {target}:{src_port} - {e}")
  34. continue
  35. matches = res.get("by_src_port", {}).get(int(src_port), [])
  36. if not matches:
  37. logger.info(f"{target}:{src_port} - 无异常连接")
  38. continue
  39. for m in matches:
  40. dst = m.get("dst_addr")
  41. # Call advanced_acl.add_ip to block
  42. try:
  43. # Do not pass custom comment here so add_ip will place the IP into Test_* groups
  44. # (comment was causing new rules to be named AutoBlock_..., not grouped).
  45. result = advanced_acl.add_ip(dst)
  46. # Support both legacy (resp, data) tuple and new dict result
  47. if isinstance(result, dict):
  48. added = result.get("added")
  49. rule = result.get("rule")
  50. row_id = result.get("row_id")
  51. msg = result.get("message")
  52. logger.info(f"已尝试阻断 {dst}, added={added}, rule={rule}, row_id={row_id}, msg={msg}")
  53. else:
  54. # assume (resp, data)
  55. resp, data = result
  56. logger.info(f"已尝试阻断 {dst}, 状态: {resp.status_code}, 返回: {data}")
  57. except Exception as e:
  58. logger.exception(f"阻断失败 {dst}: {e}")
  59. def start(self):
  60. self._stop.clear()
  61. def _loop():
  62. while not self._stop.is_set():
  63. try:
  64. self.run_once()
  65. except Exception:
  66. logger.exception("运行一次保护检查失败")
  67. # reload interval in case config changed
  68. try:
  69. cfg = load_config()
  70. self.interval = int(cfg.get("scan_interval", self.interval))
  71. except Exception:
  72. pass
  73. time.sleep(self.interval)
  74. t = threading.Thread(target=_loop, daemon=True)
  75. t.start()
  76. return t
  77. def stop(self):
  78. self._stop.set()
  79. def get_interval(self) -> int:
  80. """Return current scan interval in seconds."""
  81. try:
  82. cfg = load_config()
  83. return int(cfg.get("scan_interval", self.interval))
  84. except Exception:
  85. return self.interval
  86. def run_protector_blocking_once():
  87. pr = ProtectorRunner()
  88. pr.run_once()
  89. if __name__ == "__main__":
  90. r = ProtectorRunner()
  91. r.start()
  92. try:
  93. while True:
  94. time.sleep(1)
  95. except KeyboardInterrupt:
  96. r.stop()