Initial YakPanel commit

This commit is contained in:
Niranjan
2026-04-07 02:04:22 +05:30
commit 2826d3e7f3
5359 changed files with 1390724 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
from .ssh_executor import SSHExecutor, CommandResult
from .rate_limiter import (
RateLimiter,
ProgressTracker,
TokenBucketRateLimiter,
LeakyBucketRateLimiter
)
from .util import test_ssh_config
__all__ = [
"CommandResult",
"SSHExecutor",
"RateLimiter",
"ProgressTracker",
"TokenBucketRateLimiter",
"LeakyBucketRateLimiter",
"test_ssh_config"
]

View File

@@ -0,0 +1,255 @@
"""
Rate limiter for file transfers with clean separation of concerns.
NOTE: These rate limiters are NOT thread-safe by design.
Rate limiting operations are typically single-threaded, and removing
threading overhead improves performance.
"""
import time
from typing import Optional, Callable
class TokenBucketRateLimiter:
"""
Token bucket rate limiter for controlling transfer speed with burst support.
NOTE: This class is NOT thread-safe. Rate limiting operations are typically
single-threaded, and removing threading overhead improves performance.
The token bucket algorithm allows for:
- Sustained rate limiting (tokens per second)
- Burst control (bucket capacity)
- Peak rate control
Usage:
limiter = TokenBucketRateLimiter(
rate=1024*1024, # 1MB/s sustained rate
capacity=2*1024*1024 # 2MB burst capacity
)
for chunk in data_chunks:
limiter.wait_if_needed(len(chunk))
# transfer chunk
"""
def __init__(self,
rate: Optional[int] = None,
capacity: Optional[int] = None,
initial_tokens: Optional[int] = None):
"""
Initialize token bucket rate limiter.
Args:
rate: Tokens (bytes) per second. None means no rate limiting.
capacity: Maximum bucket capacity in bytes. If None, defaults to rate.
initial_tokens: Initial tokens in bucket. If None, defaults to capacity.
"""
self.rate = rate
self.capacity = capacity if capacity is not None else rate
self.initial_tokens = initial_tokens if initial_tokens is not None else self.capacity
# Current state
self.tokens = self.initial_tokens
self.last_update = None
self._started = False
def start(self):
"""Start timing for rate limiting."""
if self.rate and not self._started:
self._started = True
self.last_update = time.time()
self.tokens = self.initial_tokens
def wait_if_needed(self, chunk_size: int):
"""
Wait if necessary to maintain the specified transfer rate.
Args:
chunk_size: Size of the chunk just transferred in bytes.
"""
if not self.rate or not self.last_update:
return
# Refill tokens based on elapsed time
now = time.time()
elapsed = now - self.last_update
tokens_to_add = elapsed * self.rate
# Add tokens up to capacity
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
self.last_update = now
# Check if we have enough tokens
if self.tokens < chunk_size:
# Calculate wait time to get enough tokens
tokens_needed = chunk_size - self.tokens
wait_time = tokens_needed / self.rate
time.sleep(wait_time)
# Refill tokens after waiting
now = time.time()
elapsed = now - self.last_update
tokens_to_add = elapsed * self.rate
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
self.last_update = now
# Consume tokens
self.tokens -= chunk_size
class LeakyBucketRateLimiter:
"""
Leaky bucket rate limiter for strict rate limiting without burst support.
NOTE: This class is NOT thread-safe. Rate limiting operations are typically
single-threaded, and removing threading overhead improves performance.
The leaky bucket algorithm provides:
- Strict rate limiting (no burst)
- Predictable output rate
- Better for network protocols that can't handle bursts
Usage:
limiter = LeakyBucketRateLimiter(rate=1024*1024) # 1MB/s strict rate
for chunk in data_chunks:
limiter.wait_if_needed(len(chunk))
# transfer chunk
"""
def __init__(self, rate: Optional[int] = None):
"""
Initialize leaky bucket rate limiter.
Args:
rate: Tokens (bytes) per second. None means no rate limiting.
"""
self.rate = rate
self.last_update = None
self._started = False
def start(self):
"""Start timing for rate limiting."""
if self.rate and not self._started:
self.last_update = time.time()
def wait_if_needed(self, chunk_size: int):
"""
Wait if necessary to maintain the specified transfer rate.
Args:
chunk_size: Size of the chunk just transferred in bytes.
"""
if not self.rate or not self.last_update:
return
now = time.time()
elapsed = now - self.last_update
# Calculate minimum time needed for this chunk
min_time = chunk_size / self.rate
if elapsed < min_time:
wait_time = min_time - elapsed
time.sleep(wait_time)
self.last_update = time.time()
class RateLimiter:
"""
Main rate limiter class that provides a unified interface.
NOTE: This class is NOT thread-safe. Rate limiting operations are typically
single-threaded, and removing threading overhead improves performance.
This class can use different underlying algorithms:
- TokenBucketRateLimiter: For burst-capable rate limiting
- LeakyBucketRateLimiter: For strict rate limiting
"""
def __init__(self,
bytes_per_second: Optional[int] = None,
algorithm: str = "token_bucket",
burst_capacity: Optional[int] = None):
"""
Initialize rate limiter.
Args:
bytes_per_second: Transfer rate limit in bytes per second.
algorithm: Rate limiting algorithm ("token_bucket" or "leaky_bucket")
burst_capacity: For token bucket, maximum burst capacity in bytes.
If None, defaults to bytes_per_second.
"""
if algorithm == "token_bucket":
self._limiter = TokenBucketRateLimiter(
rate=bytes_per_second,
capacity=burst_capacity
)
elif algorithm == "leaky_bucket":
self._limiter = LeakyBucketRateLimiter(bytes_per_second)
else:
raise ValueError(f"Unknown algorithm: {algorithm}. Use 'token_bucket' or 'leaky_bucket'")
self.bytes_per_second = bytes_per_second
self.algorithm = algorithm
def start(self):
"""Start timing for rate limiting."""
self._limiter.start()
def wait_if_needed(self, chunk_size: int):
"""
Wait if necessary to maintain the specified transfer rate.
Args:
chunk_size: Size of the chunk just transferred in bytes.
"""
self._limiter.wait_if_needed(chunk_size)
class ProgressTracker:
"""
Simple progress tracking with configurable update frequency.
NOTE: This class is NOT thread-safe. Progress tracking operations are typically
single-threaded, and removing threading overhead improves performance.
"""
def __init__(self,
callback: Optional[Callable[[int, int], None]] = None,
update_interval: float = 0.1):
"""
Initialize progress tracker.
Args:
callback: Function to call with (transferred, total) progress updates
update_interval: Minimum seconds between progress updates
"""
self.callback = callback
self.update_interval = update_interval
self.last_update_time = 0
def start(self):
"""Start progress tracking."""
self.last_update_time = 0
def update(self, transferred: int, total: int):
"""
Update progress if enough time has passed since last update.
Args:
transferred: Bytes transferred so far
total: Total bytes to transfer
"""
if not self.callback:
return
current_time = time.time()
if current_time - self.last_update_time >= self.update_interval:
self.callback(transferred, total)
self.last_update_time = current_time
def finish(self, total: int):
"""Force final progress update."""
if self.callback:
self.callback(total, total)

View File

@@ -0,0 +1,808 @@
from __future__ import annotations
import os
import stat
from typing import Optional, Tuple, Callable, Union, Dict, Any, Iterator
import time
import io
import math
import paramiko
from dataclasses import dataclass
from .rate_limiter import RateLimiter, ProgressTracker
@dataclass
class CommandResult:
exit_code: int
stdout: str
stderr: str
class SSHExecutor:
"""
High-level SSH executor wrapping Paramiko for command execution and SFTP upload.
Usage:
with SSHExecutor(host, user, password=...) as ssh:
code, out, err = ssh.run("uname -a")
ssh.upload("./local.txt", "/tmp/remote.txt")
"""
def __init__(
self,
host: str,
username: str,
port: int = 22,
password: Optional[str] = None,
key_file: Optional[str] = None,
passphrase: Optional[str] = None,
key_data: Optional[str] = None,
timeout: Optional[int] = None,
strict_host_key_checking: bool = False,
allow_agent: bool = False,
look_for_keys: bool = False,
threading_mod: bool = False, # 线程模式默认为False当线程模式会在每次获取sftp客户端时重新获取
) -> None:
self.host = host
self.port = port
self.username = username
self.password = password
self.key_file = key_file
self.passphrase = passphrase
self.key_data = key_data
self.timeout = timeout or 20
self.strict_host_key_checking = strict_host_key_checking
self.allow_agent = allow_agent
self.look_for_keys = look_for_keys
self._client: Optional[paramiko.SSHClient] = None
self._sftp: Optional[paramiko.SFTPClient] = None
self._threading_mod = threading_mod
def open(self) -> None:
if self._client is not None:
return
client = paramiko.SSHClient()
if self.strict_host_key_checking:
client.set_missing_host_key_policy(paramiko.RejectPolicy())
else:
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
connect_kwargs: Dict[str, Any] = {
"hostname": self.host,
"port": self.port,
"username": self.username,
"look_for_keys": self.look_for_keys,
"allow_agent": self.allow_agent,
}
if self.timeout is not None:
connect_kwargs.update({
"timeout": self.timeout,
"banner_timeout": self.timeout,
"auth_timeout": self.timeout,
})
if self.password:
connect_kwargs["password"] = self.password
if self.key_file or self.key_data:
pkey = self._load_private_key(self.key_file, self.key_data, self.passphrase)
connect_kwargs["pkey"] = pkey
try:
client.connect(**connect_kwargs)
except Exception as e:
client.close()
raise RuntimeError(f"SSH connection failed: {e}")
self._client = client
def close(self) -> None:
if self._sftp is not None:
try:
self._sftp.close()
finally:
self._sftp = None
if self._client is not None:
try:
self._client.close()
finally:
self._client = None
def __enter__(self) -> "SSHExecutor":
self.open()
return self
def __exit__(self, exc_type, exc, tb) -> None:
self.close()
def run(self, command: str, timeout: Optional[int] = None) -> Tuple[int, str, str]:
client = self._require_client()
try:
effective_timeout = timeout if timeout is not None else self.timeout
stdin, stdout, stderr = client.exec_command(command, timeout=effective_timeout)
exit_status = stdout.channel.recv_exit_status()
out = stdout.read().decode("utf-8", errors="replace")
err = stderr.read().decode("utf-8", errors="replace")
return exit_status, out, err
except Exception as e:
raise RuntimeError(f"Command execution failed: {e}")
def upload(
self,
local_path: str,
remote_path: str,
rate_limit: Optional[int] = None,
progress_callback: Optional[Callable[[int, int], None]] = None,
resume: bool = False,
rate_algorithm: str = "token_bucket",
burst_capacity: Optional[int] = None,
rate_limiter: Optional[RateLimiter] = None,
) -> None:
"""
Upload a file via SFTP with optional rate limiting and resume support.
Args:
local_path: Local file path
remote_path: Remote destination path
rate_limit: Rate limit in bytes per second (None = no limit)
progress_callback: Callback(transferred_bytes, total_bytes) for progress updates
resume: Whether to resume upload if remote file exists and is smaller.
WARNING: Only checks file size, no content verification. Use with caution.
rate_algorithm: Rate limiting algorithm ("token_bucket" or "leaky_bucket")
burst_capacity: For token bucket, maximum burst capacity in bytes
rate_limiter: External RateLimiter instance for unified control across multiple transfers
"""
expanded_local = os.path.expanduser(local_path)
if not os.path.isfile(expanded_local):
raise FileNotFoundError(f"Local file not found: {expanded_local}")
local_size = os.path.getsize(expanded_local)
resume_offset = 0
if resume:
sftp = None
try:
sftp = self.get_sftp()
remote_stat = sftp.stat(remote_path)
if remote_stat.st_size < local_size:
resume_offset = remote_stat.st_size
if progress_callback:
progress_callback(resume_offset, local_size)
except FileNotFoundError:
pass
except Exception:
pass
finally:
if self._threading_mod and sftp:
sftp.close()
sftp = None
try:
sftp = self.get_sftp()
# Use external rate limiter if provided, otherwise create new one
if rate_limiter is None:
rate_limiter = RateLimiter(rate_limit, rate_algorithm, burst_capacity)
progress_tracker = ProgressTracker(progress_callback)
# Use chunked transfer for better control
self._upload_chunked(sftp, expanded_local, remote_path, resume_offset, rate_limiter, progress_tracker)
except Exception as e:
raise RuntimeError(f"SFTP upload failed: {e}")
finally:
if self._threading_mod and sftp:
sftp.close()
def _ensure_remote_dir(self, sftp, path):
try:
sftp.stat(path)
except FileNotFoundError:
self._create_remote_dir_recursive(sftp, path)
@staticmethod
def _create_remote_dir_recursive(sftp, path):
dirs = [d for d in path.split('/') if d]
current = ''
for d in dirs:
current += '/' + d
try:
sftp.stat(current)
except FileNotFoundError:
sftp.mkdir(current)
def _upload_chunked(
self,
sftp: paramiko.SFTPClient,
local_path: str,
remote_path: str,
resume_offset: int,
rate_limiter: Optional[RateLimiter] = None,
progress_tracker: Optional[ProgressTracker] = None,
) -> None:
"""Upload file in chunks with optional rate limiting and progress tracking."""
local_size = os.path.getsize(local_path)
chunk_size = 32768 # 32KB chunks
transferred = resume_offset
# Initialize components if provided
if rate_limiter:
rate_limiter.start()
if progress_tracker:
progress_tracker.start()
with open(local_path, "rb") as local_file:
if resume_offset > 0:
local_file.seek(resume_offset)
# 检查目录是否存在,不存在则创建
self._ensure_remote_dir(sftp, os.path.dirname(remote_path))
with sftp.file(remote_path, "ab" if resume_offset > 0 else "wb") as remote_file:
while transferred < local_size:
chunk = local_file.read(chunk_size)
if not chunk:
break
# Apply rate limiting before transfer
if rate_limiter:
rate_limiter.wait_if_needed(len(chunk))
# Perform the actual transfer
remote_file.write(chunk)
transferred += len(chunk)
# Update progress after transfer
if progress_tracker:
progress_tracker.update(transferred, local_size)
# Final progress update
if progress_tracker:
progress_tracker.finish(local_size)
def download(
self,
remote_path: str,
local_path: str,
rate_limit: Optional[int] = None,
progress_callback: Optional[Callable[[int, int], None]] = None,
resume: bool = False,
rate_algorithm: str = "token_bucket",
burst_capacity: Optional[int] = None,
rate_limiter: Optional[RateLimiter] = None,
) -> None:
"""
Download a file via SFTP with optional rate limiting and resume support.
Args:
remote_path: Remote file path
local_path: Local destination path
rate_limit: Rate limit in bytes per second (None = no limit)
progress_callback: Callback(transferred_bytes, total_bytes) for progress updates
resume: Whether to resume download if local file exists and is smaller.
WARNING: Only checks file size, no content verification. Use with caution.
rate_algorithm: Rate limiting algorithm ("token_bucket" or "leaky_bucket")
burst_capacity: For token bucket, maximum burst capacity in bytes
rate_limiter: External RateLimiter instance for unified control across multiple transfers
"""
expanded_local = os.path.expanduser(local_path)
resume_offset = 0
sftp = None
if resume and os.path.exists(expanded_local):
local_size = os.path.getsize(expanded_local)
try:
sftp = self.get_sftp()
remote_stat = sftp.stat(remote_path)
if local_size < remote_stat.st_size:
resume_offset = local_size
if progress_callback:
progress_callback(resume_offset, remote_stat.st_size)
except Exception:
pass
finally:
if self._threading_mod and sftp:
sftp.close()
try:
sftp = self.get_sftp()
# Use external rate limiter if provided, otherwise create new one
if rate_limiter is None:
rate_limiter = RateLimiter(rate_limit, rate_algorithm, burst_capacity)
progress_tracker = ProgressTracker(progress_callback)
# Use chunked transfer for better control
self._download_chunked(sftp, remote_path, expanded_local, resume_offset, rate_limiter, progress_tracker)
except Exception as e:
raise RuntimeError(f"SFTP download failed: {e}")
finally:
if self._threading_mod and sftp:
sftp.close()
@staticmethod
def _download_chunked(
sftp,
remote_path: str,
local_path: str,
resume_offset: int,
rate_limiter: Optional[RateLimiter] = None,
progress_tracker: Optional[ProgressTracker] = None,
) -> None:
"""Download file in chunks with optional rate limiting and progress tracking."""
remote_size = sftp.stat(remote_path).st_size
chunk_size = 32768 # 32KB chunks
transferred = resume_offset
# Initialize components if provided
if rate_limiter:
rate_limiter.start()
if progress_tracker:
progress_tracker.start()
mode = "ab" if resume_offset > 0 else "wb"
with open(local_path, mode) as local_file:
with sftp.file(remote_path, "rb") as remote_file:
if resume_offset > 0:
remote_file.seek(resume_offset)
while transferred < remote_size:
# Apply rate limiting before transfer
if rate_limiter:
rate_limiter.wait_if_needed(chunk_size)
# Perform the actual transfer
chunk = remote_file.read(chunk_size)
if not chunk:
break
local_file.write(chunk)
transferred += len(chunk)
# Update progress after transfer
if progress_tracker:
progress_tracker.update(transferred, remote_size)
# Final progress update
if progress_tracker:
progress_tracker.finish(remote_size)
def _require_client(self) -> paramiko.SSHClient:
if self._client is None:
raise RuntimeError("SSH client is not connected. Call open() or use a with-context.")
return self._client
def get_sftp(self) -> paramiko.SFTPClient:
if self._threading_mod:
th_sftp = self._require_client().open_sftp()
return th_sftp
if self._sftp is None:
self._sftp = self._require_client().open_sftp()
return self._sftp
@staticmethod
def _load_private_key(
key_file: Optional[str],
key_data: Optional[str],
passphrase: Optional[str],
) -> paramiko.PKey:
"""Load a private key by normalizing to key_data and parsing it.
Priority is mutually exclusive by design: key_file > key_data.
Supported types: RSA, DSS, ECDSA, Ed25519.
"""
if not key_data and key_file:
path = os.path.expanduser(key_file)
try:
with open(path, "r", encoding="utf-8") as f:
key_data = f.read()
except Exception as e:
raise RuntimeError(f"Failed to read private key file: {e}")
if not key_data:
raise RuntimeError("No private key provided")
stream = io.StringIO(key_data)
last_error: Optional[Exception] = None
key_classes = [paramiko.RSAKey, paramiko.ECDSAKey, paramiko.Ed25519Key]
if hasattr(paramiko, "DSSKey"): # 兼容无DSSKey功能的paramiko版本
key_classes.append(paramiko.DSSKey)
for key_cls in key_classes:
try:
stream.seek(0)
return key_cls.from_private_key(stream, password=passphrase)
except paramiko.PasswordRequiredException:
raise RuntimeError("Private key is encrypted; provide passphrase.")
except Exception as e:
last_error = e
raise RuntimeError(f"Failed to load private key from data: {last_error}")
def run_streaming(
self,
command: str,
on_stdout: Optional[Callable[[bytes], None]] = None,
on_stderr: Optional[Callable[[bytes], None]] = None,
timeout: Optional[int] = None,
read_chunk_size: int = 32768,
poll_interval_sec: float = 0.05,
) -> int:
"""
Execute a remote command and stream output chunks to callbacks to minimize memory usage.
Returns the process exit status when the command completes.
"""
client = self._require_client()
transport = client.get_transport()
if transport is None:
raise RuntimeError("SSH transport is not available")
effective_timeout = timeout if timeout is not None else self.timeout
chan = transport.open_session(timeout=effective_timeout)
chan.exec_command(command)
start_time = time.time()
try:
while True:
if chan.recv_ready():
data = chan.recv(read_chunk_size)
if data and on_stdout is not None:
on_stdout(data)
if chan.recv_stderr_ready():
data = chan.recv_stderr(read_chunk_size)
if data and on_stderr is not None:
on_stderr(data)
if chan.exit_status_ready() and not chan.recv_ready() and not chan.recv_stderr_ready():
break
if effective_timeout is not None and (time.time() - start_time) > effective_timeout:
chan.close()
raise TimeoutError("Command execution timed out")
time.sleep(poll_interval_sec)
exit_code = chan.recv_exit_status()
return exit_code
finally:
try:
chan.close()
except Exception:
pass
def execute_script_streaming(
self,
script_content: str,
script_type: str = "shell",
remote_dir: str = "/tmp",
script_name: Optional[str] = None,
timeout: Optional[int] = None,
cleanup: bool = True,
env_vars: Optional[Dict[str, str]] = None,
on_stdout: Optional[Callable[[bytes], None]] = None,
on_stderr: Optional[Callable[[bytes], None]] = None,
) -> int:
"""
Execute a bash script with streaming output.
Args:
script_content: The bash script content to execute
remote_dir: Remote directory to place the script (default: /tmp)
script_name: Name for the script file (auto-generated if None)
timeout: Command execution timeout in seconds
cleanup: Whether to delete the script file after execution
env_vars: Environment variables to set before script execution
on_stdout: Callback to receive stdout chunks (bytes)
on_stderr: Callback to receive stderr chunks (bytes)
Returns:
int: Exit code of the script execution
Raises:
RuntimeError: If script upload or execution fails
"""
remote_script_path = self._prepare_script(script_content, remote_dir, script_name)
try:
command = self._build_command(remote_script_path, script_type, env_vars)
return self.run_streaming(
command,
on_stdout=on_stdout,
on_stderr=on_stderr,
timeout=timeout,
)
finally:
if cleanup:
self._cleanup_script(remote_script_path)
def execute_script_collect(
self,
script_content: str,
script_type: str = "shell",
remote_dir: str = "/tmp",
script_name: Optional[str] = None,
timeout: Optional[int] = None,
cleanup: bool = True,
env_vars: Optional[Dict[str, str]] = None,
) -> CommandResult:
"""
Execute a bash script and collect all output.
Args:
script_content: The bash script content to execute
remote_dir: Remote directory to place the script (default: /tmp)
script_name: Name for the script file (auto-generated if None)
timeout: Command execution timeout in seconds
cleanup: Whether to delete the script file after execution
env_vars: Environment variables to set before script execution
Returns:
CommandResult: The execution result with exit_code, stdout, stderr
Raises:
RuntimeError: If script upload or execution fails
"""
remote_script_path = self._prepare_script(script_content, remote_dir, script_name)
try:
command = self._build_command(remote_script_path, script_type, env_vars)
code, out, err = self.run(command, timeout=timeout)
return CommandResult(exit_code=code, stdout=out, stderr=err)
finally:
if cleanup:
self._cleanup_script(remote_script_path)
def execute_local_script_streaming(
self,
local_script_path: str,
script_type: str = "shell",
remote_dir: str = "/tmp",
script_name: Optional[str] = None,
timeout: Optional[int] = None,
cleanup: bool = True,
env_vars: Optional[Dict[str, str]] = None,
on_stdout: Optional[Callable[[bytes], None]] = None,
on_stderr: Optional[Callable[[bytes], None]] = None,
) -> int:
"""
Execute a local bash script with streaming output.
Args:
local_script_path: Path to the local script file
remote_dir: Remote directory to place the script (default: /tmp)
script_name: Name for the script file (uses basename if None)
timeout: Command execution timeout in seconds
cleanup: Whether to delete the script file after execution
env_vars: Environment variables to set before script execution
on_stdout: Callback to receive stdout chunks (bytes)
on_stderr: Callback to receive stderr chunks (bytes)
Returns:
int: Exit code of the script execution
Raises:
FileNotFoundError: If local script file not found
RuntimeError: If script upload or execution fails
"""
if not os.path.isfile(local_script_path):
raise FileNotFoundError(f"Local script not found: {local_script_path}")
if not script_name:
script_name = os.path.basename(local_script_path)
remote_script_path = f"{remote_dir.rstrip('/')}/{script_name}"
# Upload the local script file via SFTP with LF normalization
sftp = self.get_sftp()
with open(local_script_path, "r", encoding="utf-8", newline="") as f:
content = f.read()
content_lf = content.replace("\r\n", "\n").replace("\r", "\n")
with sftp.file(remote_script_path, "w") as remote_file:
remote_file.write(content_lf.encode("utf-8"))
try:
command = self._build_command(remote_script_path, script_type, env_vars)
return self.run_streaming(
command,
on_stdout=on_stdout,
on_stderr=on_stderr,
timeout=timeout,
)
finally:
if cleanup:
self._cleanup_script(remote_script_path)
def execute_local_script_collect(
self,
local_script_path: str,
script_type: str = "shell",
remote_dir: str = "/tmp",
script_name: Optional[str] = None,
timeout: Optional[int] = None,
cleanup: bool = True,
env_vars: Optional[Dict[str, str]] = None,
) -> CommandResult:
"""
Execute a local bash script and collect all output.
Args:
local_script_path: Path to the local script file
remote_dir: Remote directory to place the script (default: /tmp)
script_name: Name for the script file (uses basename if None)
timeout: Command execution timeout in seconds
cleanup: Whether to delete the script file after execution
env_vars: Environment variables to set before script execution
Returns:
CommandResult: The execution result with exit_code, stdout, stderr
Raises:
FileNotFoundError: If local script file not found
RuntimeError: If script upload or execution fails
"""
if not os.path.isfile(local_script_path):
raise FileNotFoundError(f"Local script not found: {local_script_path}")
if not script_name:
script_name = os.path.basename(local_script_path)
remote_script_path = f"{remote_dir.rstrip('/')}/{script_name}"
# Upload the local script file via SFTP with LF normalization
sftp = self.get_sftp()
with open(local_script_path, "r", encoding="utf-8", newline="") as f:
content = f.read()
content_lf = content.replace("\r\n", "\n").replace("\r", "\n")
with sftp.file(remote_script_path, "w") as remote_file:
remote_file.write(content_lf.encode("utf-8"))
try:
command = self._build_command(remote_script_path, script_type, env_vars)
code, out, err = self.run(command, timeout=timeout)
return CommandResult(exit_code=code, stdout=out, stderr=err)
finally:
if cleanup:
self._cleanup_script(remote_script_path)
def _prepare_script(self, script_content: str, remote_dir: str, script_name: Optional[str]) -> str:
"""Prepare script by uploading content (LF normalized)."""
if not script_name:
import uuid
script_name = f"script_{uuid.uuid4().hex[:8]}"
remote_script_path = f"{remote_dir.rstrip('/')}/{script_name}"
sftp = None
try:
# Upload script content directly via SFTP (normalize to LF)
sftp = self.get_sftp()
content_lf = script_content.replace("\r\n", "\n").replace("\r", "\n")
with sftp.file(remote_script_path, "w") as remote_file:
remote_file.write(content_lf.encode("utf-8"))
except:
pass
finally:
if sftp and self._threading_mod:
sftp.close()
return remote_script_path
@staticmethod
def _build_command(
remote_script_path: str,
script_type: str = "shell",
env_vars: Optional[Dict[str, str]] = None) -> str:
"""Build the command string with environment variables."""
env_string = ""
if env_vars:
env_pairs = [f"{k}='{v}'" for k, v in env_vars.items()]
env_string = " ".join(env_pairs) + " "
if script_type == "shell":
return f"{env_string}bash {remote_script_path}"
elif script_type == "python":
get_py_bin = "pyBin=$(which python3 2> /dev/null || which python 2> /dev/null || echo 'python')"
py_info = "echo ""; echo \"Current Python environment:${pyBin} $(${pyBin} -c 'import sys,platform;print(sys.version.split()[0],platform.platform())')\""
cmd = "%s;${pyBin} %s; ret=$?; [ $ret -eq 0 ] && exit $ret; %s;exit $ret;" % (
get_py_bin, remote_script_path, py_info
)
return cmd
else:
raise ValueError("Invalid script type")
def _cleanup_script(self, remote_script_path: str) -> None:
"""Clean up the remote script file via SFTP without invoking shell."""
sftp = None
try:
sftp = self.get_sftp()
# Ensure path exists before removal
try:
sftp.stat(remote_script_path)
except FileNotFoundError:
return
sftp.remove(remote_script_path)
except Exception:
# Swallow cleanup errors
pass
finally:
if sftp and self._threading_mod:
sftp.close()
def path_exists(self, path: str) -> Tuple[bool, str]:
"""
Check if a path exists on the remote server.
Args:
path: Path to check
Returns:
Tuple[bool, str]: A tuple containing a boolean indicating whether the path exists and an error message
"""
sftp = None
try:
sftp = self.get_sftp()
try:
sftp.stat(path)
return True, ""
except FileNotFoundError:
return False, ""
except Exception as e:
return False, str(e)
finally:
if sftp and self._threading_mod:
sftp.close()
def create_dir(self, path: str):
"""
Create a directory on the remote server.
Args:
path: Path to create
Returns:
Tuple[bool, str]: A tuple containing a boolean indicating whether the directory was created successfully and an error message
"""
sftp = None
try:
sftp = self.get_sftp()
self._ensure_remote_dir(sftp, path)
return True, ""
except Exception as e:
return False, str(e)
finally:
if sftp and self._threading_mod:
return sftp.close()
def path_info(self, path: str) -> Dict:
"""
Get information about a path on the remote server.
Args:
path: Path to get information about
Returns:
Dict: A dictionary containing information about the path, including path, isdir, size, mtime, mode, uid, gid, and exists
"""
sftp = None
not_found = {"path": path,"isdir": False,"size": 0,"mtime": 0,"mode": 0,"uid": 0,"gid": 0, "exists": False}
try:
sftp = self.get_sftp()
info = sftp.stat(path)
return {
"path": path,
"isdir": stat.S_ISDIR(info.st_mode),
"size": info.st_size,
"mtime": info.st_mtime,
"mode": info.st_mode,
"uid": info.st_uid,
"gid": info.st_gid,
"exists": True
}
except FileNotFoundError:
return not_found
except:
return not_found
finally:
if sftp and self._threading_mod:
sftp.close()

View File

@@ -0,0 +1,51 @@
import io
import paramiko
def test_ssh_config(host, port, username, password, pkey, pkey_passwd, timeout: int = 10) -> str:
try:
ssh = paramiko.SSHClient()
pkey_obj = None
if pkey:
pky_io = io.StringIO(pkey)
key_cls_list = [paramiko.RSAKey, paramiko.ECDSAKey, paramiko.Ed25519Key]
if hasattr(paramiko, "DSSKey"):
key_cls_list.append(paramiko.DSSKey)
for key_cls in key_cls_list:
pky_io.seek(0)
try:
pkey_obj = key_cls.from_private_key(pky_io, password=(pkey_passwd if pkey_passwd else None))
except Exception as e:
if "base64 decoding error" in str(e):
return "Private key data error, please check if it is a complete copy of the private key information"
elif "Private key file is encrypted" in str(e):
return "The private key has been encrypted, but the password for the private key has not been provided, so the private key information cannot be verified"
elif "Invalid key" in str(e):
return "Private key parsing error, please check if the password for the private key is correct"
continue
else:
break
else:
return "Private key parsing error, please confirm that the entered key format is correct"
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
# look_for_keys 一定要是False排除不必要的私钥尝试导致的错误
ssh.connect(hostname=host, port=port, username=username, password=(password if password else None),
pkey=pkey_obj, look_for_keys=False, auth_timeout=timeout)
ssh.close()
return ""
except Exception as e:
err_str = str(e)
auth_str = "{}@{}:{}".format(username, host, port)
if err_str.find('Authentication timeout') != -1:
return 'Authentication timeout, [{}] error{}'.format(auth_str, e)
if err_str.find('Authentication failed') != -1:
if pkey:
return 'Authentication failed, please check if the private key is correct: ' + auth_str
return 'Account or password error:' + auth_str
if err_str.find('Bad authentication type; allowed types') != -1:
return 'Unsupported authentication type: {}'.format(err_str)
if err_str.find('Connection reset by peer') != -1:
return 'The target server actively rejects the connection'
if err_str.find('Error reading SSH protocol banner') != -1:
return 'Protocol header response timeout, error' + err_str
return "Connection failed" + err_str