sshtunnel/cmd/stats.go

284 lines
8.1 KiB
Go

package cmd
import (
"fmt"
"os"
"time"
"git.mvl.sh/vleeuwenmenno/sshtunnel/pkg/monitor"
"git.mvl.sh/vleeuwenmenno/sshtunnel/pkg/stats"
"github.com/spf13/cobra"
)
var (
tunnelIDFlag int
watchFlag bool
watchInterval int
allFlag bool
)
// statsCmd represents the stats command
var statsCmd = &cobra.Command{
Use: "stats",
Short: "View SSH tunnel statistics",
Long: `View traffic statistics for active SSH tunnels.
This command shows you how much data has been transferred through your
SSH tunnels, in both directions (incoming and outgoing).
You can view statistics for a specific tunnel by specifying its ID,
or view statistics for all active tunnels.
Examples:
# View statistics for all tunnels
sshtunnel stats --all
# View statistics for a specific tunnel
sshtunnel stats --id 1
# Continuously monitor a specific tunnel (every 5 seconds)
sshtunnel stats --id 1 --watch
# Continuously monitor with custom interval (2 seconds)
sshtunnel stats --id 1 --watch --interval 2`,
Run: func(cmd *cobra.Command, args []string) {
if !allFlag && tunnelIDFlag <= 0 {
fmt.Println("Error: You must specify either a tunnel ID with --id or use --all")
cmd.Help()
os.Exit(1)
}
if allFlag && tunnelIDFlag > 0 {
fmt.Println("Error: You cannot specify both --id and --all flags")
os.Exit(1)
}
if watchFlag && !allFlag && tunnelIDFlag <= 0 {
fmt.Println("Error: Watch mode requires specifying a tunnel ID with --id")
os.Exit(1)
}
if watchFlag && allFlag {
fmt.Println("Error: Watch mode can only be used with a specific tunnel ID")
os.Exit(1)
}
statsManager, err := stats.NewStatsManager()
if err != nil {
fmt.Printf("Error initializing statistics: %v\n", err)
os.Exit(1)
}
if allFlag {
showAllStats(statsManager)
} else {
if watchFlag {
watchTunnelStats(tunnelIDFlag, watchInterval)
} else {
showTunnelStats(statsManager, tunnelIDFlag)
}
}
},
}
func init() {
rootCmd.AddCommand(statsCmd)
statsCmd.Flags().IntVar(&tunnelIDFlag, "id", 0, "ID of the tunnel to show statistics for")
statsCmd.Flags().BoolVar(&allFlag, "all", false, "Show statistics for all tunnels")
statsCmd.Flags().BoolVar(&watchFlag, "watch", false, "Continuously monitor tunnel statistics")
statsCmd.Flags().IntVar(&watchInterval, "interval", 5, "Update interval for watch mode (in seconds)")
}
func showAllStats(statsManager *stats.StatsManager) {
tunnelStats, err := statsManager.GetAllStats()
if err != nil {
fmt.Printf("Error retrieving statistics: %v\n", err)
os.Exit(1)
}
if len(tunnelStats) == 0 {
fmt.Println("No statistics available for any tunnels.")
return
}
fmt.Println("SSH Tunnel Statistics:")
fmt.Println("=============================================")
for _, s := range tunnelStats {
// Verify if this tunnel is still active
tunnels, _ := getTunnels()
active := false
for _, t := range tunnels {
if t.ID == s.TunnelID {
active = true
break
}
}
uptime := time.Since(s.StartTime)
inRate := stats.CalculateRate(s.BytesIn, uptime)
outRate := stats.CalculateRate(s.BytesOut, uptime)
status := "ACTIVE"
if !active {
status = "INACTIVE"
}
fmt.Printf("Tunnel #%d (Port: %d) [%s]\n", s.TunnelID, s.LocalPort, status)
fmt.Printf(" Uptime: %s\n", formatDuration(uptime))
fmt.Printf(" Received: %s (avg: %s)\n", stats.FormatBytes(s.BytesIn), stats.FormatRate(inRate))
fmt.Printf(" Sent: %s (avg: %s)\n", stats.FormatBytes(s.BytesOut), stats.FormatRate(outRate))
fmt.Printf(" Total: %s\n", stats.FormatBytes(s.BytesIn+s.BytesOut))
fmt.Println("---------------------------------------------")
}
}
func showTunnelStats(statsManager *stats.StatsManager, tunnelID int) {
s, err := statsManager.GetStats(tunnelID)
if err != nil {
fmt.Printf("Warning: Could not retrieve statistics for tunnel #%d: %v\n", tunnelID, err)
s = stats.Stats{
TunnelID: tunnelID,
BytesIn: 0,
BytesOut: 0,
StartTime: time.Now(),
LastUpdated: time.Now(),
}
}
// Get tunnel info to validate it exists
tunnel, err := getTunnel(tunnelID)
if err != nil {
fmt.Printf("Note: Tunnel #%d does not appear to be active, showing historical data\n", tunnelID)
}
// Display stats
uptime := time.Since(s.StartTime)
inRate := stats.CalculateRate(s.BytesIn, uptime)
outRate := stats.CalculateRate(s.BytesOut, uptime)
fmt.Printf("Statistics for Tunnel #%d:\n", tunnelID)
fmt.Println("=============================================")
if tunnel != nil {
fmt.Printf("Local Port: %d\n", tunnel.LocalPort)
fmt.Printf("Remote Target: %s:%d\n", tunnel.RemoteHost, tunnel.RemotePort)
fmt.Printf("SSH Server: %s\n", tunnel.SSHServer)
fmt.Printf("PID: %d\n", tunnel.PID)
} else if s.LocalPort > 0 {
// If we don't have tunnel info but do have port in the stats
fmt.Printf("Local Port: %d\n", s.LocalPort)
fmt.Printf("Status: INACTIVE (historical data)\n")
}
fmt.Printf("Uptime: %s\n", formatDuration(uptime))
fmt.Printf("Received: %s (avg: %s)\n", stats.FormatBytes(s.BytesIn), stats.FormatRate(inRate))
fmt.Printf("Sent: %s (avg: %s)\n", stats.FormatBytes(s.BytesOut), stats.FormatRate(outRate))
fmt.Printf("Total Traffic: %s\n", stats.FormatBytes(s.BytesIn+s.BytesOut))
}
func watchTunnelStats(tunnelID int, interval int) {
// Get tunnel info to validate it exists
tunnel, err := getTunnel(tunnelID)
if err != nil {
fmt.Printf("Warning: Tunnel #%d may not be active: %v\n", tunnelID, err)
fmt.Println("Attempting to monitor anyway...")
// Try to get port from stats
statsManager, _ := stats.NewStatsManager()
s, err := statsManager.GetStats(tunnelID)
if err == nil && s.LocalPort > 0 {
// Create a tunnel object with the information we have
tunnel = &Tunnel{
ID: tunnelID,
LocalPort: s.LocalPort,
RemoteHost: "unknown",
RemotePort: 0,
SSHServer: "unknown",
PID: 0,
}
} else {
fmt.Println("Error: Cannot determine port for tunnel. Please specify a valid tunnel ID.")
os.Exit(1)
}
}
// Create a monitor for this tunnel
mon, err := monitor.NewMonitor(tunnelID, tunnel.LocalPort)
if err != nil {
fmt.Printf("Warning: Error creating monitor: %v\n", err)
fmt.Println("Will attempt to continue with limited functionality")
}
// Start monitoring with error handling
err = mon.Start()
if err != nil {
fmt.Printf("Warning: %v\n", err)
fmt.Println("Continuing with estimated statistics...")
}
defer mon.Stop()
fmt.Printf("Monitoring SSH tunnel #%d (Press Ctrl+C to stop)...\n\n", tunnelID)
ticker := time.NewTicker(time.Duration(interval) * time.Second)
defer ticker.Stop()
// Initial stats display
statsStr, err := mon.FormatStats()
if err != nil {
fmt.Printf("Warning: %v\n", err)
fmt.Println("Displaying minimal statistics...")
statsStr = fmt.Sprintf("Monitoring tunnel #%d on port %d...",
tunnelID, tunnel.LocalPort)
}
fmt.Println(statsStr)
fmt.Println("\nUpdating every", interval, "seconds...")
// Clear screen and update stats periodically
for range ticker.C {
// ANSI escape code to clear screen
fmt.Print("\033[H\033[2J")
fmt.Printf("Monitoring SSH tunnel #%d (Press Ctrl+C to stop)...\n\n", tunnelID)
statsStr, err := mon.FormatStats()
if err != nil {
statsStr = fmt.Sprintf("Warning: %v\n\nStill monitoring tunnel #%d on port %d...",
err, tunnelID, tunnel.LocalPort)
}
fmt.Println(statsStr)
fmt.Println("\nUpdating every", interval, "seconds...")
}
}
func formatDuration(d time.Duration) string {
days := int(d.Hours() / 24)
hours := int(d.Hours()) % 24
minutes := int(d.Minutes()) % 60
seconds := int(d.Seconds()) % 60
if days > 0 {
return fmt.Sprintf("%dd %dh %dm %ds", days, hours, minutes, seconds)
} else if hours > 0 {
return fmt.Sprintf("%dh %dm %ds", hours, minutes, seconds)
} else if minutes > 0 {
return fmt.Sprintf("%dm %ds", minutes, seconds)
}
return fmt.Sprintf("%ds", seconds)
}
// getTunnel gets a specific tunnel by ID
func getTunnel(id int) (*Tunnel, error) {
tunnels, err := getTunnels()
if err != nil {
return nil, err
}
for _, t := range tunnels {
if t.ID == id {
return &t, nil
}
}
return nil, fmt.Errorf("tunnel %d not found", id)
}