245 lines
5.9 KiB
Go
245 lines
5.9 KiB
Go
package cmd
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"syscall"
|
|
)
|
|
|
|
const (
|
|
tunnelDir = ".sshtunnels"
|
|
)
|
|
|
|
// Tunnel represents an SSH tunnel configuration
|
|
type Tunnel struct {
|
|
ID int // Unique identifier for the tunnel
|
|
LocalPort int // Local port being forwarded
|
|
RemotePort int // Remote port being forwarded to
|
|
RemoteHost string // Remote host being forwarded to
|
|
SSHServer string // SSH server (user@host)
|
|
PID int // Process ID of the SSH process
|
|
}
|
|
|
|
// checkIfSSHProcess checks if a process ID belongs to an SSH process
|
|
func checkIfSSHProcess(pid int) bool {
|
|
// Try to read /proc/{pid}/comm if on Linux
|
|
if _, err := os.Stat("/proc"); err == nil {
|
|
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/comm", pid))
|
|
if err == nil && strings.TrimSpace(string(data)) == "ssh" {
|
|
return true
|
|
}
|
|
}
|
|
|
|
// Alternative approach - use ps command
|
|
cmd := exec.Command("ps", "-p", strconv.Itoa(pid), "-o", "comm=")
|
|
output, err := cmd.Output()
|
|
if err == nil && strings.Contains(string(output), "ssh") {
|
|
return true
|
|
}
|
|
|
|
// Last resort - just check if process exists
|
|
process, err := os.FindProcess(pid)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// Send signal 0 to check if process exists
|
|
return process.Signal(syscall.Signal(0)) == nil
|
|
}
|
|
|
|
// findSSHTunnelPID attempts to find the PID of an SSH tunnel process by its local port
|
|
func findSSHTunnelPID(port int) (int, error) {
|
|
// Try using lsof first (most reliable)
|
|
cmd := exec.Command("lsof", "-i", fmt.Sprintf("TCP:%d", port), "-t")
|
|
output, err := cmd.Output()
|
|
if err == nil && len(output) > 0 {
|
|
pidStr := strings.TrimSpace(string(output))
|
|
pid, err := strconv.Atoi(pidStr)
|
|
if err == nil {
|
|
// Verify this is an SSH process
|
|
if checkIfSSHProcess(pid) {
|
|
return pid, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// Try netstat as a fallback
|
|
cmd = exec.Command("netstat", "-tlnp")
|
|
output, err = cmd.Output()
|
|
if err == nil {
|
|
lines := strings.Split(string(output), "\n")
|
|
portStr := fmt.Sprintf(":%d", port)
|
|
|
|
for _, line := range lines {
|
|
if strings.Contains(line, portStr) && strings.Contains(line, "ssh") {
|
|
// Extract PID from the line
|
|
parts := strings.Fields(line)
|
|
if len(parts) >= 7 {
|
|
pidPart := parts[6]
|
|
pidStr := strings.Split(pidPart, "/")[0]
|
|
pid, err := strconv.Atoi(pidStr)
|
|
if err == nil {
|
|
return pid, nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Try ps as a last resort
|
|
cmd = exec.Command("ps", "aux")
|
|
output, err = cmd.Output()
|
|
if err == nil {
|
|
lines := strings.Split(string(output), "\n")
|
|
portStr := fmt.Sprintf(":%d", port)
|
|
|
|
for _, line := range lines {
|
|
if strings.Contains(line, "ssh") && strings.Contains(line, portStr) {
|
|
parts := strings.Fields(line)
|
|
if len(parts) >= 2 {
|
|
pid, err := strconv.Atoi(parts[1])
|
|
if err == nil {
|
|
return pid, nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return 0, fmt.Errorf("could not find SSH tunnel PID for port %d", port)
|
|
}
|
|
|
|
func getTunnels() ([]Tunnel, error) {
|
|
homeDir, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting home directory: %v", err)
|
|
}
|
|
|
|
tunnelPath := filepath.Join(homeDir, tunnelDir)
|
|
entries, err := os.ReadDir(tunnelPath)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return []Tunnel{}, nil
|
|
}
|
|
return nil, fmt.Errorf("error reading tunnel directory: %v", err)
|
|
}
|
|
|
|
var tunnels []Tunnel
|
|
for _, entry := range entries {
|
|
if !entry.IsDir() && strings.HasPrefix(entry.Name(), "tunnel-") {
|
|
filePath := filepath.Join(tunnelPath, entry.Name())
|
|
data, err := os.ReadFile(filePath)
|
|
if err != nil {
|
|
fmt.Printf("Error reading tunnel file %s: %v\n", filePath, err)
|
|
continue
|
|
}
|
|
|
|
var t Tunnel
|
|
parts := strings.Split(string(data), ":")
|
|
if len(parts) != 5 {
|
|
fmt.Printf("Invalid tunnel file format: %s\n", filePath)
|
|
continue
|
|
}
|
|
|
|
idStr := strings.TrimPrefix(entry.Name(), "tunnel-")
|
|
id, err := strconv.Atoi(idStr)
|
|
if err != nil {
|
|
fmt.Printf("Invalid tunnel ID: %s\n", idStr)
|
|
continue
|
|
}
|
|
t.ID = id
|
|
|
|
t.LocalPort, err = strconv.Atoi(parts[0])
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
t.RemoteHost = parts[1]
|
|
|
|
t.RemotePort, err = strconv.Atoi(parts[2])
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
t.SSHServer = parts[3]
|
|
|
|
t.PID, err = strconv.Atoi(parts[4])
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// Verify if this is actually a SSH process
|
|
isSSH := checkIfSSHProcess(t.PID)
|
|
if !isSSH {
|
|
fmt.Printf("Process %d is not an SSH process anymore, cleaning up\n", t.PID)
|
|
removeFile(t.ID)
|
|
continue
|
|
}
|
|
|
|
tunnels = append(tunnels, t)
|
|
}
|
|
}
|
|
|
|
return tunnels, nil
|
|
}
|
|
|
|
func saveTunnel(t Tunnel) error {
|
|
homeDir, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return fmt.Errorf("error getting home directory: %v", err)
|
|
}
|
|
|
|
tunnelPath := filepath.Join(homeDir, tunnelDir, fmt.Sprintf("tunnel-%d", t.ID))
|
|
data := fmt.Sprintf("%d:%s:%d:%s:%d", t.LocalPort, t.RemoteHost, t.RemotePort, t.SSHServer, t.PID)
|
|
|
|
err = os.WriteFile(tunnelPath, []byte(data), 0644)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Verify the file was written correctly
|
|
_, err = os.Stat(tunnelPath)
|
|
return err
|
|
}
|
|
|
|
func removeFile(id int) {
|
|
homeDir, err := os.UserHomeDir()
|
|
if err != nil {
|
|
fmt.Printf("Error getting home directory: %v\n", err)
|
|
return
|
|
}
|
|
|
|
tunnelPath := filepath.Join(homeDir, tunnelDir, fmt.Sprintf("tunnel-%d", id))
|
|
if err := os.Remove(tunnelPath); err != nil && !os.IsNotExist(err) {
|
|
fmt.Printf("Error removing tunnel file %s: %v\n", tunnelPath, err)
|
|
}
|
|
}
|
|
|
|
func generateTunnelID() int {
|
|
homeDir, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return int(os.Getpid())
|
|
}
|
|
|
|
tunnelPath := filepath.Join(homeDir, tunnelDir)
|
|
entries, err := os.ReadDir(tunnelPath)
|
|
if err != nil {
|
|
return int(os.Getpid())
|
|
}
|
|
|
|
id := 1
|
|
for _, entry := range entries {
|
|
if !entry.IsDir() && strings.HasPrefix(entry.Name(), "tunnel-") {
|
|
idStr := strings.TrimPrefix(entry.Name(), "tunnel-")
|
|
if val, err := strconv.Atoi(idStr); err == nil && val >= id {
|
|
id = val + 1
|
|
}
|
|
}
|
|
}
|
|
|
|
return id
|
|
} |