106 lines
2.2 KiB
Go
106 lines
2.2 KiB
Go
package cmd
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"syscall"
|
|
|
|
"git.mvl.sh/vleeuwenmenno/sshtunnel/pkg/stats"
|
|
"github.com/spf13/cobra"
|
|
)
|
|
|
|
var (
|
|
tunnelID int
|
|
all bool
|
|
)
|
|
|
|
var stopCmd = &cobra.Command{
|
|
Use: "stop",
|
|
Short: "Stop an SSH tunnel",
|
|
Long: `Stop an active SSH tunnel by its ID or stop all tunnels.
|
|
Use the list command to see all active tunnels and their IDs.`,
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
if !all && tunnelID == -1 {
|
|
fmt.Println("Error: must specify either --id or --all")
|
|
cmd.Help()
|
|
os.Exit(1)
|
|
}
|
|
|
|
tunnels, err := getTunnels()
|
|
if err != nil {
|
|
fmt.Printf("Error getting tunnels: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
if len(tunnels) == 0 {
|
|
fmt.Println("No active SSH tunnels found.")
|
|
return
|
|
}
|
|
|
|
if all {
|
|
for _, t := range tunnels {
|
|
killTunnel(t)
|
|
}
|
|
fmt.Println("All SSH tunnels stopped.")
|
|
return
|
|
}
|
|
|
|
// Stop specific tunnel
|
|
found := false
|
|
for _, t := range tunnels {
|
|
if t.ID == tunnelID {
|
|
killTunnel(t)
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
fmt.Printf("Error: no tunnel with ID %d found\n", tunnelID)
|
|
os.Exit(1)
|
|
}
|
|
},
|
|
}
|
|
|
|
func init() {
|
|
rootCmd.AddCommand(stopCmd)
|
|
|
|
// Add flags for the stop command
|
|
stopCmd.Flags().IntVarP(&tunnelID, "id", "i", -1, "ID of the tunnel to stop")
|
|
stopCmd.Flags().BoolVarP(&all, "all", "a", false, "Stop all tunnels")
|
|
}
|
|
|
|
func killTunnel(t Tunnel) {
|
|
process, err := os.FindProcess(t.PID)
|
|
if err != nil {
|
|
fmt.Printf("Error finding process %d: %v\n", t.PID, err)
|
|
removeFile(t.ID)
|
|
cleanupStats(t.ID)
|
|
return
|
|
}
|
|
|
|
if err := process.Signal(syscall.SIGTERM); err != nil {
|
|
fmt.Printf("Error stopping tunnel %d (PID %d): %v\n", t.ID, t.PID, err)
|
|
// Process might not exist anymore, so clean up anyway
|
|
}
|
|
|
|
removeFile(t.ID)
|
|
cleanupStats(t.ID)
|
|
fmt.Printf("Stopped SSH tunnel (ID: %d): localhost:%d -> %s:%d\n",
|
|
t.ID, t.LocalPort, t.RemoteHost, t.RemotePort)
|
|
}
|
|
|
|
// cleanupStats removes the statistics for a tunnel
|
|
func cleanupStats(tunnelID int) {
|
|
statsManager, err := stats.NewStatsManager()
|
|
if err != nil {
|
|
fmt.Printf("Warning: Could not clean up statistics: %v\n", err)
|
|
return
|
|
}
|
|
|
|
err = statsManager.DeleteStats(tunnelID)
|
|
if err != nil {
|
|
fmt.Printf("Warning: Failed to delete statistics for tunnel %d: %v\n", tunnelID, err)
|
|
}
|
|
}
|