Files
dotfiles/config/ansible/tasks/global/utils/ssh/ssh.go
Menno van Leeuwen f0bf6bc8aa
Some checks failed
Ansible Lint Check / check-ansible (push) Failing after 9s
Nix Format Check / check-format (push) Failing after 22s
Python Lint Check / check-python (push) Failing after 7s
wip
Signed-off-by: Menno van Leeuwen <menno@vleeuwen.me>
2025-07-25 14:54:29 +02:00

951 lines
27 KiB
Go

package main
import (
"bufio"
"encoding/json"
"fmt"
"net"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"gopkg.in/yaml.v3"
)
// LoggingConfig represents logging configuration
type LoggingConfig struct {
Enabled bool `yaml:"enabled"`
Level string `yaml:"level"`
Format string `yaml:"format"`
}
// SmartAlias represents a smart SSH alias configuration
type SmartAlias struct {
Primary string `yaml:"primary"` // SSH config host to use when local
Fallback string `yaml:"fallback"` // SSH config host to use when remote
CheckHost string `yaml:"check_host"` // IP to ping for connectivity test
Timeout string `yaml:"timeout"` // Ping timeout (default: "2s")
}
// TunnelDefinition represents a tunnel configuration
type TunnelDefinition struct {
Type string `yaml:"type"` // local, remote, dynamic
LocalPort int `yaml:"local_port"` // Local port for binding
RemoteHost string `yaml:"remote_host"` // Remote host (for local/remote tunnels)
RemotePort int `yaml:"remote_port"` // Remote port (for local/remote tunnels)
SSHHost string `yaml:"ssh_host"` // SSH host to tunnel through
}
// TunnelState represents runtime state of an active tunnel
type TunnelState struct {
Name string `json:"name"`
Source string `json:"source"` // "config" or "adhoc"
Type string `json:"type"` // local, remote, dynamic
LocalPort int `json:"local_port"`
RemoteHost string `json:"remote_host"`
RemotePort int `json:"remote_port"`
SSHHost string `json:"ssh_host"`
SSHHostResolved string `json:"ssh_host_resolved"` // After smart alias resolution
PID int `json:"pid"`
Status string `json:"status"`
StartedAt time.Time `json:"started_at"`
LastSeen time.Time `json:"last_seen"`
CommandLine string `json:"command_line"`
}
// Config represents the YAML configuration structure
type Config struct {
Logging LoggingConfig `yaml:"logging"`
SmartAliases map[string]SmartAlias `yaml:"smart_aliases"`
Tunnels map[string]TunnelDefinition `yaml:"tunnels"`
}
const (
realSSHPath = "/usr/bin/ssh"
)
var (
configDir string
tunnelsDir string
config *Config
// Global flags
tunnelMode bool
// Tunnel command flags
tunnelOpen bool
tunnelClose bool
tunnelList bool
tunnelLocal string
tunnelRemote string
tunnelDynamic int
tunnelVia string
)
var rootCmd = &cobra.Command{
Use: "ssh",
Short: "Smart SSH utility with tunnel management",
Long: "A transparent SSH wrapper that provides smart alias resolution and background tunnel management",
Run: handleSSH,
DisableFlagParsing: true,
}
var tunnelCmd = &cobra.Command{
Use: "tunnel [tunnel-name]",
Short: "Manage background SSH tunnels",
Long: "Create, list, and manage persistent SSH tunnels in the background",
Run: func(cmd *cobra.Command, args []string) {
handleTunnelManual(append([]string{"--tunnel"}, args...))
},
Args: cobra.MaximumNArgs(1),
}
func init() {
// Initialize config directory
homeDir, err := os.UserHomeDir()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: Failed to get home directory: %v\n", err)
os.Exit(1)
}
configDir = filepath.Join(homeDir, ".config", "ssh-util")
tunnelsDir = filepath.Join(configDir, "tunnels")
// Ensure directories exist
os.MkdirAll(tunnelsDir, 0755)
// Load configuration
var configErr error
config, configErr = loadConfig()
if configErr != nil {
// Use default config if loading fails
config = &Config{
Logging: LoggingConfig{
Enabled: true,
Level: "info",
Format: "console",
},
SmartAliases: make(map[string]SmartAlias),
Tunnels: make(map[string]TunnelDefinition),
}
}
// Initialize logging
initLogging(config.Logging)
// Global flags
rootCmd.PersistentFlags().BoolVarP(&tunnelMode, "tunnel", "T", false, "Enable tunnel mode")
rootCmd.Flags().BoolVarP(&tunnelOpen, "open", "O", false, "Open a tunnel")
rootCmd.Flags().BoolVarP(&tunnelClose, "close", "C", false, "Close a tunnel")
rootCmd.Flags().BoolVarP(&tunnelList, "list", "L", false, "List active tunnels")
rootCmd.Flags().StringVar(&tunnelLocal, "local", "", "Local port forwarding (port:host:port)")
rootCmd.Flags().StringVar(&tunnelRemote, "remote", "", "Remote port forwarding (port:host:port)")
rootCmd.Flags().IntVar(&tunnelDynamic, "dynamic", 0, "Dynamic port forwarding (SOCKS proxy port)")
rootCmd.Flags().StringVar(&tunnelVia, "via", "", "SSH host to tunnel through")
// Add tunnel command
rootCmd.AddCommand(tunnelCmd)
// Tunnel command flags (same as root for consistency)
tunnelCmd.Flags().BoolVarP(&tunnelOpen, "open", "O", false, "Open a tunnel")
tunnelCmd.Flags().BoolVarP(&tunnelClose, "close", "C", false, "Close a tunnel")
tunnelCmd.Flags().BoolVarP(&tunnelList, "list", "L", false, "List active tunnels")
tunnelCmd.Flags().StringVar(&tunnelLocal, "local", "", "Local port forwarding (port:host:port)")
tunnelCmd.Flags().StringVar(&tunnelRemote, "remote", "", "Remote port forwarding (port:host:port)")
tunnelCmd.Flags().IntVar(&tunnelDynamic, "dynamic", 0, "Dynamic port forwarding (SOCKS proxy port)")
tunnelCmd.Flags().StringVar(&tunnelVia, "via", "", "SSH host to tunnel through")
// Handle combined flags like -TO, -TC, -TL
rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error {
return handleCombinedFlags(cmd, args)
}
}
func main() {
// Check if this is a tunnel command first
args := os.Args[1:]
isTunnelCommand := false
for _, arg := range args {
if arg == "--tunnel" || arg == "-T" || strings.HasPrefix(arg, "-T") {
isTunnelCommand = true
break
}
if arg == "tunnel" {
isTunnelCommand = true
break
}
}
if isTunnelCommand {
// Use Cobra for tunnel commands
if err := rootCmd.Execute(); err != nil {
log.Error().Err(err).Msg("Command execution failed")
os.Exit(1)
}
} else {
// Bypass Cobra for regular SSH commands (smart alias resolution)
handleSSHDirect(args)
}
}
func handleCombinedFlags(cmd *cobra.Command, args []string) error {
// Check for combined tunnel flags in os.Args
for _, arg := range os.Args {
if strings.HasPrefix(arg, "-T") && len(arg) > 2 {
// Handle combined flags like -TO, -TC, -TL
tunnelMode = true
suffix := arg[2:]
if strings.Contains(suffix, "O") {
tunnelOpen = true
}
if strings.Contains(suffix, "C") {
tunnelClose = true
}
if strings.Contains(suffix, "L") {
tunnelList = true
}
break
}
}
return nil
}
func handleSSH(cmd *cobra.Command, args []string) {
// This handles tunnel commands via Cobra
handleTunnelManual(os.Args[1:])
}
func handleSSHDirect(args []string) {
log.Debug().Strs("original_args", args).Msg("SSH utility started")
// Pass through immediately if no args, starts with dash (flags), or contains @
if len(args) == 0 || (len(args) > 0 && (strings.HasPrefix(args[0], "-") || strings.Contains(args[0], "@"))) {
log.Debug().Msg("Passing through to real SSH (no smart alias detected)")
executeRealSSH(args)
return
}
// Check if first argument is a smart alias
aliasName := args[0]
modifiedArgs := make([]string, len(args))
copy(modifiedArgs, args)
if smartAlias, exists := config.SmartAliases[aliasName]; exists {
log.Info().Str("alias", aliasName).Msg("Smart alias detected")
// Parse timeout
timeout := parseTimeout(smartAlias.Timeout)
log.Debug().Dur("timeout", timeout).Msg("Parsed timeout")
// Get the port for the primary host from SSH config
port := getSSHConfigPort(smartAlias.Primary)
log.Debug().Str("host", smartAlias.Primary).Int("port", port).Msg("Extracted port from SSH config")
// Test connectivity to determine which host to use
log.Info().Str("check_host", smartAlias.CheckHost).Int("port", port).Msg("Testing connectivity")
if pingHost(smartAlias.CheckHost, timeout, port) {
// Local network is reachable, use primary
log.Info().Str("chosen_host", smartAlias.Primary).Msg("Local network reachable, using primary host")
modifiedArgs[0] = smartAlias.Primary
} else {
// Local network not reachable, use fallback
log.Info().Str("chosen_host", smartAlias.Fallback).Msg("Local network not reachable, using fallback host")
modifiedArgs[0] = smartAlias.Fallback
}
} else {
log.Debug().Str("host", aliasName).Msg("Not a smart alias, passing through")
}
// Execute the real SSH with potentially modified arguments
log.Debug().Strs("final_args", modifiedArgs).Msg("Executing real SSH")
executeRealSSH(modifiedArgs)
}
func handleTunnelManual(args []string) {
log.Debug().Msg("Tunnel mode activated")
// Always validate tunnel states first
if err := validateTunnelStates(); err != nil {
log.Error().Err(err).Msg("Failed to validate tunnel states")
}
// Parse tunnel arguments manually
var tunnelName string
var action string
var localForward, remoteForward, via string
var dynamicPort int
for i, arg := range args {
switch arg {
case "--tunnel", "-T":
continue
case "--open", "-O":
action = "open"
case "--close", "-C":
action = "close"
case "--list", "-L":
action = "list"
case "--local":
if i+1 < len(args) {
localForward = args[i+1]
}
case "--remote":
if i+1 < len(args) {
remoteForward = args[i+1]
}
case "--via":
if i+1 < len(args) {
via = args[i+1]
}
case "--dynamic":
if i+1 < len(args) {
fmt.Sscanf(args[i+1], "%d", &dynamicPort)
}
default:
if strings.HasPrefix(arg, "-T") && len(arg) > 2 {
suffix := arg[2:]
if strings.Contains(suffix, "O") {
action = "open"
}
if strings.Contains(suffix, "C") {
action = "close"
}
if strings.Contains(suffix, "L") {
action = "list"
}
} else if !strings.HasPrefix(arg, "-") && tunnelName == "" {
tunnelName = arg
}
}
}
// Handle tunnel commands
if action == "list" {
listTunnels()
return
}
if tunnelName == "" && action != "list" {
fmt.Fprintf(os.Stderr, "Error: tunnel name required\n")
fmt.Fprintf(os.Stderr, "Usage: ssh --tunnel --open <name> [flags]\n")
os.Exit(1)
}
if action == "open" {
// Set global variables for openTunnel function
tunnelLocal = localForward
tunnelRemote = remoteForward
tunnelDynamic = dynamicPort
tunnelVia = via
if err := openTunnel(tunnelName); err != nil {
log.Error().Err(err).Str("tunnel", tunnelName).Msg("Failed to open tunnel")
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
} else if action == "close" {
if err := closeTunnel(tunnelName); err != nil {
log.Error().Err(err).Str("tunnel", tunnelName).Msg("Failed to close tunnel")
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
} else {
fmt.Fprintf(os.Stderr, "Error: must specify --open, --close, or --list\n")
os.Exit(1)
}
}
func validateTunnelStates() error {
files, err := os.ReadDir(tunnelsDir)
if err != nil {
return fmt.Errorf("failed to read tunnels directory: %w", err)
}
for _, file := range files {
if !strings.HasSuffix(file.Name(), ".json") {
continue
}
pidStr := strings.TrimSuffix(file.Name(), ".json")
pid, err := strconv.Atoi(pidStr)
if err != nil {
log.Warn().Str("file", file.Name()).Msg("Invalid PID filename, removing")
os.Remove(filepath.Join(tunnelsDir, file.Name()))
continue
}
if !isProcessAlive(pid) {
log.Info().Int("pid", pid).Msg("Removing state for dead tunnel")
os.Remove(filepath.Join(tunnelsDir, file.Name()))
continue
}
// Update last seen time
stateFile := filepath.Join(tunnelsDir, file.Name())
state, err := loadTunnelState(stateFile)
if err != nil {
log.Warn().Str("file", stateFile).Err(err).Msg("Failed to load tunnel state")
continue
}
state.LastSeen = time.Now()
if err := saveTunnelState(stateFile, state); err != nil {
log.Warn().Str("file", stateFile).Err(err).Msg("Failed to update tunnel state")
}
}
return nil
}
func listTunnels() {
files, err := os.ReadDir(tunnelsDir)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: Failed to read tunnels directory: %v\n", err)
return
}
if len(files) == 0 {
fmt.Println("No active tunnels")
return
}
fmt.Printf("%-20s %-8s %-8s %-25s %-12s %-8s %s\n",
"NAME", "TYPE", "LOCAL", "REMOTE", "HOST", "PID", "UPTIME")
fmt.Println(strings.Repeat("-", 80))
for _, file := range files {
if !strings.HasSuffix(file.Name(), ".json") {
continue
}
stateFile := filepath.Join(tunnelsDir, file.Name())
state, err := loadTunnelState(stateFile)
if err != nil {
continue
}
uptime := time.Since(state.StartedAt).Truncate(time.Second)
remote := ""
if state.Type == "local" || state.Type == "remote" {
remote = fmt.Sprintf("%s:%d", state.RemoteHost, state.RemotePort)
} else if state.Type == "dynamic" {
remote = "SOCKS"
}
fmt.Printf("%-20s %-8s %-8d %-25s %-12s %-8d %s\n",
state.Name, state.Type, state.LocalPort, remote,
state.SSHHostResolved, state.PID, uptime)
}
}
func openTunnel(name string) error {
// Check if tunnel is already running
if existingPID := findTunnelByName(name); existingPID != 0 {
return fmt.Errorf("tunnel '%s' already running (PID %d)", name, existingPID)
}
var tunnel TunnelDefinition
var source string
// Check if tunnel is defined in config
if configTunnel, exists := config.Tunnels[name]; exists {
tunnel = configTunnel
source = "config"
log.Info().Str("tunnel", name).Msg("Using tunnel definition from config")
} else {
// Create ad-hoc tunnel from flags
var err error
tunnel, err = createAdhocTunnel()
if err != nil {
return fmt.Errorf("tunnel '%s' not found in config and invalid adhoc parameters: %w", name, err)
}
source = "adhoc"
log.Info().Str("tunnel", name).Msg("Creating ad-hoc tunnel")
}
// Check for port conflicts
if isPortInUse(tunnel.LocalPort) {
return fmt.Errorf("port %d already in use", tunnel.LocalPort)
}
// Resolve SSH host using smart alias logic
resolvedSSHHost := resolveSSHHost(tunnel.SSHHost)
log.Debug().Str("original", tunnel.SSHHost).Str("resolved", resolvedSSHHost).Msg("SSH host resolution")
// Build SSH command
cmdArgs := buildSSHCommand(tunnel, resolvedSSHHost)
log.Debug().Strs("command", cmdArgs).Msg("Starting SSH tunnel")
// Start SSH process
cmd := &exec.Cmd{
Path: realSSHPath,
Args: cmdArgs,
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start SSH tunnel: %w", err)
}
pid := cmd.Process.Pid
log.Info().Str("tunnel", name).Int("pid", pid).Msg("SSH tunnel started")
// Create tunnel state
state := TunnelState{
Name: name,
Source: source,
Type: tunnel.Type,
LocalPort: tunnel.LocalPort,
RemoteHost: tunnel.RemoteHost,
RemotePort: tunnel.RemotePort,
SSHHost: tunnel.SSHHost,
SSHHostResolved: resolvedSSHHost,
PID: pid,
Status: "active",
StartedAt: time.Now(),
LastSeen: time.Now(),
CommandLine: strings.Join(cmdArgs, " "),
}
// Save state file
stateFile := filepath.Join(tunnelsDir, fmt.Sprintf("%d.json", pid))
if err := saveTunnelState(stateFile, state); err != nil {
// If we can't save state, kill the process
cmd.Process.Kill()
return fmt.Errorf("failed to save tunnel state: %w", err)
}
fmt.Printf("Tunnel '%s' opened on port %d (PID %d)\n", name, tunnel.LocalPort, pid)
return nil
}
func closeTunnel(name string) error {
pid := findTunnelByName(name)
if pid == 0 {
return fmt.Errorf("tunnel '%s' not found", name)
}
// Kill the process
process, err := os.FindProcess(pid)
if err != nil {
return fmt.Errorf("failed to find process %d: %w", pid, err)
}
if err := process.Kill(); err != nil {
return fmt.Errorf("failed to kill process %d: %w", pid, err)
}
// Remove state file
stateFile := filepath.Join(tunnelsDir, fmt.Sprintf("%d.json", pid))
if err := os.Remove(stateFile); err != nil {
log.Warn().Str("file", stateFile).Err(err).Msg("Failed to remove state file")
}
log.Info().Str("tunnel", name).Int("pid", pid).Msg("Tunnel closed")
fmt.Printf("Tunnel '%s' closed\n", name)
return nil
}
func createAdhocTunnel() (TunnelDefinition, error) {
tunnel := TunnelDefinition{}
if tunnelVia == "" {
return tunnel, fmt.Errorf("--via flag required for ad-hoc tunnels")
}
tunnel.SSHHost = tunnelVia
if tunnelLocal != "" {
parts := strings.Split(tunnelLocal, ":")
if len(parts) != 3 {
return tunnel, fmt.Errorf("invalid --local format, expected port:host:port")
}
localPort, err := strconv.Atoi(parts[0])
if err != nil {
return tunnel, fmt.Errorf("invalid local port: %s", parts[0])
}
remotePort, err := strconv.Atoi(parts[2])
if err != nil {
return tunnel, fmt.Errorf("invalid remote port: %s", parts[2])
}
tunnel.Type = "local"
tunnel.LocalPort = localPort
tunnel.RemoteHost = parts[1]
tunnel.RemotePort = remotePort
} else if tunnelRemote != "" {
parts := strings.Split(tunnelRemote, ":")
if len(parts) != 3 {
return tunnel, fmt.Errorf("invalid --remote format, expected port:host:port")
}
localPort, err := strconv.Atoi(parts[0])
if err != nil {
return tunnel, fmt.Errorf("invalid local port: %s", parts[0])
}
remotePort, err := strconv.Atoi(parts[2])
if err != nil {
return tunnel, fmt.Errorf("invalid remote port: %s", parts[2])
}
tunnel.Type = "remote"
tunnel.LocalPort = localPort
tunnel.RemoteHost = parts[1]
tunnel.RemotePort = remotePort
} else if tunnelDynamic != 0 {
tunnel.Type = "dynamic"
tunnel.LocalPort = tunnelDynamic
} else {
return tunnel, fmt.Errorf("must specify --local, --remote, or --dynamic")
}
return tunnel, nil
}
func buildSSHCommand(tunnel TunnelDefinition, sshHost string) []string {
args := []string{"ssh", "-f", "-N"}
switch tunnel.Type {
case "local":
args = append(args, "-L", fmt.Sprintf("%d:%s:%d", tunnel.LocalPort, tunnel.RemoteHost, tunnel.RemotePort))
case "remote":
args = append(args, "-R", fmt.Sprintf("%d:%s:%d", tunnel.LocalPort, tunnel.RemoteHost, tunnel.RemotePort))
case "dynamic":
args = append(args, "-D", strconv.Itoa(tunnel.LocalPort))
}
args = append(args, sshHost)
return args
}
func findTunnelByName(name string) int {
files, err := os.ReadDir(tunnelsDir)
if err != nil {
return 0
}
for _, file := range files {
if !strings.HasSuffix(file.Name(), ".json") {
continue
}
stateFile := filepath.Join(tunnelsDir, file.Name())
state, err := loadTunnelState(stateFile)
if err != nil {
continue
}
if state.Name == name {
pidStr := strings.TrimSuffix(file.Name(), ".json")
pid, _ := strconv.Atoi(pidStr)
return pid
}
}
return 0
}
func isPortInUse(port int) bool {
conn, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return true
}
defer conn.Close()
return false
}
func isProcessAlive(pid int) bool {
process, err := os.FindProcess(pid)
if err != nil {
return false
}
// Send signal 0 to check if process exists
err = process.Signal(syscall.Signal(0))
return err == nil
}
func resolveSSHHost(sshHost string) string {
// Apply smart alias logic if host is a smart alias
if smartAlias, exists := config.SmartAliases[sshHost]; exists {
timeout := parseTimeout(smartAlias.Timeout)
port := getSSHConfigPort(smartAlias.Primary)
if pingHost(smartAlias.CheckHost, timeout, port) {
log.Debug().Str("host", sshHost).Str("resolved", smartAlias.Primary).Msg("Smart alias resolved to primary")
return smartAlias.Primary
} else {
log.Debug().Str("host", sshHost).Str("resolved", smartAlias.Fallback).Msg("Smart alias resolved to fallback")
return smartAlias.Fallback
}
}
return sshHost
}
func loadTunnelState(stateFile string) (TunnelState, error) {
var state TunnelState
data, err := os.ReadFile(stateFile)
if err != nil {
return state, fmt.Errorf("failed to read state file: %w", err)
}
if err := json.Unmarshal(data, &state); err != nil {
return state, fmt.Errorf("failed to parse state file: %w", err)
}
return state, nil
}
func saveTunnelState(stateFile string, state TunnelState) error {
data, err := json.MarshalIndent(state, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal state: %w", err)
}
if err := os.WriteFile(stateFile, data, 0644); err != nil {
return fmt.Errorf("failed to write state file: %w", err)
}
return nil
}
// initLogging configures zerolog based on the logging configuration
func initLogging(cfg LoggingConfig) {
if !cfg.Enabled {
zerolog.SetGlobalLevel(zerolog.Disabled)
return
}
switch strings.ToLower(cfg.Level) {
case "debug":
zerolog.SetGlobalLevel(zerolog.DebugLevel)
case "info":
zerolog.SetGlobalLevel(zerolog.InfoLevel)
case "warn", "warning":
zerolog.SetGlobalLevel(zerolog.WarnLevel)
case "error":
zerolog.SetGlobalLevel(zerolog.ErrorLevel)
default:
zerolog.SetGlobalLevel(zerolog.InfoLevel)
}
if strings.ToLower(cfg.Format) == "console" {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
} else {
log.Logger = log.Output(os.Stderr)
}
}
// loadConfig loads the YAML configuration file
func loadConfig() (*Config, error) {
configFile := filepath.Join(configDir, "config.yaml")
if _, err := os.Stat(configFile); os.IsNotExist(err) {
return &Config{
Logging: LoggingConfig{
Enabled: true,
Level: "info",
Format: "console",
},
SmartAliases: make(map[string]SmartAlias),
Tunnels: make(map[string]TunnelDefinition),
}, nil
}
data, err := os.ReadFile(configFile)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
var config Config
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("failed to parse config YAML: %w", err)
}
if config.Logging.Level == "" {
config.Logging.Level = "info"
}
if config.Logging.Format == "" {
config.Logging.Format = "console"
}
return &config, nil
}
// parseTimeout converts timeout string to time.Duration, defaults to 2s
func parseTimeout(timeoutStr string) time.Duration {
if timeoutStr == "" {
return 2 * time.Second
}
duration, err := time.ParseDuration(timeoutStr)
if err != nil {
return 2 * time.Second
}
return duration
}
// pingHost checks if a host is reachable via TCP connection test
func pingHost(host string, timeout time.Duration, port int) bool {
result := tcpConnectTest(host, timeout, port)
log.Debug().Str("host", host).Int("port", port).Bool("reachable", result).Msg("Connectivity test result")
return result
}
// tcpConnectTest tests TCP connection on the specified port
func tcpConnectTest(host string, timeout time.Duration, port int) bool {
portStr := strconv.Itoa(port)
address := net.JoinHostPort(host, portStr)
conn, err := net.DialTimeout("tcp", address, timeout)
if err != nil {
log.Debug().Str("address", address).Err(err).Msg("TCP connection failed")
return false
}
defer conn.Close()
log.Debug().Str("address", address).Msg("TCP connection successful")
return true
}
// getSSHConfigPort parses the SSH config file to find the port for a given host
func getSSHConfigPort(hostname string) int {
homeDir, err := os.UserHomeDir()
if err != nil {
log.Debug().Err(err).Msg("Failed to get home directory")
return 22
}
mainConfigFile := filepath.Join(homeDir, ".ssh", "config")
configFiles := []string{mainConfigFile}
// Check if main config exists and parse it for includes
if file, err := os.Open(mainConfigFile); err == nil {
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(strings.ToLower(line), "include ") {
parts := strings.Fields(line)
if len(parts) >= 2 {
includePath := parts[1]
if strings.HasPrefix(includePath, "~") {
includePath = strings.Replace(includePath, "~", homeDir, 1)
} else if !filepath.IsAbs(includePath) {
includePath = filepath.Join(homeDir, ".ssh", includePath)
}
if matches, err := filepath.Glob(includePath); err == nil {
configFiles = append(configFiles, matches...)
} else {
configFiles = append(configFiles, includePath)
}
}
}
}
} else {
log.Debug().Str("config_file", mainConfigFile).Err(err).Msg("Main SSH config file not found")
}
// Also check common config.d directory pattern
configDPattern := filepath.Join(homeDir, ".ssh", "config.d", "*")
if matches, err := filepath.Glob(configDPattern); err == nil {
configFiles = append(configFiles, matches...)
}
log.Debug().Str("hostname", hostname).Strs("config_files", configFiles).Msg("Parsing SSH config files for port")
// Parse all config files
for _, configFile := range configFiles {
if port := parseSSHConfigFile(configFile, hostname); port != 22 {
log.Debug().Str("hostname", hostname).Int("port", port).Str("config_file", configFile).Msg("Found port in SSH config")
return port
}
}
log.Debug().Str("hostname", hostname).Int("port", 22).Msg("Using default port")
return 22
}
// parseSSHConfigFile parses a single SSH config file for a host's port
func parseSSHConfigFile(configFile, hostname string) int {
file, err := os.Open(configFile)
if err != nil {
log.Debug().Str("config_file", configFile).Err(err).Msg("Could not open SSH config file")
return 22
}
defer file.Close()
scanner := bufio.NewScanner(file)
var inTargetHost bool
var port int = 22
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// Check for Host directive
if strings.HasPrefix(strings.ToLower(line), "host ") {
parts := strings.Fields(line)
if len(parts) >= 2 {
hostPattern := parts[1]
// Simple match - for more complex patterns, we'd need glob matching
inTargetHost = (hostPattern == hostname)
}
continue
}
// If we're in the target host section, look for Port directive
if inTargetHost && strings.HasPrefix(strings.ToLower(line), "port ") {
parts := strings.Fields(line)
if len(parts) >= 2 {
if parsedPort, err := strconv.Atoi(parts[1]); err == nil {
port = parsedPort
return port // Found the port, return immediately
}
}
}
// If we hit another Host directive and we were in target host, we're done
if inTargetHost && strings.HasPrefix(strings.ToLower(line), "host ") {
break
}
}
return 22 // default port
}
// executeRealSSH executes the real SSH binary with given arguments
func executeRealSSH(args []string) {
// Check if real SSH exists
if _, err := os.Stat(realSSHPath); os.IsNotExist(err) {
log.Error().Str("path", realSSHPath).Msg("Real SSH binary not found")
fmt.Fprintf(os.Stderr, "Error: Real SSH binary not found at %s\n", realSSHPath)
os.Exit(1)
}
log.Debug().Str("ssh_path", realSSHPath).Strs("args", args).Msg("Executing real SSH")
// Execute the real SSH binary
// Using syscall.Exec to replace current process (like exec in shell)
err := syscall.Exec(realSSHPath, append([]string{"ssh"}, args...), os.Environ())
if err != nil {
log.Error().Err(err).Msg("Failed to execute SSH")
fmt.Fprintf(os.Stderr, "Error executing SSH: %v\n", err)
os.Exit(1)
}
}