From 0c6c188bf1730a4ea15a8f37d59e6265d331b143 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Kadan=C4=9B?= Date: Tue, 12 Nov 2024 20:21:40 +0100 Subject: [PATCH] Add tcpstat connection states per port metric MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tomáš Kadaně --- collector/tcpstat_linux.go | 108 ++++++++++++++++++++++----- collector/tcpstat_linux_test.go | 125 +++++++++++++++++++++++++------- 2 files changed, 187 insertions(+), 46 deletions(-) diff --git a/collector/tcpstat_linux.go b/collector/tcpstat_linux.go index 476a9b47..35f601fc 100644 --- a/collector/tcpstat_linux.go +++ b/collector/tcpstat_linux.go @@ -17,12 +17,15 @@ package collector import ( + "encoding/binary" "fmt" "log/slog" "os" + "strconv" "syscall" "unsafe" + "github.com/alecthomas/kingpin/v2" "github.com/mdlayher/netlink" "github.com/prometheus/client_golang/prometheus" ) @@ -58,6 +61,11 @@ const ( tcpTxQueuedBytes ) +var ( + tcpstatSourcePorts = kingpin.Flag("collector.tcpstat.port.source", "List of tcpstat source ports").Strings() + tcpstatDestPorts = kingpin.Flag("collector.tcpstat.port.dest", "List of tcpstat destination ports").Strings() +) + type tcpStatCollector struct { desc typedDesc logger *slog.Logger @@ -73,7 +81,7 @@ func NewTCPStatCollector(logger *slog.Logger) (Collector, error) { desc: typedDesc{prometheus.NewDesc( prometheus.BuildFQName(namespace, "tcp", "connection_states"), "Number of connection states.", - []string{"state"}, nil, + []string{"state", "port", "direction"}, nil, ), prometheus.GaugeValue}, logger: logger, }, nil @@ -129,31 +137,97 @@ func parseInetDiagMsg(b []byte) *InetDiagMsg { } func (c *tcpStatCollector) Update(ch chan<- prometheus.Metric) error { - tcpStats, err := getTCPStats(syscall.AF_INET) + messages, err := getMessagesFromSocket(syscall.AF_INET) if err != nil { return fmt.Errorf("couldn't get tcpstats: %w", err) } - // if enabled ipv6 system + tcpStats, err := parseTCPStats(messages) + if err != nil { + return fmt.Errorf("couldn't parse tcpstats: %w", err) + } + if _, hasIPv6 := os.Stat(procFilePath("net/tcp6")); hasIPv6 == nil { - tcp6Stats, err := getTCPStats(syscall.AF_INET6) + messagesIPv6, err := getMessagesFromSocket(syscall.AF_INET6) if err != nil { return fmt.Errorf("couldn't get tcp6stats: %w", err) } + tcp6Stats, err := parseTCPStats(messagesIPv6) + if err != nil { + return fmt.Errorf("couldn't parse tcp6stats: %w", err) + } + for st, value := range tcp6Stats { tcpStats[st] += value } + + messages = append(messages, messagesIPv6...) } - for st, value := range tcpStats { - ch <- c.desc.mustNewConstMetric(value, st.String()) - } + emitTotalTCPStats(c, ch, tcpStats) + emitTCPStatsPerPort(c, ch, messages, *tcpstatSourcePorts, "source", true) + emitTCPStatsPerPort(c, ch, messages, *tcpstatDestPorts, "dest", false) return nil } -func getTCPStats(family uint8) (map[tcpConnectionState]float64, error) { +func emitTotalTCPStats(c *tcpStatCollector, ch chan<- prometheus.Metric, stats map[tcpConnectionState]float64) { + for st, value := range stats { + ch <- c.desc.mustNewConstMetric(value, st.String(), "0", "total") + } +} + +func emitTCPStatsPerPort( + c *tcpStatCollector, + ch chan<- prometheus.Metric, + messages []netlink.Message, + ports []string, + direction string, + isSource bool, +) { + if len(ports) == 0 { + return + } + + portSet := map[string]struct{}{} + for _, p := range ports { + portSet[p] = struct{}{} + } + + counts := map[string]map[string]float64{} + + for _, m := range messages { + msg := parseInetDiagMsg(m.Data) + + state := tcpConnectionState(msg.State).String() + + var rawPort uint16 + if isSource { + rawPort = binary.BigEndian.Uint16(msg.ID.SourcePort[:]) + } else { + rawPort = binary.BigEndian.Uint16(msg.ID.DestPort[:]) + } + + portStr := strconv.Itoa(int(rawPort)) + + if _, ok := portSet[portStr]; ok { + if _, ok := counts[state]; !ok { + counts[state] = make(map[string]float64) + } + + counts[state][portStr]++ + } + } + + for state, portMap := range counts { + for port, count := range portMap { + ch <- c.desc.mustNewConstMetric(count, state, port, direction) + } + } +} + +func getMessagesFromSocket(family uint8) ([]netlink.Message, error) { const TCPFAll = 0xFFF const InetDiagInfo = 2 const SockDiagByFamily = 20 @@ -177,26 +251,20 @@ func getTCPStats(family uint8) (map[tcpConnectionState]float64, error) { }).Serialize(), } - messages, err := conn.Execute(msg) - if err != nil { - return nil, err - } - - return parseTCPStats(messages) + return conn.Execute(msg) } func parseTCPStats(msgs []netlink.Message) (map[tcpConnectionState]float64, error) { - tcpStats := map[tcpConnectionState]float64{} + stats := make(map[tcpConnectionState]float64) for _, m := range msgs { msg := parseInetDiagMsg(m.Data) - - tcpStats[tcpTxQueuedBytes] += float64(msg.WQueue) - tcpStats[tcpRxQueuedBytes] += float64(msg.RQueue) - tcpStats[tcpConnectionState(msg.State)]++ + stats[tcpTxQueuedBytes] += float64(msg.WQueue) + stats[tcpRxQueuedBytes] += float64(msg.RQueue) + stats[tcpConnectionState(msg.State)]++ } - return tcpStats, nil + return stats, nil } func (st tcpConnectionState) String() string { diff --git a/collector/tcpstat_linux_test.go b/collector/tcpstat_linux_test.go index e1bd090a..110a2c35 100644 --- a/collector/tcpstat_linux_test.go +++ b/collector/tcpstat_linux_test.go @@ -24,21 +24,22 @@ import ( "github.com/josharian/native" "github.com/mdlayher/netlink" + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" ) -func Test_parseTCPStats(t *testing.T) { - encode := func(m InetDiagMsg) []byte { - var buf bytes.Buffer - err := binary.Write(&buf, native.Endian, m) - if err != nil { - panic(err) - } - return buf.Bytes() +func encodeDiagMsg(m InetDiagMsg) []byte { + var buf bytes.Buffer + if err := binary.Write(&buf, native.Endian, m); err != nil { + panic(err) } + return buf.Bytes() +} +func Test_parseTCPStats(t *testing.T) { msg := []netlink.Message{ { - Data: encode(InetDiagMsg{ + Data: encodeDiagMsg(InetDiagMsg{ Family: syscall.AF_INET, State: uint8(tcpEstablished), Timer: 0, @@ -52,7 +53,7 @@ func Test_parseTCPStats(t *testing.T) { }), }, { - Data: encode(InetDiagMsg{ + Data: encodeDiagMsg(InetDiagMsg{ Family: syscall.AF_INET, State: uint8(tcpListen), Timer: 0, @@ -67,24 +68,96 @@ func Test_parseTCPStats(t *testing.T) { }, } - tcpStats, err := parseTCPStats(msg) + stats, err := parseTCPStats(msg) if err != nil { t.Fatal(err) } - if want, got := 1, int(tcpStats[tcpEstablished]); want != got { - t.Errorf("want tcpstat number of established state %d, got %d", want, got) - } - - if want, got := 1, int(tcpStats[tcpListen]); want != got { - t.Errorf("want tcpstat number of listen state %d, got %d", want, got) - } - - if want, got := 42, int(tcpStats[tcpTxQueuedBytes]); want != got { - t.Errorf("want tcpstat number of bytes in tx queue %d, got %d", want, got) - } - if want, got := 22, int(tcpStats[tcpRxQueuedBytes]); want != got { - t.Errorf("want tcpstat number of bytes in rx queue %d, got %d", want, got) - } - + assertStat(t, stats, tcpEstablished, 1) + assertStat(t, stats, tcpListen, 1) + assertStat(t, stats, tcpTxQueuedBytes, 42) + assertStat(t, stats, tcpRxQueuedBytes, 22) +} + +func assertStat(t *testing.T, stats map[tcpConnectionState]float64, state tcpConnectionState, expected int) { + t.Helper() + if got := int(stats[state]); got != expected { + t.Errorf("expected %s = %d, got %d", state.String(), expected, got) + } +} + +func Test_emitTCPStatsPerPort(t *testing.T) { + msg := []netlink.Message{ + { + Data: encodeDiagMsg(InetDiagMsg{ + State: uint8(tcpEstablished), + ID: InetDiagSockID{SourcePort: [2]byte{0, 80}}, + }), + }, + { + Data: encodeDiagMsg(InetDiagMsg{ + State: uint8(tcpListen), + ID: InetDiagSockID{DestPort: [2]byte{0, 123}}, + }), + }, + { + Data: encodeDiagMsg(InetDiagMsg{ + State: uint8(tcpTimeWait), + ID: InetDiagSockID{DestPort: [2]byte{0, 123}}, + }), + }, + } + + var metrics []string + + collector := &tcpStatCollector{ + desc: typedDesc{ + desc: prometheus.NewDesc("test_tcp_stat", "Test metric", []string{"state", "port", "direction"}, nil), + valueType: prometheus.GaugeValue, + }, + } + + ch := make(chan prometheus.Metric, 10) + + emitTCPStatsPerPort(collector, ch, msg, []string{"80"}, "source", true) + emitTCPStatsPerPort(collector, ch, msg, []string{"123"}, "dest", false) + + close(ch) + for m := range ch { + d := &dto.Metric{} + if err := m.Write(d); err != nil { + t.Fatalf("failed to write metric: %v", err) + } + + var state, port, direction string + for _, label := range d.Label { + switch label.GetName() { + case "state": + state = label.GetValue() + case "port": + port = label.GetValue() + case "direction": + direction = label.GetValue() + } + } + + metrics = append(metrics, state+"_"+port+"_"+direction) + } + + expected := map[string]bool{ + "established_80_source": true, + "listen_123_dest": true, + "time_wait_123_dest": true, + } + + for _, metric := range metrics { + if !expected[metric] { + t.Errorf("unexpected metric emitted: %s", metric) + } + delete(expected, metric) + } + + for k := range expected { + t.Errorf("expected metric missing: %s", k) + } }