package main import ( "bufio" "encoding/json" "fmt" "net" "os" "os/exec" "path/filepath" "strconv" "strings" "syscall" "time" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/spf13/cobra" "gopkg.in/yaml.v3" ) // LoggingConfig represents logging configuration type LoggingConfig struct { Enabled bool `yaml:"enabled"` Level string `yaml:"level"` Format string `yaml:"format"` } // SmartAlias represents a smart SSH alias configuration type SmartAlias struct { Primary string `yaml:"primary"` // SSH config host to use when local Fallback string `yaml:"fallback"` // SSH config host to use when remote CheckHost string `yaml:"check_host"` // IP to ping for connectivity test Timeout string `yaml:"timeout"` // Ping timeout (default: "2s") } // TunnelDefinition represents a tunnel configuration type TunnelDefinition struct { Type string `yaml:"type"` // local, remote, dynamic LocalPort int `yaml:"local_port"` // Local port for binding RemoteHost string `yaml:"remote_host"` // Remote host (for local/remote tunnels) RemotePort int `yaml:"remote_port"` // Remote port (for local/remote tunnels) SSHHost string `yaml:"ssh_host"` // SSH host to tunnel through } // TunnelState represents runtime state of an active tunnel type TunnelState struct { Name string `json:"name"` Source string `json:"source"` // "config" or "adhoc" Type string `json:"type"` // local, remote, dynamic LocalPort int `json:"local_port"` RemoteHost string `json:"remote_host"` RemotePort int `json:"remote_port"` SSHHost string `json:"ssh_host"` SSHHostResolved string `json:"ssh_host_resolved"` // After smart alias resolution PID int `json:"pid"` Status string `json:"status"` StartedAt time.Time `json:"started_at"` LastSeen time.Time `json:"last_seen"` CommandLine string `json:"command_line"` } // Config represents the YAML configuration structure type Config struct { Logging LoggingConfig `yaml:"logging"` SmartAliases map[string]SmartAlias `yaml:"smart_aliases"` Tunnels map[string]TunnelDefinition `yaml:"tunnels"` } const ( realSSHPath = "/usr/bin/ssh" ) var ( configDir string tunnelsDir string config *Config // Global flags tunnelMode bool // Tunnel command flags tunnelOpen bool tunnelClose bool tunnelList bool tunnelLocal string tunnelRemote string tunnelDynamic int tunnelVia string ) var rootCmd = &cobra.Command{ Use: "ssh", Short: "Smart SSH utility with tunnel management", Long: "A transparent SSH wrapper that provides smart alias resolution and background tunnel management", Run: handleSSH, DisableFlagParsing: true, } var tunnelCmd = &cobra.Command{ Use: "tunnel [tunnel-name]", Short: "Manage background SSH tunnels", Long: "Create, list, and manage persistent SSH tunnels in the background", Run: func(cmd *cobra.Command, args []string) { handleTunnelManual(append([]string{"--tunnel"}, args...)) }, Args: cobra.MaximumNArgs(1), } func init() { // Initialize config directory homeDir, err := os.UserHomeDir() if err != nil { fmt.Fprintf(os.Stderr, "Error: Failed to get home directory: %v\n", err) os.Exit(1) } configDir = filepath.Join(homeDir, ".config", "ssh-util") tunnelsDir = filepath.Join(configDir, "tunnels") // Ensure directories exist os.MkdirAll(tunnelsDir, 0755) // Load configuration var configErr error config, configErr = loadConfig() if configErr != nil { // Use default config if loading fails config = &Config{ Logging: LoggingConfig{ Enabled: true, Level: "info", Format: "console", }, SmartAliases: make(map[string]SmartAlias), Tunnels: make(map[string]TunnelDefinition), } } // Initialize logging initLogging(config.Logging) // Global flags rootCmd.PersistentFlags().BoolVarP(&tunnelMode, "tunnel", "T", false, "Enable tunnel mode") rootCmd.Flags().BoolVarP(&tunnelOpen, "open", "O", false, "Open a tunnel") rootCmd.Flags().BoolVarP(&tunnelClose, "close", "C", false, "Close a tunnel") rootCmd.Flags().BoolVarP(&tunnelList, "list", "L", false, "List active tunnels") rootCmd.Flags().StringVar(&tunnelLocal, "local", "", "Local port forwarding (port:host:port)") rootCmd.Flags().StringVar(&tunnelRemote, "remote", "", "Remote port forwarding (port:host:port)") rootCmd.Flags().IntVar(&tunnelDynamic, "dynamic", 0, "Dynamic port forwarding (SOCKS proxy port)") rootCmd.Flags().StringVar(&tunnelVia, "via", "", "SSH host to tunnel through") // Add tunnel command rootCmd.AddCommand(tunnelCmd) // Tunnel command flags (same as root for consistency) tunnelCmd.Flags().BoolVarP(&tunnelOpen, "open", "O", false, "Open a tunnel") tunnelCmd.Flags().BoolVarP(&tunnelClose, "close", "C", false, "Close a tunnel") tunnelCmd.Flags().BoolVarP(&tunnelList, "list", "L", false, "List active tunnels") tunnelCmd.Flags().StringVar(&tunnelLocal, "local", "", "Local port forwarding (port:host:port)") tunnelCmd.Flags().StringVar(&tunnelRemote, "remote", "", "Remote port forwarding (port:host:port)") tunnelCmd.Flags().IntVar(&tunnelDynamic, "dynamic", 0, "Dynamic port forwarding (SOCKS proxy port)") tunnelCmd.Flags().StringVar(&tunnelVia, "via", "", "SSH host to tunnel through") // Handle combined flags like -TO, -TC, -TL rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { return handleCombinedFlags(cmd, args) } } func main() { // Check if this is a tunnel command first args := os.Args[1:] isTunnelCommand := false for _, arg := range args { if arg == "--tunnel" || arg == "-T" || strings.HasPrefix(arg, "-T") { isTunnelCommand = true break } if arg == "tunnel" { isTunnelCommand = true break } } if isTunnelCommand { // Use Cobra for tunnel commands if err := rootCmd.Execute(); err != nil { log.Error().Err(err).Msg("Command execution failed") os.Exit(1) } } else { // Bypass Cobra for regular SSH commands (smart alias resolution) handleSSHDirect(args) } } func handleCombinedFlags(cmd *cobra.Command, args []string) error { // Check for combined tunnel flags in os.Args for _, arg := range os.Args { if strings.HasPrefix(arg, "-T") && len(arg) > 2 { // Handle combined flags like -TO, -TC, -TL tunnelMode = true suffix := arg[2:] if strings.Contains(suffix, "O") { tunnelOpen = true } if strings.Contains(suffix, "C") { tunnelClose = true } if strings.Contains(suffix, "L") { tunnelList = true } break } } return nil } func handleSSH(cmd *cobra.Command, args []string) { // This handles tunnel commands via Cobra handleTunnelManual(os.Args[1:]) } func handleSSHDirect(args []string) { log.Debug().Strs("original_args", args).Msg("SSH utility started") // Pass through immediately if no args, starts with dash (flags), or contains @ if len(args) == 0 || (len(args) > 0 && (strings.HasPrefix(args[0], "-") || strings.Contains(args[0], "@"))) { log.Debug().Msg("Passing through to real SSH (no smart alias detected)") executeRealSSH(args) return } // Check if first argument is a smart alias aliasName := args[0] modifiedArgs := make([]string, len(args)) copy(modifiedArgs, args) if smartAlias, exists := config.SmartAliases[aliasName]; exists { log.Info().Str("alias", aliasName).Msg("Smart alias detected") // Parse timeout timeout := parseTimeout(smartAlias.Timeout) log.Debug().Dur("timeout", timeout).Msg("Parsed timeout") // Get the port for the primary host from SSH config port := getSSHConfigPort(smartAlias.Primary) log.Debug().Str("host", smartAlias.Primary).Int("port", port).Msg("Extracted port from SSH config") // Test connectivity to determine which host to use log.Info().Str("check_host", smartAlias.CheckHost).Int("port", port).Msg("Testing connectivity") if pingHost(smartAlias.CheckHost, timeout, port) { // Local network is reachable, use primary log.Info().Str("chosen_host", smartAlias.Primary).Msg("Local network reachable, using primary host") modifiedArgs[0] = smartAlias.Primary } else { // Local network not reachable, use fallback log.Info().Str("chosen_host", smartAlias.Fallback).Msg("Local network not reachable, using fallback host") modifiedArgs[0] = smartAlias.Fallback } } else { log.Debug().Str("host", aliasName).Msg("Not a smart alias, passing through") } // Execute the real SSH with potentially modified arguments log.Debug().Strs("final_args", modifiedArgs).Msg("Executing real SSH") executeRealSSH(modifiedArgs) } func handleTunnelManual(args []string) { log.Debug().Msg("Tunnel mode activated") // Always validate tunnel states first if err := validateTunnelStates(); err != nil { log.Error().Err(err).Msg("Failed to validate tunnel states") } // Parse tunnel arguments manually var tunnelName string var action string var localForward, remoteForward, via string var dynamicPort int for i, arg := range args { switch arg { case "--tunnel", "-T": continue case "--open", "-O": action = "open" case "--close", "-C": action = "close" case "--list", "-L": action = "list" case "--local": if i+1 < len(args) { localForward = args[i+1] } case "--remote": if i+1 < len(args) { remoteForward = args[i+1] } case "--via": if i+1 < len(args) { via = args[i+1] } case "--dynamic": if i+1 < len(args) { fmt.Sscanf(args[i+1], "%d", &dynamicPort) } default: if strings.HasPrefix(arg, "-T") && len(arg) > 2 { suffix := arg[2:] if strings.Contains(suffix, "O") { action = "open" } if strings.Contains(suffix, "C") { action = "close" } if strings.Contains(suffix, "L") { action = "list" } } else if !strings.HasPrefix(arg, "-") && tunnelName == "" { tunnelName = arg } } } // Handle tunnel commands if action == "list" { listTunnels() return } if tunnelName == "" && action != "list" { fmt.Fprintf(os.Stderr, "Error: tunnel name required\n") fmt.Fprintf(os.Stderr, "Usage: ssh --tunnel --open [flags]\n") os.Exit(1) } if action == "open" { // Set global variables for openTunnel function tunnelLocal = localForward tunnelRemote = remoteForward tunnelDynamic = dynamicPort tunnelVia = via if err := openTunnel(tunnelName); err != nil { log.Error().Err(err).Str("tunnel", tunnelName).Msg("Failed to open tunnel") fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } } else if action == "close" { if err := closeTunnel(tunnelName); err != nil { log.Error().Err(err).Str("tunnel", tunnelName).Msg("Failed to close tunnel") fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } } else { fmt.Fprintf(os.Stderr, "Error: must specify --open, --close, or --list\n") os.Exit(1) } } func validateTunnelStates() error { files, err := os.ReadDir(tunnelsDir) if err != nil { return fmt.Errorf("failed to read tunnels directory: %w", err) } for _, file := range files { if !strings.HasSuffix(file.Name(), ".json") { continue } pidStr := strings.TrimSuffix(file.Name(), ".json") pid, err := strconv.Atoi(pidStr) if err != nil { log.Warn().Str("file", file.Name()).Msg("Invalid PID filename, removing") os.Remove(filepath.Join(tunnelsDir, file.Name())) continue } if !isProcessAlive(pid) { log.Info().Int("pid", pid).Msg("Removing state for dead tunnel") os.Remove(filepath.Join(tunnelsDir, file.Name())) continue } // Update last seen time stateFile := filepath.Join(tunnelsDir, file.Name()) state, err := loadTunnelState(stateFile) if err != nil { log.Warn().Str("file", stateFile).Err(err).Msg("Failed to load tunnel state") continue } state.LastSeen = time.Now() if err := saveTunnelState(stateFile, state); err != nil { log.Warn().Str("file", stateFile).Err(err).Msg("Failed to update tunnel state") } } return nil } func listTunnels() { files, err := os.ReadDir(tunnelsDir) if err != nil { fmt.Fprintf(os.Stderr, "Error: Failed to read tunnels directory: %v\n", err) return } if len(files) == 0 { fmt.Println("No active tunnels") return } fmt.Printf("%-20s %-8s %-8s %-25s %-12s %-8s %s\n", "NAME", "TYPE", "LOCAL", "REMOTE", "HOST", "PID", "UPTIME") fmt.Println(strings.Repeat("-", 80)) for _, file := range files { if !strings.HasSuffix(file.Name(), ".json") { continue } stateFile := filepath.Join(tunnelsDir, file.Name()) state, err := loadTunnelState(stateFile) if err != nil { continue } uptime := time.Since(state.StartedAt).Truncate(time.Second) remote := "" if state.Type == "local" || state.Type == "remote" { remote = fmt.Sprintf("%s:%d", state.RemoteHost, state.RemotePort) } else if state.Type == "dynamic" { remote = "SOCKS" } fmt.Printf("%-20s %-8s %-8d %-25s %-12s %-8d %s\n", state.Name, state.Type, state.LocalPort, remote, state.SSHHostResolved, state.PID, uptime) } } func openTunnel(name string) error { // Check if tunnel is already running if existingPID := findTunnelByName(name); existingPID != 0 { return fmt.Errorf("tunnel '%s' already running (PID %d)", name, existingPID) } var tunnel TunnelDefinition var source string // Check if tunnel is defined in config if configTunnel, exists := config.Tunnels[name]; exists { tunnel = configTunnel source = "config" log.Info().Str("tunnel", name).Msg("Using tunnel definition from config") } else { // Create ad-hoc tunnel from flags var err error tunnel, err = createAdhocTunnel() if err != nil { return fmt.Errorf("tunnel '%s' not found in config and invalid adhoc parameters: %w", name, err) } source = "adhoc" log.Info().Str("tunnel", name).Msg("Creating ad-hoc tunnel") } // Check for port conflicts if isPortInUse(tunnel.LocalPort) { return fmt.Errorf("port %d already in use", tunnel.LocalPort) } // Resolve SSH host using smart alias logic resolvedSSHHost := resolveSSHHost(tunnel.SSHHost) log.Debug().Str("original", tunnel.SSHHost).Str("resolved", resolvedSSHHost).Msg("SSH host resolution") // Build SSH command cmdArgs := buildSSHCommand(tunnel, resolvedSSHHost) log.Debug().Strs("command", cmdArgs).Msg("Starting SSH tunnel") // Start SSH process cmd := &exec.Cmd{ Path: realSSHPath, Args: cmdArgs, } if err := cmd.Start(); err != nil { return fmt.Errorf("failed to start SSH tunnel: %w", err) } pid := cmd.Process.Pid log.Info().Str("tunnel", name).Int("pid", pid).Msg("SSH tunnel started") // Create tunnel state state := TunnelState{ Name: name, Source: source, Type: tunnel.Type, LocalPort: tunnel.LocalPort, RemoteHost: tunnel.RemoteHost, RemotePort: tunnel.RemotePort, SSHHost: tunnel.SSHHost, SSHHostResolved: resolvedSSHHost, PID: pid, Status: "active", StartedAt: time.Now(), LastSeen: time.Now(), CommandLine: strings.Join(cmdArgs, " "), } // Save state file stateFile := filepath.Join(tunnelsDir, fmt.Sprintf("%d.json", pid)) if err := saveTunnelState(stateFile, state); err != nil { // If we can't save state, kill the process cmd.Process.Kill() return fmt.Errorf("failed to save tunnel state: %w", err) } fmt.Printf("Tunnel '%s' opened on port %d (PID %d)\n", name, tunnel.LocalPort, pid) return nil } func closeTunnel(name string) error { pid := findTunnelByName(name) if pid == 0 { return fmt.Errorf("tunnel '%s' not found", name) } // Kill the process process, err := os.FindProcess(pid) if err != nil { return fmt.Errorf("failed to find process %d: %w", pid, err) } if err := process.Kill(); err != nil { return fmt.Errorf("failed to kill process %d: %w", pid, err) } // Remove state file stateFile := filepath.Join(tunnelsDir, fmt.Sprintf("%d.json", pid)) if err := os.Remove(stateFile); err != nil { log.Warn().Str("file", stateFile).Err(err).Msg("Failed to remove state file") } log.Info().Str("tunnel", name).Int("pid", pid).Msg("Tunnel closed") fmt.Printf("Tunnel '%s' closed\n", name) return nil } func createAdhocTunnel() (TunnelDefinition, error) { tunnel := TunnelDefinition{} if tunnelVia == "" { return tunnel, fmt.Errorf("--via flag required for ad-hoc tunnels") } tunnel.SSHHost = tunnelVia if tunnelLocal != "" { parts := strings.Split(tunnelLocal, ":") if len(parts) != 3 { return tunnel, fmt.Errorf("invalid --local format, expected port:host:port") } localPort, err := strconv.Atoi(parts[0]) if err != nil { return tunnel, fmt.Errorf("invalid local port: %s", parts[0]) } remotePort, err := strconv.Atoi(parts[2]) if err != nil { return tunnel, fmt.Errorf("invalid remote port: %s", parts[2]) } tunnel.Type = "local" tunnel.LocalPort = localPort tunnel.RemoteHost = parts[1] tunnel.RemotePort = remotePort } else if tunnelRemote != "" { parts := strings.Split(tunnelRemote, ":") if len(parts) != 3 { return tunnel, fmt.Errorf("invalid --remote format, expected port:host:port") } localPort, err := strconv.Atoi(parts[0]) if err != nil { return tunnel, fmt.Errorf("invalid local port: %s", parts[0]) } remotePort, err := strconv.Atoi(parts[2]) if err != nil { return tunnel, fmt.Errorf("invalid remote port: %s", parts[2]) } tunnel.Type = "remote" tunnel.LocalPort = localPort tunnel.RemoteHost = parts[1] tunnel.RemotePort = remotePort } else if tunnelDynamic != 0 { tunnel.Type = "dynamic" tunnel.LocalPort = tunnelDynamic } else { return tunnel, fmt.Errorf("must specify --local, --remote, or --dynamic") } return tunnel, nil } func buildSSHCommand(tunnel TunnelDefinition, sshHost string) []string { args := []string{"ssh", "-f", "-N"} switch tunnel.Type { case "local": args = append(args, "-L", fmt.Sprintf("%d:%s:%d", tunnel.LocalPort, tunnel.RemoteHost, tunnel.RemotePort)) case "remote": args = append(args, "-R", fmt.Sprintf("%d:%s:%d", tunnel.LocalPort, tunnel.RemoteHost, tunnel.RemotePort)) case "dynamic": args = append(args, "-D", strconv.Itoa(tunnel.LocalPort)) } args = append(args, sshHost) return args } func findTunnelByName(name string) int { files, err := os.ReadDir(tunnelsDir) if err != nil { return 0 } for _, file := range files { if !strings.HasSuffix(file.Name(), ".json") { continue } stateFile := filepath.Join(tunnelsDir, file.Name()) state, err := loadTunnelState(stateFile) if err != nil { continue } if state.Name == name { pidStr := strings.TrimSuffix(file.Name(), ".json") pid, _ := strconv.Atoi(pidStr) return pid } } return 0 } func isPortInUse(port int) bool { conn, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { return true } defer conn.Close() return false } func isProcessAlive(pid int) bool { process, err := os.FindProcess(pid) if err != nil { return false } // Send signal 0 to check if process exists err = process.Signal(syscall.Signal(0)) return err == nil } func resolveSSHHost(sshHost string) string { // Apply smart alias logic if host is a smart alias if smartAlias, exists := config.SmartAliases[sshHost]; exists { timeout := parseTimeout(smartAlias.Timeout) port := getSSHConfigPort(smartAlias.Primary) if pingHost(smartAlias.CheckHost, timeout, port) { log.Debug().Str("host", sshHost).Str("resolved", smartAlias.Primary).Msg("Smart alias resolved to primary") return smartAlias.Primary } else { log.Debug().Str("host", sshHost).Str("resolved", smartAlias.Fallback).Msg("Smart alias resolved to fallback") return smartAlias.Fallback } } return sshHost } func loadTunnelState(stateFile string) (TunnelState, error) { var state TunnelState data, err := os.ReadFile(stateFile) if err != nil { return state, fmt.Errorf("failed to read state file: %w", err) } if err := json.Unmarshal(data, &state); err != nil { return state, fmt.Errorf("failed to parse state file: %w", err) } return state, nil } func saveTunnelState(stateFile string, state TunnelState) error { data, err := json.MarshalIndent(state, "", " ") if err != nil { return fmt.Errorf("failed to marshal state: %w", err) } if err := os.WriteFile(stateFile, data, 0644); err != nil { return fmt.Errorf("failed to write state file: %w", err) } return nil } // initLogging configures zerolog based on the logging configuration func initLogging(cfg LoggingConfig) { if !cfg.Enabled { zerolog.SetGlobalLevel(zerolog.Disabled) return } switch strings.ToLower(cfg.Level) { case "debug": zerolog.SetGlobalLevel(zerolog.DebugLevel) case "info": zerolog.SetGlobalLevel(zerolog.InfoLevel) case "warn", "warning": zerolog.SetGlobalLevel(zerolog.WarnLevel) case "error": zerolog.SetGlobalLevel(zerolog.ErrorLevel) default: zerolog.SetGlobalLevel(zerolog.InfoLevel) } if strings.ToLower(cfg.Format) == "console" { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) } else { log.Logger = log.Output(os.Stderr) } } // loadConfig loads the YAML configuration file func loadConfig() (*Config, error) { configFile := filepath.Join(configDir, "config.yaml") if _, err := os.Stat(configFile); os.IsNotExist(err) { return &Config{ Logging: LoggingConfig{ Enabled: true, Level: "info", Format: "console", }, SmartAliases: make(map[string]SmartAlias), Tunnels: make(map[string]TunnelDefinition), }, nil } data, err := os.ReadFile(configFile) if err != nil { return nil, fmt.Errorf("failed to read config file: %w", err) } var config Config if err := yaml.Unmarshal(data, &config); err != nil { return nil, fmt.Errorf("failed to parse config YAML: %w", err) } if config.Logging.Level == "" { config.Logging.Level = "info" } if config.Logging.Format == "" { config.Logging.Format = "console" } return &config, nil } // parseTimeout converts timeout string to time.Duration, defaults to 2s func parseTimeout(timeoutStr string) time.Duration { if timeoutStr == "" { return 2 * time.Second } duration, err := time.ParseDuration(timeoutStr) if err != nil { return 2 * time.Second } return duration } // pingHost checks if a host is reachable via TCP connection test func pingHost(host string, timeout time.Duration, port int) bool { result := tcpConnectTest(host, timeout, port) log.Debug().Str("host", host).Int("port", port).Bool("reachable", result).Msg("Connectivity test result") return result } // tcpConnectTest tests TCP connection on the specified port func tcpConnectTest(host string, timeout time.Duration, port int) bool { portStr := strconv.Itoa(port) address := net.JoinHostPort(host, portStr) conn, err := net.DialTimeout("tcp", address, timeout) if err != nil { log.Debug().Str("address", address).Err(err).Msg("TCP connection failed") return false } defer conn.Close() log.Debug().Str("address", address).Msg("TCP connection successful") return true } // getSSHConfigPort parses the SSH config file to find the port for a given host func getSSHConfigPort(hostname string) int { homeDir, err := os.UserHomeDir() if err != nil { log.Debug().Err(err).Msg("Failed to get home directory") return 22 } mainConfigFile := filepath.Join(homeDir, ".ssh", "config") configFiles := []string{mainConfigFile} // Check if main config exists and parse it for includes if file, err := os.Open(mainConfigFile); err == nil { defer file.Close() scanner := bufio.NewScanner(file) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if strings.HasPrefix(strings.ToLower(line), "include ") { parts := strings.Fields(line) if len(parts) >= 2 { includePath := parts[1] if strings.HasPrefix(includePath, "~") { includePath = strings.Replace(includePath, "~", homeDir, 1) } else if !filepath.IsAbs(includePath) { includePath = filepath.Join(homeDir, ".ssh", includePath) } if matches, err := filepath.Glob(includePath); err == nil { configFiles = append(configFiles, matches...) } else { configFiles = append(configFiles, includePath) } } } } } else { log.Debug().Str("config_file", mainConfigFile).Err(err).Msg("Main SSH config file not found") } // Also check common config.d directory pattern configDPattern := filepath.Join(homeDir, ".ssh", "config.d", "*") if matches, err := filepath.Glob(configDPattern); err == nil { configFiles = append(configFiles, matches...) } log.Debug().Str("hostname", hostname).Strs("config_files", configFiles).Msg("Parsing SSH config files for port") // Parse all config files for _, configFile := range configFiles { if port := parseSSHConfigFile(configFile, hostname); port != 22 { log.Debug().Str("hostname", hostname).Int("port", port).Str("config_file", configFile).Msg("Found port in SSH config") return port } } log.Debug().Str("hostname", hostname).Int("port", 22).Msg("Using default port") return 22 } // parseSSHConfigFile parses a single SSH config file for a host's port func parseSSHConfigFile(configFile, hostname string) int { file, err := os.Open(configFile) if err != nil { log.Debug().Str("config_file", configFile).Err(err).Msg("Could not open SSH config file") return 22 } defer file.Close() scanner := bufio.NewScanner(file) var inTargetHost bool var port int = 22 for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line == "" || strings.HasPrefix(line, "#") { continue } // Check for Host directive if strings.HasPrefix(strings.ToLower(line), "host ") { parts := strings.Fields(line) if len(parts) >= 2 { hostPattern := parts[1] // Simple match - for more complex patterns, we'd need glob matching inTargetHost = (hostPattern == hostname) } continue } // If we're in the target host section, look for Port directive if inTargetHost && strings.HasPrefix(strings.ToLower(line), "port ") { parts := strings.Fields(line) if len(parts) >= 2 { if parsedPort, err := strconv.Atoi(parts[1]); err == nil { port = parsedPort return port // Found the port, return immediately } } } // If we hit another Host directive and we were in target host, we're done if inTargetHost && strings.HasPrefix(strings.ToLower(line), "host ") { break } } return 22 // default port } // executeRealSSH executes the real SSH binary with given arguments func executeRealSSH(args []string) { // Check if real SSH exists if _, err := os.Stat(realSSHPath); os.IsNotExist(err) { log.Error().Str("path", realSSHPath).Msg("Real SSH binary not found") fmt.Fprintf(os.Stderr, "Error: Real SSH binary not found at %s\n", realSSHPath) os.Exit(1) } log.Debug().Str("ssh_path", realSSHPath).Strs("args", args).Msg("Executing real SSH") // Execute the real SSH binary // Using syscall.Exec to replace current process (like exec in shell) err := syscall.Exec(realSSHPath, append([]string{"ssh"}, args...), os.Environ()) if err != nil { log.Error().Err(err).Msg("Failed to execute SSH") fmt.Fprintf(os.Stderr, "Error executing SSH: %v\n", err) os.Exit(1) } }