From f0bf6bc8aaab3bd370fa753ae705f1436de8d690 Mon Sep 17 00:00:00 2001 From: Menno van Leeuwen Date: Fri, 25 Jul 2025 14:54:29 +0200 Subject: [PATCH] wip Signed-off-by: Menno van Leeuwen --- config/ansible/tasks/global/utils.yml | 27 +- .../ansible/tasks/global/utils/ssh/README.md | 119 +++ .../tasks/global/utils/ssh/config.yaml | 77 ++ config/ansible/tasks/global/utils/ssh/go.mod | 17 + config/ansible/tasks/global/utils/ssh/go.sum | 27 + config/ansible/tasks/global/utils/ssh/ssh.go | 950 ++++++++++++++++++ config/home-manager/flake.lock | 18 +- config/ssh/ssh-util-config.yaml | 17 + 8 files changed, 1241 insertions(+), 11 deletions(-) create mode 100644 config/ansible/tasks/global/utils/ssh/README.md create mode 100644 config/ansible/tasks/global/utils/ssh/config.yaml create mode 100644 config/ansible/tasks/global/utils/ssh/go.mod create mode 100644 config/ansible/tasks/global/utils/ssh/go.sum create mode 100644 config/ansible/tasks/global/utils/ssh/ssh.go create mode 100644 config/ssh/ssh-util-config.yaml diff --git a/config/ansible/tasks/global/utils.yml b/config/ansible/tasks/global/utils.yml index 9c4bcbd..f36e07b 100644 --- a/config/ansible/tasks/global/utils.yml +++ b/config/ansible/tasks/global/utils.yml @@ -13,13 +13,28 @@ mode: "0755" become: false - - name: Scan utils folder and create symlinks in ~/.local/bin + - name: Scan utils folder for files ansible.builtin.find: paths: "{{ dotfiles_path }}/config/ansible/tasks/global/utils" file_type: file register: utils_files become: false + - name: Scan utils folder for Go projects (directories with go.mod) + ansible.builtin.find: + paths: "{{ dotfiles_path }}/config/ansible/tasks/global/utils" + file_type: directory + recurse: true + register: utils_dirs + become: false + + - name: Filter directories that contain go.mod files + ansible.builtin.stat: + path: "{{ item.path }}/go.mod" + loop: "{{ utils_dirs.files }}" + register: go_mod_check + become: false + - name: Create symlinks for utils scripts ansible.builtin.file: src: "{{ item.path }}" @@ -29,11 +44,19 @@ when: not item.path.endswith('.go') become: false - - name: Compile Go files and place binaries in ~/.local/bin + - name: Compile standalone Go files and place binaries in ~/.local/bin ansible.builtin.command: cmd: go build -o "{{ ansible_env.HOME }}/.local/bin/{{ item.path | basename | regex_replace('\.go$', '') }}" "{{ item.path }}" loop: "{{ utils_files.files }}" when: item.path.endswith('.go') become: false + + - name: Compile Go projects and place binaries in ~/.local/bin + ansible.builtin.command: + cmd: go build -o "{{ ansible_env.HOME }}/.local/bin/{{ item.item.path | basename }}" . + chdir: "{{ item.item.path }}" + loop: "{{ go_mod_check.results }}" + when: item.stat.exists + become: false tags: - utils diff --git a/config/ansible/tasks/global/utils/ssh/README.md b/config/ansible/tasks/global/utils/ssh/README.md new file mode 100644 index 0000000..fc9b083 --- /dev/null +++ b/config/ansible/tasks/global/utils/ssh/README.md @@ -0,0 +1,119 @@ +# SSH Utility - Smart SSH Connection Manager + +A transparent SSH wrapper that automatically chooses between local and remote connections based on network connectivity. + +## What it does + +This utility acts as a drop-in replacement for the `ssh` command that intelligently routes connections: + +- When you type `ssh desktop`, it automatically checks if your local network is available +- If local: connects via `desktop-local` (faster local connection) +- If remote: connects via `desktop` (Tailscale/VPN connection) +- All other SSH usage passes through unchanged (`ssh --help`, `ssh user@host`, etc.) + +## Installation + +The utility is automatically compiled and installed to `~/.local/bin/ssh` via Ansible when you run your dotfiles setup. + +## Configuration + +1. Copy the example config: + ```bash + mkdir -p ~/.config/ssh-util + cp ~/.dotfiles/config/ssh-util/config.yaml ~/.config/ssh-util/ + ``` + +2. Edit `~/.config/ssh-util/config.yaml` to match your setup: + ```yaml + smart_aliases: + desktop: + primary: "desktop-local" # SSH config entry for local connection + fallback: "desktop" # SSH config entry for remote connection + check_host: "192.168.86.22" # IP to ping for connectivity test + timeout: "2s" # Ping timeout + ``` + +3. Ensure your `~/.ssh/config` contains the referenced host entries: + ``` + Host desktop + HostName mennos-cachyos-desktop + User menno + Port 400 + ForwardAgent yes + AddKeysToAgent yes + + Host desktop-local + HostName 192.168.86.22 + User menno + Port 400 + ForwardAgent yes + AddKeysToAgent yes + ``` + +## Usage + +Once configured, simply use SSH as normal: + +```bash +# Smart connection - automatically chooses local vs remote +ssh desktop + +# All other SSH usage works exactly the same +ssh --help +ssh --version +ssh user@example.com +ssh -L 8080:localhost:80 server +``` + +## How it works + +1. When you run `ssh `, the utility checks if `` is defined in the smart_aliases config +2. If yes, it pings the `check_host` IP address +3. If ping succeeds: executes `ssh ` instead +4. If ping fails: executes `ssh ` instead +5. If not a smart alias: passes through to real SSH unchanged + +## Troubleshooting + +### SSH utility not found +Make sure `~/.local/bin` is in your PATH: +```bash +echo $PATH | grep -o ~/.local/bin +``` + +### Config not loading +Check the config file exists and has correct syntax: +```bash +ls -la ~/.config/ssh-util/config.yaml +cat ~/.config/ssh-util/config.yaml +``` + +### Connectivity test failing +Test manually: +```bash +ping -c 1 -W 2 192.168.86.22 +``` + +### Falls back to real SSH +If there are any errors loading config or parsing, the utility safely falls back to executing the real SSH binary at `/usr/bin/ssh`. + +## Adding more aliases + +To add more smart aliases, just extend the config: + +```yaml +smart_aliases: + desktop: + primary: "desktop-local" + fallback: "desktop" + check_host: "192.168.86.22" + timeout: "2s" + + server: + primary: "server-local" + fallback: "server-remote" + check_host: "192.168.1.100" + timeout: "1s" +``` + +Remember to create the corresponding entries in your `~/.ssh/config`. diff --git a/config/ansible/tasks/global/utils/ssh/config.yaml b/config/ansible/tasks/global/utils/ssh/config.yaml new file mode 100644 index 0000000..9fcc7ae --- /dev/null +++ b/config/ansible/tasks/global/utils/ssh/config.yaml @@ -0,0 +1,77 @@ +# SSH Utility Configuration +# This file defines smart aliases that automatically choose between local and remote connections + +# Logging configuration +logging: + enabled: true + # Levels: debug, info, warn, error + level: "info" + # Formats: console, json + format: "console" + +smart_aliases: + # Desktop connection - tries local network first, falls back to Tailscale + desktop: + primary: "desktop-local" # Use this SSH config entry when local network is available + fallback: "desktop" # Use this SSH config entry when local network is not available + check_host: "192.168.86.22" # IP address to ping for connectivity test + timeout: "2s" # Timeout for connectivity check + +# Background SSH Tunnel Definitions +tunnels: + # Example: Desktop database tunnel + desktop-database: + type: local + local_port: 5432 + remote_host: database + remote_port: 5432 + ssh_host: desktop # Uses smart alias logic (desktop-local/desktop) + + # Example: Development API tunnel + dev-api: + type: local + local_port: 8080 + remote_host: api + remote_port: 80 + ssh_host: dev-server + + # Example: SOCKS proxy tunnel + socks-proxy: + type: dynamic + local_port: 1080 + ssh_host: bastion +# Tunnel Management Commands: +# ssh --tunnel --open desktop-database (or ssh -TO desktop-database) +# ssh --tunnel --close desktop-database (or ssh -TC desktop-database) +# ssh --tunnel --list (or ssh -TL) +# +# Ad-hoc tunnels (not in config): +# ssh -TO temp-api --local 8080:api:80 --via server + +# Logging options: +# - enabled: true/false - whether to show any logs +# - level: debug (verbose), info (normal), warn (warnings only), error (errors only) +# - format: console (human readable), json (structured) +# Logs are written to stderr so they don't interfere with SSH output + +# How it works: +# 1. When you run: ssh desktop +# 2. The utility pings 192.168.86.22 with a 2s timeout +# 3. If ping succeeds: runs "ssh desktop-local" instead +# 4. If ping fails: runs "ssh desktop" instead +# 5. All other SSH usage (flags, user@host, etc.) passes through unchanged + +# Your SSH config should contain the actual host definitions: +# Host desktop +# HostName mennos-cachyos-desktop +# User menno +# Port 400 +# ForwardAgent yes +# AddKeysToAgent yes +# +# Host desktop-local +# HostName 192.168.86.22 +# User menno +# Port 400 +# ForwardAgent yes +# AddKeysToAgent yes diff --git a/config/ansible/tasks/global/utils/ssh/go.mod b/config/ansible/tasks/global/utils/ssh/go.mod new file mode 100644 index 0000000..3301abb --- /dev/null +++ b/config/ansible/tasks/global/utils/ssh/go.mod @@ -0,0 +1,17 @@ +module ssh-util + +go 1.21 + +require ( + github.com/rs/zerolog v1.31.0 + github.com/spf13/cobra v1.8.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/sys v0.12.0 // indirect +) diff --git a/config/ansible/tasks/global/utils/ssh/go.sum b/config/ansible/tasks/global/utils/ssh/go.sum new file mode 100644 index 0000000..630d9de --- /dev/null +++ b/config/ansible/tasks/global/utils/ssh/go.sum @@ -0,0 +1,27 @@ +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A= +github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= +github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/config/ansible/tasks/global/utils/ssh/ssh.go b/config/ansible/tasks/global/utils/ssh/ssh.go new file mode 100644 index 0000000..6395949 --- /dev/null +++ b/config/ansible/tasks/global/utils/ssh/ssh.go @@ -0,0 +1,950 @@ +package main + +import ( + "bufio" + "encoding/json" + "fmt" + "net" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "syscall" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" +) + +// LoggingConfig represents logging configuration +type LoggingConfig struct { + Enabled bool `yaml:"enabled"` + Level string `yaml:"level"` + Format string `yaml:"format"` +} + +// SmartAlias represents a smart SSH alias configuration +type SmartAlias struct { + Primary string `yaml:"primary"` // SSH config host to use when local + Fallback string `yaml:"fallback"` // SSH config host to use when remote + CheckHost string `yaml:"check_host"` // IP to ping for connectivity test + Timeout string `yaml:"timeout"` // Ping timeout (default: "2s") +} + +// TunnelDefinition represents a tunnel configuration +type TunnelDefinition struct { + Type string `yaml:"type"` // local, remote, dynamic + LocalPort int `yaml:"local_port"` // Local port for binding + RemoteHost string `yaml:"remote_host"` // Remote host (for local/remote tunnels) + RemotePort int `yaml:"remote_port"` // Remote port (for local/remote tunnels) + SSHHost string `yaml:"ssh_host"` // SSH host to tunnel through +} + +// TunnelState represents runtime state of an active tunnel +type TunnelState struct { + Name string `json:"name"` + Source string `json:"source"` // "config" or "adhoc" + Type string `json:"type"` // local, remote, dynamic + LocalPort int `json:"local_port"` + RemoteHost string `json:"remote_host"` + RemotePort int `json:"remote_port"` + SSHHost string `json:"ssh_host"` + SSHHostResolved string `json:"ssh_host_resolved"` // After smart alias resolution + PID int `json:"pid"` + Status string `json:"status"` + StartedAt time.Time `json:"started_at"` + LastSeen time.Time `json:"last_seen"` + CommandLine string `json:"command_line"` +} + +// Config represents the YAML configuration structure +type Config struct { + Logging LoggingConfig `yaml:"logging"` + SmartAliases map[string]SmartAlias `yaml:"smart_aliases"` + Tunnels map[string]TunnelDefinition `yaml:"tunnels"` +} + +const ( + realSSHPath = "/usr/bin/ssh" +) + +var ( + configDir string + tunnelsDir string + config *Config + + // Global flags + tunnelMode bool + + // Tunnel command flags + tunnelOpen bool + tunnelClose bool + tunnelList bool + tunnelLocal string + tunnelRemote string + tunnelDynamic int + tunnelVia string +) + +var rootCmd = &cobra.Command{ + Use: "ssh", + Short: "Smart SSH utility with tunnel management", + Long: "A transparent SSH wrapper that provides smart alias resolution and background tunnel management", + Run: handleSSH, + DisableFlagParsing: true, +} + +var tunnelCmd = &cobra.Command{ + Use: "tunnel [tunnel-name]", + Short: "Manage background SSH tunnels", + Long: "Create, list, and manage persistent SSH tunnels in the background", + Run: func(cmd *cobra.Command, args []string) { + handleTunnelManual(append([]string{"--tunnel"}, args...)) + }, + Args: cobra.MaximumNArgs(1), +} + +func init() { + // Initialize config directory + homeDir, err := os.UserHomeDir() + if err != nil { + fmt.Fprintf(os.Stderr, "Error: Failed to get home directory: %v\n", err) + os.Exit(1) + } + configDir = filepath.Join(homeDir, ".config", "ssh-util") + tunnelsDir = filepath.Join(configDir, "tunnels") + + // Ensure directories exist + os.MkdirAll(tunnelsDir, 0755) + + // Load configuration + var configErr error + config, configErr = loadConfig() + if configErr != nil { + // Use default config if loading fails + config = &Config{ + Logging: LoggingConfig{ + Enabled: true, + Level: "info", + Format: "console", + }, + SmartAliases: make(map[string]SmartAlias), + Tunnels: make(map[string]TunnelDefinition), + } + } + + // Initialize logging + initLogging(config.Logging) + + // Global flags + rootCmd.PersistentFlags().BoolVarP(&tunnelMode, "tunnel", "T", false, "Enable tunnel mode") + rootCmd.Flags().BoolVarP(&tunnelOpen, "open", "O", false, "Open a tunnel") + rootCmd.Flags().BoolVarP(&tunnelClose, "close", "C", false, "Close a tunnel") + rootCmd.Flags().BoolVarP(&tunnelList, "list", "L", false, "List active tunnels") + rootCmd.Flags().StringVar(&tunnelLocal, "local", "", "Local port forwarding (port:host:port)") + rootCmd.Flags().StringVar(&tunnelRemote, "remote", "", "Remote port forwarding (port:host:port)") + rootCmd.Flags().IntVar(&tunnelDynamic, "dynamic", 0, "Dynamic port forwarding (SOCKS proxy port)") + rootCmd.Flags().StringVar(&tunnelVia, "via", "", "SSH host to tunnel through") + + // Add tunnel command + rootCmd.AddCommand(tunnelCmd) + + // Tunnel command flags (same as root for consistency) + tunnelCmd.Flags().BoolVarP(&tunnelOpen, "open", "O", false, "Open a tunnel") + tunnelCmd.Flags().BoolVarP(&tunnelClose, "close", "C", false, "Close a tunnel") + tunnelCmd.Flags().BoolVarP(&tunnelList, "list", "L", false, "List active tunnels") + tunnelCmd.Flags().StringVar(&tunnelLocal, "local", "", "Local port forwarding (port:host:port)") + tunnelCmd.Flags().StringVar(&tunnelRemote, "remote", "", "Remote port forwarding (port:host:port)") + tunnelCmd.Flags().IntVar(&tunnelDynamic, "dynamic", 0, "Dynamic port forwarding (SOCKS proxy port)") + tunnelCmd.Flags().StringVar(&tunnelVia, "via", "", "SSH host to tunnel through") + + // Handle combined flags like -TO, -TC, -TL + rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { + return handleCombinedFlags(cmd, args) + } +} + +func main() { + // Check if this is a tunnel command first + args := os.Args[1:] + isTunnelCommand := false + + for _, arg := range args { + if arg == "--tunnel" || arg == "-T" || strings.HasPrefix(arg, "-T") { + isTunnelCommand = true + break + } + if arg == "tunnel" { + isTunnelCommand = true + break + } + } + + if isTunnelCommand { + // Use Cobra for tunnel commands + if err := rootCmd.Execute(); err != nil { + log.Error().Err(err).Msg("Command execution failed") + os.Exit(1) + } + } else { + // Bypass Cobra for regular SSH commands (smart alias resolution) + handleSSHDirect(args) + } +} + +func handleCombinedFlags(cmd *cobra.Command, args []string) error { + // Check for combined tunnel flags in os.Args + for _, arg := range os.Args { + if strings.HasPrefix(arg, "-T") && len(arg) > 2 { + // Handle combined flags like -TO, -TC, -TL + tunnelMode = true + suffix := arg[2:] + + if strings.Contains(suffix, "O") { + tunnelOpen = true + } + if strings.Contains(suffix, "C") { + tunnelClose = true + } + if strings.Contains(suffix, "L") { + tunnelList = true + } + break + } + } + return nil +} + +func handleSSH(cmd *cobra.Command, args []string) { + // This handles tunnel commands via Cobra + handleTunnelManual(os.Args[1:]) +} + +func handleSSHDirect(args []string) { + log.Debug().Strs("original_args", args).Msg("SSH utility started") + + // Pass through immediately if no args, starts with dash (flags), or contains @ + if len(args) == 0 || (len(args) > 0 && (strings.HasPrefix(args[0], "-") || strings.Contains(args[0], "@"))) { + log.Debug().Msg("Passing through to real SSH (no smart alias detected)") + executeRealSSH(args) + return + } + + // Check if first argument is a smart alias + aliasName := args[0] + modifiedArgs := make([]string, len(args)) + copy(modifiedArgs, args) + + if smartAlias, exists := config.SmartAliases[aliasName]; exists { + log.Info().Str("alias", aliasName).Msg("Smart alias detected") + + // Parse timeout + timeout := parseTimeout(smartAlias.Timeout) + log.Debug().Dur("timeout", timeout).Msg("Parsed timeout") + + // Get the port for the primary host from SSH config + port := getSSHConfigPort(smartAlias.Primary) + log.Debug().Str("host", smartAlias.Primary).Int("port", port).Msg("Extracted port from SSH config") + + // Test connectivity to determine which host to use + log.Info().Str("check_host", smartAlias.CheckHost).Int("port", port).Msg("Testing connectivity") + if pingHost(smartAlias.CheckHost, timeout, port) { + // Local network is reachable, use primary + log.Info().Str("chosen_host", smartAlias.Primary).Msg("Local network reachable, using primary host") + modifiedArgs[0] = smartAlias.Primary + } else { + // Local network not reachable, use fallback + log.Info().Str("chosen_host", smartAlias.Fallback).Msg("Local network not reachable, using fallback host") + modifiedArgs[0] = smartAlias.Fallback + } + } else { + log.Debug().Str("host", aliasName).Msg("Not a smart alias, passing through") + } + + // Execute the real SSH with potentially modified arguments + log.Debug().Strs("final_args", modifiedArgs).Msg("Executing real SSH") + executeRealSSH(modifiedArgs) +} + +func handleTunnelManual(args []string) { + log.Debug().Msg("Tunnel mode activated") + + // Always validate tunnel states first + if err := validateTunnelStates(); err != nil { + log.Error().Err(err).Msg("Failed to validate tunnel states") + } + + // Parse tunnel arguments manually + var tunnelName string + var action string + var localForward, remoteForward, via string + var dynamicPort int + + for i, arg := range args { + switch arg { + case "--tunnel", "-T": + continue + case "--open", "-O": + action = "open" + case "--close", "-C": + action = "close" + case "--list", "-L": + action = "list" + case "--local": + if i+1 < len(args) { + localForward = args[i+1] + } + case "--remote": + if i+1 < len(args) { + remoteForward = args[i+1] + } + case "--via": + if i+1 < len(args) { + via = args[i+1] + } + case "--dynamic": + if i+1 < len(args) { + fmt.Sscanf(args[i+1], "%d", &dynamicPort) + } + default: + if strings.HasPrefix(arg, "-T") && len(arg) > 2 { + suffix := arg[2:] + if strings.Contains(suffix, "O") { + action = "open" + } + if strings.Contains(suffix, "C") { + action = "close" + } + if strings.Contains(suffix, "L") { + action = "list" + } + } else if !strings.HasPrefix(arg, "-") && tunnelName == "" { + tunnelName = arg + } + } + } + + // Handle tunnel commands + if action == "list" { + listTunnels() + return + } + + if tunnelName == "" && action != "list" { + fmt.Fprintf(os.Stderr, "Error: tunnel name required\n") + fmt.Fprintf(os.Stderr, "Usage: ssh --tunnel --open [flags]\n") + os.Exit(1) + } + + if action == "open" { + // Set global variables for openTunnel function + tunnelLocal = localForward + tunnelRemote = remoteForward + tunnelDynamic = dynamicPort + tunnelVia = via + + if err := openTunnel(tunnelName); err != nil { + log.Error().Err(err).Str("tunnel", tunnelName).Msg("Failed to open tunnel") + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + } else if action == "close" { + if err := closeTunnel(tunnelName); err != nil { + log.Error().Err(err).Str("tunnel", tunnelName).Msg("Failed to close tunnel") + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + } else { + fmt.Fprintf(os.Stderr, "Error: must specify --open, --close, or --list\n") + os.Exit(1) + } +} + +func validateTunnelStates() error { + files, err := os.ReadDir(tunnelsDir) + if err != nil { + return fmt.Errorf("failed to read tunnels directory: %w", err) + } + + for _, file := range files { + if !strings.HasSuffix(file.Name(), ".json") { + continue + } + + pidStr := strings.TrimSuffix(file.Name(), ".json") + pid, err := strconv.Atoi(pidStr) + if err != nil { + log.Warn().Str("file", file.Name()).Msg("Invalid PID filename, removing") + os.Remove(filepath.Join(tunnelsDir, file.Name())) + continue + } + + if !isProcessAlive(pid) { + log.Info().Int("pid", pid).Msg("Removing state for dead tunnel") + os.Remove(filepath.Join(tunnelsDir, file.Name())) + continue + } + + // Update last seen time + stateFile := filepath.Join(tunnelsDir, file.Name()) + state, err := loadTunnelState(stateFile) + if err != nil { + log.Warn().Str("file", stateFile).Err(err).Msg("Failed to load tunnel state") + continue + } + + state.LastSeen = time.Now() + if err := saveTunnelState(stateFile, state); err != nil { + log.Warn().Str("file", stateFile).Err(err).Msg("Failed to update tunnel state") + } + } + + return nil +} + +func listTunnels() { + files, err := os.ReadDir(tunnelsDir) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: Failed to read tunnels directory: %v\n", err) + return + } + + if len(files) == 0 { + fmt.Println("No active tunnels") + return + } + + fmt.Printf("%-20s %-8s %-8s %-25s %-12s %-8s %s\n", + "NAME", "TYPE", "LOCAL", "REMOTE", "HOST", "PID", "UPTIME") + fmt.Println(strings.Repeat("-", 80)) + + for _, file := range files { + if !strings.HasSuffix(file.Name(), ".json") { + continue + } + + stateFile := filepath.Join(tunnelsDir, file.Name()) + state, err := loadTunnelState(stateFile) + if err != nil { + continue + } + + uptime := time.Since(state.StartedAt).Truncate(time.Second) + remote := "" + if state.Type == "local" || state.Type == "remote" { + remote = fmt.Sprintf("%s:%d", state.RemoteHost, state.RemotePort) + } else if state.Type == "dynamic" { + remote = "SOCKS" + } + + fmt.Printf("%-20s %-8s %-8d %-25s %-12s %-8d %s\n", + state.Name, state.Type, state.LocalPort, remote, + state.SSHHostResolved, state.PID, uptime) + } +} + +func openTunnel(name string) error { + // Check if tunnel is already running + if existingPID := findTunnelByName(name); existingPID != 0 { + return fmt.Errorf("tunnel '%s' already running (PID %d)", name, existingPID) + } + + var tunnel TunnelDefinition + var source string + + // Check if tunnel is defined in config + if configTunnel, exists := config.Tunnels[name]; exists { + tunnel = configTunnel + source = "config" + log.Info().Str("tunnel", name).Msg("Using tunnel definition from config") + } else { + // Create ad-hoc tunnel from flags + var err error + tunnel, err = createAdhocTunnel() + if err != nil { + return fmt.Errorf("tunnel '%s' not found in config and invalid adhoc parameters: %w", name, err) + } + source = "adhoc" + log.Info().Str("tunnel", name).Msg("Creating ad-hoc tunnel") + } + + // Check for port conflicts + if isPortInUse(tunnel.LocalPort) { + return fmt.Errorf("port %d already in use", tunnel.LocalPort) + } + + // Resolve SSH host using smart alias logic + resolvedSSHHost := resolveSSHHost(tunnel.SSHHost) + log.Debug().Str("original", tunnel.SSHHost).Str("resolved", resolvedSSHHost).Msg("SSH host resolution") + + // Build SSH command + cmdArgs := buildSSHCommand(tunnel, resolvedSSHHost) + log.Debug().Strs("command", cmdArgs).Msg("Starting SSH tunnel") + + // Start SSH process + cmd := &exec.Cmd{ + Path: realSSHPath, + Args: cmdArgs, + } + + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start SSH tunnel: %w", err) + } + + pid := cmd.Process.Pid + log.Info().Str("tunnel", name).Int("pid", pid).Msg("SSH tunnel started") + + // Create tunnel state + state := TunnelState{ + Name: name, + Source: source, + Type: tunnel.Type, + LocalPort: tunnel.LocalPort, + RemoteHost: tunnel.RemoteHost, + RemotePort: tunnel.RemotePort, + SSHHost: tunnel.SSHHost, + SSHHostResolved: resolvedSSHHost, + PID: pid, + Status: "active", + StartedAt: time.Now(), + LastSeen: time.Now(), + CommandLine: strings.Join(cmdArgs, " "), + } + + // Save state file + stateFile := filepath.Join(tunnelsDir, fmt.Sprintf("%d.json", pid)) + if err := saveTunnelState(stateFile, state); err != nil { + // If we can't save state, kill the process + cmd.Process.Kill() + return fmt.Errorf("failed to save tunnel state: %w", err) + } + + fmt.Printf("Tunnel '%s' opened on port %d (PID %d)\n", name, tunnel.LocalPort, pid) + return nil +} + +func closeTunnel(name string) error { + pid := findTunnelByName(name) + if pid == 0 { + return fmt.Errorf("tunnel '%s' not found", name) + } + + // Kill the process + process, err := os.FindProcess(pid) + if err != nil { + return fmt.Errorf("failed to find process %d: %w", pid, err) + } + + if err := process.Kill(); err != nil { + return fmt.Errorf("failed to kill process %d: %w", pid, err) + } + + // Remove state file + stateFile := filepath.Join(tunnelsDir, fmt.Sprintf("%d.json", pid)) + if err := os.Remove(stateFile); err != nil { + log.Warn().Str("file", stateFile).Err(err).Msg("Failed to remove state file") + } + + log.Info().Str("tunnel", name).Int("pid", pid).Msg("Tunnel closed") + fmt.Printf("Tunnel '%s' closed\n", name) + return nil +} + +func createAdhocTunnel() (TunnelDefinition, error) { + tunnel := TunnelDefinition{} + + if tunnelVia == "" { + return tunnel, fmt.Errorf("--via flag required for ad-hoc tunnels") + } + tunnel.SSHHost = tunnelVia + + if tunnelLocal != "" { + parts := strings.Split(tunnelLocal, ":") + if len(parts) != 3 { + return tunnel, fmt.Errorf("invalid --local format, expected port:host:port") + } + + localPort, err := strconv.Atoi(parts[0]) + if err != nil { + return tunnel, fmt.Errorf("invalid local port: %s", parts[0]) + } + + remotePort, err := strconv.Atoi(parts[2]) + if err != nil { + return tunnel, fmt.Errorf("invalid remote port: %s", parts[2]) + } + + tunnel.Type = "local" + tunnel.LocalPort = localPort + tunnel.RemoteHost = parts[1] + tunnel.RemotePort = remotePort + } else if tunnelRemote != "" { + parts := strings.Split(tunnelRemote, ":") + if len(parts) != 3 { + return tunnel, fmt.Errorf("invalid --remote format, expected port:host:port") + } + + localPort, err := strconv.Atoi(parts[0]) + if err != nil { + return tunnel, fmt.Errorf("invalid local port: %s", parts[0]) + } + + remotePort, err := strconv.Atoi(parts[2]) + if err != nil { + return tunnel, fmt.Errorf("invalid remote port: %s", parts[2]) + } + + tunnel.Type = "remote" + tunnel.LocalPort = localPort + tunnel.RemoteHost = parts[1] + tunnel.RemotePort = remotePort + } else if tunnelDynamic != 0 { + tunnel.Type = "dynamic" + tunnel.LocalPort = tunnelDynamic + } else { + return tunnel, fmt.Errorf("must specify --local, --remote, or --dynamic") + } + + return tunnel, nil +} + +func buildSSHCommand(tunnel TunnelDefinition, sshHost string) []string { + args := []string{"ssh", "-f", "-N"} + + switch tunnel.Type { + case "local": + args = append(args, "-L", fmt.Sprintf("%d:%s:%d", tunnel.LocalPort, tunnel.RemoteHost, tunnel.RemotePort)) + case "remote": + args = append(args, "-R", fmt.Sprintf("%d:%s:%d", tunnel.LocalPort, tunnel.RemoteHost, tunnel.RemotePort)) + case "dynamic": + args = append(args, "-D", strconv.Itoa(tunnel.LocalPort)) + } + + args = append(args, sshHost) + return args +} + +func findTunnelByName(name string) int { + files, err := os.ReadDir(tunnelsDir) + if err != nil { + return 0 + } + + for _, file := range files { + if !strings.HasSuffix(file.Name(), ".json") { + continue + } + + stateFile := filepath.Join(tunnelsDir, file.Name()) + state, err := loadTunnelState(stateFile) + if err != nil { + continue + } + + if state.Name == name { + pidStr := strings.TrimSuffix(file.Name(), ".json") + pid, _ := strconv.Atoi(pidStr) + return pid + } + } + + return 0 +} + +func isPortInUse(port int) bool { + conn, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + return true + } + defer conn.Close() + return false +} + +func isProcessAlive(pid int) bool { + process, err := os.FindProcess(pid) + if err != nil { + return false + } + + // Send signal 0 to check if process exists + err = process.Signal(syscall.Signal(0)) + return err == nil +} + +func resolveSSHHost(sshHost string) string { + // Apply smart alias logic if host is a smart alias + if smartAlias, exists := config.SmartAliases[sshHost]; exists { + timeout := parseTimeout(smartAlias.Timeout) + port := getSSHConfigPort(smartAlias.Primary) + + if pingHost(smartAlias.CheckHost, timeout, port) { + log.Debug().Str("host", sshHost).Str("resolved", smartAlias.Primary).Msg("Smart alias resolved to primary") + return smartAlias.Primary + } else { + log.Debug().Str("host", sshHost).Str("resolved", smartAlias.Fallback).Msg("Smart alias resolved to fallback") + return smartAlias.Fallback + } + } + + return sshHost +} + +func loadTunnelState(stateFile string) (TunnelState, error) { + var state TunnelState + data, err := os.ReadFile(stateFile) + if err != nil { + return state, fmt.Errorf("failed to read state file: %w", err) + } + + if err := json.Unmarshal(data, &state); err != nil { + return state, fmt.Errorf("failed to parse state file: %w", err) + } + + return state, nil +} + +func saveTunnelState(stateFile string, state TunnelState) error { + data, err := json.MarshalIndent(state, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal state: %w", err) + } + + if err := os.WriteFile(stateFile, data, 0644); err != nil { + return fmt.Errorf("failed to write state file: %w", err) + } + + return nil +} + +// initLogging configures zerolog based on the logging configuration +func initLogging(cfg LoggingConfig) { + if !cfg.Enabled { + zerolog.SetGlobalLevel(zerolog.Disabled) + return + } + + switch strings.ToLower(cfg.Level) { + case "debug": + zerolog.SetGlobalLevel(zerolog.DebugLevel) + case "info": + zerolog.SetGlobalLevel(zerolog.InfoLevel) + case "warn", "warning": + zerolog.SetGlobalLevel(zerolog.WarnLevel) + case "error": + zerolog.SetGlobalLevel(zerolog.ErrorLevel) + default: + zerolog.SetGlobalLevel(zerolog.InfoLevel) + } + + if strings.ToLower(cfg.Format) == "console" { + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + } else { + log.Logger = log.Output(os.Stderr) + } +} + +// loadConfig loads the YAML configuration file +func loadConfig() (*Config, error) { + configFile := filepath.Join(configDir, "config.yaml") + + if _, err := os.Stat(configFile); os.IsNotExist(err) { + return &Config{ + Logging: LoggingConfig{ + Enabled: true, + Level: "info", + Format: "console", + }, + SmartAliases: make(map[string]SmartAlias), + Tunnels: make(map[string]TunnelDefinition), + }, nil + } + + data, err := os.ReadFile(configFile) + if err != nil { + return nil, fmt.Errorf("failed to read config file: %w", err) + } + + var config Config + if err := yaml.Unmarshal(data, &config); err != nil { + return nil, fmt.Errorf("failed to parse config YAML: %w", err) + } + + if config.Logging.Level == "" { + config.Logging.Level = "info" + } + if config.Logging.Format == "" { + config.Logging.Format = "console" + } + + return &config, nil +} + +// parseTimeout converts timeout string to time.Duration, defaults to 2s +func parseTimeout(timeoutStr string) time.Duration { + if timeoutStr == "" { + return 2 * time.Second + } + + duration, err := time.ParseDuration(timeoutStr) + if err != nil { + return 2 * time.Second + } + + return duration +} + +// pingHost checks if a host is reachable via TCP connection test +func pingHost(host string, timeout time.Duration, port int) bool { + result := tcpConnectTest(host, timeout, port) + log.Debug().Str("host", host).Int("port", port).Bool("reachable", result).Msg("Connectivity test result") + return result +} + +// tcpConnectTest tests TCP connection on the specified port +func tcpConnectTest(host string, timeout time.Duration, port int) bool { + portStr := strconv.Itoa(port) + address := net.JoinHostPort(host, portStr) + + conn, err := net.DialTimeout("tcp", address, timeout) + if err != nil { + log.Debug().Str("address", address).Err(err).Msg("TCP connection failed") + return false + } + defer conn.Close() + log.Debug().Str("address", address).Msg("TCP connection successful") + return true +} + +// getSSHConfigPort parses the SSH config file to find the port for a given host +func getSSHConfigPort(hostname string) int { + homeDir, err := os.UserHomeDir() + if err != nil { + log.Debug().Err(err).Msg("Failed to get home directory") + return 22 + } + + mainConfigFile := filepath.Join(homeDir, ".ssh", "config") + configFiles := []string{mainConfigFile} + + // Check if main config exists and parse it for includes + if file, err := os.Open(mainConfigFile); err == nil { + defer file.Close() + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(strings.ToLower(line), "include ") { + parts := strings.Fields(line) + if len(parts) >= 2 { + includePath := parts[1] + if strings.HasPrefix(includePath, "~") { + includePath = strings.Replace(includePath, "~", homeDir, 1) + } else if !filepath.IsAbs(includePath) { + includePath = filepath.Join(homeDir, ".ssh", includePath) + } + + if matches, err := filepath.Glob(includePath); err == nil { + configFiles = append(configFiles, matches...) + } else { + configFiles = append(configFiles, includePath) + } + } + } + } + } else { + log.Debug().Str("config_file", mainConfigFile).Err(err).Msg("Main SSH config file not found") + } + + // Also check common config.d directory pattern + configDPattern := filepath.Join(homeDir, ".ssh", "config.d", "*") + if matches, err := filepath.Glob(configDPattern); err == nil { + configFiles = append(configFiles, matches...) + } + + log.Debug().Str("hostname", hostname).Strs("config_files", configFiles).Msg("Parsing SSH config files for port") + + // Parse all config files + for _, configFile := range configFiles { + if port := parseSSHConfigFile(configFile, hostname); port != 22 { + log.Debug().Str("hostname", hostname).Int("port", port).Str("config_file", configFile).Msg("Found port in SSH config") + return port + } + } + + log.Debug().Str("hostname", hostname).Int("port", 22).Msg("Using default port") + return 22 +} + +// parseSSHConfigFile parses a single SSH config file for a host's port +func parseSSHConfigFile(configFile, hostname string) int { + file, err := os.Open(configFile) + if err != nil { + log.Debug().Str("config_file", configFile).Err(err).Msg("Could not open SSH config file") + return 22 + } + defer file.Close() + + scanner := bufio.NewScanner(file) + var inTargetHost bool + var port int = 22 + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + // Check for Host directive + if strings.HasPrefix(strings.ToLower(line), "host ") { + parts := strings.Fields(line) + if len(parts) >= 2 { + hostPattern := parts[1] + // Simple match - for more complex patterns, we'd need glob matching + inTargetHost = (hostPattern == hostname) + } + continue + } + + // If we're in the target host section, look for Port directive + if inTargetHost && strings.HasPrefix(strings.ToLower(line), "port ") { + parts := strings.Fields(line) + if len(parts) >= 2 { + if parsedPort, err := strconv.Atoi(parts[1]); err == nil { + port = parsedPort + return port // Found the port, return immediately + } + } + } + + // If we hit another Host directive and we were in target host, we're done + if inTargetHost && strings.HasPrefix(strings.ToLower(line), "host ") { + break + } + } + + return 22 // default port +} + +// executeRealSSH executes the real SSH binary with given arguments +func executeRealSSH(args []string) { + // Check if real SSH exists + if _, err := os.Stat(realSSHPath); os.IsNotExist(err) { + log.Error().Str("path", realSSHPath).Msg("Real SSH binary not found") + fmt.Fprintf(os.Stderr, "Error: Real SSH binary not found at %s\n", realSSHPath) + os.Exit(1) + } + + log.Debug().Str("ssh_path", realSSHPath).Strs("args", args).Msg("Executing real SSH") + + // Execute the real SSH binary + // Using syscall.Exec to replace current process (like exec in shell) + err := syscall.Exec(realSSHPath, append([]string{"ssh"}, args...), os.Environ()) + if err != nil { + log.Error().Err(err).Msg("Failed to execute SSH") + fmt.Fprintf(os.Stderr, "Error executing SSH: %v\n", err) + os.Exit(1) + } +} diff --git a/config/home-manager/flake.lock b/config/home-manager/flake.lock index 65e3936..e714e15 100644 --- a/config/home-manager/flake.lock +++ b/config/home-manager/flake.lock @@ -25,11 +25,11 @@ ] }, "locked": { - "lastModified": 1753198507, - "narHash": "sha256-NCG6izg+B3zsCwcT6+ssiWT3Y202jhOqGL/zh6fofa4=", + "lastModified": 1753288231, + "narHash": "sha256-WcMW9yUDfER8kz4NdCaaI/ep0Ef91L+Nf7MetNzHZc4=", "owner": "nix-community", "repo": "home-manager", - "rev": "fce051eaf881220843401df545a1444ab676520c", + "rev": "7b5a978e00273b8676c530c03d315f5b75fae564", "type": "github" }, "original": { @@ -41,11 +41,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1753115646, - "narHash": "sha256-yLuz5cz5Z+sn8DRAfNkrd2Z1cV6DaYO9JMrEz4KZo/c=", + "lastModified": 1753345091, + "narHash": "sha256-CdX2Rtvp5I8HGu9swBmYuq+ILwRxpXdJwlpg8jvN4tU=", "owner": "nixos", "repo": "nixpkgs", - "rev": "92c2e04a475523e723c67ef872d8037379073681", + "rev": "3ff0e34b1383648053bba8ed03f201d3466f90c9", "type": "github" }, "original": { @@ -57,11 +57,11 @@ }, "nixpkgs-unstable": { "locked": { - "lastModified": 1752950548, - "narHash": "sha256-NS6BLD0lxOrnCiEOcvQCDVPXafX1/ek1dfJHX1nUIzc=", + "lastModified": 1753250450, + "narHash": "sha256-i+CQV2rPmP8wHxj0aq4siYyohHwVlsh40kV89f3nw1s=", "owner": "nixos", "repo": "nixpkgs", - "rev": "c87b95e25065c028d31a94f06a62927d18763fdf", + "rev": "fc02ee70efb805d3b2865908a13ddd4474557ecf", "type": "github" }, "original": { diff --git a/config/ssh/ssh-util-config.yaml b/config/ssh/ssh-util-config.yaml new file mode 100644 index 0000000..2a55dff --- /dev/null +++ b/config/ssh/ssh-util-config.yaml @@ -0,0 +1,17 @@ +logging: + enabled: true + level: "info" + format: "console" + +smart_aliases: + desktop: + primary: "desktop-local" + fallback: "desktop" + check_host: "192.168.86.254" + timeout: "2s" + + laptop: + primary: "laptop-local" + fallback: "laptop" + check_host: "192.168.86.22" + timeout: "2s"