package cmd import ( "fmt" "os" "os/exec" "path/filepath" "strconv" "strings" "syscall" ) const ( tunnelDir = ".sshtunnels" ) // Tunnel represents an SSH tunnel configuration type Tunnel struct { ID int // Unique identifier for the tunnel LocalPort int // Local port being forwarded RemotePort int // Remote port being forwarded to RemoteHost string // Remote host being forwarded to SSHServer string // SSH server (user@host) PID int // Process ID of the SSH process } // checkIfSSHProcess checks if a process ID belongs to an SSH process func checkIfSSHProcess(pid int) bool { // Try to read /proc/{pid}/comm if on Linux if _, err := os.Stat("/proc"); err == nil { data, err := os.ReadFile(fmt.Sprintf("/proc/%d/comm", pid)) if err == nil && strings.TrimSpace(string(data)) == "ssh" { return true } } // Alternative approach - use ps command cmd := exec.Command("ps", "-p", strconv.Itoa(pid), "-o", "comm=") output, err := cmd.Output() if err == nil && strings.Contains(string(output), "ssh") { return true } // Last resort - just check if process exists process, err := os.FindProcess(pid) if err != nil { return false } // Send signal 0 to check if process exists return process.Signal(syscall.Signal(0)) == nil } // findSSHTunnelPID attempts to find the PID of an SSH tunnel process by its local port func findSSHTunnelPID(port int) (int, error) { // Try using lsof first (most reliable) cmd := exec.Command("lsof", "-i", fmt.Sprintf("TCP:%d", port), "-t") output, err := cmd.Output() if err == nil && len(output) > 0 { pidStr := strings.TrimSpace(string(output)) pid, err := strconv.Atoi(pidStr) if err == nil { // Verify this is an SSH process if checkIfSSHProcess(pid) { return pid, nil } } } // Try netstat as a fallback cmd = exec.Command("netstat", "-tlnp") output, err = cmd.Output() if err == nil { lines := strings.Split(string(output), "\n") portStr := fmt.Sprintf(":%d", port) for _, line := range lines { if strings.Contains(line, portStr) && strings.Contains(line, "ssh") { // Extract PID from the line parts := strings.Fields(line) if len(parts) >= 7 { pidPart := parts[6] pidStr := strings.Split(pidPart, "/")[0] pid, err := strconv.Atoi(pidStr) if err == nil { return pid, nil } } } } } // Try ps as a last resort cmd = exec.Command("ps", "aux") output, err = cmd.Output() if err == nil { lines := strings.Split(string(output), "\n") portStr := fmt.Sprintf(":%d", port) for _, line := range lines { if strings.Contains(line, "ssh") && strings.Contains(line, portStr) { parts := strings.Fields(line) if len(parts) >= 2 { pid, err := strconv.Atoi(parts[1]) if err == nil { return pid, nil } } } } } return 0, fmt.Errorf("could not find SSH tunnel PID for port %d", port) } func getTunnels() ([]Tunnel, error) { homeDir, err := os.UserHomeDir() if err != nil { return nil, fmt.Errorf("error getting home directory: %v", err) } tunnelPath := filepath.Join(homeDir, tunnelDir) entries, err := os.ReadDir(tunnelPath) if err != nil { if os.IsNotExist(err) { return []Tunnel{}, nil } return nil, fmt.Errorf("error reading tunnel directory: %v", err) } var tunnels []Tunnel for _, entry := range entries { if !entry.IsDir() && strings.HasPrefix(entry.Name(), "tunnel-") { filePath := filepath.Join(tunnelPath, entry.Name()) data, err := os.ReadFile(filePath) if err != nil { fmt.Printf("Error reading tunnel file %s: %v\n", filePath, err) continue } var t Tunnel parts := strings.Split(string(data), ":") if len(parts) != 5 { fmt.Printf("Invalid tunnel file format: %s\n", filePath) continue } idStr := strings.TrimPrefix(entry.Name(), "tunnel-") id, err := strconv.Atoi(idStr) if err != nil { fmt.Printf("Invalid tunnel ID: %s\n", idStr) continue } t.ID = id t.LocalPort, err = strconv.Atoi(parts[0]) if err != nil { continue } t.RemoteHost = parts[1] t.RemotePort, err = strconv.Atoi(parts[2]) if err != nil { continue } t.SSHServer = parts[3] t.PID, err = strconv.Atoi(parts[4]) if err != nil { continue } // Verify if this is actually a SSH process isSSH := checkIfSSHProcess(t.PID) if !isSSH { fmt.Printf("Process %d is not an SSH process anymore, cleaning up\n", t.PID) removeFile(t.ID) continue } tunnels = append(tunnels, t) } } return tunnels, nil } func saveTunnel(t Tunnel) error { homeDir, err := os.UserHomeDir() if err != nil { return fmt.Errorf("error getting home directory: %v", err) } tunnelPath := filepath.Join(homeDir, tunnelDir, fmt.Sprintf("tunnel-%d", t.ID)) data := fmt.Sprintf("%d:%s:%d:%s:%d", t.LocalPort, t.RemoteHost, t.RemotePort, t.SSHServer, t.PID) err = os.WriteFile(tunnelPath, []byte(data), 0644) if err != nil { return err } // Verify the file was written correctly _, err = os.Stat(tunnelPath) return err } func removeFile(id int) { homeDir, err := os.UserHomeDir() if err != nil { fmt.Printf("Error getting home directory: %v\n", err) return } tunnelPath := filepath.Join(homeDir, tunnelDir, fmt.Sprintf("tunnel-%d", id)) if err := os.Remove(tunnelPath); err != nil && !os.IsNotExist(err) { fmt.Printf("Error removing tunnel file %s: %v\n", tunnelPath, err) } } func generateTunnelID() int { homeDir, err := os.UserHomeDir() if err != nil { return int(os.Getpid()) } tunnelPath := filepath.Join(homeDir, tunnelDir) entries, err := os.ReadDir(tunnelPath) if err != nil { return int(os.Getpid()) } id := 1 for _, entry := range entries { if !entry.IsDir() && strings.HasPrefix(entry.Name(), "tunnel-") { idStr := strings.TrimPrefix(entry.Name(), "tunnel-") if val, err := strconv.Atoi(idStr); err == nil && val >= id { id = val + 1 } } } return id }