initial commit
This commit is contained in:
283
cmd/stats.go
Normal file
283
cmd/stats.go
Normal file
@@ -0,0 +1,283 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"sshtunnel/pkg/monitor"
|
||||
"sshtunnel/pkg/stats"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
tunnelIDFlag int
|
||||
watchFlag bool
|
||||
watchInterval int
|
||||
allFlag bool
|
||||
)
|
||||
|
||||
// statsCmd represents the stats command
|
||||
var statsCmd = &cobra.Command{
|
||||
Use: "stats",
|
||||
Short: "View SSH tunnel statistics",
|
||||
Long: `View traffic statistics for active SSH tunnels.
|
||||
This command shows you how much data has been transferred through your
|
||||
SSH tunnels, in both directions (incoming and outgoing).
|
||||
|
||||
You can view statistics for a specific tunnel by specifying its ID,
|
||||
or view statistics for all active tunnels.
|
||||
|
||||
Examples:
|
||||
# View statistics for all tunnels
|
||||
sshtunnel stats --all
|
||||
|
||||
# View statistics for a specific tunnel
|
||||
sshtunnel stats --id 1
|
||||
|
||||
# Continuously monitor a specific tunnel (every 5 seconds)
|
||||
sshtunnel stats --id 1 --watch
|
||||
|
||||
# Continuously monitor with custom interval (2 seconds)
|
||||
sshtunnel stats --id 1 --watch --interval 2`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if !allFlag && tunnelIDFlag <= 0 {
|
||||
fmt.Println("Error: You must specify either a tunnel ID with --id or use --all")
|
||||
cmd.Help()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if allFlag && tunnelIDFlag > 0 {
|
||||
fmt.Println("Error: You cannot specify both --id and --all flags")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if watchFlag && !allFlag && tunnelIDFlag <= 0 {
|
||||
fmt.Println("Error: Watch mode requires specifying a tunnel ID with --id")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if watchFlag && allFlag {
|
||||
fmt.Println("Error: Watch mode can only be used with a specific tunnel ID")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
statsManager, err := stats.NewStatsManager()
|
||||
if err != nil {
|
||||
fmt.Printf("Error initializing statistics: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if allFlag {
|
||||
showAllStats(statsManager)
|
||||
} else {
|
||||
if watchFlag {
|
||||
watchTunnelStats(tunnelIDFlag, watchInterval)
|
||||
} else {
|
||||
showTunnelStats(statsManager, tunnelIDFlag)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(statsCmd)
|
||||
|
||||
statsCmd.Flags().IntVar(&tunnelIDFlag, "id", 0, "ID of the tunnel to show statistics for")
|
||||
statsCmd.Flags().BoolVar(&allFlag, "all", false, "Show statistics for all tunnels")
|
||||
statsCmd.Flags().BoolVar(&watchFlag, "watch", false, "Continuously monitor tunnel statistics")
|
||||
statsCmd.Flags().IntVar(&watchInterval, "interval", 5, "Update interval for watch mode (in seconds)")
|
||||
}
|
||||
|
||||
func showAllStats(statsManager *stats.StatsManager) {
|
||||
tunnelStats, err := statsManager.GetAllStats()
|
||||
if err != nil {
|
||||
fmt.Printf("Error retrieving statistics: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(tunnelStats) == 0 {
|
||||
fmt.Println("No statistics available for any tunnels.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("SSH Tunnel Statistics:")
|
||||
fmt.Println("=============================================")
|
||||
for _, s := range tunnelStats {
|
||||
// Verify if this tunnel is still active
|
||||
tunnels, _ := getTunnels()
|
||||
active := false
|
||||
for _, t := range tunnels {
|
||||
if t.ID == s.TunnelID {
|
||||
active = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
uptime := time.Since(s.StartTime)
|
||||
inRate := stats.CalculateRate(s.BytesIn, uptime)
|
||||
outRate := stats.CalculateRate(s.BytesOut, uptime)
|
||||
|
||||
status := "ACTIVE"
|
||||
if !active {
|
||||
status = "INACTIVE"
|
||||
}
|
||||
|
||||
fmt.Printf("Tunnel #%d (Port: %d) [%s]\n", s.TunnelID, s.LocalPort, status)
|
||||
fmt.Printf(" Uptime: %s\n", formatDuration(uptime))
|
||||
fmt.Printf(" Received: %s (avg: %s)\n", stats.FormatBytes(s.BytesIn), stats.FormatRate(inRate))
|
||||
fmt.Printf(" Sent: %s (avg: %s)\n", stats.FormatBytes(s.BytesOut), stats.FormatRate(outRate))
|
||||
fmt.Printf(" Total: %s\n", stats.FormatBytes(s.BytesIn+s.BytesOut))
|
||||
fmt.Println("---------------------------------------------")
|
||||
}
|
||||
}
|
||||
|
||||
func showTunnelStats(statsManager *stats.StatsManager, tunnelID int) {
|
||||
s, err := statsManager.GetStats(tunnelID)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Could not retrieve statistics for tunnel #%d: %v\n", tunnelID, err)
|
||||
s = stats.Stats{
|
||||
TunnelID: tunnelID,
|
||||
BytesIn: 0,
|
||||
BytesOut: 0,
|
||||
StartTime: time.Now(),
|
||||
LastUpdated: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Get tunnel info to validate it exists
|
||||
tunnel, err := getTunnel(tunnelID)
|
||||
if err != nil {
|
||||
fmt.Printf("Note: Tunnel #%d does not appear to be active, showing historical data\n", tunnelID)
|
||||
}
|
||||
|
||||
// Display stats
|
||||
uptime := time.Since(s.StartTime)
|
||||
inRate := stats.CalculateRate(s.BytesIn, uptime)
|
||||
outRate := stats.CalculateRate(s.BytesOut, uptime)
|
||||
|
||||
fmt.Printf("Statistics for Tunnel #%d:\n", tunnelID)
|
||||
fmt.Println("=============================================")
|
||||
if tunnel != nil {
|
||||
fmt.Printf("Local Port: %d\n", tunnel.LocalPort)
|
||||
fmt.Printf("Remote Target: %s:%d\n", tunnel.RemoteHost, tunnel.RemotePort)
|
||||
fmt.Printf("SSH Server: %s\n", tunnel.SSHServer)
|
||||
fmt.Printf("PID: %d\n", tunnel.PID)
|
||||
} else if s.LocalPort > 0 {
|
||||
// If we don't have tunnel info but do have port in the stats
|
||||
fmt.Printf("Local Port: %d\n", s.LocalPort)
|
||||
fmt.Printf("Status: INACTIVE (historical data)\n")
|
||||
}
|
||||
fmt.Printf("Uptime: %s\n", formatDuration(uptime))
|
||||
fmt.Printf("Received: %s (avg: %s)\n", stats.FormatBytes(s.BytesIn), stats.FormatRate(inRate))
|
||||
fmt.Printf("Sent: %s (avg: %s)\n", stats.FormatBytes(s.BytesOut), stats.FormatRate(outRate))
|
||||
fmt.Printf("Total Traffic: %s\n", stats.FormatBytes(s.BytesIn+s.BytesOut))
|
||||
}
|
||||
|
||||
func watchTunnelStats(tunnelID int, interval int) {
|
||||
// Get tunnel info to validate it exists
|
||||
tunnel, err := getTunnel(tunnelID)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Tunnel #%d may not be active: %v\n", tunnelID, err)
|
||||
fmt.Println("Attempting to monitor anyway...")
|
||||
|
||||
// Try to get port from stats
|
||||
statsManager, _ := stats.NewStatsManager()
|
||||
s, err := statsManager.GetStats(tunnelID)
|
||||
if err == nil && s.LocalPort > 0 {
|
||||
// Create a tunnel object with the information we have
|
||||
tunnel = &Tunnel{
|
||||
ID: tunnelID,
|
||||
LocalPort: s.LocalPort,
|
||||
RemoteHost: "unknown",
|
||||
RemotePort: 0,
|
||||
SSHServer: "unknown",
|
||||
PID: 0,
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Error: Cannot determine port for tunnel. Please specify a valid tunnel ID.")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a monitor for this tunnel
|
||||
mon, err := monitor.NewMonitor(tunnelID, tunnel.LocalPort)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Error creating monitor: %v\n", err)
|
||||
fmt.Println("Will attempt to continue with limited functionality")
|
||||
}
|
||||
|
||||
// Start monitoring with error handling
|
||||
err = mon.Start()
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: %v\n", err)
|
||||
fmt.Println("Continuing with estimated statistics...")
|
||||
}
|
||||
defer mon.Stop()
|
||||
|
||||
fmt.Printf("Monitoring SSH tunnel #%d (Press Ctrl+C to stop)...\n\n", tunnelID)
|
||||
|
||||
ticker := time.NewTicker(time.Duration(interval) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Initial stats display
|
||||
statsStr, err := mon.FormatStats()
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: %v\n", err)
|
||||
fmt.Println("Displaying minimal statistics...")
|
||||
statsStr = fmt.Sprintf("Monitoring tunnel #%d on port %d...",
|
||||
tunnelID, tunnel.LocalPort)
|
||||
}
|
||||
fmt.Println(statsStr)
|
||||
fmt.Println("\nUpdating every", interval, "seconds...")
|
||||
|
||||
// Clear screen and update stats periodically
|
||||
for range ticker.C {
|
||||
// ANSI escape code to clear screen
|
||||
fmt.Print("\033[H\033[2J")
|
||||
|
||||
fmt.Printf("Monitoring SSH tunnel #%d (Press Ctrl+C to stop)...\n\n", tunnelID)
|
||||
|
||||
statsStr, err := mon.FormatStats()
|
||||
if err != nil {
|
||||
statsStr = fmt.Sprintf("Warning: %v\n\nStill monitoring tunnel #%d on port %d...",
|
||||
err, tunnelID, tunnel.LocalPort)
|
||||
}
|
||||
|
||||
fmt.Println(statsStr)
|
||||
fmt.Println("\nUpdating every", interval, "seconds...")
|
||||
}
|
||||
}
|
||||
|
||||
func formatDuration(d time.Duration) string {
|
||||
days := int(d.Hours() / 24)
|
||||
hours := int(d.Hours()) % 24
|
||||
minutes := int(d.Minutes()) % 60
|
||||
seconds := int(d.Seconds()) % 60
|
||||
|
||||
if days > 0 {
|
||||
return fmt.Sprintf("%dd %dh %dm %ds", days, hours, minutes, seconds)
|
||||
} else if hours > 0 {
|
||||
return fmt.Sprintf("%dh %dm %ds", hours, minutes, seconds)
|
||||
} else if minutes > 0 {
|
||||
return fmt.Sprintf("%dm %ds", minutes, seconds)
|
||||
}
|
||||
return fmt.Sprintf("%ds", seconds)
|
||||
}
|
||||
|
||||
// getTunnel gets a specific tunnel by ID
|
||||
func getTunnel(id int) (*Tunnel, error) {
|
||||
tunnels, err := getTunnels()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, t := range tunnels {
|
||||
if t.ID == id {
|
||||
return &t, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("tunnel %d not found", id)
|
||||
}
|
Reference in New Issue
Block a user