package cmd import ( "fmt" "os" "time" "git.mvl.sh/vleeuwenmenno/sshtunnel/pkg/monitor" "git.mvl.sh/vleeuwenmenno/sshtunnel/pkg/stats" "github.com/spf13/cobra" ) var ( tunnelIDFlag int watchFlag bool watchInterval int allFlag bool ) // statsCmd represents the stats command var statsCmd = &cobra.Command{ Use: "stats", Short: "View SSH tunnel statistics", Long: `View traffic statistics for active SSH tunnels. This command shows you how much data has been transferred through your SSH tunnels, in both directions (incoming and outgoing). You can view statistics for a specific tunnel by specifying its ID, or view statistics for all active tunnels. Examples: # View statistics for all tunnels sshtunnel stats --all # View statistics for a specific tunnel sshtunnel stats --id 1 # Continuously monitor a specific tunnel (every 5 seconds) sshtunnel stats --id 1 --watch # Continuously monitor with custom interval (2 seconds) sshtunnel stats --id 1 --watch --interval 2`, Run: func(cmd *cobra.Command, args []string) { if !allFlag && tunnelIDFlag <= 0 { fmt.Println("Error: You must specify either a tunnel ID with --id or use --all") cmd.Help() os.Exit(1) } if allFlag && tunnelIDFlag > 0 { fmt.Println("Error: You cannot specify both --id and --all flags") os.Exit(1) } if watchFlag && !allFlag && tunnelIDFlag <= 0 { fmt.Println("Error: Watch mode requires specifying a tunnel ID with --id") os.Exit(1) } if watchFlag && allFlag { fmt.Println("Error: Watch mode can only be used with a specific tunnel ID") os.Exit(1) } statsManager, err := stats.NewStatsManager() if err != nil { fmt.Printf("Error initializing statistics: %v\n", err) os.Exit(1) } if allFlag { showAllStats(statsManager) } else { if watchFlag { watchTunnelStats(tunnelIDFlag, watchInterval) } else { showTunnelStats(statsManager, tunnelIDFlag) } } }, } func init() { rootCmd.AddCommand(statsCmd) statsCmd.Flags().IntVar(&tunnelIDFlag, "id", 0, "ID of the tunnel to show statistics for") statsCmd.Flags().BoolVar(&allFlag, "all", false, "Show statistics for all tunnels") statsCmd.Flags().BoolVar(&watchFlag, "watch", false, "Continuously monitor tunnel statistics") statsCmd.Flags().IntVar(&watchInterval, "interval", 5, "Update interval for watch mode (in seconds)") } func showAllStats(statsManager *stats.StatsManager) { tunnelStats, err := statsManager.GetAllStats() if err != nil { fmt.Printf("Error retrieving statistics: %v\n", err) os.Exit(1) } if len(tunnelStats) == 0 { fmt.Println("No statistics available for any tunnels.") return } fmt.Println("SSH Tunnel Statistics:") fmt.Println("=============================================") for _, s := range tunnelStats { // Verify if this tunnel is still active tunnels, _ := getTunnels() active := false for _, t := range tunnels { if t.ID == s.TunnelID { active = true break } } uptime := time.Since(s.StartTime) inRate := stats.CalculateRate(s.BytesIn, uptime) outRate := stats.CalculateRate(s.BytesOut, uptime) status := "ACTIVE" if !active { status = "INACTIVE" } fmt.Printf("Tunnel #%d (Port: %d) [%s]\n", s.TunnelID, s.LocalPort, status) fmt.Printf(" Uptime: %s\n", formatDuration(uptime)) fmt.Printf(" Received: %s (avg: %s)\n", stats.FormatBytes(s.BytesIn), stats.FormatRate(inRate)) fmt.Printf(" Sent: %s (avg: %s)\n", stats.FormatBytes(s.BytesOut), stats.FormatRate(outRate)) fmt.Printf(" Total: %s\n", stats.FormatBytes(s.BytesIn+s.BytesOut)) fmt.Println("---------------------------------------------") } } func showTunnelStats(statsManager *stats.StatsManager, tunnelID int) { s, err := statsManager.GetStats(tunnelID) if err != nil { fmt.Printf("Warning: Could not retrieve statistics for tunnel #%d: %v\n", tunnelID, err) s = stats.Stats{ TunnelID: tunnelID, BytesIn: 0, BytesOut: 0, StartTime: time.Now(), LastUpdated: time.Now(), } } // Get tunnel info to validate it exists tunnel, err := getTunnel(tunnelID) if err != nil { fmt.Printf("Note: Tunnel #%d does not appear to be active, showing historical data\n", tunnelID) } // Display stats uptime := time.Since(s.StartTime) inRate := stats.CalculateRate(s.BytesIn, uptime) outRate := stats.CalculateRate(s.BytesOut, uptime) fmt.Printf("Statistics for Tunnel #%d:\n", tunnelID) fmt.Println("=============================================") if tunnel != nil { fmt.Printf("Local Port: %d\n", tunnel.LocalPort) fmt.Printf("Remote Target: %s:%d\n", tunnel.RemoteHost, tunnel.RemotePort) fmt.Printf("SSH Server: %s\n", tunnel.SSHServer) fmt.Printf("PID: %d\n", tunnel.PID) } else if s.LocalPort > 0 { // If we don't have tunnel info but do have port in the stats fmt.Printf("Local Port: %d\n", s.LocalPort) fmt.Printf("Status: INACTIVE (historical data)\n") } fmt.Printf("Uptime: %s\n", formatDuration(uptime)) fmt.Printf("Received: %s (avg: %s)\n", stats.FormatBytes(s.BytesIn), stats.FormatRate(inRate)) fmt.Printf("Sent: %s (avg: %s)\n", stats.FormatBytes(s.BytesOut), stats.FormatRate(outRate)) fmt.Printf("Total Traffic: %s\n", stats.FormatBytes(s.BytesIn+s.BytesOut)) } func watchTunnelStats(tunnelID int, interval int) { // Get tunnel info to validate it exists tunnel, err := getTunnel(tunnelID) if err != nil { fmt.Printf("Warning: Tunnel #%d may not be active: %v\n", tunnelID, err) fmt.Println("Attempting to monitor anyway...") // Try to get port from stats statsManager, _ := stats.NewStatsManager() s, err := statsManager.GetStats(tunnelID) if err == nil && s.LocalPort > 0 { // Create a tunnel object with the information we have tunnel = &Tunnel{ ID: tunnelID, LocalPort: s.LocalPort, RemoteHost: "unknown", RemotePort: 0, SSHServer: "unknown", PID: 0, } } else { fmt.Println("Error: Cannot determine port for tunnel. Please specify a valid tunnel ID.") os.Exit(1) } } // Create a monitor for this tunnel mon, err := monitor.NewMonitor(tunnelID, tunnel.LocalPort) if err != nil { fmt.Printf("Warning: Error creating monitor: %v\n", err) fmt.Println("Will attempt to continue with limited functionality") } // Start monitoring with error handling err = mon.Start() if err != nil { fmt.Printf("Warning: %v\n", err) fmt.Println("Continuing with estimated statistics...") } defer mon.Stop() fmt.Printf("Monitoring SSH tunnel #%d (Press Ctrl+C to stop)...\n\n", tunnelID) ticker := time.NewTicker(time.Duration(interval) * time.Second) defer ticker.Stop() // Initial stats display statsStr, err := mon.FormatStats() if err != nil { fmt.Printf("Warning: %v\n", err) fmt.Println("Displaying minimal statistics...") statsStr = fmt.Sprintf("Monitoring tunnel #%d on port %d...", tunnelID, tunnel.LocalPort) } fmt.Println(statsStr) fmt.Println("\nUpdating every", interval, "seconds...") // Clear screen and update stats periodically for range ticker.C { // ANSI escape code to clear screen fmt.Print("\033[H\033[2J") fmt.Printf("Monitoring SSH tunnel #%d (Press Ctrl+C to stop)...\n\n", tunnelID) statsStr, err := mon.FormatStats() if err != nil { statsStr = fmt.Sprintf("Warning: %v\n\nStill monitoring tunnel #%d on port %d...", err, tunnelID, tunnel.LocalPort) } fmt.Println(statsStr) fmt.Println("\nUpdating every", interval, "seconds...") } } func formatDuration(d time.Duration) string { days := int(d.Hours() / 24) hours := int(d.Hours()) % 24 minutes := int(d.Minutes()) % 60 seconds := int(d.Seconds()) % 60 if days > 0 { return fmt.Sprintf("%dd %dh %dm %ds", days, hours, minutes, seconds) } else if hours > 0 { return fmt.Sprintf("%dh %dm %ds", hours, minutes, seconds) } else if minutes > 0 { return fmt.Sprintf("%dm %ds", minutes, seconds) } return fmt.Sprintf("%ds", seconds) } // getTunnel gets a specific tunnel by ID func getTunnel(id int) (*Tunnel, error) { tunnels, err := getTunnels() if err != nil { return nil, err } for _, t := range tunnels { if t.ID == id { return &t, nil } } return nil, fmt.Errorf("tunnel %d not found", id) }