Initial YakPanel commit
This commit is contained in:
310
YakPanel-server/backend/app/api/public_installer.py
Normal file
310
YakPanel-server/backend/app/api/public_installer.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""Optional remote SSH installer — disabled by default (ENABLE_REMOTE_INSTALLER)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import json
|
||||
import re
|
||||
import shlex
|
||||
import socket
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import asyncssh
|
||||
from fastapi import APIRouter, HTTPException, Request, WebSocket
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
router = APIRouter(prefix="/public-install", tags=["public-install"])
|
||||
|
||||
_jobs: dict[str, dict[str, Any]] = {}
|
||||
_rate_buckets: dict[str, list[float]] = defaultdict(list)
|
||||
_jobs_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _safe_host(host: str) -> str:
|
||||
h = host.strip()
|
||||
if not h or len(h) > 253:
|
||||
raise ValueError("invalid host")
|
||||
if re.search(r"[\s;|&$`\\'\"\n<>()]", h):
|
||||
raise ValueError("invalid host characters")
|
||||
try:
|
||||
ipaddress.ip_address(h)
|
||||
return h
|
||||
except ValueError:
|
||||
pass
|
||||
if not re.match(r"^[a-zA-Z0-9.\-]+$", h):
|
||||
raise ValueError("invalid hostname")
|
||||
return h
|
||||
|
||||
|
||||
def _validate_install_url(url: str) -> str:
|
||||
p = urlparse(url.strip())
|
||||
if p.scheme != "https":
|
||||
raise ValueError("install_url must use https")
|
||||
if not p.netloc or p.username is not None or p.password is not None:
|
||||
raise ValueError("invalid install URL")
|
||||
return url.strip()
|
||||
|
||||
|
||||
async def _target_ip_allowed(host: str) -> bool:
|
||||
settings = get_settings()
|
||||
raw = settings.remote_install_allowed_target_cidrs.strip()
|
||||
if not raw:
|
||||
return True
|
||||
cidrs: list = []
|
||||
for part in raw.split(","):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
try:
|
||||
cidrs.append(ipaddress.ip_network(part, strict=False))
|
||||
except ValueError:
|
||||
continue
|
||||
if not cidrs:
|
||||
return True
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def resolve() -> list[str]:
|
||||
out: list[str] = []
|
||||
try:
|
||||
for fam, _ty, _pr, _cn, sa in socket.getaddrinfo(host, None, type=socket.SOCK_STREAM):
|
||||
out.append(sa[0])
|
||||
except socket.gaierror:
|
||||
return []
|
||||
return out
|
||||
|
||||
addrs = await loop.run_in_executor(None, resolve)
|
||||
if not addrs:
|
||||
return False
|
||||
for addr in addrs:
|
||||
try:
|
||||
ip = ipaddress.ip_address(addr)
|
||||
if any(ip in net for net in cidrs):
|
||||
return True
|
||||
except ValueError:
|
||||
continue
|
||||
return False
|
||||
|
||||
|
||||
def _check_rate_limit(client_ip: str) -> None:
|
||||
settings = get_settings()
|
||||
lim = settings.remote_install_rate_limit_per_ip
|
||||
if lim <= 0:
|
||||
return
|
||||
window_sec = max(1, settings.remote_install_rate_window_minutes) * 60
|
||||
now = time.monotonic()
|
||||
bucket = _rate_buckets[client_ip]
|
||||
bucket[:] = [t for t in bucket if now - t < window_sec]
|
||||
if len(bucket) >= lim:
|
||||
raise HTTPException(status_code=429, detail="Rate limit exceeded. Try again later.")
|
||||
bucket.append(now)
|
||||
|
||||
|
||||
class AuthKey(BaseModel):
|
||||
type: Literal["key"] = "key"
|
||||
private_key: str = Field(..., min_length=1)
|
||||
passphrase: Optional[str] = None
|
||||
|
||||
|
||||
class AuthPassword(BaseModel):
|
||||
type: Literal["password"] = "password"
|
||||
password: str = Field(..., min_length=1)
|
||||
|
||||
|
||||
class CreateJobRequest(BaseModel):
|
||||
host: str
|
||||
port: int = Field(default=22, ge=1, le=65535)
|
||||
username: str = Field(..., min_length=1, max_length=64)
|
||||
auth: Annotated[Union[AuthKey, AuthPassword], Field(discriminator="type")]
|
||||
install_url: Optional[str] = None
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def host_ok(cls, v: str) -> str:
|
||||
return _safe_host(v)
|
||||
|
||||
@field_validator("username")
|
||||
@classmethod
|
||||
def user_ok(cls, v: str) -> str:
|
||||
u = v.strip()
|
||||
if re.search(r"[\s;|&$`\\'\"\n<>]", u):
|
||||
raise ValueError("invalid username")
|
||||
return u
|
||||
|
||||
@field_validator("install_url")
|
||||
@classmethod
|
||||
def url_ok(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
return _validate_install_url(v)
|
||||
|
||||
|
||||
class CreateJobResponse(BaseModel):
|
||||
job_id: str
|
||||
|
||||
|
||||
def _broadcast(job: dict[str, Any], msg: str) -> None:
|
||||
for q in list(job.get("channels", ())):
|
||||
try:
|
||||
q.put_nowait(msg)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def installer_config():
|
||||
s = get_settings()
|
||||
return {
|
||||
"enabled": s.enable_remote_installer,
|
||||
"default_install_url": s.remote_install_default_url,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/jobs", response_model=CreateJobResponse)
|
||||
async def create_job(body: CreateJobRequest, request: Request):
|
||||
settings = get_settings()
|
||||
if not settings.enable_remote_installer:
|
||||
raise HTTPException(status_code=403, detail="Remote installer is disabled")
|
||||
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
_check_rate_limit(client_ip)
|
||||
|
||||
host = body.host
|
||||
if not await _target_ip_allowed(host):
|
||||
raise HTTPException(status_code=400, detail="Target host is not in allowed CIDR list")
|
||||
|
||||
url = body.install_url or settings.remote_install_default_url
|
||||
try:
|
||||
url = _validate_install_url(url)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
job_id = uuid.uuid4().hex
|
||||
channels: set[asyncio.Queue] = set()
|
||||
|
||||
inner = f"curl -fsSL {shlex.quote(url)} | bash"
|
||||
if body.username == "root":
|
||||
remote_cmd = f"bash -lc {shlex.quote(inner)}"
|
||||
else:
|
||||
remote_cmd = f"sudo -n bash -lc {shlex.quote(inner)}"
|
||||
|
||||
auth_payload = body.auth
|
||||
|
||||
async def runner() -> None:
|
||||
async with _jobs_lock:
|
||||
job = _jobs.get(job_id)
|
||||
if not job:
|
||||
return
|
||||
|
||||
exit_code: Optional[int] = None
|
||||
|
||||
def broadcast(msg: str) -> None:
|
||||
_broadcast(job, msg)
|
||||
|
||||
try:
|
||||
connect_kw: dict[str, Any] = {
|
||||
"host": host,
|
||||
"port": body.port,
|
||||
"username": body.username,
|
||||
"known_hosts": None,
|
||||
"connect_timeout": 30,
|
||||
}
|
||||
if auth_payload.type == "key":
|
||||
try:
|
||||
key = asyncssh.import_private_key(
|
||||
auth_payload.private_key.encode(),
|
||||
passphrase=auth_payload.passphrase or None,
|
||||
)
|
||||
except Exception:
|
||||
broadcast(json.dumps({"type": "line", "text": "Invalid private key or passphrase"}))
|
||||
broadcast(json.dumps({"type": "done", "exit_code": -1}))
|
||||
return
|
||||
connect_kw["client_keys"] = [key]
|
||||
else:
|
||||
connect_kw["password"] = auth_payload.password
|
||||
|
||||
async with asyncssh.connect(**connect_kw) as conn:
|
||||
async with conn.create_process(remote_cmd) as proc:
|
||||
|
||||
async def pump(stream: Any, is_err: bool) -> None:
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if not line:
|
||||
break
|
||||
text = line.decode(errors="replace").rstrip("\n\r")
|
||||
prefix = "[stderr] " if is_err else ""
|
||||
broadcast(json.dumps({"type": "line", "text": prefix + text}))
|
||||
|
||||
await asyncio.gather(
|
||||
pump(proc.stdout, False),
|
||||
pump(proc.stderr, True),
|
||||
)
|
||||
await proc.wait()
|
||||
exit_code = proc.exit_status
|
||||
except asyncssh.Error as e:
|
||||
msg = str(e).split("\n")[0][:240]
|
||||
broadcast(json.dumps({"type": "line", "text": "SSH error: " + msg}))
|
||||
except OSError as e:
|
||||
broadcast(json.dumps({"type": "line", "text": "Connection error: " + str(e)[:200]}))
|
||||
except Exception:
|
||||
broadcast(json.dumps({"type": "line", "text": "Unexpected installer error"}))
|
||||
|
||||
broadcast(json.dumps({"type": "done", "exit_code": exit_code if exit_code is not None else -1}))
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _purge() -> None:
|
||||
_jobs.pop(job_id, None)
|
||||
|
||||
loop.call_later(900, _purge)
|
||||
|
||||
async with _jobs_lock:
|
||||
_jobs[job_id] = {"channels": channels}
|
||||
task = asyncio.create_task(runner())
|
||||
_jobs[job_id]["task"] = task
|
||||
|
||||
return CreateJobResponse(job_id=job_id)
|
||||
|
||||
|
||||
@router.websocket("/ws/{job_id}")
|
||||
async def job_ws(websocket: WebSocket, job_id: str):
|
||||
settings = get_settings()
|
||||
if not settings.enable_remote_installer:
|
||||
await websocket.close(code=4403)
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
async with _jobs_lock:
|
||||
job = _jobs.get(job_id)
|
||||
if not job:
|
||||
await websocket.send_text(json.dumps({"type": "line", "text": "Unknown or expired job_id"}))
|
||||
await websocket.close()
|
||||
return
|
||||
|
||||
q: asyncio.Queue = asyncio.Queue(maxsize=500)
|
||||
job["channels"].add(q)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
msg = await asyncio.wait_for(q.get(), timeout=7200.0)
|
||||
except asyncio.TimeoutError:
|
||||
await websocket.send_text(json.dumps({"type": "line", "text": "… idle timeout"}))
|
||||
break
|
||||
await websocket.send_text(msg)
|
||||
try:
|
||||
data = json.loads(msg)
|
||||
if data.get("type") == "done":
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
break
|
||||
finally:
|
||||
job["channels"].discard(q)
|
||||
await websocket.close()
|
||||
Reference in New Issue
Block a user