sshtunnel/cmd/common.go
2025-05-23 15:08:44 +02:00

245 lines
5.9 KiB
Go

package cmd
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"syscall"
)
const (
tunnelDir = ".sshtunnels"
)
// Tunnel represents an SSH tunnel configuration
type Tunnel struct {
ID int // Unique identifier for the tunnel
LocalPort int // Local port being forwarded
RemotePort int // Remote port being forwarded to
RemoteHost string // Remote host being forwarded to
SSHServer string // SSH server (user@host)
PID int // Process ID of the SSH process
}
// checkIfSSHProcess checks if a process ID belongs to an SSH process
func checkIfSSHProcess(pid int) bool {
// Try to read /proc/{pid}/comm if on Linux
if _, err := os.Stat("/proc"); err == nil {
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/comm", pid))
if err == nil && strings.TrimSpace(string(data)) == "ssh" {
return true
}
}
// Alternative approach - use ps command
cmd := exec.Command("ps", "-p", strconv.Itoa(pid), "-o", "comm=")
output, err := cmd.Output()
if err == nil && strings.Contains(string(output), "ssh") {
return true
}
// Last resort - just check if process exists
process, err := os.FindProcess(pid)
if err != nil {
return false
}
// Send signal 0 to check if process exists
return process.Signal(syscall.Signal(0)) == nil
}
// findSSHTunnelPID attempts to find the PID of an SSH tunnel process by its local port
func findSSHTunnelPID(port int) (int, error) {
// Try using lsof first (most reliable)
cmd := exec.Command("lsof", "-i", fmt.Sprintf("TCP:%d", port), "-t")
output, err := cmd.Output()
if err == nil && len(output) > 0 {
pidStr := strings.TrimSpace(string(output))
pid, err := strconv.Atoi(pidStr)
if err == nil {
// Verify this is an SSH process
if checkIfSSHProcess(pid) {
return pid, nil
}
}
}
// Try netstat as a fallback
cmd = exec.Command("netstat", "-tlnp")
output, err = cmd.Output()
if err == nil {
lines := strings.Split(string(output), "\n")
portStr := fmt.Sprintf(":%d", port)
for _, line := range lines {
if strings.Contains(line, portStr) && strings.Contains(line, "ssh") {
// Extract PID from the line
parts := strings.Fields(line)
if len(parts) >= 7 {
pidPart := parts[6]
pidStr := strings.Split(pidPart, "/")[0]
pid, err := strconv.Atoi(pidStr)
if err == nil {
return pid, nil
}
}
}
}
}
// Try ps as a last resort
cmd = exec.Command("ps", "aux")
output, err = cmd.Output()
if err == nil {
lines := strings.Split(string(output), "\n")
portStr := fmt.Sprintf(":%d", port)
for _, line := range lines {
if strings.Contains(line, "ssh") && strings.Contains(line, portStr) {
parts := strings.Fields(line)
if len(parts) >= 2 {
pid, err := strconv.Atoi(parts[1])
if err == nil {
return pid, nil
}
}
}
}
}
return 0, fmt.Errorf("could not find SSH tunnel PID for port %d", port)
}
func getTunnels() ([]Tunnel, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("error getting home directory: %v", err)
}
tunnelPath := filepath.Join(homeDir, tunnelDir)
entries, err := os.ReadDir(tunnelPath)
if err != nil {
if os.IsNotExist(err) {
return []Tunnel{}, nil
}
return nil, fmt.Errorf("error reading tunnel directory: %v", err)
}
var tunnels []Tunnel
for _, entry := range entries {
if !entry.IsDir() && strings.HasPrefix(entry.Name(), "tunnel-") {
filePath := filepath.Join(tunnelPath, entry.Name())
data, err := os.ReadFile(filePath)
if err != nil {
fmt.Printf("Error reading tunnel file %s: %v\n", filePath, err)
continue
}
var t Tunnel
parts := strings.Split(string(data), ":")
if len(parts) != 5 {
fmt.Printf("Invalid tunnel file format: %s\n", filePath)
continue
}
idStr := strings.TrimPrefix(entry.Name(), "tunnel-")
id, err := strconv.Atoi(idStr)
if err != nil {
fmt.Printf("Invalid tunnel ID: %s\n", idStr)
continue
}
t.ID = id
t.LocalPort, err = strconv.Atoi(parts[0])
if err != nil {
continue
}
t.RemoteHost = parts[1]
t.RemotePort, err = strconv.Atoi(parts[2])
if err != nil {
continue
}
t.SSHServer = parts[3]
t.PID, err = strconv.Atoi(parts[4])
if err != nil {
continue
}
// Verify if this is actually a SSH process
isSSH := checkIfSSHProcess(t.PID)
if !isSSH {
fmt.Printf("Process %d is not an SSH process anymore, cleaning up\n", t.PID)
removeFile(t.ID)
continue
}
tunnels = append(tunnels, t)
}
}
return tunnels, nil
}
func saveTunnel(t Tunnel) error {
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("error getting home directory: %v", err)
}
tunnelPath := filepath.Join(homeDir, tunnelDir, fmt.Sprintf("tunnel-%d", t.ID))
data := fmt.Sprintf("%d:%s:%d:%s:%d", t.LocalPort, t.RemoteHost, t.RemotePort, t.SSHServer, t.PID)
err = os.WriteFile(tunnelPath, []byte(data), 0644)
if err != nil {
return err
}
// Verify the file was written correctly
_, err = os.Stat(tunnelPath)
return err
}
func removeFile(id int) {
homeDir, err := os.UserHomeDir()
if err != nil {
fmt.Printf("Error getting home directory: %v\n", err)
return
}
tunnelPath := filepath.Join(homeDir, tunnelDir, fmt.Sprintf("tunnel-%d", id))
if err := os.Remove(tunnelPath); err != nil && !os.IsNotExist(err) {
fmt.Printf("Error removing tunnel file %s: %v\n", tunnelPath, err)
}
}
func generateTunnelID() int {
homeDir, err := os.UserHomeDir()
if err != nil {
return int(os.Getpid())
}
tunnelPath := filepath.Join(homeDir, tunnelDir)
entries, err := os.ReadDir(tunnelPath)
if err != nil {
return int(os.Getpid())
}
id := 1
for _, entry := range entries {
if !entry.IsDir() && strings.HasPrefix(entry.Name(), "tunnel-") {
idStr := strings.TrimPrefix(entry.Name(), "tunnel-")
if val, err := strconv.Atoi(idStr); err == nil && val >= id {
id = val + 1
}
}
}
return id
}