301 lines
8.1 KiB
Go
301 lines
8.1 KiB
Go
package stats
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// Stats represents the traffic statistics for a tunnel
|
|
type Stats struct {
|
|
TunnelID int // ID of the tunnel
|
|
LocalPort int // Local port number
|
|
BytesIn uint64 // Total incoming bytes
|
|
BytesOut uint64 // Total outgoing bytes
|
|
StartTime time.Time // When the tunnel was started
|
|
LastUpdated time.Time // Last time stats were updated
|
|
}
|
|
|
|
// StatsManager handles the statistics collection and storage
|
|
type StatsManager struct {
|
|
statsDir string
|
|
}
|
|
|
|
// NewStatsManager creates a new stats manager
|
|
func NewStatsManager() (*StatsManager, error) {
|
|
homeDir, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting home directory: %v", err)
|
|
}
|
|
|
|
statsDir := filepath.Join(homeDir, ".sshtunnels", "stats")
|
|
if _, err := os.Stat(statsDir); os.IsNotExist(err) {
|
|
err = os.MkdirAll(statsDir, 0755)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error creating stats directory: %v", err)
|
|
}
|
|
}
|
|
|
|
return &StatsManager{statsDir: statsDir}, nil
|
|
}
|
|
|
|
// InitStats initializes statistics for a new tunnel
|
|
func (sm *StatsManager) InitStats(tunnelID, localPort int) error {
|
|
// Check if stats already exist for this tunnel
|
|
existingStats, err := sm.GetStats(tunnelID)
|
|
if err == nil && (existingStats.BytesIn > 0 || existingStats.BytesOut > 0) {
|
|
// Stats already exist, just update the LastUpdated time
|
|
existingStats.LastUpdated = time.Now()
|
|
return sm.saveStats(existingStats)
|
|
}
|
|
|
|
// Create new stats
|
|
stats := Stats{
|
|
TunnelID: tunnelID,
|
|
LocalPort: localPort,
|
|
BytesIn: 0,
|
|
BytesOut: 0,
|
|
StartTime: time.Now(),
|
|
LastUpdated: time.Now(),
|
|
}
|
|
|
|
return sm.saveStats(stats)
|
|
}
|
|
|
|
// UpdateStats updates traffic statistics for a tunnel
|
|
func (sm *StatsManager) UpdateStats(tunnelID int, bytesIn, bytesOut uint64) error {
|
|
stats, err := sm.GetStats(tunnelID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
stats.BytesIn += bytesIn
|
|
stats.BytesOut += bytesOut
|
|
stats.LastUpdated = time.Now()
|
|
|
|
return sm.saveStats(stats)
|
|
}
|
|
|
|
// GetStats retrieves statistics for a specific tunnel
|
|
func (sm *StatsManager) GetStats(tunnelID int) (Stats, error) {
|
|
statPath := filepath.Join(sm.statsDir, fmt.Sprintf("tunnel-%d.stats", tunnelID))
|
|
data, err := os.ReadFile(statPath)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
// Initialize new stats if file doesn't exist
|
|
return Stats{
|
|
TunnelID: tunnelID,
|
|
BytesIn: 0,
|
|
BytesOut: 0,
|
|
StartTime: time.Now(),
|
|
LastUpdated: time.Now(),
|
|
}, nil
|
|
}
|
|
return Stats{}, fmt.Errorf("error reading stats file: %v", err)
|
|
}
|
|
|
|
stats, err := parseStatsData(data)
|
|
if err != nil {
|
|
// If parsing fails, create a new clean stats object
|
|
// rather than failing completely
|
|
fmt.Printf("Warning: Stats file for tunnel %d is corrupt, reinitializing: %v\n", tunnelID, err)
|
|
return Stats{
|
|
TunnelID: tunnelID,
|
|
BytesIn: 0,
|
|
BytesOut: 0,
|
|
StartTime: time.Now(),
|
|
LastUpdated: time.Now(),
|
|
}, nil
|
|
}
|
|
return stats, nil
|
|
}
|
|
|
|
// GetAllStats retrieves statistics for all tunnels
|
|
func (sm *StatsManager) GetAllStats() ([]Stats, error) {
|
|
entries, err := os.ReadDir(sm.statsDir)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return []Stats{}, nil
|
|
}
|
|
return nil, fmt.Errorf("error reading stats directory: %v", err)
|
|
}
|
|
|
|
var statsList []Stats
|
|
for _, entry := range entries {
|
|
if !entry.IsDir() && strings.HasPrefix(entry.Name(), "tunnel-") && strings.HasSuffix(entry.Name(), ".stats") {
|
|
filePath := filepath.Join(sm.statsDir, entry.Name())
|
|
data, err := os.ReadFile(filePath)
|
|
if err != nil {
|
|
fmt.Printf("Error reading stats file %s: %v\n", filePath, err)
|
|
continue
|
|
}
|
|
|
|
stats, err := parseStatsData(data)
|
|
if err != nil {
|
|
fmt.Printf("Error parsing stats file %s: %v\n", filePath, err)
|
|
continue
|
|
}
|
|
|
|
statsList = append(statsList, stats)
|
|
}
|
|
}
|
|
|
|
return statsList, nil
|
|
}
|
|
|
|
// DeleteStats removes statistics for a tunnel
|
|
func (sm *StatsManager) DeleteStats(tunnelID int) error {
|
|
statPath := filepath.Join(sm.statsDir, fmt.Sprintf("tunnel-%d.stats", tunnelID))
|
|
err := os.Remove(statPath)
|
|
if err != nil && !os.IsNotExist(err) {
|
|
return fmt.Errorf("error deleting stats file: %v", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// saveStats saves statistics to a file
|
|
func (sm *StatsManager) saveStats(stats Stats) error {
|
|
// Sanity checks before saving
|
|
if stats.TunnelID <= 0 {
|
|
return fmt.Errorf("invalid tunnel ID: %d", stats.TunnelID)
|
|
}
|
|
|
|
if stats.LocalPort <= 0 {
|
|
stats.LocalPort = 1024 // Use a fallback port
|
|
}
|
|
|
|
// Ensure timestamps are valid
|
|
now := time.Now()
|
|
if stats.StartTime.After(now) || stats.StartTime.Before(time.Unix(0, 0)) {
|
|
stats.StartTime = now
|
|
}
|
|
|
|
if stats.LastUpdated.After(now) || stats.LastUpdated.Before(time.Unix(0, 0)) {
|
|
stats.LastUpdated = now
|
|
}
|
|
|
|
// Create the directory if it doesn't exist (for robustness)
|
|
if _, err := os.Stat(sm.statsDir); os.IsNotExist(err) {
|
|
if err := os.MkdirAll(sm.statsDir, 0755); err != nil {
|
|
return fmt.Errorf("error creating stats directory: %v", err)
|
|
}
|
|
}
|
|
|
|
statPath := filepath.Join(sm.statsDir, fmt.Sprintf("tunnel-%d.stats", stats.TunnelID))
|
|
data := fmt.Sprintf("%d:%d:%d:%d:%d:%d",
|
|
stats.TunnelID,
|
|
stats.LocalPort,
|
|
stats.BytesIn,
|
|
stats.BytesOut,
|
|
stats.StartTime.Unix(),
|
|
stats.LastUpdated.Unix(),
|
|
)
|
|
|
|
tempPath := statPath + ".tmp"
|
|
|
|
// Write to a temporary file first to avoid corruption if the process is interrupted
|
|
if err := os.WriteFile(tempPath, []byte(data), 0644); err != nil {
|
|
return fmt.Errorf("error writing temporary stats file: %v", err)
|
|
}
|
|
|
|
// Then rename it to the final path (atomic on most filesystems)
|
|
return os.Rename(tempPath, statPath)
|
|
}
|
|
|
|
// parseStatsData parses raw statistics data from a file
|
|
func parseStatsData(data []byte) (Stats, error) {
|
|
parts := strings.Split(strings.TrimSpace(string(data)), ":")
|
|
if len(parts) != 6 {
|
|
return Stats{}, fmt.Errorf("invalid stats format (expected 6 parts, got %d)", len(parts))
|
|
}
|
|
|
|
// Defend against empty parts
|
|
for i, part := range parts {
|
|
if part == "" {
|
|
parts[i] = "0"
|
|
}
|
|
}
|
|
|
|
tunnelID, err := strconv.Atoi(parts[0])
|
|
if err != nil {
|
|
return Stats{}, fmt.Errorf("invalid tunnel ID: %v", err)
|
|
}
|
|
|
|
localPort, err := strconv.Atoi(parts[1])
|
|
if err != nil {
|
|
return Stats{}, fmt.Errorf("invalid local port: %v", err)
|
|
}
|
|
|
|
bytesIn, err := strconv.ParseUint(parts[2], 10, 64)
|
|
if err != nil {
|
|
return Stats{}, fmt.Errorf("invalid bytes in: %v", err)
|
|
}
|
|
|
|
bytesOut, err := strconv.ParseUint(parts[3], 10, 64)
|
|
if err != nil {
|
|
return Stats{}, fmt.Errorf("invalid bytes out: %v", err)
|
|
}
|
|
|
|
startTimestamp, err := strconv.ParseInt(parts[4], 10, 64)
|
|
if err != nil {
|
|
return Stats{}, fmt.Errorf("invalid start timestamp: %v", err)
|
|
}
|
|
|
|
lastUpdatedTimestamp, err := strconv.ParseInt(parts[5], 10, 64)
|
|
if err != nil {
|
|
return Stats{}, fmt.Errorf("invalid last updated timestamp: %v", err)
|
|
}
|
|
|
|
// Sanity checks
|
|
if tunnelID <= 0 || localPort <= 0 || startTimestamp <= 0 {
|
|
return Stats{}, fmt.Errorf("invalid stats data (negative or zero values)")
|
|
}
|
|
|
|
// Ensure timestamps make sense
|
|
now := time.Now().Unix()
|
|
if startTimestamp > now + 3600 || lastUpdatedTimestamp > now + 3600 {
|
|
// More than an hour in the future isn't right
|
|
startTimestamp = now
|
|
lastUpdatedTimestamp = now
|
|
}
|
|
|
|
return Stats{
|
|
TunnelID: tunnelID,
|
|
LocalPort: localPort,
|
|
BytesIn: bytesIn,
|
|
BytesOut: bytesOut,
|
|
StartTime: time.Unix(startTimestamp, 0),
|
|
LastUpdated: time.Unix(lastUpdatedTimestamp, 0),
|
|
}, nil
|
|
}
|
|
|
|
// FormatBytes formats a byte count as human-readable string
|
|
func FormatBytes(bytes uint64) string {
|
|
const unit = 1024
|
|
if bytes < unit {
|
|
return fmt.Sprintf("%d B", bytes)
|
|
}
|
|
div, exp := uint64(unit), 0
|
|
for n := bytes / unit; n >= unit; n /= unit {
|
|
div *= unit
|
|
exp++
|
|
}
|
|
return fmt.Sprintf("%.2f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
|
}
|
|
|
|
// CalculateRate calculates the transfer rate in bytes per second
|
|
func CalculateRate(bytes uint64, duration time.Duration) float64 {
|
|
seconds := duration.Seconds()
|
|
if seconds <= 0 {
|
|
return 0
|
|
}
|
|
return float64(bytes) / seconds
|
|
}
|
|
|
|
// FormatRate formats a byte rate as human-readable string
|
|
func FormatRate(bytesPerSec float64) string {
|
|
return fmt.Sprintf("%s/s", FormatBytes(uint64(bytesPerSec)))
|
|
} |