101 lines
3.7 KiB
Python
101 lines
3.7 KiB
Python
from PyQt6.QtCore import QThread, pyqtSignal
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
# We need these imports inside workers since they use them
|
|
from core.scanner import scan_network
|
|
from core.flasher import flash_device
|
|
from core.ssh_flasher import flash_device_ssh
|
|
from utils.network import _resolve_hostname
|
|
|
|
class ScanThread(QThread):
|
|
"""Run network scan in a background thread so UI doesn't freeze."""
|
|
finished = pyqtSignal(list)
|
|
error = pyqtSignal(str)
|
|
scan_progress = pyqtSignal(int, int) # done, total (ping sweep)
|
|
stage = pyqtSignal(str) # current scan phase
|
|
|
|
def __init__(self, network):
|
|
super().__init__()
|
|
self.network = network
|
|
|
|
def run(self):
|
|
try:
|
|
def on_ping_progress(done, total):
|
|
self.scan_progress.emit(done, total)
|
|
|
|
def on_stage(s):
|
|
self.stage.emit(s)
|
|
|
|
results = scan_network(self.network,
|
|
progress_cb=on_ping_progress,
|
|
stage_cb=on_stage)
|
|
# Resolve hostnames in parallel
|
|
self.stage.emit("hostname")
|
|
with ThreadPoolExecutor(max_workers=50) as executor:
|
|
future_to_dev = {
|
|
executor.submit(_resolve_hostname, d["ip"]): d
|
|
for d in results
|
|
}
|
|
for future in as_completed(future_to_dev):
|
|
dev = future_to_dev[future]
|
|
try:
|
|
dev["name"] = future.result(timeout=3)
|
|
except Exception:
|
|
dev["name"] = ""
|
|
self.finished.emit(results)
|
|
except Exception as e:
|
|
self.error.emit(str(e))
|
|
|
|
|
|
class FlashThread(QThread):
|
|
"""Run firmware flash in background so UI stays responsive."""
|
|
device_status = pyqtSignal(int, str) # index, status message
|
|
device_done = pyqtSignal(int, str) # index, result
|
|
all_done = pyqtSignal()
|
|
|
|
def __init__(self, devices, firmware_path, max_workers=10,
|
|
method="api", ssh_user="root", ssh_password="admin123a",
|
|
set_passwd=False):
|
|
super().__init__()
|
|
self.devices = devices
|
|
self.firmware_path = firmware_path
|
|
self.max_workers = max_workers
|
|
self.method = method
|
|
self.ssh_user = ssh_user
|
|
self.ssh_password = ssh_password
|
|
self.set_passwd = set_passwd
|
|
|
|
def run(self):
|
|
def _flash_one(i, dev):
|
|
try:
|
|
def on_status(msg):
|
|
self.device_status.emit(i, msg)
|
|
|
|
if self.method == "ssh":
|
|
result = flash_device_ssh(
|
|
dev["ip"], self.firmware_path,
|
|
user=self.ssh_user,
|
|
password=self.ssh_password,
|
|
set_passwd=self.set_passwd,
|
|
status_cb=on_status
|
|
)
|
|
else:
|
|
result = flash_device(
|
|
dev["ip"], self.firmware_path,
|
|
status_cb=on_status
|
|
)
|
|
self.device_done.emit(i, result)
|
|
except Exception as e:
|
|
self.device_done.emit(i, f"FAIL: {e}")
|
|
|
|
# Use configured max_workers (0 = unlimited = one per device)
|
|
workers = self.max_workers if self.max_workers > 0 else len(self.devices)
|
|
with ThreadPoolExecutor(max_workers=workers) as executor:
|
|
futures = []
|
|
for i, dev in enumerate(self.devices):
|
|
futures.append(executor.submit(_flash_one, i, dev))
|
|
for f in futures:
|
|
f.result()
|
|
|
|
self.all_done.emit()
|