sshtunnel/cmd/stop.go

106 lines
2.2 KiB
Go

package cmd
import (
"fmt"
"os"
"syscall"
"git.mvl.sh/vleeuwenmenno/sshtunnel/pkg/stats"
"github.com/spf13/cobra"
)
var (
tunnelID int
all bool
)
var stopCmd = &cobra.Command{
Use: "stop",
Short: "Stop an SSH tunnel",
Long: `Stop an active SSH tunnel by its ID or stop all tunnels.
Use the list command to see all active tunnels and their IDs.`,
Run: func(cmd *cobra.Command, args []string) {
if !all && tunnelID == -1 {
fmt.Println("Error: must specify either --id or --all")
cmd.Help()
os.Exit(1)
}
tunnels, err := getTunnels()
if err != nil {
fmt.Printf("Error getting tunnels: %v\n", err)
os.Exit(1)
}
if len(tunnels) == 0 {
fmt.Println("No active SSH tunnels found.")
return
}
if all {
for _, t := range tunnels {
killTunnel(t)
}
fmt.Println("All SSH tunnels stopped.")
return
}
// Stop specific tunnel
found := false
for _, t := range tunnels {
if t.ID == tunnelID {
killTunnel(t)
found = true
break
}
}
if !found {
fmt.Printf("Error: no tunnel with ID %d found\n", tunnelID)
os.Exit(1)
}
},
}
func init() {
rootCmd.AddCommand(stopCmd)
// Add flags for the stop command
stopCmd.Flags().IntVarP(&tunnelID, "id", "i", -1, "ID of the tunnel to stop")
stopCmd.Flags().BoolVarP(&all, "all", "a", false, "Stop all tunnels")
}
func killTunnel(t Tunnel) {
process, err := os.FindProcess(t.PID)
if err != nil {
fmt.Printf("Error finding process %d: %v\n", t.PID, err)
removeFile(t.ID)
cleanupStats(t.ID)
return
}
if err := process.Signal(syscall.SIGTERM); err != nil {
fmt.Printf("Error stopping tunnel %d (PID %d): %v\n", t.ID, t.PID, err)
// Process might not exist anymore, so clean up anyway
}
removeFile(t.ID)
cleanupStats(t.ID)
fmt.Printf("Stopped SSH tunnel (ID: %d): localhost:%d -> %s:%d\n",
t.ID, t.LocalPort, t.RemoteHost, t.RemotePort)
}
// cleanupStats removes the statistics for a tunnel
func cleanupStats(tunnelID int) {
statsManager, err := stats.NewStatsManager()
if err != nil {
fmt.Printf("Warning: Could not clean up statistics: %v\n", err)
return
}
err = statsManager.DeleteStats(tunnelID)
if err != nil {
fmt.Printf("Warning: Failed to delete statistics for tunnel %d: %v\n", tunnelID, err)
}
}