package cmd import ( "fmt" "net" "os" "os/exec" "strings" "time" "git.mvl.sh/vleeuwenmenno/sshtunnel/pkg/stats" "github.com/spf13/cobra" ) var ( localPort int remotePort int remoteHost string sshServer string identity string ) // startCmd represents the start command var startCmd = &cobra.Command{ Use: "start", Short: "Start a new SSH tunnel", Long: `Start a new SSH tunnel with specified local port, remote port, host and SSH server. The tunnel will run in the background and can be managed using the list and stop commands.`, Run: func(cmd *cobra.Command, args []string) { // Check required flags // Generate the SSH command with appropriate flags for reliable background operation sshArgs := []string{ "-N", // Don't execute remote command "-f", // Run in background "-L", fmt.Sprintf("%d:%s:%d", localPort, remoteHost, remotePort), } if identity != "" { sshArgs = append(sshArgs, "-i", identity) } sshArgs = append(sshArgs, sshServer) sshCmd := exec.Command("ssh", sshArgs...) // Capture output for debugging var outputBuffer strings.Builder sshCmd.Stdout = &outputBuffer sshCmd.Stderr = &outputBuffer // Run the command (not just Start) - the -f flag means it will return immediately // after going to the background if err := sshCmd.Run(); err != nil { fmt.Printf("Error starting SSH tunnel: %v\n", err) fmt.Printf("SSH output: %s\n", outputBuffer.String()) os.Exit(1) } // The PID from cmd.Process is no longer valid since ssh -f forks // We need to find the actual SSH process PID actualPID, err := findSSHTunnelPID(localPort) if err != nil { fmt.Printf("Warning: Could not determine tunnel PID: %v\n", err) } // Store tunnel information id := generateTunnelID() pid := 0 if actualPID > 0 { pid = actualPID } tunnel := Tunnel{ ID: id, LocalPort: localPort, RemotePort: remotePort, RemoteHost: remoteHost, SSHServer: sshServer, PID: pid, } if err := saveTunnel(tunnel); err != nil { fmt.Printf("Error saving tunnel information: %v\n", err) if pid > 0 { // Try to kill the process process, _ := os.FindProcess(pid) if process != nil { process.Kill() } } os.Exit(1) } // Initialize statistics for this tunnel statsManager, err := stats.NewStatsManager() if err == nil { err = statsManager.InitStats(id, localPort) if err != nil { fmt.Printf("Warning: Failed to initialize statistics: %v\n", err) } } // Verify tunnel is actually working time.Sleep(500 * time.Millisecond) active := verifyTunnelActive(localPort) status := "ACTIVE" if !active { status = "UNKNOWN" } fmt.Printf("Started SSH tunnel (ID: %d): localhost:%d -> %s:%d (%s) [PID: %d] [Status: %s]\n", id, localPort, remoteHost, remotePort, sshServer, pid, status) }, } func init() { rootCmd.AddCommand(startCmd) // Add flags for the start command startCmd.Flags().IntVarP(&localPort, "local", "l", 0, "Local port to forward") startCmd.Flags().IntVarP(&remotePort, "remote", "r", 0, "Remote port to forward to") startCmd.Flags().StringVarP(&remoteHost, "host", "H", "localhost", "Remote host to forward to") startCmd.Flags().StringVarP(&sshServer, "server", "s", "", "SSH server address (user@host)") startCmd.Flags().StringVarP(&identity, "identity", "i", "", "Path to SSH identity file") // Mark required flags startCmd.MarkFlagRequired("local") startCmd.MarkFlagRequired("remote") startCmd.MarkFlagRequired("server") } // verifyTunnelActive checks if the tunnel is actually working func verifyTunnelActive(port int) bool { // Try to connect to the port to verify it's open conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 500*time.Millisecond) if err != nil { return false } conn.Close() return true }