312 lines
9.5 KiB
Python
312 lines
9.5 KiB
Python
"""Optional remote SSH installer — disabled by default (ENABLE_REMOTE_INSTALLER)."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import ipaddress
|
|
import json
|
|
import re
|
|
import shlex
|
|
import socket
|
|
import time
|
|
import uuid
|
|
from collections import defaultdict
|
|
from typing import Annotated, Any, Literal, Optional, Union
|
|
from urllib.parse import urlparse
|
|
|
|
from fastapi import APIRouter, HTTPException, Request, WebSocket
|
|
from pydantic import BaseModel, Field, field_validator
|
|
|
|
from app.core.config import get_settings
|
|
|
|
router = APIRouter(prefix="/public-install", tags=["public-install"])
|
|
|
|
_jobs: dict[str, dict[str, Any]] = {}
|
|
_rate_buckets: dict[str, list[float]] = defaultdict(list)
|
|
_jobs_lock = asyncio.Lock()
|
|
|
|
|
|
def _safe_host(host: str) -> str:
|
|
h = host.strip()
|
|
if not h or len(h) > 253:
|
|
raise ValueError("invalid host")
|
|
if re.search(r"[\s;|&$`\\'\"\n<>()]", h):
|
|
raise ValueError("invalid host characters")
|
|
try:
|
|
ipaddress.ip_address(h)
|
|
return h
|
|
except ValueError:
|
|
pass
|
|
if not re.match(r"^[a-zA-Z0-9.\-]+$", h):
|
|
raise ValueError("invalid hostname")
|
|
return h
|
|
|
|
|
|
def _validate_install_url(url: str) -> str:
|
|
p = urlparse(url.strip())
|
|
if p.scheme != "https":
|
|
raise ValueError("install_url must use https")
|
|
if not p.netloc or p.username is not None or p.password is not None:
|
|
raise ValueError("invalid install URL")
|
|
return url.strip()
|
|
|
|
|
|
async def _target_ip_allowed(host: str) -> bool:
|
|
settings = get_settings()
|
|
raw = settings.remote_install_allowed_target_cidrs.strip()
|
|
if not raw:
|
|
return True
|
|
cidrs: list = []
|
|
for part in raw.split(","):
|
|
part = part.strip()
|
|
if not part:
|
|
continue
|
|
try:
|
|
cidrs.append(ipaddress.ip_network(part, strict=False))
|
|
except ValueError:
|
|
continue
|
|
if not cidrs:
|
|
return True
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def resolve() -> list[str]:
|
|
out: list[str] = []
|
|
try:
|
|
for fam, _ty, _pr, _cn, sa in socket.getaddrinfo(host, None, type=socket.SOCK_STREAM):
|
|
out.append(sa[0])
|
|
except socket.gaierror:
|
|
return []
|
|
return out
|
|
|
|
addrs = await loop.run_in_executor(None, resolve)
|
|
if not addrs:
|
|
return False
|
|
for addr in addrs:
|
|
try:
|
|
ip = ipaddress.ip_address(addr)
|
|
if any(ip in net for net in cidrs):
|
|
return True
|
|
except ValueError:
|
|
continue
|
|
return False
|
|
|
|
|
|
def _check_rate_limit(client_ip: str) -> None:
|
|
settings = get_settings()
|
|
lim = settings.remote_install_rate_limit_per_ip
|
|
if lim <= 0:
|
|
return
|
|
window_sec = max(1, settings.remote_install_rate_window_minutes) * 60
|
|
now = time.monotonic()
|
|
bucket = _rate_buckets[client_ip]
|
|
bucket[:] = [t for t in bucket if now - t < window_sec]
|
|
if len(bucket) >= lim:
|
|
raise HTTPException(status_code=429, detail="Rate limit exceeded. Try again later.")
|
|
bucket.append(now)
|
|
|
|
|
|
class AuthKey(BaseModel):
|
|
type: Literal["key"] = "key"
|
|
private_key: str = Field(..., min_length=1)
|
|
passphrase: Optional[str] = None
|
|
|
|
|
|
class AuthPassword(BaseModel):
|
|
type: Literal["password"] = "password"
|
|
password: str = Field(..., min_length=1)
|
|
|
|
|
|
class CreateJobRequest(BaseModel):
|
|
host: str
|
|
port: int = Field(default=22, ge=1, le=65535)
|
|
username: str = Field(..., min_length=1, max_length=64)
|
|
auth: Annotated[Union[AuthKey, AuthPassword], Field(discriminator="type")]
|
|
install_url: Optional[str] = None
|
|
|
|
@field_validator("host")
|
|
@classmethod
|
|
def host_ok(cls, v: str) -> str:
|
|
return _safe_host(v)
|
|
|
|
@field_validator("username")
|
|
@classmethod
|
|
def user_ok(cls, v: str) -> str:
|
|
u = v.strip()
|
|
if re.search(r"[\s;|&$`\\'\"\n<>]", u):
|
|
raise ValueError("invalid username")
|
|
return u
|
|
|
|
@field_validator("install_url")
|
|
@classmethod
|
|
def url_ok(cls, v: Optional[str]) -> Optional[str]:
|
|
if v is None or v == "":
|
|
return None
|
|
return _validate_install_url(v)
|
|
|
|
|
|
class CreateJobResponse(BaseModel):
|
|
job_id: str
|
|
|
|
|
|
def _broadcast(job: dict[str, Any], msg: str) -> None:
|
|
for q in list(job.get("channels", ())):
|
|
try:
|
|
q.put_nowait(msg)
|
|
except asyncio.QueueFull:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
@router.get("/config")
|
|
async def installer_config():
|
|
s = get_settings()
|
|
return {
|
|
"enabled": s.enable_remote_installer,
|
|
"default_install_url": s.remote_install_default_url,
|
|
}
|
|
|
|
|
|
@router.post("/jobs", response_model=CreateJobResponse)
|
|
async def create_job(body: CreateJobRequest, request: Request):
|
|
settings = get_settings()
|
|
if not settings.enable_remote_installer:
|
|
raise HTTPException(status_code=403, detail="Remote installer is disabled")
|
|
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
_check_rate_limit(client_ip)
|
|
|
|
host = body.host
|
|
if not await _target_ip_allowed(host):
|
|
raise HTTPException(status_code=400, detail="Target host is not in allowed CIDR list")
|
|
|
|
url = body.install_url or settings.remote_install_default_url
|
|
try:
|
|
url = _validate_install_url(url)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
job_id = uuid.uuid4().hex
|
|
channels: set[asyncio.Queue] = set()
|
|
|
|
inner = f"curl -fsSL {shlex.quote(url)} | bash"
|
|
if body.username == "root":
|
|
remote_cmd = f"bash -lc {shlex.quote(inner)}"
|
|
else:
|
|
remote_cmd = f"sudo -n bash -lc {shlex.quote(inner)}"
|
|
|
|
auth_payload = body.auth
|
|
|
|
async def runner() -> None:
|
|
import asyncssh
|
|
|
|
async with _jobs_lock:
|
|
job = _jobs.get(job_id)
|
|
if not job:
|
|
return
|
|
|
|
exit_code: Optional[int] = None
|
|
|
|
def broadcast(msg: str) -> None:
|
|
_broadcast(job, msg)
|
|
|
|
try:
|
|
connect_kw: dict[str, Any] = {
|
|
"host": host,
|
|
"port": body.port,
|
|
"username": body.username,
|
|
"known_hosts": None,
|
|
"connect_timeout": 30,
|
|
}
|
|
if auth_payload.type == "key":
|
|
try:
|
|
key = asyncssh.import_private_key(
|
|
auth_payload.private_key.encode(),
|
|
passphrase=auth_payload.passphrase or None,
|
|
)
|
|
except Exception:
|
|
broadcast(json.dumps({"type": "line", "text": "Invalid private key or passphrase"}))
|
|
broadcast(json.dumps({"type": "done", "exit_code": -1}))
|
|
return
|
|
connect_kw["client_keys"] = [key]
|
|
else:
|
|
connect_kw["password"] = auth_payload.password
|
|
|
|
async with asyncssh.connect(**connect_kw) as conn:
|
|
async with conn.create_process(remote_cmd) as proc:
|
|
|
|
async def pump(stream: Any, is_err: bool) -> None:
|
|
while True:
|
|
line = await stream.readline()
|
|
if not line:
|
|
break
|
|
text = line.decode(errors="replace").rstrip("\n\r")
|
|
prefix = "[stderr] " if is_err else ""
|
|
broadcast(json.dumps({"type": "line", "text": prefix + text}))
|
|
|
|
await asyncio.gather(
|
|
pump(proc.stdout, False),
|
|
pump(proc.stderr, True),
|
|
)
|
|
await proc.wait()
|
|
exit_code = proc.exit_status
|
|
except asyncssh.Error as e:
|
|
msg = str(e).split("\n")[0][:240]
|
|
broadcast(json.dumps({"type": "line", "text": "SSH error: " + msg}))
|
|
except OSError as e:
|
|
broadcast(json.dumps({"type": "line", "text": "Connection error: " + str(e)[:200]}))
|
|
except Exception:
|
|
broadcast(json.dumps({"type": "line", "text": "Unexpected installer error"}))
|
|
|
|
broadcast(json.dumps({"type": "done", "exit_code": exit_code if exit_code is not None else -1}))
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def _purge() -> None:
|
|
_jobs.pop(job_id, None)
|
|
|
|
loop.call_later(900, _purge)
|
|
|
|
async with _jobs_lock:
|
|
_jobs[job_id] = {"channels": channels}
|
|
task = asyncio.create_task(runner())
|
|
_jobs[job_id]["task"] = task
|
|
|
|
return CreateJobResponse(job_id=job_id)
|
|
|
|
|
|
@router.websocket("/ws/{job_id}")
|
|
async def job_ws(websocket: WebSocket, job_id: str):
|
|
settings = get_settings()
|
|
if not settings.enable_remote_installer:
|
|
await websocket.close(code=4403)
|
|
return
|
|
|
|
await websocket.accept()
|
|
async with _jobs_lock:
|
|
job = _jobs.get(job_id)
|
|
if not job:
|
|
await websocket.send_text(json.dumps({"type": "line", "text": "Unknown or expired job_id"}))
|
|
await websocket.close()
|
|
return
|
|
|
|
q: asyncio.Queue = asyncio.Queue(maxsize=500)
|
|
job["channels"].add(q)
|
|
|
|
try:
|
|
while True:
|
|
try:
|
|
msg = await asyncio.wait_for(q.get(), timeout=7200.0)
|
|
except asyncio.TimeoutError:
|
|
await websocket.send_text(json.dumps({"type": "line", "text": "… idle timeout"}))
|
|
break
|
|
await websocket.send_text(msg)
|
|
try:
|
|
data = json.loads(msg)
|
|
if data.get("type") == "done":
|
|
break
|
|
except json.JSONDecodeError:
|
|
break
|
|
finally:
|
|
job["channels"].discard(q)
|
|
await websocket.close()
|