initial commit
This commit is contained in:
245
cmd/common.go
Normal file
245
cmd/common.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
const (
|
||||
tunnelDir = ".sshtunnels"
|
||||
)
|
||||
|
||||
// Tunnel represents an SSH tunnel configuration
|
||||
type Tunnel struct {
|
||||
ID int // Unique identifier for the tunnel
|
||||
LocalPort int // Local port being forwarded
|
||||
RemotePort int // Remote port being forwarded to
|
||||
RemoteHost string // Remote host being forwarded to
|
||||
SSHServer string // SSH server (user@host)
|
||||
PID int // Process ID of the SSH process
|
||||
}
|
||||
|
||||
// checkIfSSHProcess checks if a process ID belongs to an SSH process
|
||||
func checkIfSSHProcess(pid int) bool {
|
||||
// Try to read /proc/{pid}/comm if on Linux
|
||||
if _, err := os.Stat("/proc"); err == nil {
|
||||
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/comm", pid))
|
||||
if err == nil && strings.TrimSpace(string(data)) == "ssh" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Alternative approach - use ps command
|
||||
cmd := exec.Command("ps", "-p", strconv.Itoa(pid), "-o", "comm=")
|
||||
output, err := cmd.Output()
|
||||
if err == nil && strings.Contains(string(output), "ssh") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Last resort - just check if process exists
|
||||
process, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Send signal 0 to check if process exists
|
||||
return process.Signal(syscall.Signal(0)) == nil
|
||||
}
|
||||
|
||||
// findSSHTunnelPID attempts to find the PID of an SSH tunnel process by its local port
|
||||
func findSSHTunnelPID(port int) (int, error) {
|
||||
// Try using lsof first (most reliable)
|
||||
cmd := exec.Command("lsof", "-i", fmt.Sprintf("TCP:%d", port), "-t")
|
||||
output, err := cmd.Output()
|
||||
if err == nil && len(output) > 0 {
|
||||
pidStr := strings.TrimSpace(string(output))
|
||||
pid, err := strconv.Atoi(pidStr)
|
||||
if err == nil {
|
||||
// Verify this is an SSH process
|
||||
if checkIfSSHProcess(pid) {
|
||||
return pid, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try netstat as a fallback
|
||||
cmd = exec.Command("netstat", "-tlnp")
|
||||
output, err = cmd.Output()
|
||||
if err == nil {
|
||||
lines := strings.Split(string(output), "\n")
|
||||
portStr := fmt.Sprintf(":%d", port)
|
||||
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, portStr) && strings.Contains(line, "ssh") {
|
||||
// Extract PID from the line
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) >= 7 {
|
||||
pidPart := parts[6]
|
||||
pidStr := strings.Split(pidPart, "/")[0]
|
||||
pid, err := strconv.Atoi(pidStr)
|
||||
if err == nil {
|
||||
return pid, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try ps as a last resort
|
||||
cmd = exec.Command("ps", "aux")
|
||||
output, err = cmd.Output()
|
||||
if err == nil {
|
||||
lines := strings.Split(string(output), "\n")
|
||||
portStr := fmt.Sprintf(":%d", port)
|
||||
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, "ssh") && strings.Contains(line, portStr) {
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) >= 2 {
|
||||
pid, err := strconv.Atoi(parts[1])
|
||||
if err == nil {
|
||||
return pid, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("could not find SSH tunnel PID for port %d", port)
|
||||
}
|
||||
|
||||
func getTunnels() ([]Tunnel, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting home directory: %v", err)
|
||||
}
|
||||
|
||||
tunnelPath := filepath.Join(homeDir, tunnelDir)
|
||||
entries, err := os.ReadDir(tunnelPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return []Tunnel{}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("error reading tunnel directory: %v", err)
|
||||
}
|
||||
|
||||
var tunnels []Tunnel
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() && strings.HasPrefix(entry.Name(), "tunnel-") {
|
||||
filePath := filepath.Join(tunnelPath, entry.Name())
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error reading tunnel file %s: %v\n", filePath, err)
|
||||
continue
|
||||
}
|
||||
|
||||
var t Tunnel
|
||||
parts := strings.Split(string(data), ":")
|
||||
if len(parts) != 5 {
|
||||
fmt.Printf("Invalid tunnel file format: %s\n", filePath)
|
||||
continue
|
||||
}
|
||||
|
||||
idStr := strings.TrimPrefix(entry.Name(), "tunnel-")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
fmt.Printf("Invalid tunnel ID: %s\n", idStr)
|
||||
continue
|
||||
}
|
||||
t.ID = id
|
||||
|
||||
t.LocalPort, err = strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
t.RemoteHost = parts[1]
|
||||
|
||||
t.RemotePort, err = strconv.Atoi(parts[2])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
t.SSHServer = parts[3]
|
||||
|
||||
t.PID, err = strconv.Atoi(parts[4])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify if this is actually a SSH process
|
||||
isSSH := checkIfSSHProcess(t.PID)
|
||||
if !isSSH {
|
||||
fmt.Printf("Process %d is not an SSH process anymore, cleaning up\n", t.PID)
|
||||
removeFile(t.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
tunnels = append(tunnels, t)
|
||||
}
|
||||
}
|
||||
|
||||
return tunnels, nil
|
||||
}
|
||||
|
||||
func saveTunnel(t Tunnel) error {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting home directory: %v", err)
|
||||
}
|
||||
|
||||
tunnelPath := filepath.Join(homeDir, tunnelDir, fmt.Sprintf("tunnel-%d", t.ID))
|
||||
data := fmt.Sprintf("%d:%s:%d:%s:%d", t.LocalPort, t.RemoteHost, t.RemotePort, t.SSHServer, t.PID)
|
||||
|
||||
err = os.WriteFile(tunnelPath, []byte(data), 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify the file was written correctly
|
||||
_, err = os.Stat(tunnelPath)
|
||||
return err
|
||||
}
|
||||
|
||||
func removeFile(id int) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
fmt.Printf("Error getting home directory: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
tunnelPath := filepath.Join(homeDir, tunnelDir, fmt.Sprintf("tunnel-%d", id))
|
||||
if err := os.Remove(tunnelPath); err != nil && !os.IsNotExist(err) {
|
||||
fmt.Printf("Error removing tunnel file %s: %v\n", tunnelPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
func generateTunnelID() int {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return int(os.Getpid())
|
||||
}
|
||||
|
||||
tunnelPath := filepath.Join(homeDir, tunnelDir)
|
||||
entries, err := os.ReadDir(tunnelPath)
|
||||
if err != nil {
|
||||
return int(os.Getpid())
|
||||
}
|
||||
|
||||
id := 1
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() && strings.HasPrefix(entry.Name(), "tunnel-") {
|
||||
idStr := strings.TrimPrefix(entry.Name(), "tunnel-")
|
||||
if val, err := strconv.Atoi(idStr); err == nil && val >= id {
|
||||
id = val + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return id
|
||||
}
|
207
cmd/debug.go
Normal file
207
cmd/debug.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var debugCmd = &cobra.Command{
|
||||
Use: "debug",
|
||||
Short: "Run diagnostics on SSH tunnels",
|
||||
Long: `Run diagnostic checks on your SSH tunnel setup, including:
|
||||
- SSH client availability
|
||||
- Tunnel directory integrity
|
||||
- Recorded tunnels status
|
||||
- Active SSH processes verification`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runDebugCommand()
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
}
|
||||
|
||||
// runDebugCommand handles the debug subcommand logic
|
||||
func runDebugCommand() {
|
||||
fmt.Println("SSH Tunnel Manager Diagnostics")
|
||||
fmt.Println("==============================")
|
||||
|
||||
// Check SSH client availability
|
||||
fmt.Println("\n1. Checking SSH client:")
|
||||
checkSSHClient()
|
||||
|
||||
// Check tunnel directory
|
||||
fmt.Println("\n2. Checking tunnel directory:")
|
||||
checkTunnelDirectory()
|
||||
|
||||
// Check recorded tunnels
|
||||
fmt.Println("\n3. Checking recorded tunnels:")
|
||||
checkRecordedTunnels()
|
||||
|
||||
// Check active SSH processes
|
||||
fmt.Println("\n4. Checking active SSH processes:")
|
||||
checkActiveSSHProcesses()
|
||||
}
|
||||
|
||||
func checkSSHClient() {
|
||||
path, err := exec.LookPath("ssh")
|
||||
if err != nil {
|
||||
fmt.Printf(" ❌ SSH client not found in PATH: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf(" ✅ SSH client found at: %s\n", path)
|
||||
|
||||
// Get SSH version
|
||||
cmd := exec.Command("ssh", "-V")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
fmt.Printf(" ⚠️ Could not determine SSH version: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf(" ✅ SSH version info: %s\n", strings.TrimSpace(string(output)))
|
||||
}
|
||||
|
||||
func checkTunnelDirectory() {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
fmt.Printf(" ❌ Could not determine home directory: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
tunnelPath := filepath.Join(homeDir, tunnelDir)
|
||||
info, err := os.Stat(tunnelPath)
|
||||
if os.IsNotExist(err) {
|
||||
fmt.Printf(" ⚠️ Tunnel directory does not exist: %s\n", tunnelPath)
|
||||
return
|
||||
} else if err != nil {
|
||||
fmt.Printf(" ❌ Error accessing tunnel directory: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf(" ✅ Tunnel directory exists: %s\n", tunnelPath)
|
||||
fmt.Printf(" ✅ Permissions: %s\n", info.Mode().String())
|
||||
|
||||
entries, err := os.ReadDir(tunnelPath)
|
||||
if err != nil {
|
||||
fmt.Printf(" ❌ Could not read tunnel directory: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf(" ✅ Directory contains %d entries\n", len(entries))
|
||||
}
|
||||
|
||||
func checkRecordedTunnels() {
|
||||
tunnels, err := getTunnels()
|
||||
if err != nil {
|
||||
fmt.Printf(" ❌ Error reading tunnels: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(tunnels) == 0 {
|
||||
fmt.Printf(" ℹ️ No recorded tunnels found\n")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf(" ✅ Found %d recorded tunnels\n", len(tunnels))
|
||||
for i, t := range tunnels {
|
||||
fmt.Printf("\n Tunnel #%d (ID: %d):\n", i+1, t.ID)
|
||||
fmt.Printf(" Local port: %d\n", t.LocalPort)
|
||||
fmt.Printf(" Remote: %s:%d\n", t.RemoteHost, t.RemotePort)
|
||||
fmt.Printf(" Server: %s\n", t.SSHServer)
|
||||
fmt.Printf(" PID: %d\n", t.PID)
|
||||
|
||||
// Check if process exists
|
||||
process, err := os.FindProcess(t.PID)
|
||||
if err != nil {
|
||||
fmt.Printf(" ❌ Process not found: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to send signal 0 to check if process exists
|
||||
err = process.Signal(syscall.Signal(0))
|
||||
if err != nil {
|
||||
fmt.Printf(" ❌ Process not running: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf(" ✅ Process is running\n")
|
||||
|
||||
// Check if it's actually an SSH process
|
||||
isSSH := checkIfSSHProcess(t.PID)
|
||||
if isSSH {
|
||||
fmt.Printf(" ✅ Process is an SSH process\n")
|
||||
} else {
|
||||
fmt.Printf(" ⚠️ Process is not an SSH process!\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkActiveSSHProcesses() {
|
||||
// Try using ps to find SSH processes
|
||||
cmd := exec.Command("ps", "-eo", "pid,command")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
fmt.Printf(" ❌ Could not list processes: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
lines := strings.Split(string(output), "\n")
|
||||
sshProcesses := []string{}
|
||||
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, "ssh") && strings.Contains(line, "-L") {
|
||||
sshProcesses = append(sshProcesses, strings.TrimSpace(line))
|
||||
}
|
||||
}
|
||||
|
||||
if len(sshProcesses) == 0 {
|
||||
fmt.Printf(" ℹ️ No SSH tunnel processes found\n")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf(" ✅ Found %d SSH tunnel processes:\n", len(sshProcesses))
|
||||
for _, proc := range sshProcesses {
|
||||
fmt.Printf(" %s\n", proc)
|
||||
|
||||
// Extract PID
|
||||
fields := strings.Fields(proc)
|
||||
if len(fields) > 0 {
|
||||
pid, err := strconv.Atoi(fields[0])
|
||||
if err == nil {
|
||||
// Check if this process is in our records
|
||||
found := false
|
||||
tunnels, _ := getTunnels()
|
||||
for _, t := range tunnels {
|
||||
if t.PID == pid {
|
||||
fmt.Printf(" ✅ This process is tracked as tunnel ID %d\n", t.ID)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
fmt.Printf(" ⚠️ This process is not tracked by the tunnel manager\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func verifyTunnelConnectivity(t Tunnel) error {
|
||||
// Try to connect to the local port to verify the tunnel is working
|
||||
cmd := exec.Command("nc", "-z", "-w", "1", "localhost", strconv.Itoa(t.LocalPort))
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not connect to local port %d: %v", t.LocalPort, err)
|
||||
}
|
||||
return nil
|
||||
}
|
56
cmd/list.go
Normal file
56
cmd/list.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var listCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List all active SSH tunnels",
|
||||
Long: `Display all currently active SSH tunnels managed by this tool.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
listTunnels()
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(listCmd)
|
||||
}
|
||||
|
||||
func listTunnels() {
|
||||
tunnels, err := getTunnels()
|
||||
if err != nil {
|
||||
fmt.Printf("Error listing tunnels: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(tunnels) == 0 {
|
||||
fmt.Println("No active SSH tunnels found.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Active SSH tunnels:")
|
||||
fmt.Printf("%-5s %-15s %-20s %-10s\n", "ID", "LOCAL", "REMOTE", "PID")
|
||||
fmt.Println(strings.Repeat("-", 60))
|
||||
|
||||
for _, t := range tunnels {
|
||||
// Check if the process is still running
|
||||
process, err := os.FindProcess(t.PID)
|
||||
if err != nil || process.Signal(syscall.Signal(0)) != nil {
|
||||
// Process does not exist anymore, clean up the tunnel file
|
||||
removeFile(t.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("%-5d %-15s %-20s %-10d\n",
|
||||
t.ID,
|
||||
fmt.Sprintf("localhost:%d", t.LocalPort),
|
||||
fmt.Sprintf("%s:%d", t.RemoteHost, t.RemotePort),
|
||||
t.PID)
|
||||
}
|
||||
}
|
32
cmd/root.go
Normal file
32
cmd/root.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "sshtunnel",
|
||||
Short: "SSH tunnel manager",
|
||||
Long: `SSH Tunnel Manager is a CLI tool for creating and managing SSH tunnels.
|
||||
It allows you to easily create, list, and terminate SSH port forwarding tunnels
|
||||
in the background without having to remember complex SSH commands.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
// If no subcommand is provided, print help
|
||||
cmd.Help()
|
||||
},
|
||||
}
|
||||
|
||||
// Execute executes the root command
|
||||
func Execute() {
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Global flags can be defined here
|
||||
}
|
141
cmd/start.go
Normal file
141
cmd/start.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"sshtunnel/pkg/stats"
|
||||
)
|
||||
|
||||
var (
|
||||
localPort int
|
||||
remotePort int
|
||||
remoteHost string
|
||||
sshServer string
|
||||
identity string
|
||||
)
|
||||
|
||||
// startCmd represents the start command
|
||||
var startCmd = &cobra.Command{
|
||||
Use: "start",
|
||||
Short: "Start a new SSH tunnel",
|
||||
Long: `Start a new SSH tunnel with specified local port, remote port, host and SSH server.
|
||||
The tunnel will run in the background and can be managed using the list and stop commands.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
// Check required flags
|
||||
// Generate the SSH command with appropriate flags for reliable background operation
|
||||
sshArgs := []string{
|
||||
"-N", // Don't execute remote command
|
||||
"-f", // Run in background
|
||||
"-L", fmt.Sprintf("%d:%s:%d", localPort, remoteHost, remotePort),
|
||||
}
|
||||
|
||||
if identity != "" {
|
||||
sshArgs = append(sshArgs, "-i", identity)
|
||||
}
|
||||
|
||||
sshArgs = append(sshArgs, sshServer)
|
||||
sshCmd := exec.Command("ssh", sshArgs...)
|
||||
|
||||
// Capture output for debugging
|
||||
var outputBuffer strings.Builder
|
||||
sshCmd.Stdout = &outputBuffer
|
||||
sshCmd.Stderr = &outputBuffer
|
||||
|
||||
// Run the command (not just Start) - the -f flag means it will return immediately
|
||||
// after going to the background
|
||||
if err := sshCmd.Run(); err != nil {
|
||||
fmt.Printf("Error starting SSH tunnel: %v\n", err)
|
||||
fmt.Printf("SSH output: %s\n", outputBuffer.String())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// The PID from cmd.Process is no longer valid since ssh -f forks
|
||||
// We need to find the actual SSH process PID
|
||||
actualPID, err := findSSHTunnelPID(localPort)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Could not determine tunnel PID: %v\n", err)
|
||||
}
|
||||
|
||||
// Store tunnel information
|
||||
id := generateTunnelID()
|
||||
pid := 0
|
||||
|
||||
if actualPID > 0 {
|
||||
pid = actualPID
|
||||
}
|
||||
|
||||
tunnel := Tunnel{
|
||||
ID: id,
|
||||
LocalPort: localPort,
|
||||
RemotePort: remotePort,
|
||||
RemoteHost: remoteHost,
|
||||
SSHServer: sshServer,
|
||||
PID: pid,
|
||||
}
|
||||
|
||||
if err := saveTunnel(tunnel); err != nil {
|
||||
fmt.Printf("Error saving tunnel information: %v\n", err)
|
||||
if pid > 0 {
|
||||
// Try to kill the process
|
||||
process, _ := os.FindProcess(pid)
|
||||
if process != nil {
|
||||
process.Kill()
|
||||
}
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Initialize statistics for this tunnel
|
||||
statsManager, err := stats.NewStatsManager()
|
||||
if err == nil {
|
||||
err = statsManager.InitStats(id, localPort)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Failed to initialize statistics: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify tunnel is actually working
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
active := verifyTunnelActive(localPort)
|
||||
status := "ACTIVE"
|
||||
if !active {
|
||||
status = "UNKNOWN"
|
||||
}
|
||||
|
||||
fmt.Printf("Started SSH tunnel (ID: %d): localhost:%d -> %s:%d (%s) [PID: %d] [Status: %s]\n",
|
||||
id, localPort, remoteHost, remotePort, sshServer, pid, status)
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(startCmd)
|
||||
|
||||
// Add flags for the start command
|
||||
startCmd.Flags().IntVarP(&localPort, "local", "l", 0, "Local port to forward")
|
||||
startCmd.Flags().IntVarP(&remotePort, "remote", "r", 0, "Remote port to forward to")
|
||||
startCmd.Flags().StringVarP(&remoteHost, "host", "H", "localhost", "Remote host to forward to")
|
||||
startCmd.Flags().StringVarP(&sshServer, "server", "s", "", "SSH server address (user@host)")
|
||||
startCmd.Flags().StringVarP(&identity, "identity", "i", "", "Path to SSH identity file")
|
||||
|
||||
// Mark required flags
|
||||
startCmd.MarkFlagRequired("local")
|
||||
startCmd.MarkFlagRequired("remote")
|
||||
startCmd.MarkFlagRequired("server")
|
||||
}
|
||||
|
||||
// verifyTunnelActive checks if the tunnel is actually working
|
||||
func verifyTunnelActive(port int) bool {
|
||||
// Try to connect to the port to verify it's open
|
||||
conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 500*time.Millisecond)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
conn.Close()
|
||||
return true
|
||||
}
|
283
cmd/stats.go
Normal file
283
cmd/stats.go
Normal file
@@ -0,0 +1,283 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"sshtunnel/pkg/monitor"
|
||||
"sshtunnel/pkg/stats"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
tunnelIDFlag int
|
||||
watchFlag bool
|
||||
watchInterval int
|
||||
allFlag bool
|
||||
)
|
||||
|
||||
// statsCmd represents the stats command
|
||||
var statsCmd = &cobra.Command{
|
||||
Use: "stats",
|
||||
Short: "View SSH tunnel statistics",
|
||||
Long: `View traffic statistics for active SSH tunnels.
|
||||
This command shows you how much data has been transferred through your
|
||||
SSH tunnels, in both directions (incoming and outgoing).
|
||||
|
||||
You can view statistics for a specific tunnel by specifying its ID,
|
||||
or view statistics for all active tunnels.
|
||||
|
||||
Examples:
|
||||
# View statistics for all tunnels
|
||||
sshtunnel stats --all
|
||||
|
||||
# View statistics for a specific tunnel
|
||||
sshtunnel stats --id 1
|
||||
|
||||
# Continuously monitor a specific tunnel (every 5 seconds)
|
||||
sshtunnel stats --id 1 --watch
|
||||
|
||||
# Continuously monitor with custom interval (2 seconds)
|
||||
sshtunnel stats --id 1 --watch --interval 2`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if !allFlag && tunnelIDFlag <= 0 {
|
||||
fmt.Println("Error: You must specify either a tunnel ID with --id or use --all")
|
||||
cmd.Help()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if allFlag && tunnelIDFlag > 0 {
|
||||
fmt.Println("Error: You cannot specify both --id and --all flags")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if watchFlag && !allFlag && tunnelIDFlag <= 0 {
|
||||
fmt.Println("Error: Watch mode requires specifying a tunnel ID with --id")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if watchFlag && allFlag {
|
||||
fmt.Println("Error: Watch mode can only be used with a specific tunnel ID")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
statsManager, err := stats.NewStatsManager()
|
||||
if err != nil {
|
||||
fmt.Printf("Error initializing statistics: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if allFlag {
|
||||
showAllStats(statsManager)
|
||||
} else {
|
||||
if watchFlag {
|
||||
watchTunnelStats(tunnelIDFlag, watchInterval)
|
||||
} else {
|
||||
showTunnelStats(statsManager, tunnelIDFlag)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(statsCmd)
|
||||
|
||||
statsCmd.Flags().IntVar(&tunnelIDFlag, "id", 0, "ID of the tunnel to show statistics for")
|
||||
statsCmd.Flags().BoolVar(&allFlag, "all", false, "Show statistics for all tunnels")
|
||||
statsCmd.Flags().BoolVar(&watchFlag, "watch", false, "Continuously monitor tunnel statistics")
|
||||
statsCmd.Flags().IntVar(&watchInterval, "interval", 5, "Update interval for watch mode (in seconds)")
|
||||
}
|
||||
|
||||
func showAllStats(statsManager *stats.StatsManager) {
|
||||
tunnelStats, err := statsManager.GetAllStats()
|
||||
if err != nil {
|
||||
fmt.Printf("Error retrieving statistics: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(tunnelStats) == 0 {
|
||||
fmt.Println("No statistics available for any tunnels.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("SSH Tunnel Statistics:")
|
||||
fmt.Println("=============================================")
|
||||
for _, s := range tunnelStats {
|
||||
// Verify if this tunnel is still active
|
||||
tunnels, _ := getTunnels()
|
||||
active := false
|
||||
for _, t := range tunnels {
|
||||
if t.ID == s.TunnelID {
|
||||
active = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
uptime := time.Since(s.StartTime)
|
||||
inRate := stats.CalculateRate(s.BytesIn, uptime)
|
||||
outRate := stats.CalculateRate(s.BytesOut, uptime)
|
||||
|
||||
status := "ACTIVE"
|
||||
if !active {
|
||||
status = "INACTIVE"
|
||||
}
|
||||
|
||||
fmt.Printf("Tunnel #%d (Port: %d) [%s]\n", s.TunnelID, s.LocalPort, status)
|
||||
fmt.Printf(" Uptime: %s\n", formatDuration(uptime))
|
||||
fmt.Printf(" Received: %s (avg: %s)\n", stats.FormatBytes(s.BytesIn), stats.FormatRate(inRate))
|
||||
fmt.Printf(" Sent: %s (avg: %s)\n", stats.FormatBytes(s.BytesOut), stats.FormatRate(outRate))
|
||||
fmt.Printf(" Total: %s\n", stats.FormatBytes(s.BytesIn+s.BytesOut))
|
||||
fmt.Println("---------------------------------------------")
|
||||
}
|
||||
}
|
||||
|
||||
func showTunnelStats(statsManager *stats.StatsManager, tunnelID int) {
|
||||
s, err := statsManager.GetStats(tunnelID)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Could not retrieve statistics for tunnel #%d: %v\n", tunnelID, err)
|
||||
s = stats.Stats{
|
||||
TunnelID: tunnelID,
|
||||
BytesIn: 0,
|
||||
BytesOut: 0,
|
||||
StartTime: time.Now(),
|
||||
LastUpdated: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Get tunnel info to validate it exists
|
||||
tunnel, err := getTunnel(tunnelID)
|
||||
if err != nil {
|
||||
fmt.Printf("Note: Tunnel #%d does not appear to be active, showing historical data\n", tunnelID)
|
||||
}
|
||||
|
||||
// Display stats
|
||||
uptime := time.Since(s.StartTime)
|
||||
inRate := stats.CalculateRate(s.BytesIn, uptime)
|
||||
outRate := stats.CalculateRate(s.BytesOut, uptime)
|
||||
|
||||
fmt.Printf("Statistics for Tunnel #%d:\n", tunnelID)
|
||||
fmt.Println("=============================================")
|
||||
if tunnel != nil {
|
||||
fmt.Printf("Local Port: %d\n", tunnel.LocalPort)
|
||||
fmt.Printf("Remote Target: %s:%d\n", tunnel.RemoteHost, tunnel.RemotePort)
|
||||
fmt.Printf("SSH Server: %s\n", tunnel.SSHServer)
|
||||
fmt.Printf("PID: %d\n", tunnel.PID)
|
||||
} else if s.LocalPort > 0 {
|
||||
// If we don't have tunnel info but do have port in the stats
|
||||
fmt.Printf("Local Port: %d\n", s.LocalPort)
|
||||
fmt.Printf("Status: INACTIVE (historical data)\n")
|
||||
}
|
||||
fmt.Printf("Uptime: %s\n", formatDuration(uptime))
|
||||
fmt.Printf("Received: %s (avg: %s)\n", stats.FormatBytes(s.BytesIn), stats.FormatRate(inRate))
|
||||
fmt.Printf("Sent: %s (avg: %s)\n", stats.FormatBytes(s.BytesOut), stats.FormatRate(outRate))
|
||||
fmt.Printf("Total Traffic: %s\n", stats.FormatBytes(s.BytesIn+s.BytesOut))
|
||||
}
|
||||
|
||||
func watchTunnelStats(tunnelID int, interval int) {
|
||||
// Get tunnel info to validate it exists
|
||||
tunnel, err := getTunnel(tunnelID)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Tunnel #%d may not be active: %v\n", tunnelID, err)
|
||||
fmt.Println("Attempting to monitor anyway...")
|
||||
|
||||
// Try to get port from stats
|
||||
statsManager, _ := stats.NewStatsManager()
|
||||
s, err := statsManager.GetStats(tunnelID)
|
||||
if err == nil && s.LocalPort > 0 {
|
||||
// Create a tunnel object with the information we have
|
||||
tunnel = &Tunnel{
|
||||
ID: tunnelID,
|
||||
LocalPort: s.LocalPort,
|
||||
RemoteHost: "unknown",
|
||||
RemotePort: 0,
|
||||
SSHServer: "unknown",
|
||||
PID: 0,
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Error: Cannot determine port for tunnel. Please specify a valid tunnel ID.")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a monitor for this tunnel
|
||||
mon, err := monitor.NewMonitor(tunnelID, tunnel.LocalPort)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Error creating monitor: %v\n", err)
|
||||
fmt.Println("Will attempt to continue with limited functionality")
|
||||
}
|
||||
|
||||
// Start monitoring with error handling
|
||||
err = mon.Start()
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: %v\n", err)
|
||||
fmt.Println("Continuing with estimated statistics...")
|
||||
}
|
||||
defer mon.Stop()
|
||||
|
||||
fmt.Printf("Monitoring SSH tunnel #%d (Press Ctrl+C to stop)...\n\n", tunnelID)
|
||||
|
||||
ticker := time.NewTicker(time.Duration(interval) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Initial stats display
|
||||
statsStr, err := mon.FormatStats()
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: %v\n", err)
|
||||
fmt.Println("Displaying minimal statistics...")
|
||||
statsStr = fmt.Sprintf("Monitoring tunnel #%d on port %d...",
|
||||
tunnelID, tunnel.LocalPort)
|
||||
}
|
||||
fmt.Println(statsStr)
|
||||
fmt.Println("\nUpdating every", interval, "seconds...")
|
||||
|
||||
// Clear screen and update stats periodically
|
||||
for range ticker.C {
|
||||
// ANSI escape code to clear screen
|
||||
fmt.Print("\033[H\033[2J")
|
||||
|
||||
fmt.Printf("Monitoring SSH tunnel #%d (Press Ctrl+C to stop)...\n\n", tunnelID)
|
||||
|
||||
statsStr, err := mon.FormatStats()
|
||||
if err != nil {
|
||||
statsStr = fmt.Sprintf("Warning: %v\n\nStill monitoring tunnel #%d on port %d...",
|
||||
err, tunnelID, tunnel.LocalPort)
|
||||
}
|
||||
|
||||
fmt.Println(statsStr)
|
||||
fmt.Println("\nUpdating every", interval, "seconds...")
|
||||
}
|
||||
}
|
||||
|
||||
func formatDuration(d time.Duration) string {
|
||||
days := int(d.Hours() / 24)
|
||||
hours := int(d.Hours()) % 24
|
||||
minutes := int(d.Minutes()) % 60
|
||||
seconds := int(d.Seconds()) % 60
|
||||
|
||||
if days > 0 {
|
||||
return fmt.Sprintf("%dd %dh %dm %ds", days, hours, minutes, seconds)
|
||||
} else if hours > 0 {
|
||||
return fmt.Sprintf("%dh %dm %ds", hours, minutes, seconds)
|
||||
} else if minutes > 0 {
|
||||
return fmt.Sprintf("%dm %ds", minutes, seconds)
|
||||
}
|
||||
return fmt.Sprintf("%ds", seconds)
|
||||
}
|
||||
|
||||
// getTunnel gets a specific tunnel by ID
|
||||
func getTunnel(id int) (*Tunnel, error) {
|
||||
tunnels, err := getTunnels()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, t := range tunnels {
|
||||
if t.ID == id {
|
||||
return &t, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("tunnel %d not found", id)
|
||||
}
|
105
cmd/stop.go
Normal file
105
cmd/stop.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"sshtunnel/pkg/stats"
|
||||
)
|
||||
|
||||
var (
|
||||
tunnelID int
|
||||
all bool
|
||||
)
|
||||
|
||||
var stopCmd = &cobra.Command{
|
||||
Use: "stop",
|
||||
Short: "Stop an SSH tunnel",
|
||||
Long: `Stop an active SSH tunnel by its ID or stop all tunnels.
|
||||
Use the list command to see all active tunnels and their IDs.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if !all && tunnelID == -1 {
|
||||
fmt.Println("Error: must specify either --id or --all")
|
||||
cmd.Help()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
tunnels, err := getTunnels()
|
||||
if err != nil {
|
||||
fmt.Printf("Error getting tunnels: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(tunnels) == 0 {
|
||||
fmt.Println("No active SSH tunnels found.")
|
||||
return
|
||||
}
|
||||
|
||||
if all {
|
||||
for _, t := range tunnels {
|
||||
killTunnel(t)
|
||||
}
|
||||
fmt.Println("All SSH tunnels stopped.")
|
||||
return
|
||||
}
|
||||
|
||||
// Stop specific tunnel
|
||||
found := false
|
||||
for _, t := range tunnels {
|
||||
if t.ID == tunnelID {
|
||||
killTunnel(t)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
fmt.Printf("Error: no tunnel with ID %d found\n", tunnelID)
|
||||
os.Exit(1)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(stopCmd)
|
||||
|
||||
// Add flags for the stop command
|
||||
stopCmd.Flags().IntVarP(&tunnelID, "id", "i", -1, "ID of the tunnel to stop")
|
||||
stopCmd.Flags().BoolVarP(&all, "all", "a", false, "Stop all tunnels")
|
||||
}
|
||||
|
||||
func killTunnel(t Tunnel) {
|
||||
process, err := os.FindProcess(t.PID)
|
||||
if err != nil {
|
||||
fmt.Printf("Error finding process %d: %v\n", t.PID, err)
|
||||
removeFile(t.ID)
|
||||
cleanupStats(t.ID)
|
||||
return
|
||||
}
|
||||
|
||||
if err := process.Signal(syscall.SIGTERM); err != nil {
|
||||
fmt.Printf("Error stopping tunnel %d (PID %d): %v\n", t.ID, t.PID, err)
|
||||
// Process might not exist anymore, so clean up anyway
|
||||
}
|
||||
|
||||
removeFile(t.ID)
|
||||
cleanupStats(t.ID)
|
||||
fmt.Printf("Stopped SSH tunnel (ID: %d): localhost:%d -> %s:%d\n",
|
||||
t.ID, t.LocalPort, t.RemoteHost, t.RemotePort)
|
||||
}
|
||||
|
||||
// cleanupStats removes the statistics for a tunnel
|
||||
func cleanupStats(tunnelID int) {
|
||||
statsManager, err := stats.NewStatsManager()
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Could not clean up statistics: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = statsManager.DeleteStats(tunnelID)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Failed to delete statistics for tunnel %d: %v\n", tunnelID, err)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user