package stats import ( "fmt" "os" "path/filepath" "strconv" "strings" "time" ) // Stats represents the traffic statistics for a tunnel type Stats struct { TunnelID int // ID of the tunnel LocalPort int // Local port number BytesIn uint64 // Total incoming bytes BytesOut uint64 // Total outgoing bytes StartTime time.Time // When the tunnel was started LastUpdated time.Time // Last time stats were updated } // StatsManager handles the statistics collection and storage type StatsManager struct { statsDir string } // NewStatsManager creates a new stats manager func NewStatsManager() (*StatsManager, error) { homeDir, err := os.UserHomeDir() if err != nil { return nil, fmt.Errorf("error getting home directory: %v", err) } statsDir := filepath.Join(homeDir, ".sshtunnels", "stats") if _, err := os.Stat(statsDir); os.IsNotExist(err) { err = os.MkdirAll(statsDir, 0755) if err != nil { return nil, fmt.Errorf("error creating stats directory: %v", err) } } return &StatsManager{statsDir: statsDir}, nil } // InitStats initializes statistics for a new tunnel func (sm *StatsManager) InitStats(tunnelID, localPort int) error { // Check if stats already exist for this tunnel existingStats, err := sm.GetStats(tunnelID) if err == nil && (existingStats.BytesIn > 0 || existingStats.BytesOut > 0) { // Stats already exist, just update the LastUpdated time existingStats.LastUpdated = time.Now() return sm.saveStats(existingStats) } // Create new stats stats := Stats{ TunnelID: tunnelID, LocalPort: localPort, BytesIn: 0, BytesOut: 0, StartTime: time.Now(), LastUpdated: time.Now(), } return sm.saveStats(stats) } // UpdateStats updates traffic statistics for a tunnel func (sm *StatsManager) UpdateStats(tunnelID int, bytesIn, bytesOut uint64) error { stats, err := sm.GetStats(tunnelID) if err != nil { return err } stats.BytesIn += bytesIn stats.BytesOut += bytesOut stats.LastUpdated = time.Now() return sm.saveStats(stats) } // GetStats retrieves statistics for a specific tunnel func (sm *StatsManager) GetStats(tunnelID int) (Stats, error) { statPath := filepath.Join(sm.statsDir, fmt.Sprintf("tunnel-%d.stats", tunnelID)) data, err := os.ReadFile(statPath) if err != nil { if os.IsNotExist(err) { // Initialize new stats if file doesn't exist return Stats{ TunnelID: tunnelID, BytesIn: 0, BytesOut: 0, StartTime: time.Now(), LastUpdated: time.Now(), }, nil } return Stats{}, fmt.Errorf("error reading stats file: %v", err) } stats, err := parseStatsData(data) if err != nil { // If parsing fails, create a new clean stats object // rather than failing completely fmt.Printf("Warning: Stats file for tunnel %d is corrupt, reinitializing: %v\n", tunnelID, err) return Stats{ TunnelID: tunnelID, BytesIn: 0, BytesOut: 0, StartTime: time.Now(), LastUpdated: time.Now(), }, nil } return stats, nil } // GetAllStats retrieves statistics for all tunnels func (sm *StatsManager) GetAllStats() ([]Stats, error) { entries, err := os.ReadDir(sm.statsDir) if err != nil { if os.IsNotExist(err) { return []Stats{}, nil } return nil, fmt.Errorf("error reading stats directory: %v", err) } var statsList []Stats for _, entry := range entries { if !entry.IsDir() && strings.HasPrefix(entry.Name(), "tunnel-") && strings.HasSuffix(entry.Name(), ".stats") { filePath := filepath.Join(sm.statsDir, entry.Name()) data, err := os.ReadFile(filePath) if err != nil { fmt.Printf("Error reading stats file %s: %v\n", filePath, err) continue } stats, err := parseStatsData(data) if err != nil { fmt.Printf("Error parsing stats file %s: %v\n", filePath, err) continue } statsList = append(statsList, stats) } } return statsList, nil } // DeleteStats removes statistics for a tunnel func (sm *StatsManager) DeleteStats(tunnelID int) error { statPath := filepath.Join(sm.statsDir, fmt.Sprintf("tunnel-%d.stats", tunnelID)) err := os.Remove(statPath) if err != nil && !os.IsNotExist(err) { return fmt.Errorf("error deleting stats file: %v", err) } return nil } // saveStats saves statistics to a file func (sm *StatsManager) saveStats(stats Stats) error { // Sanity checks before saving if stats.TunnelID <= 0 { return fmt.Errorf("invalid tunnel ID: %d", stats.TunnelID) } if stats.LocalPort <= 0 { stats.LocalPort = 1024 // Use a fallback port } // Ensure timestamps are valid now := time.Now() if stats.StartTime.After(now) || stats.StartTime.Before(time.Unix(0, 0)) { stats.StartTime = now } if stats.LastUpdated.After(now) || stats.LastUpdated.Before(time.Unix(0, 0)) { stats.LastUpdated = now } // Create the directory if it doesn't exist (for robustness) if _, err := os.Stat(sm.statsDir); os.IsNotExist(err) { if err := os.MkdirAll(sm.statsDir, 0755); err != nil { return fmt.Errorf("error creating stats directory: %v", err) } } statPath := filepath.Join(sm.statsDir, fmt.Sprintf("tunnel-%d.stats", stats.TunnelID)) data := fmt.Sprintf("%d:%d:%d:%d:%d:%d", stats.TunnelID, stats.LocalPort, stats.BytesIn, stats.BytesOut, stats.StartTime.Unix(), stats.LastUpdated.Unix(), ) tempPath := statPath + ".tmp" // Write to a temporary file first to avoid corruption if the process is interrupted if err := os.WriteFile(tempPath, []byte(data), 0644); err != nil { return fmt.Errorf("error writing temporary stats file: %v", err) } // Then rename it to the final path (atomic on most filesystems) return os.Rename(tempPath, statPath) } // parseStatsData parses raw statistics data from a file func parseStatsData(data []byte) (Stats, error) { parts := strings.Split(strings.TrimSpace(string(data)), ":") if len(parts) != 6 { return Stats{}, fmt.Errorf("invalid stats format (expected 6 parts, got %d)", len(parts)) } // Defend against empty parts for i, part := range parts { if part == "" { parts[i] = "0" } } tunnelID, err := strconv.Atoi(parts[0]) if err != nil { return Stats{}, fmt.Errorf("invalid tunnel ID: %v", err) } localPort, err := strconv.Atoi(parts[1]) if err != nil { return Stats{}, fmt.Errorf("invalid local port: %v", err) } bytesIn, err := strconv.ParseUint(parts[2], 10, 64) if err != nil { return Stats{}, fmt.Errorf("invalid bytes in: %v", err) } bytesOut, err := strconv.ParseUint(parts[3], 10, 64) if err != nil { return Stats{}, fmt.Errorf("invalid bytes out: %v", err) } startTimestamp, err := strconv.ParseInt(parts[4], 10, 64) if err != nil { return Stats{}, fmt.Errorf("invalid start timestamp: %v", err) } lastUpdatedTimestamp, err := strconv.ParseInt(parts[5], 10, 64) if err != nil { return Stats{}, fmt.Errorf("invalid last updated timestamp: %v", err) } // Sanity checks if tunnelID <= 0 || localPort <= 0 || startTimestamp <= 0 { return Stats{}, fmt.Errorf("invalid stats data (negative or zero values)") } // Ensure timestamps make sense now := time.Now().Unix() if startTimestamp > now + 3600 || lastUpdatedTimestamp > now + 3600 { // More than an hour in the future isn't right startTimestamp = now lastUpdatedTimestamp = now } return Stats{ TunnelID: tunnelID, LocalPort: localPort, BytesIn: bytesIn, BytesOut: bytesOut, StartTime: time.Unix(startTimestamp, 0), LastUpdated: time.Unix(lastUpdatedTimestamp, 0), }, nil } // FormatBytes formats a byte count as human-readable string func FormatBytes(bytes uint64) string { const unit = 1024 if bytes < unit { return fmt.Sprintf("%d B", bytes) } div, exp := uint64(unit), 0 for n := bytes / unit; n >= unit; n /= unit { div *= unit exp++ } return fmt.Sprintf("%.2f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp]) } // CalculateRate calculates the transfer rate in bytes per second func CalculateRate(bytes uint64, duration time.Duration) float64 { seconds := duration.Seconds() if seconds <= 0 { return 0 } return float64(bytes) / seconds } // FormatRate formats a byte rate as human-readable string func FormatRate(bytesPerSec float64) string { return fmt.Sprintf("%s/s", FormatBytes(uint64(bytesPerSec))) }