"""Optional remote SSH installer — disabled by default (ENABLE_REMOTE_INSTALLER).""" from __future__ import annotations import asyncio import ipaddress import json import re import shlex import socket import time import uuid from collections import defaultdict from typing import Annotated, Any, Literal, Optional, Union from urllib.parse import urlparse import asyncssh from fastapi import APIRouter, HTTPException, Request, WebSocket from pydantic import BaseModel, Field, field_validator from app.core.config import get_settings router = APIRouter(prefix="/public-install", tags=["public-install"]) _jobs: dict[str, dict[str, Any]] = {} _rate_buckets: dict[str, list[float]] = defaultdict(list) _jobs_lock = asyncio.Lock() def _safe_host(host: str) -> str: h = host.strip() if not h or len(h) > 253: raise ValueError("invalid host") if re.search(r"[\s;|&$`\\'\"\n<>()]", h): raise ValueError("invalid host characters") try: ipaddress.ip_address(h) return h except ValueError: pass if not re.match(r"^[a-zA-Z0-9.\-]+$", h): raise ValueError("invalid hostname") return h def _validate_install_url(url: str) -> str: p = urlparse(url.strip()) if p.scheme != "https": raise ValueError("install_url must use https") if not p.netloc or p.username is not None or p.password is not None: raise ValueError("invalid install URL") return url.strip() async def _target_ip_allowed(host: str) -> bool: settings = get_settings() raw = settings.remote_install_allowed_target_cidrs.strip() if not raw: return True cidrs: list = [] for part in raw.split(","): part = part.strip() if not part: continue try: cidrs.append(ipaddress.ip_network(part, strict=False)) except ValueError: continue if not cidrs: return True loop = asyncio.get_event_loop() def resolve() -> list[str]: out: list[str] = [] try: for fam, _ty, _pr, _cn, sa in socket.getaddrinfo(host, None, type=socket.SOCK_STREAM): out.append(sa[0]) except socket.gaierror: return [] return out addrs = await loop.run_in_executor(None, resolve) if not addrs: return False for addr in addrs: try: ip = ipaddress.ip_address(addr) if any(ip in net for net in cidrs): return True except ValueError: continue return False def _check_rate_limit(client_ip: str) -> None: settings = get_settings() lim = settings.remote_install_rate_limit_per_ip if lim <= 0: return window_sec = max(1, settings.remote_install_rate_window_minutes) * 60 now = time.monotonic() bucket = _rate_buckets[client_ip] bucket[:] = [t for t in bucket if now - t < window_sec] if len(bucket) >= lim: raise HTTPException(status_code=429, detail="Rate limit exceeded. Try again later.") bucket.append(now) class AuthKey(BaseModel): type: Literal["key"] = "key" private_key: str = Field(..., min_length=1) passphrase: Optional[str] = None class AuthPassword(BaseModel): type: Literal["password"] = "password" password: str = Field(..., min_length=1) class CreateJobRequest(BaseModel): host: str port: int = Field(default=22, ge=1, le=65535) username: str = Field(..., min_length=1, max_length=64) auth: Annotated[Union[AuthKey, AuthPassword], Field(discriminator="type")] install_url: Optional[str] = None @field_validator("host") @classmethod def host_ok(cls, v: str) -> str: return _safe_host(v) @field_validator("username") @classmethod def user_ok(cls, v: str) -> str: u = v.strip() if re.search(r"[\s;|&$`\\'\"\n<>]", u): raise ValueError("invalid username") return u @field_validator("install_url") @classmethod def url_ok(cls, v: Optional[str]) -> Optional[str]: if v is None or v == "": return None return _validate_install_url(v) class CreateJobResponse(BaseModel): job_id: str def _broadcast(job: dict[str, Any], msg: str) -> None: for q in list(job.get("channels", ())): try: q.put_nowait(msg) except asyncio.QueueFull: pass except Exception: pass @router.get("/config") async def installer_config(): s = get_settings() return { "enabled": s.enable_remote_installer, "default_install_url": s.remote_install_default_url, } @router.post("/jobs", response_model=CreateJobResponse) async def create_job(body: CreateJobRequest, request: Request): settings = get_settings() if not settings.enable_remote_installer: raise HTTPException(status_code=403, detail="Remote installer is disabled") client_ip = request.client.host if request.client else "unknown" _check_rate_limit(client_ip) host = body.host if not await _target_ip_allowed(host): raise HTTPException(status_code=400, detail="Target host is not in allowed CIDR list") url = body.install_url or settings.remote_install_default_url try: url = _validate_install_url(url) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) job_id = uuid.uuid4().hex channels: set[asyncio.Queue] = set() inner = f"curl -fsSL {shlex.quote(url)} | bash" if body.username == "root": remote_cmd = f"bash -lc {shlex.quote(inner)}" else: remote_cmd = f"sudo -n bash -lc {shlex.quote(inner)}" auth_payload = body.auth async def runner() -> None: async with _jobs_lock: job = _jobs.get(job_id) if not job: return exit_code: Optional[int] = None def broadcast(msg: str) -> None: _broadcast(job, msg) try: connect_kw: dict[str, Any] = { "host": host, "port": body.port, "username": body.username, "known_hosts": None, "connect_timeout": 30, } if auth_payload.type == "key": try: key = asyncssh.import_private_key( auth_payload.private_key.encode(), passphrase=auth_payload.passphrase or None, ) except Exception: broadcast(json.dumps({"type": "line", "text": "Invalid private key or passphrase"})) broadcast(json.dumps({"type": "done", "exit_code": -1})) return connect_kw["client_keys"] = [key] else: connect_kw["password"] = auth_payload.password async with asyncssh.connect(**connect_kw) as conn: async with conn.create_process(remote_cmd) as proc: async def pump(stream: Any, is_err: bool) -> None: while True: line = await stream.readline() if not line: break text = line.decode(errors="replace").rstrip("\n\r") prefix = "[stderr] " if is_err else "" broadcast(json.dumps({"type": "line", "text": prefix + text})) await asyncio.gather( pump(proc.stdout, False), pump(proc.stderr, True), ) await proc.wait() exit_code = proc.exit_status except asyncssh.Error as e: msg = str(e).split("\n")[0][:240] broadcast(json.dumps({"type": "line", "text": "SSH error: " + msg})) except OSError as e: broadcast(json.dumps({"type": "line", "text": "Connection error: " + str(e)[:200]})) except Exception: broadcast(json.dumps({"type": "line", "text": "Unexpected installer error"})) broadcast(json.dumps({"type": "done", "exit_code": exit_code if exit_code is not None else -1})) loop = asyncio.get_event_loop() def _purge() -> None: _jobs.pop(job_id, None) loop.call_later(900, _purge) async with _jobs_lock: _jobs[job_id] = {"channels": channels} task = asyncio.create_task(runner()) _jobs[job_id]["task"] = task return CreateJobResponse(job_id=job_id) @router.websocket("/ws/{job_id}") async def job_ws(websocket: WebSocket, job_id: str): settings = get_settings() if not settings.enable_remote_installer: await websocket.close(code=4403) return await websocket.accept() async with _jobs_lock: job = _jobs.get(job_id) if not job: await websocket.send_text(json.dumps({"type": "line", "text": "Unknown or expired job_id"})) await websocket.close() return q: asyncio.Queue = asyncio.Queue(maxsize=500) job["channels"].add(q) try: while True: try: msg = await asyncio.wait_for(q.get(), timeout=7200.0) except asyncio.TimeoutError: await websocket.send_text(json.dumps({"type": "line", "text": "… idle timeout"})) break await websocket.send_text(msg) try: data = json.loads(msg) if data.get("type") == "done": break except json.JSONDecodeError: break finally: job["channels"].discard(q) await websocket.close()