protector.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. """Background protector: periodically checks protector_list entries and blocks offending dst_addr via advanced_acl.add_ip
  2. """
  3. from __future__ import annotations
  4. import threading
  5. import time
  6. from typing import Optional
  7. import advanced_acl
  8. from auth_session import load_config
  9. import protector_list
  10. from analysis_connections import find_high_connections_for_src_ports
  11. from log_util import get_logger
  12. logger = get_logger("protector")
  13. class ProtectorRunner:
  14. def __init__(self):
  15. cfg = load_config()
  16. self.interval = int(cfg.get("scan_interval", 60))
  17. self._stop = threading.Event()
  18. self.last_run_time: Optional[float] = None
  19. # pause control: _pause_event is set when running, cleared when paused
  20. self._pause_event = threading.Event()
  21. self._pause_event.set()
  22. # optional callback invoked after each run_once() completes. Signature: fn()
  23. self.on_run_complete = None
  24. def run_once(self):
  25. # if paused, skip immediate execution (protect against external triggers)
  26. try:
  27. if not self._pause_event.is_set():
  28. logger.info("Protector is paused, skipping run_once")
  29. return
  30. except Exception:
  31. pass
  32. # record start time to help external coordinators avoid duplicate runs
  33. try:
  34. self.last_run_time = time.time()
  35. except Exception:
  36. pass
  37. items = protector_list.load_list()
  38. for entry in items:
  39. target = entry.get("target_ip")
  40. src_port = int(entry.get("src_port"))
  41. threshold = entry.get("threshold")
  42. try:
  43. res = find_high_connections_for_src_ports(target, src_port, threshold=threshold)
  44. except Exception as e:
  45. logger.exception(f"检查失败 {target}:{src_port} - {e}")
  46. continue
  47. matches = res.get("by_src_port", {}).get(int(src_port), [])
  48. if not matches:
  49. logger.info(f"{target}:{src_port} - 无异常连接")
  50. continue
  51. for m in matches:
  52. dst = m.get("dst_addr")
  53. # Call advanced_acl.add_ip to block
  54. try:
  55. # Do not pass custom comment here so add_ip will place the IP into Test_* groups
  56. # (comment was causing new rules to be named AutoBlock_..., not grouped).
  57. result = advanced_acl.add_ip(dst)
  58. # Support both legacy (resp, data) tuple and new dict result
  59. if isinstance(result, dict):
  60. added = result.get("added")
  61. rule = result.get("rule")
  62. row_id = result.get("row_id")
  63. msg = result.get("message")
  64. logger.info(f"已尝试阻断 {dst}, added={added}, rule={rule}, row_id={row_id}, msg={msg}")
  65. else:
  66. # assume (resp, data)
  67. resp, data = result
  68. logger.info(f"已尝试阻断 {dst}, 状态: {resp.status_code}, 返回: {data}")
  69. except Exception as e:
  70. logger.exception(f"阻断失败 {dst}: {e}")
  71. # invoke completion callback (best-effort, do not allow callback exceptions to break flow)
  72. try:
  73. if callable(self.on_run_complete):
  74. try:
  75. self.on_run_complete()
  76. except Exception:
  77. logger.exception("on_run_complete callback failed")
  78. except Exception:
  79. pass
  80. def start(self):
  81. self._stop.clear()
  82. # mark last_run_time when loop is started so external callers can compute
  83. # next run time predictably (will be updated when run_once actually runs)
  84. try:
  85. self.last_run_time = time.time()
  86. except Exception:
  87. pass
  88. def _loop():
  89. while not self._stop.is_set():
  90. # respect pause state
  91. if not self._pause_event.is_set():
  92. # paused: wait until unpaused or stopped
  93. # wake every 1s to check stop flag
  94. self._pause_event.wait(timeout=1)
  95. continue
  96. try:
  97. self.run_once()
  98. except Exception:
  99. logger.exception("运行一次保护检查失败")
  100. # reload interval in case config changed
  101. try:
  102. cfg = load_config()
  103. self.interval = int(cfg.get("scan_interval", self.interval))
  104. except Exception:
  105. pass
  106. # sleep in one-second steps so we can be responsive to pause/stop
  107. slept = 0
  108. while slept < self.interval and not self._stop.is_set():
  109. if not self._pause_event.is_set():
  110. break
  111. time.sleep(1)
  112. slept += 1
  113. t = threading.Thread(target=_loop, daemon=True)
  114. t.start()
  115. return t
  116. def stop(self):
  117. self._stop.set()
  118. def get_interval(self) -> int:
  119. """Return current scan interval in seconds."""
  120. try:
  121. cfg = load_config()
  122. return int(cfg.get("scan_interval", self.interval))
  123. except Exception:
  124. return self.interval
  125. def get_next_run_time(self) -> float | None:
  126. """Return the epoch timestamp of the next scheduled run, or None if unknown."""
  127. try:
  128. interval = self.get_interval()
  129. if self.last_run_time:
  130. return float(self.last_run_time + interval)
  131. # if never run, next run is interval seconds from now (if started)
  132. return float(time.time() + interval)
  133. except Exception:
  134. return None
  135. def pause(self):
  136. """Pause periodic execution."""
  137. try:
  138. self._pause_event.clear()
  139. logger.info("Protector 已暂停")
  140. except Exception:
  141. pass
  142. def resume(self):
  143. """Resume periodic execution."""
  144. try:
  145. self._pause_event.set()
  146. logger.info("Protector 已恢复")
  147. except Exception:
  148. pass
  149. def is_paused(self) -> bool:
  150. return not self._pause_event.is_set()
  151. def run_protector_blocking_once():
  152. pr = ProtectorRunner()
  153. pr.run_once()
  154. if __name__ == "__main__":
  155. r = ProtectorRunner()
  156. r.start()
  157. try:
  158. while True:
  159. time.sleep(1)
  160. except KeyboardInterrupt:
  161. r.stop()