initial commit

This commit is contained in:
2025-05-23 15:08:44 +02:00
commit e602d503e8
22 changed files with 2408 additions and 0 deletions

245
cmd/common.go Normal file
View File

@@ -0,0 +1,245 @@
package cmd
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"syscall"
)
const (
tunnelDir = ".sshtunnels"
)
// Tunnel represents an SSH tunnel configuration
type Tunnel struct {
ID int // Unique identifier for the tunnel
LocalPort int // Local port being forwarded
RemotePort int // Remote port being forwarded to
RemoteHost string // Remote host being forwarded to
SSHServer string // SSH server (user@host)
PID int // Process ID of the SSH process
}
// checkIfSSHProcess checks if a process ID belongs to an SSH process
func checkIfSSHProcess(pid int) bool {
// Try to read /proc/{pid}/comm if on Linux
if _, err := os.Stat("/proc"); err == nil {
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/comm", pid))
if err == nil && strings.TrimSpace(string(data)) == "ssh" {
return true
}
}
// Alternative approach - use ps command
cmd := exec.Command("ps", "-p", strconv.Itoa(pid), "-o", "comm=")
output, err := cmd.Output()
if err == nil && strings.Contains(string(output), "ssh") {
return true
}
// Last resort - just check if process exists
process, err := os.FindProcess(pid)
if err != nil {
return false
}
// Send signal 0 to check if process exists
return process.Signal(syscall.Signal(0)) == nil
}
// findSSHTunnelPID attempts to find the PID of an SSH tunnel process by its local port
func findSSHTunnelPID(port int) (int, error) {
// Try using lsof first (most reliable)
cmd := exec.Command("lsof", "-i", fmt.Sprintf("TCP:%d", port), "-t")
output, err := cmd.Output()
if err == nil && len(output) > 0 {
pidStr := strings.TrimSpace(string(output))
pid, err := strconv.Atoi(pidStr)
if err == nil {
// Verify this is an SSH process
if checkIfSSHProcess(pid) {
return pid, nil
}
}
}
// Try netstat as a fallback
cmd = exec.Command("netstat", "-tlnp")
output, err = cmd.Output()
if err == nil {
lines := strings.Split(string(output), "\n")
portStr := fmt.Sprintf(":%d", port)
for _, line := range lines {
if strings.Contains(line, portStr) && strings.Contains(line, "ssh") {
// Extract PID from the line
parts := strings.Fields(line)
if len(parts) >= 7 {
pidPart := parts[6]
pidStr := strings.Split(pidPart, "/")[0]
pid, err := strconv.Atoi(pidStr)
if err == nil {
return pid, nil
}
}
}
}
}
// Try ps as a last resort
cmd = exec.Command("ps", "aux")
output, err = cmd.Output()
if err == nil {
lines := strings.Split(string(output), "\n")
portStr := fmt.Sprintf(":%d", port)
for _, line := range lines {
if strings.Contains(line, "ssh") && strings.Contains(line, portStr) {
parts := strings.Fields(line)
if len(parts) >= 2 {
pid, err := strconv.Atoi(parts[1])
if err == nil {
return pid, nil
}
}
}
}
}
return 0, fmt.Errorf("could not find SSH tunnel PID for port %d", port)
}
func getTunnels() ([]Tunnel, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("error getting home directory: %v", err)
}
tunnelPath := filepath.Join(homeDir, tunnelDir)
entries, err := os.ReadDir(tunnelPath)
if err != nil {
if os.IsNotExist(err) {
return []Tunnel{}, nil
}
return nil, fmt.Errorf("error reading tunnel directory: %v", err)
}
var tunnels []Tunnel
for _, entry := range entries {
if !entry.IsDir() && strings.HasPrefix(entry.Name(), "tunnel-") {
filePath := filepath.Join(tunnelPath, entry.Name())
data, err := os.ReadFile(filePath)
if err != nil {
fmt.Printf("Error reading tunnel file %s: %v\n", filePath, err)
continue
}
var t Tunnel
parts := strings.Split(string(data), ":")
if len(parts) != 5 {
fmt.Printf("Invalid tunnel file format: %s\n", filePath)
continue
}
idStr := strings.TrimPrefix(entry.Name(), "tunnel-")
id, err := strconv.Atoi(idStr)
if err != nil {
fmt.Printf("Invalid tunnel ID: %s\n", idStr)
continue
}
t.ID = id
t.LocalPort, err = strconv.Atoi(parts[0])
if err != nil {
continue
}
t.RemoteHost = parts[1]
t.RemotePort, err = strconv.Atoi(parts[2])
if err != nil {
continue
}
t.SSHServer = parts[3]
t.PID, err = strconv.Atoi(parts[4])
if err != nil {
continue
}
// Verify if this is actually a SSH process
isSSH := checkIfSSHProcess(t.PID)
if !isSSH {
fmt.Printf("Process %d is not an SSH process anymore, cleaning up\n", t.PID)
removeFile(t.ID)
continue
}
tunnels = append(tunnels, t)
}
}
return tunnels, nil
}
func saveTunnel(t Tunnel) error {
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("error getting home directory: %v", err)
}
tunnelPath := filepath.Join(homeDir, tunnelDir, fmt.Sprintf("tunnel-%d", t.ID))
data := fmt.Sprintf("%d:%s:%d:%s:%d", t.LocalPort, t.RemoteHost, t.RemotePort, t.SSHServer, t.PID)
err = os.WriteFile(tunnelPath, []byte(data), 0644)
if err != nil {
return err
}
// Verify the file was written correctly
_, err = os.Stat(tunnelPath)
return err
}
func removeFile(id int) {
homeDir, err := os.UserHomeDir()
if err != nil {
fmt.Printf("Error getting home directory: %v\n", err)
return
}
tunnelPath := filepath.Join(homeDir, tunnelDir, fmt.Sprintf("tunnel-%d", id))
if err := os.Remove(tunnelPath); err != nil && !os.IsNotExist(err) {
fmt.Printf("Error removing tunnel file %s: %v\n", tunnelPath, err)
}
}
func generateTunnelID() int {
homeDir, err := os.UserHomeDir()
if err != nil {
return int(os.Getpid())
}
tunnelPath := filepath.Join(homeDir, tunnelDir)
entries, err := os.ReadDir(tunnelPath)
if err != nil {
return int(os.Getpid())
}
id := 1
for _, entry := range entries {
if !entry.IsDir() && strings.HasPrefix(entry.Name(), "tunnel-") {
idStr := strings.TrimPrefix(entry.Name(), "tunnel-")
if val, err := strconv.Atoi(idStr); err == nil && val >= id {
id = val + 1
}
}
}
return id
}

207
cmd/debug.go Normal file
View File

@@ -0,0 +1,207 @@
package cmd
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"syscall"
"github.com/spf13/cobra"
)
var debugCmd = &cobra.Command{
Use: "debug",
Short: "Run diagnostics on SSH tunnels",
Long: `Run diagnostic checks on your SSH tunnel setup, including:
- SSH client availability
- Tunnel directory integrity
- Recorded tunnels status
- Active SSH processes verification`,
Run: func(cmd *cobra.Command, args []string) {
runDebugCommand()
},
}
func init() {
rootCmd.AddCommand(debugCmd)
}
// runDebugCommand handles the debug subcommand logic
func runDebugCommand() {
fmt.Println("SSH Tunnel Manager Diagnostics")
fmt.Println("==============================")
// Check SSH client availability
fmt.Println("\n1. Checking SSH client:")
checkSSHClient()
// Check tunnel directory
fmt.Println("\n2. Checking tunnel directory:")
checkTunnelDirectory()
// Check recorded tunnels
fmt.Println("\n3. Checking recorded tunnels:")
checkRecordedTunnels()
// Check active SSH processes
fmt.Println("\n4. Checking active SSH processes:")
checkActiveSSHProcesses()
}
func checkSSHClient() {
path, err := exec.LookPath("ssh")
if err != nil {
fmt.Printf(" ❌ SSH client not found in PATH: %v\n", err)
return
}
fmt.Printf(" ✅ SSH client found at: %s\n", path)
// Get SSH version
cmd := exec.Command("ssh", "-V")
output, err := cmd.CombinedOutput()
if err != nil {
fmt.Printf(" ⚠️ Could not determine SSH version: %v\n", err)
return
}
fmt.Printf(" ✅ SSH version info: %s\n", strings.TrimSpace(string(output)))
}
func checkTunnelDirectory() {
homeDir, err := os.UserHomeDir()
if err != nil {
fmt.Printf(" ❌ Could not determine home directory: %v\n", err)
return
}
tunnelPath := filepath.Join(homeDir, tunnelDir)
info, err := os.Stat(tunnelPath)
if os.IsNotExist(err) {
fmt.Printf(" ⚠️ Tunnel directory does not exist: %s\n", tunnelPath)
return
} else if err != nil {
fmt.Printf(" ❌ Error accessing tunnel directory: %v\n", err)
return
}
fmt.Printf(" ✅ Tunnel directory exists: %s\n", tunnelPath)
fmt.Printf(" ✅ Permissions: %s\n", info.Mode().String())
entries, err := os.ReadDir(tunnelPath)
if err != nil {
fmt.Printf(" ❌ Could not read tunnel directory: %v\n", err)
return
}
fmt.Printf(" ✅ Directory contains %d entries\n", len(entries))
}
func checkRecordedTunnels() {
tunnels, err := getTunnels()
if err != nil {
fmt.Printf(" ❌ Error reading tunnels: %v\n", err)
return
}
if len(tunnels) == 0 {
fmt.Printf(" No recorded tunnels found\n")
return
}
fmt.Printf(" ✅ Found %d recorded tunnels\n", len(tunnels))
for i, t := range tunnels {
fmt.Printf("\n Tunnel #%d (ID: %d):\n", i+1, t.ID)
fmt.Printf(" Local port: %d\n", t.LocalPort)
fmt.Printf(" Remote: %s:%d\n", t.RemoteHost, t.RemotePort)
fmt.Printf(" Server: %s\n", t.SSHServer)
fmt.Printf(" PID: %d\n", t.PID)
// Check if process exists
process, err := os.FindProcess(t.PID)
if err != nil {
fmt.Printf(" ❌ Process not found: %v\n", err)
continue
}
// Try to send signal 0 to check if process exists
err = process.Signal(syscall.Signal(0))
if err != nil {
fmt.Printf(" ❌ Process not running: %v\n", err)
} else {
fmt.Printf(" ✅ Process is running\n")
// Check if it's actually an SSH process
isSSH := checkIfSSHProcess(t.PID)
if isSSH {
fmt.Printf(" ✅ Process is an SSH process\n")
} else {
fmt.Printf(" ⚠️ Process is not an SSH process!\n")
}
}
}
}
func checkActiveSSHProcesses() {
// Try using ps to find SSH processes
cmd := exec.Command("ps", "-eo", "pid,command")
output, err := cmd.Output()
if err != nil {
fmt.Printf(" ❌ Could not list processes: %v\n", err)
return
}
lines := strings.Split(string(output), "\n")
sshProcesses := []string{}
for _, line := range lines {
if strings.Contains(line, "ssh") && strings.Contains(line, "-L") {
sshProcesses = append(sshProcesses, strings.TrimSpace(line))
}
}
if len(sshProcesses) == 0 {
fmt.Printf(" No SSH tunnel processes found\n")
return
}
fmt.Printf(" ✅ Found %d SSH tunnel processes:\n", len(sshProcesses))
for _, proc := range sshProcesses {
fmt.Printf(" %s\n", proc)
// Extract PID
fields := strings.Fields(proc)
if len(fields) > 0 {
pid, err := strconv.Atoi(fields[0])
if err == nil {
// Check if this process is in our records
found := false
tunnels, _ := getTunnels()
for _, t := range tunnels {
if t.PID == pid {
fmt.Printf(" ✅ This process is tracked as tunnel ID %d\n", t.ID)
found = true
break
}
}
if !found {
fmt.Printf(" ⚠️ This process is not tracked by the tunnel manager\n")
}
}
}
}
}
func verifyTunnelConnectivity(t Tunnel) error {
// Try to connect to the local port to verify the tunnel is working
cmd := exec.Command("nc", "-z", "-w", "1", "localhost", strconv.Itoa(t.LocalPort))
err := cmd.Run()
if err != nil {
return fmt.Errorf("could not connect to local port %d: %v", t.LocalPort, err)
}
return nil
}

56
cmd/list.go Normal file
View File

@@ -0,0 +1,56 @@
package cmd
import (
"fmt"
"os"
"strings"
"syscall"
"github.com/spf13/cobra"
)
var listCmd = &cobra.Command{
Use: "list",
Short: "List all active SSH tunnels",
Long: `Display all currently active SSH tunnels managed by this tool.`,
Run: func(cmd *cobra.Command, args []string) {
listTunnels()
},
}
func init() {
rootCmd.AddCommand(listCmd)
}
func listTunnels() {
tunnels, err := getTunnels()
if err != nil {
fmt.Printf("Error listing tunnels: %v\n", err)
os.Exit(1)
}
if len(tunnels) == 0 {
fmt.Println("No active SSH tunnels found.")
return
}
fmt.Println("Active SSH tunnels:")
fmt.Printf("%-5s %-15s %-20s %-10s\n", "ID", "LOCAL", "REMOTE", "PID")
fmt.Println(strings.Repeat("-", 60))
for _, t := range tunnels {
// Check if the process is still running
process, err := os.FindProcess(t.PID)
if err != nil || process.Signal(syscall.Signal(0)) != nil {
// Process does not exist anymore, clean up the tunnel file
removeFile(t.ID)
continue
}
fmt.Printf("%-5d %-15s %-20s %-10d\n",
t.ID,
fmt.Sprintf("localhost:%d", t.LocalPort),
fmt.Sprintf("%s:%d", t.RemoteHost, t.RemotePort),
t.PID)
}
}

32
cmd/root.go Normal file
View File

@@ -0,0 +1,32 @@
package cmd
import (
"fmt"
"os"
"github.com/spf13/cobra"
)
var rootCmd = &cobra.Command{
Use: "sshtunnel",
Short: "SSH tunnel manager",
Long: `SSH Tunnel Manager is a CLI tool for creating and managing SSH tunnels.
It allows you to easily create, list, and terminate SSH port forwarding tunnels
in the background without having to remember complex SSH commands.`,
Run: func(cmd *cobra.Command, args []string) {
// If no subcommand is provided, print help
cmd.Help()
},
}
// Execute executes the root command
func Execute() {
if err := rootCmd.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func init() {
// Global flags can be defined here
}

141
cmd/start.go Normal file
View File

@@ -0,0 +1,141 @@
package cmd
import (
"fmt"
"net"
"os"
"os/exec"
"strings"
"time"
"github.com/spf13/cobra"
"sshtunnel/pkg/stats"
)
var (
localPort int
remotePort int
remoteHost string
sshServer string
identity string
)
// startCmd represents the start command
var startCmd = &cobra.Command{
Use: "start",
Short: "Start a new SSH tunnel",
Long: `Start a new SSH tunnel with specified local port, remote port, host and SSH server.
The tunnel will run in the background and can be managed using the list and stop commands.`,
Run: func(cmd *cobra.Command, args []string) {
// Check required flags
// Generate the SSH command with appropriate flags for reliable background operation
sshArgs := []string{
"-N", // Don't execute remote command
"-f", // Run in background
"-L", fmt.Sprintf("%d:%s:%d", localPort, remoteHost, remotePort),
}
if identity != "" {
sshArgs = append(sshArgs, "-i", identity)
}
sshArgs = append(sshArgs, sshServer)
sshCmd := exec.Command("ssh", sshArgs...)
// Capture output for debugging
var outputBuffer strings.Builder
sshCmd.Stdout = &outputBuffer
sshCmd.Stderr = &outputBuffer
// Run the command (not just Start) - the -f flag means it will return immediately
// after going to the background
if err := sshCmd.Run(); err != nil {
fmt.Printf("Error starting SSH tunnel: %v\n", err)
fmt.Printf("SSH output: %s\n", outputBuffer.String())
os.Exit(1)
}
// The PID from cmd.Process is no longer valid since ssh -f forks
// We need to find the actual SSH process PID
actualPID, err := findSSHTunnelPID(localPort)
if err != nil {
fmt.Printf("Warning: Could not determine tunnel PID: %v\n", err)
}
// Store tunnel information
id := generateTunnelID()
pid := 0
if actualPID > 0 {
pid = actualPID
}
tunnel := Tunnel{
ID: id,
LocalPort: localPort,
RemotePort: remotePort,
RemoteHost: remoteHost,
SSHServer: sshServer,
PID: pid,
}
if err := saveTunnel(tunnel); err != nil {
fmt.Printf("Error saving tunnel information: %v\n", err)
if pid > 0 {
// Try to kill the process
process, _ := os.FindProcess(pid)
if process != nil {
process.Kill()
}
}
os.Exit(1)
}
// Initialize statistics for this tunnel
statsManager, err := stats.NewStatsManager()
if err == nil {
err = statsManager.InitStats(id, localPort)
if err != nil {
fmt.Printf("Warning: Failed to initialize statistics: %v\n", err)
}
}
// Verify tunnel is actually working
time.Sleep(500 * time.Millisecond)
active := verifyTunnelActive(localPort)
status := "ACTIVE"
if !active {
status = "UNKNOWN"
}
fmt.Printf("Started SSH tunnel (ID: %d): localhost:%d -> %s:%d (%s) [PID: %d] [Status: %s]\n",
id, localPort, remoteHost, remotePort, sshServer, pid, status)
},
}
func init() {
rootCmd.AddCommand(startCmd)
// Add flags for the start command
startCmd.Flags().IntVarP(&localPort, "local", "l", 0, "Local port to forward")
startCmd.Flags().IntVarP(&remotePort, "remote", "r", 0, "Remote port to forward to")
startCmd.Flags().StringVarP(&remoteHost, "host", "H", "localhost", "Remote host to forward to")
startCmd.Flags().StringVarP(&sshServer, "server", "s", "", "SSH server address (user@host)")
startCmd.Flags().StringVarP(&identity, "identity", "i", "", "Path to SSH identity file")
// Mark required flags
startCmd.MarkFlagRequired("local")
startCmd.MarkFlagRequired("remote")
startCmd.MarkFlagRequired("server")
}
// verifyTunnelActive checks if the tunnel is actually working
func verifyTunnelActive(port int) bool {
// Try to connect to the port to verify it's open
conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 500*time.Millisecond)
if err != nil {
return false
}
conn.Close()
return true
}

283
cmd/stats.go Normal file
View File

@@ -0,0 +1,283 @@
package cmd
import (
"fmt"
"os"
"time"
"sshtunnel/pkg/monitor"
"sshtunnel/pkg/stats"
"github.com/spf13/cobra"
)
var (
tunnelIDFlag int
watchFlag bool
watchInterval int
allFlag bool
)
// statsCmd represents the stats command
var statsCmd = &cobra.Command{
Use: "stats",
Short: "View SSH tunnel statistics",
Long: `View traffic statistics for active SSH tunnels.
This command shows you how much data has been transferred through your
SSH tunnels, in both directions (incoming and outgoing).
You can view statistics for a specific tunnel by specifying its ID,
or view statistics for all active tunnels.
Examples:
# View statistics for all tunnels
sshtunnel stats --all
# View statistics for a specific tunnel
sshtunnel stats --id 1
# Continuously monitor a specific tunnel (every 5 seconds)
sshtunnel stats --id 1 --watch
# Continuously monitor with custom interval (2 seconds)
sshtunnel stats --id 1 --watch --interval 2`,
Run: func(cmd *cobra.Command, args []string) {
if !allFlag && tunnelIDFlag <= 0 {
fmt.Println("Error: You must specify either a tunnel ID with --id or use --all")
cmd.Help()
os.Exit(1)
}
if allFlag && tunnelIDFlag > 0 {
fmt.Println("Error: You cannot specify both --id and --all flags")
os.Exit(1)
}
if watchFlag && !allFlag && tunnelIDFlag <= 0 {
fmt.Println("Error: Watch mode requires specifying a tunnel ID with --id")
os.Exit(1)
}
if watchFlag && allFlag {
fmt.Println("Error: Watch mode can only be used with a specific tunnel ID")
os.Exit(1)
}
statsManager, err := stats.NewStatsManager()
if err != nil {
fmt.Printf("Error initializing statistics: %v\n", err)
os.Exit(1)
}
if allFlag {
showAllStats(statsManager)
} else {
if watchFlag {
watchTunnelStats(tunnelIDFlag, watchInterval)
} else {
showTunnelStats(statsManager, tunnelIDFlag)
}
}
},
}
func init() {
rootCmd.AddCommand(statsCmd)
statsCmd.Flags().IntVar(&tunnelIDFlag, "id", 0, "ID of the tunnel to show statistics for")
statsCmd.Flags().BoolVar(&allFlag, "all", false, "Show statistics for all tunnels")
statsCmd.Flags().BoolVar(&watchFlag, "watch", false, "Continuously monitor tunnel statistics")
statsCmd.Flags().IntVar(&watchInterval, "interval", 5, "Update interval for watch mode (in seconds)")
}
func showAllStats(statsManager *stats.StatsManager) {
tunnelStats, err := statsManager.GetAllStats()
if err != nil {
fmt.Printf("Error retrieving statistics: %v\n", err)
os.Exit(1)
}
if len(tunnelStats) == 0 {
fmt.Println("No statistics available for any tunnels.")
return
}
fmt.Println("SSH Tunnel Statistics:")
fmt.Println("=============================================")
for _, s := range tunnelStats {
// Verify if this tunnel is still active
tunnels, _ := getTunnels()
active := false
for _, t := range tunnels {
if t.ID == s.TunnelID {
active = true
break
}
}
uptime := time.Since(s.StartTime)
inRate := stats.CalculateRate(s.BytesIn, uptime)
outRate := stats.CalculateRate(s.BytesOut, uptime)
status := "ACTIVE"
if !active {
status = "INACTIVE"
}
fmt.Printf("Tunnel #%d (Port: %d) [%s]\n", s.TunnelID, s.LocalPort, status)
fmt.Printf(" Uptime: %s\n", formatDuration(uptime))
fmt.Printf(" Received: %s (avg: %s)\n", stats.FormatBytes(s.BytesIn), stats.FormatRate(inRate))
fmt.Printf(" Sent: %s (avg: %s)\n", stats.FormatBytes(s.BytesOut), stats.FormatRate(outRate))
fmt.Printf(" Total: %s\n", stats.FormatBytes(s.BytesIn+s.BytesOut))
fmt.Println("---------------------------------------------")
}
}
func showTunnelStats(statsManager *stats.StatsManager, tunnelID int) {
s, err := statsManager.GetStats(tunnelID)
if err != nil {
fmt.Printf("Warning: Could not retrieve statistics for tunnel #%d: %v\n", tunnelID, err)
s = stats.Stats{
TunnelID: tunnelID,
BytesIn: 0,
BytesOut: 0,
StartTime: time.Now(),
LastUpdated: time.Now(),
}
}
// Get tunnel info to validate it exists
tunnel, err := getTunnel(tunnelID)
if err != nil {
fmt.Printf("Note: Tunnel #%d does not appear to be active, showing historical data\n", tunnelID)
}
// Display stats
uptime := time.Since(s.StartTime)
inRate := stats.CalculateRate(s.BytesIn, uptime)
outRate := stats.CalculateRate(s.BytesOut, uptime)
fmt.Printf("Statistics for Tunnel #%d:\n", tunnelID)
fmt.Println("=============================================")
if tunnel != nil {
fmt.Printf("Local Port: %d\n", tunnel.LocalPort)
fmt.Printf("Remote Target: %s:%d\n", tunnel.RemoteHost, tunnel.RemotePort)
fmt.Printf("SSH Server: %s\n", tunnel.SSHServer)
fmt.Printf("PID: %d\n", tunnel.PID)
} else if s.LocalPort > 0 {
// If we don't have tunnel info but do have port in the stats
fmt.Printf("Local Port: %d\n", s.LocalPort)
fmt.Printf("Status: INACTIVE (historical data)\n")
}
fmt.Printf("Uptime: %s\n", formatDuration(uptime))
fmt.Printf("Received: %s (avg: %s)\n", stats.FormatBytes(s.BytesIn), stats.FormatRate(inRate))
fmt.Printf("Sent: %s (avg: %s)\n", stats.FormatBytes(s.BytesOut), stats.FormatRate(outRate))
fmt.Printf("Total Traffic: %s\n", stats.FormatBytes(s.BytesIn+s.BytesOut))
}
func watchTunnelStats(tunnelID int, interval int) {
// Get tunnel info to validate it exists
tunnel, err := getTunnel(tunnelID)
if err != nil {
fmt.Printf("Warning: Tunnel #%d may not be active: %v\n", tunnelID, err)
fmt.Println("Attempting to monitor anyway...")
// Try to get port from stats
statsManager, _ := stats.NewStatsManager()
s, err := statsManager.GetStats(tunnelID)
if err == nil && s.LocalPort > 0 {
// Create a tunnel object with the information we have
tunnel = &Tunnel{
ID: tunnelID,
LocalPort: s.LocalPort,
RemoteHost: "unknown",
RemotePort: 0,
SSHServer: "unknown",
PID: 0,
}
} else {
fmt.Println("Error: Cannot determine port for tunnel. Please specify a valid tunnel ID.")
os.Exit(1)
}
}
// Create a monitor for this tunnel
mon, err := monitor.NewMonitor(tunnelID, tunnel.LocalPort)
if err != nil {
fmt.Printf("Warning: Error creating monitor: %v\n", err)
fmt.Println("Will attempt to continue with limited functionality")
}
// Start monitoring with error handling
err = mon.Start()
if err != nil {
fmt.Printf("Warning: %v\n", err)
fmt.Println("Continuing with estimated statistics...")
}
defer mon.Stop()
fmt.Printf("Monitoring SSH tunnel #%d (Press Ctrl+C to stop)...\n\n", tunnelID)
ticker := time.NewTicker(time.Duration(interval) * time.Second)
defer ticker.Stop()
// Initial stats display
statsStr, err := mon.FormatStats()
if err != nil {
fmt.Printf("Warning: %v\n", err)
fmt.Println("Displaying minimal statistics...")
statsStr = fmt.Sprintf("Monitoring tunnel #%d on port %d...",
tunnelID, tunnel.LocalPort)
}
fmt.Println(statsStr)
fmt.Println("\nUpdating every", interval, "seconds...")
// Clear screen and update stats periodically
for range ticker.C {
// ANSI escape code to clear screen
fmt.Print("\033[H\033[2J")
fmt.Printf("Monitoring SSH tunnel #%d (Press Ctrl+C to stop)...\n\n", tunnelID)
statsStr, err := mon.FormatStats()
if err != nil {
statsStr = fmt.Sprintf("Warning: %v\n\nStill monitoring tunnel #%d on port %d...",
err, tunnelID, tunnel.LocalPort)
}
fmt.Println(statsStr)
fmt.Println("\nUpdating every", interval, "seconds...")
}
}
func formatDuration(d time.Duration) string {
days := int(d.Hours() / 24)
hours := int(d.Hours()) % 24
minutes := int(d.Minutes()) % 60
seconds := int(d.Seconds()) % 60
if days > 0 {
return fmt.Sprintf("%dd %dh %dm %ds", days, hours, minutes, seconds)
} else if hours > 0 {
return fmt.Sprintf("%dh %dm %ds", hours, minutes, seconds)
} else if minutes > 0 {
return fmt.Sprintf("%dm %ds", minutes, seconds)
}
return fmt.Sprintf("%ds", seconds)
}
// getTunnel gets a specific tunnel by ID
func getTunnel(id int) (*Tunnel, error) {
tunnels, err := getTunnels()
if err != nil {
return nil, err
}
for _, t := range tunnels {
if t.ID == id {
return &t, nil
}
}
return nil, fmt.Errorf("tunnel %d not found", id)
}

105
cmd/stop.go Normal file
View File

@@ -0,0 +1,105 @@
package cmd
import (
"fmt"
"os"
"syscall"
"github.com/spf13/cobra"
"sshtunnel/pkg/stats"
)
var (
tunnelID int
all bool
)
var stopCmd = &cobra.Command{
Use: "stop",
Short: "Stop an SSH tunnel",
Long: `Stop an active SSH tunnel by its ID or stop all tunnels.
Use the list command to see all active tunnels and their IDs.`,
Run: func(cmd *cobra.Command, args []string) {
if !all && tunnelID == -1 {
fmt.Println("Error: must specify either --id or --all")
cmd.Help()
os.Exit(1)
}
tunnels, err := getTunnels()
if err != nil {
fmt.Printf("Error getting tunnels: %v\n", err)
os.Exit(1)
}
if len(tunnels) == 0 {
fmt.Println("No active SSH tunnels found.")
return
}
if all {
for _, t := range tunnels {
killTunnel(t)
}
fmt.Println("All SSH tunnels stopped.")
return
}
// Stop specific tunnel
found := false
for _, t := range tunnels {
if t.ID == tunnelID {
killTunnel(t)
found = true
break
}
}
if !found {
fmt.Printf("Error: no tunnel with ID %d found\n", tunnelID)
os.Exit(1)
}
},
}
func init() {
rootCmd.AddCommand(stopCmd)
// Add flags for the stop command
stopCmd.Flags().IntVarP(&tunnelID, "id", "i", -1, "ID of the tunnel to stop")
stopCmd.Flags().BoolVarP(&all, "all", "a", false, "Stop all tunnels")
}
func killTunnel(t Tunnel) {
process, err := os.FindProcess(t.PID)
if err != nil {
fmt.Printf("Error finding process %d: %v\n", t.PID, err)
removeFile(t.ID)
cleanupStats(t.ID)
return
}
if err := process.Signal(syscall.SIGTERM); err != nil {
fmt.Printf("Error stopping tunnel %d (PID %d): %v\n", t.ID, t.PID, err)
// Process might not exist anymore, so clean up anyway
}
removeFile(t.ID)
cleanupStats(t.ID)
fmt.Printf("Stopped SSH tunnel (ID: %d): localhost:%d -> %s:%d\n",
t.ID, t.LocalPort, t.RemoteHost, t.RemotePort)
}
// cleanupStats removes the statistics for a tunnel
func cleanupStats(tunnelID int) {
statsManager, err := stats.NewStatsManager()
if err != nil {
fmt.Printf("Warning: Could not clean up statistics: %v\n", err)
return
}
err = statsManager.DeleteStats(tunnelID)
if err != nil {
fmt.Printf("Warning: Failed to delete statistics for tunnel %d: %v\n", tunnelID, err)
}
}