284 lines
8.1 KiB
Go
284 lines
8.1 KiB
Go
package cmd
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"time"
|
|
|
|
"git.mvl.sh/vleeuwenmenno/sshtunnel/pkg/monitor"
|
|
"git.mvl.sh/vleeuwenmenno/sshtunnel/pkg/stats"
|
|
|
|
"github.com/spf13/cobra"
|
|
)
|
|
|
|
var (
|
|
tunnelIDFlag int
|
|
watchFlag bool
|
|
watchInterval int
|
|
allFlag bool
|
|
)
|
|
|
|
// statsCmd represents the stats command
|
|
var statsCmd = &cobra.Command{
|
|
Use: "stats",
|
|
Short: "View SSH tunnel statistics",
|
|
Long: `View traffic statistics for active SSH tunnels.
|
|
This command shows you how much data has been transferred through your
|
|
SSH tunnels, in both directions (incoming and outgoing).
|
|
|
|
You can view statistics for a specific tunnel by specifying its ID,
|
|
or view statistics for all active tunnels.
|
|
|
|
Examples:
|
|
# View statistics for all tunnels
|
|
sshtunnel stats --all
|
|
|
|
# View statistics for a specific tunnel
|
|
sshtunnel stats --id 1
|
|
|
|
# Continuously monitor a specific tunnel (every 5 seconds)
|
|
sshtunnel stats --id 1 --watch
|
|
|
|
# Continuously monitor with custom interval (2 seconds)
|
|
sshtunnel stats --id 1 --watch --interval 2`,
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
if !allFlag && tunnelIDFlag <= 0 {
|
|
fmt.Println("Error: You must specify either a tunnel ID with --id or use --all")
|
|
cmd.Help()
|
|
os.Exit(1)
|
|
}
|
|
|
|
if allFlag && tunnelIDFlag > 0 {
|
|
fmt.Println("Error: You cannot specify both --id and --all flags")
|
|
os.Exit(1)
|
|
}
|
|
|
|
if watchFlag && !allFlag && tunnelIDFlag <= 0 {
|
|
fmt.Println("Error: Watch mode requires specifying a tunnel ID with --id")
|
|
os.Exit(1)
|
|
}
|
|
|
|
if watchFlag && allFlag {
|
|
fmt.Println("Error: Watch mode can only be used with a specific tunnel ID")
|
|
os.Exit(1)
|
|
}
|
|
|
|
statsManager, err := stats.NewStatsManager()
|
|
if err != nil {
|
|
fmt.Printf("Error initializing statistics: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
if allFlag {
|
|
showAllStats(statsManager)
|
|
} else {
|
|
if watchFlag {
|
|
watchTunnelStats(tunnelIDFlag, watchInterval)
|
|
} else {
|
|
showTunnelStats(statsManager, tunnelIDFlag)
|
|
}
|
|
}
|
|
},
|
|
}
|
|
|
|
func init() {
|
|
rootCmd.AddCommand(statsCmd)
|
|
|
|
statsCmd.Flags().IntVar(&tunnelIDFlag, "id", 0, "ID of the tunnel to show statistics for")
|
|
statsCmd.Flags().BoolVar(&allFlag, "all", false, "Show statistics for all tunnels")
|
|
statsCmd.Flags().BoolVar(&watchFlag, "watch", false, "Continuously monitor tunnel statistics")
|
|
statsCmd.Flags().IntVar(&watchInterval, "interval", 5, "Update interval for watch mode (in seconds)")
|
|
}
|
|
|
|
func showAllStats(statsManager *stats.StatsManager) {
|
|
tunnelStats, err := statsManager.GetAllStats()
|
|
if err != nil {
|
|
fmt.Printf("Error retrieving statistics: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
if len(tunnelStats) == 0 {
|
|
fmt.Println("No statistics available for any tunnels.")
|
|
return
|
|
}
|
|
|
|
fmt.Println("SSH Tunnel Statistics:")
|
|
fmt.Println("=============================================")
|
|
for _, s := range tunnelStats {
|
|
// Verify if this tunnel is still active
|
|
tunnels, _ := getTunnels()
|
|
active := false
|
|
for _, t := range tunnels {
|
|
if t.ID == s.TunnelID {
|
|
active = true
|
|
break
|
|
}
|
|
}
|
|
|
|
uptime := time.Since(s.StartTime)
|
|
inRate := stats.CalculateRate(s.BytesIn, uptime)
|
|
outRate := stats.CalculateRate(s.BytesOut, uptime)
|
|
|
|
status := "ACTIVE"
|
|
if !active {
|
|
status = "INACTIVE"
|
|
}
|
|
|
|
fmt.Printf("Tunnel #%d (Port: %d) [%s]\n", s.TunnelID, s.LocalPort, status)
|
|
fmt.Printf(" Uptime: %s\n", formatDuration(uptime))
|
|
fmt.Printf(" Received: %s (avg: %s)\n", stats.FormatBytes(s.BytesIn), stats.FormatRate(inRate))
|
|
fmt.Printf(" Sent: %s (avg: %s)\n", stats.FormatBytes(s.BytesOut), stats.FormatRate(outRate))
|
|
fmt.Printf(" Total: %s\n", stats.FormatBytes(s.BytesIn+s.BytesOut))
|
|
fmt.Println("---------------------------------------------")
|
|
}
|
|
}
|
|
|
|
func showTunnelStats(statsManager *stats.StatsManager, tunnelID int) {
|
|
s, err := statsManager.GetStats(tunnelID)
|
|
if err != nil {
|
|
fmt.Printf("Warning: Could not retrieve statistics for tunnel #%d: %v\n", tunnelID, err)
|
|
s = stats.Stats{
|
|
TunnelID: tunnelID,
|
|
BytesIn: 0,
|
|
BytesOut: 0,
|
|
StartTime: time.Now(),
|
|
LastUpdated: time.Now(),
|
|
}
|
|
}
|
|
|
|
// Get tunnel info to validate it exists
|
|
tunnel, err := getTunnel(tunnelID)
|
|
if err != nil {
|
|
fmt.Printf("Note: Tunnel #%d does not appear to be active, showing historical data\n", tunnelID)
|
|
}
|
|
|
|
// Display stats
|
|
uptime := time.Since(s.StartTime)
|
|
inRate := stats.CalculateRate(s.BytesIn, uptime)
|
|
outRate := stats.CalculateRate(s.BytesOut, uptime)
|
|
|
|
fmt.Printf("Statistics for Tunnel #%d:\n", tunnelID)
|
|
fmt.Println("=============================================")
|
|
if tunnel != nil {
|
|
fmt.Printf("Local Port: %d\n", tunnel.LocalPort)
|
|
fmt.Printf("Remote Target: %s:%d\n", tunnel.RemoteHost, tunnel.RemotePort)
|
|
fmt.Printf("SSH Server: %s\n", tunnel.SSHServer)
|
|
fmt.Printf("PID: %d\n", tunnel.PID)
|
|
} else if s.LocalPort > 0 {
|
|
// If we don't have tunnel info but do have port in the stats
|
|
fmt.Printf("Local Port: %d\n", s.LocalPort)
|
|
fmt.Printf("Status: INACTIVE (historical data)\n")
|
|
}
|
|
fmt.Printf("Uptime: %s\n", formatDuration(uptime))
|
|
fmt.Printf("Received: %s (avg: %s)\n", stats.FormatBytes(s.BytesIn), stats.FormatRate(inRate))
|
|
fmt.Printf("Sent: %s (avg: %s)\n", stats.FormatBytes(s.BytesOut), stats.FormatRate(outRate))
|
|
fmt.Printf("Total Traffic: %s\n", stats.FormatBytes(s.BytesIn+s.BytesOut))
|
|
}
|
|
|
|
func watchTunnelStats(tunnelID int, interval int) {
|
|
// Get tunnel info to validate it exists
|
|
tunnel, err := getTunnel(tunnelID)
|
|
if err != nil {
|
|
fmt.Printf("Warning: Tunnel #%d may not be active: %v\n", tunnelID, err)
|
|
fmt.Println("Attempting to monitor anyway...")
|
|
|
|
// Try to get port from stats
|
|
statsManager, _ := stats.NewStatsManager()
|
|
s, err := statsManager.GetStats(tunnelID)
|
|
if err == nil && s.LocalPort > 0 {
|
|
// Create a tunnel object with the information we have
|
|
tunnel = &Tunnel{
|
|
ID: tunnelID,
|
|
LocalPort: s.LocalPort,
|
|
RemoteHost: "unknown",
|
|
RemotePort: 0,
|
|
SSHServer: "unknown",
|
|
PID: 0,
|
|
}
|
|
} else {
|
|
fmt.Println("Error: Cannot determine port for tunnel. Please specify a valid tunnel ID.")
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
// Create a monitor for this tunnel
|
|
mon, err := monitor.NewMonitor(tunnelID, tunnel.LocalPort)
|
|
if err != nil {
|
|
fmt.Printf("Warning: Error creating monitor: %v\n", err)
|
|
fmt.Println("Will attempt to continue with limited functionality")
|
|
}
|
|
|
|
// Start monitoring with error handling
|
|
err = mon.Start()
|
|
if err != nil {
|
|
fmt.Printf("Warning: %v\n", err)
|
|
fmt.Println("Continuing with estimated statistics...")
|
|
}
|
|
defer mon.Stop()
|
|
|
|
fmt.Printf("Monitoring SSH tunnel #%d (Press Ctrl+C to stop)...\n\n", tunnelID)
|
|
|
|
ticker := time.NewTicker(time.Duration(interval) * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
// Initial stats display
|
|
statsStr, err := mon.FormatStats()
|
|
if err != nil {
|
|
fmt.Printf("Warning: %v\n", err)
|
|
fmt.Println("Displaying minimal statistics...")
|
|
statsStr = fmt.Sprintf("Monitoring tunnel #%d on port %d...",
|
|
tunnelID, tunnel.LocalPort)
|
|
}
|
|
fmt.Println(statsStr)
|
|
fmt.Println("\nUpdating every", interval, "seconds...")
|
|
|
|
// Clear screen and update stats periodically
|
|
for range ticker.C {
|
|
// ANSI escape code to clear screen
|
|
fmt.Print("\033[H\033[2J")
|
|
|
|
fmt.Printf("Monitoring SSH tunnel #%d (Press Ctrl+C to stop)...\n\n", tunnelID)
|
|
|
|
statsStr, err := mon.FormatStats()
|
|
if err != nil {
|
|
statsStr = fmt.Sprintf("Warning: %v\n\nStill monitoring tunnel #%d on port %d...",
|
|
err, tunnelID, tunnel.LocalPort)
|
|
}
|
|
|
|
fmt.Println(statsStr)
|
|
fmt.Println("\nUpdating every", interval, "seconds...")
|
|
}
|
|
}
|
|
|
|
func formatDuration(d time.Duration) string {
|
|
days := int(d.Hours() / 24)
|
|
hours := int(d.Hours()) % 24
|
|
minutes := int(d.Minutes()) % 60
|
|
seconds := int(d.Seconds()) % 60
|
|
|
|
if days > 0 {
|
|
return fmt.Sprintf("%dd %dh %dm %ds", days, hours, minutes, seconds)
|
|
} else if hours > 0 {
|
|
return fmt.Sprintf("%dh %dm %ds", hours, minutes, seconds)
|
|
} else if minutes > 0 {
|
|
return fmt.Sprintf("%dm %ds", minutes, seconds)
|
|
}
|
|
return fmt.Sprintf("%ds", seconds)
|
|
}
|
|
|
|
// getTunnel gets a specific tunnel by ID
|
|
func getTunnel(id int) (*Tunnel, error) {
|
|
tunnels, err := getTunnels()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, t := range tunnels {
|
|
if t.ID == id {
|
|
return &t, nil
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("tunnel %d not found", id)
|
|
}
|