From ce04362bb49f9d81f9c4d7e27c39345b7b00017d Mon Sep 17 00:00:00 2001 From: Khaos Date: Sun, 26 Dec 2021 17:17:44 +0100 Subject: [PATCH] feat(wifi): read wifi state using native methods --- docs/docs/segment-wifi.md | 8 +- src/constants_windows.go | 33 ++++++ src/environment.go | 17 +++ src/environment_unix.go | 4 + src/environment_windows.go | 212 +++++++++++++++++++++++++++++++++++++ src/segment_path_test.go | 5 + src/segment_wifi.go | 82 ++------------ src/segment_wifi_test.go | 110 ++++--------------- 8 files changed, 302 insertions(+), 169 deletions(-) create mode 100644 src/constants_windows.go diff --git a/docs/docs/segment-wifi.md b/docs/docs/segment-wifi.md index d403a961..8e7fef7e 100644 --- a/docs/docs/segment-wifi.md +++ b/docs/docs/segment-wifi.md @@ -21,24 +21,22 @@ Currently only supports Windows and WSL. Pull requests for Darwin and Linux supp "background": "#8822ee", "foreground": "#222222", "background_templates": [ - "{{ if (not .Connected) }}#FF1111{{ end }}", "{{ if (lt .Signal 60) }}#DDDD11{{ else if (lt .Signal 90) }}#DD6611{{ else }}#11CC11{{ end }}" ], "powerline_symbol": "\uE0B0", "properties": { - "template": "{{ if .Connected }}\uFAA8{{ else }}\uFAA9{{ end }} {{ if .Connected }}{{ .SSID }} {{ .Signal }}% {{ .ReceiveRate }}Mbps{{ else }}{{ .State }}{{ end }}" + "template": "\uFAA8 {{ .SSID }} {{ .Signal }}% {{ .ReceiveRate }}Mbps" } } ``` ## Properties -- template: `string` - A go [text/template][go-text-template] extended with [sprig][sprig] using the properties below +- template: `string` - A go [text/template][go-text-template] extended with [sprig][sprig] using the properties below. +Defaults to `{{ if .Error }}{{ .Error }}{{ else }}\uFAA8 {{ .SSID }} {{ .Signal }}% {{ .ReceiveRate }}Mbps{{ end }}` ## Template Properties -- `.Connected`: `bool` - if WiFi is currently connected -- `.State`: `string` - WiFi connection status - _e.g. connected or disconnected_ - `.SSID`: `string` - the SSID of the current wifi network - `.RadioType`: `string` - the radio type - _e.g. 802.11ac, 802.11ax, 802.11n, etc._ - `.Authentication`: `string` - the authentication type - _e.g. WPA2-Personal, WPA2-Enterprise, etc._ diff --git a/src/constants_windows.go b/src/constants_windows.go new file mode 100644 index 00000000..0585beb8 --- /dev/null +++ b/src/constants_windows.go @@ -0,0 +1,33 @@ +package main + +const ( + FHSS WifiType = "FHSS" + DSSS WifiType = "DSSS" + IR WifiType = "IR" + A WifiType = "802.11a" + HRDSSS WifiType = "HRDSSS" + G WifiType = "802.11g" + N WifiType = "802.11n" + AC WifiType = "802.11ac" + + Infrastructure WifiType = "Infrastructure" + Independent WifiType = "Independent" + Any WifiType = "Any" + + OpenSystem WifiType = "802.11 Open System" + SharedKey WifiType = "802.11 Shared Key" + WPA WifiType = "WPA" + WPAPSK WifiType = "WPA PSK" + WPANone WifiType = "WPA NONE" + WPA2 WifiType = "WPA2" + WPA2PSK WifiType = "WPA2 PSK" + Disabled WifiType = "disabled" + Unknown WifiType = "Unknown" + + None WifiType = "None" + WEP40 WifiType = "WEP40" + TKIP WifiType = "TKIP" + CCMP WifiType = "CCMP" + WEP104 WifiType = "WEP104" + WEP WifiType = "WEP" +) diff --git a/src/environment.go b/src/environment.go index 89aa7973..49949a47 100644 --- a/src/environment.go +++ b/src/environment.go @@ -72,6 +72,22 @@ type windowsRegistryValue struct { str string } +type WifiType string + +type wifiInfo struct { + SSID string + Interface string + RadioType WifiType + PhysType WifiType + Authentication WifiType + Cipher WifiType + Channel int + ReceiveRate int + TransmitRate int + Signal int + Error string +} + type environmentInfo interface { getenv(key string) string getcwd() string @@ -110,6 +126,7 @@ type environmentInfo interface { inWSLSharedDrive() bool convertToLinuxPath(path string) string convertToWindowsPath(path string) string + getWifiNetwork() (*wifiInfo, error) } type commandCache struct { diff --git a/src/environment_unix.go b/src/environment_unix.go index 17fa42e0..b05325b1 100644 --- a/src/environment_unix.go +++ b/src/environment_unix.go @@ -93,3 +93,7 @@ func (env *environment) convertToLinuxPath(path string) string { } return path } + +func (env *environment) getWifiNetwork() (*wifiInfo, error) { + return nil, errors.New("not implemented") +} diff --git a/src/environment_windows.go b/src/environment_windows.go index 40e28da6..6d77f3d6 100644 --- a/src/environment_windows.go +++ b/src/environment_windows.go @@ -9,6 +9,7 @@ import ( "strings" "syscall" "time" + "unicode/utf16" "unsafe" "github.com/Azure/go-ansiterm/winterm" @@ -260,3 +261,214 @@ func (env *environment) convertToWindowsPath(path string) string { func (env *environment) convertToLinuxPath(path string) string { return path } + +var ( + hapi = syscall.NewLazyDLL("wlanapi.dll") + hWlanOpenHandle = hapi.NewProc("WlanOpenHandle") + hWlanCloseHandle = hapi.NewProc("WlanCloseHandle") + hWlanEnumInterfaces = hapi.NewProc("WlanEnumInterfaces") + hWlanQueryInterface = hapi.NewProc("WlanQueryInterface") +) + +func (env *environment) getWifiNetwork() (*wifiInfo, error) { + env.trace(time.Now(), "getWifiNetwork") + // Open handle + var pdwNegotiatedVersion uint32 + var phClientHandle uint32 + e, _, err := hWlanOpenHandle.Call(uintptr(uint32(2)), uintptr(unsafe.Pointer(nil)), uintptr(unsafe.Pointer(&pdwNegotiatedVersion)), uintptr(unsafe.Pointer(&phClientHandle))) + if e != 0 { + return nil, err + } + + // defer closing handle + defer func() { + _, _, _ = hWlanCloseHandle.Call(uintptr(phClientHandle), uintptr(unsafe.Pointer(nil))) + }() + + // list interfaces + var interfaceList *WLAN_INTERFACE_INFO_LIST + e, _, err = hWlanEnumInterfaces.Call(uintptr(phClientHandle), uintptr(unsafe.Pointer(nil)), uintptr(unsafe.Pointer(&interfaceList))) + if e != 0 { + return nil, err + } + + // use first interface that is connected + numberOfInterfaces := int(interfaceList.dwNumberOfItems) + infoSize := unsafe.Sizeof(interfaceList.InterfaceInfo[0]) + for i := 0; i < numberOfInterfaces; i++ { + network := (*WLAN_INTERFACE_INFO)(unsafe.Pointer(uintptr(unsafe.Pointer(&interfaceList.InterfaceInfo[0])) + uintptr(i)*infoSize)) + if network.isState != 1 { + continue + } + return env.parseNetworkInterface(network, phClientHandle) + } + return nil, errors.New("Not connected") +} + +func (env *environment) parseNetworkInterface(network *WLAN_INTERFACE_INFO, clientHandle uint32) (*wifiInfo, error) { + info := wifiInfo{} + info.Interface = strings.TrimRight(string(utf16.Decode(network.strInterfaceDescription[:])), "\x00") + + // Query wifi connection state + var dataSize uint16 + var wlanAttr *WLAN_CONNECTION_ATTRIBUTES + e, _, err := hWlanQueryInterface.Call(uintptr(clientHandle), + uintptr(unsafe.Pointer(&network.InterfaceGuid)), + uintptr(7), // wlan_intf_opcode_current_connection + uintptr(unsafe.Pointer(nil)), + uintptr(unsafe.Pointer(&dataSize)), + uintptr(unsafe.Pointer(&wlanAttr)), + uintptr(unsafe.Pointer(nil))) + if e != 0 { + env.log(Error, "parseNetworkInterface", "wlan_intf_opcode_current_connection error") + return &info, err + } + + // SSID + ssid := wlanAttr.wlanAssociationAttributes.dot11Ssid + if ssid.uSSIDLength > 0 { + info.SSID = string(ssid.ucSSID[0:ssid.uSSIDLength]) + } + + // see https://docs.microsoft.com/en-us/windows/win32/nativewifi/dot11-phy-type + switch wlanAttr.wlanAssociationAttributes.dot11PhyType { + case 1: + info.PhysType = FHSS + case 2: + info.PhysType = DSSS + case 3: + info.PhysType = IR + case 4: + info.PhysType = A + case 5: + info.PhysType = HRDSSS + case 6: + info.PhysType = G + case 7: + info.PhysType = N + case 8: + info.PhysType = AC + default: + info.PhysType = Unknown + } + + // see https://docs.microsoft.com/en-us/windows/win32/nativewifi/dot11-bss-type + switch wlanAttr.wlanAssociationAttributes.dot11BssType { + case 1: + info.RadioType = Infrastructure + case 2: + info.RadioType = Independent + default: + info.RadioType = Any + } + + info.Signal = int(wlanAttr.wlanAssociationAttributes.wlanSignalQuality) + info.TransmitRate = int(wlanAttr.wlanAssociationAttributes.ulTxRate) / 1024 + info.ReceiveRate = int(wlanAttr.wlanAssociationAttributes.ulRxRate) / 1024 + + // Query wifi channel + dataSize = 0 + var channel *uint32 + e, _, err = hWlanQueryInterface.Call(uintptr(clientHandle), + uintptr(unsafe.Pointer(&network.InterfaceGuid)), + uintptr(8), // wlan_intf_opcode_channel_number + uintptr(unsafe.Pointer(nil)), + uintptr(unsafe.Pointer(&dataSize)), + uintptr(unsafe.Pointer(&channel)), + uintptr(unsafe.Pointer(nil))) + if e != 0 { + env.log(Error, "parseNetworkInterface", "wlan_intf_opcode_channel_number error") + return &info, err + } + info.Channel = int(*channel) + + if wlanAttr.wlanSecurityAttributes.bSecurityEnabled <= 0 { + info.Authentication = Disabled + return &info, nil + } + + // see https://docs.microsoft.com/en-us/windows/win32/nativewifi/dot11-auth-algorithm + switch wlanAttr.wlanSecurityAttributes.dot11AuthAlgorithm { + case 1: + info.Authentication = OpenSystem + case 2: + info.Authentication = SharedKey + case 3: + info.Authentication = WPA + case 4: + info.Authentication = WPAPSK + case 5: + info.Authentication = WPANone + case 6: + info.Authentication = WPA2 + case 7: + info.Authentication = WPA2PSK + default: + info.Authentication = Unknown + } + + // see https://docs.microsoft.com/en-us/windows/win32/nativewifi/dot11-cipher-algorithm + switch wlanAttr.wlanSecurityAttributes.dot11CipherAlgorithm { + case 0: + info.Cipher = None + case 0x1: + info.Cipher = WEP40 + case 0x2: + info.Cipher = TKIP + case 0x4: + info.Cipher = CCMP + case 0x5: + info.Cipher = WEP104 + case 0x100: + info.Cipher = WPA + case 0x101: + info.Cipher = WEP + default: + info.Cipher = Unknown + } + + return &info, nil +} + +type WLAN_INTERFACE_INFO_LIST struct { // nolint: revive + dwNumberOfItems uint32 + dwIndex uint32 // nolint: structcheck,unused + InterfaceInfo [1]WLAN_INTERFACE_INFO +} + +type WLAN_INTERFACE_INFO struct { // nolint: revive + InterfaceGuid syscall.GUID // nolint: revive + strInterfaceDescription [256]uint16 + isState uint32 +} + +type WLAN_CONNECTION_ATTRIBUTES struct { // nolint: revive + isState uint32 // nolint: structcheck,unused + wlanConnectionMode uint32 // nolint: structcheck,unused + strProfileName [256]uint16 // nolint: structcheck,unused + wlanAssociationAttributes WLAN_ASSOCIATION_ATTRIBUTES + wlanSecurityAttributes WLAN_SECURITY_ATTRIBUTES +} + +type WLAN_ASSOCIATION_ATTRIBUTES struct { // nolint: revive + dot11Ssid DOT11_SSID + dot11BssType uint32 + dot11Bssid [6]uint8 // nolint: structcheck,unused + dot11PhyType uint32 + uDot11PhyIndex uint32 // nolint: structcheck,unused + wlanSignalQuality uint32 + ulRxRate uint32 + ulTxRate uint32 +} + +type WLAN_SECURITY_ATTRIBUTES struct { // nolint: revive + bSecurityEnabled uint32 + bOneXEnabled uint32 // nolint: structcheck,unused + dot11AuthAlgorithm uint32 + dot11CipherAlgorithm uint32 +} + +type DOT11_SSID struct { // nolint: revive + uSSIDLength uint32 + ucSSID [32]uint8 +} diff --git a/src/segment_path_test.go b/src/segment_path_test.go index e3ee0b25..e293894f 100644 --- a/src/segment_path_test.go +++ b/src/segment_path_test.go @@ -198,6 +198,11 @@ func (env *MockedEnvironment) convertToLinuxPath(path string) string { return args.String(0) } +func (env *MockedEnvironment) getWifiNetwork() (*wifiInfo, error) { + args := env.Called(nil) + return args.Get(0).(*wifiInfo), args.Error(1) +} + const ( homeBill = "/home/bill" homeJan = "/usr/home/jan" diff --git a/src/segment_wifi.go b/src/segment_wifi.go index cce8cda6..1d3efa02 100644 --- a/src/segment_wifi.go +++ b/src/segment_wifi.go @@ -1,27 +1,14 @@ package main -import ( - "fmt" - "strconv" - "strings" -) - type wifi struct { - props properties - env environmentInfo - Connected bool - State string - SSID string - RadioType string - Authentication string - Channel int - ReceiveRate int - TransmitRate int - Signal int + props properties + env environmentInfo + + wifiInfo } const ( - defaultTemplate = "{{ if .Connected }}\uFAA8{{ else }}\uFAA9{{ end }}{{ if .Connected }}{{ .SSID }} {{ .Signal }}% {{ .ReceiveRate }}Mbps{{ else }}{{ .State }}{{ end }}" + defaultTemplate = "{{ if .Error }}{{ .Error }}{{ else }}\uFAA8 {{ .SSID }} {{ .Signal }}% {{ .ReceiveRate }}Mbps{{ end }}" ) func (w *wifi) enabled() bool { @@ -29,27 +16,16 @@ func (w *wifi) enabled() bool { if w.env.getPlatform() != windowsPlatform && !w.env.isWsl() { return false } - - // Bail out of no netsh command found - cmd := "netsh.exe" - if !w.env.hasCommand(cmd) { - return false - } - - // Attempt to retrieve output from netsh command - cmdResult, err := w.env.runCommand(cmd, "wlan", "show", "interfaces") + wifiInfo, err := w.env.getWifiNetwork() displayError := w.props.getBool(DisplayError, false) if err != nil && displayError { - w.State = fmt.Sprintf("WIFI ERR: %s", err) + w.Error = err.Error() return true } - if err != nil { + if err != nil || wifiInfo == nil { return false } - - // Extract data from netsh cmdResult - parseNetshCmdResult(cmdResult, w) - + w.wifiInfo = *wifiInfo return true } @@ -72,43 +48,3 @@ func (w *wifi) init(props properties, env environmentInfo) { w.props = props w.env = env } - -func parseNetshCmdResult(netshCmdResult string, w *wifi) { - lines := strings.Split(netshCmdResult, "\n") - for _, line := range lines { - matches := strings.Split(line, " : ") - if len(matches) != 2 { - continue - } - name := strings.TrimSpace(matches[0]) - value := strings.TrimSpace(matches[1]) - - switch name { - case "State": - w.State = value - w.Connected = value == "connected" - case "SSID": - w.SSID = value - case "Radio type": - w.RadioType = value - case "Authentication": - w.Authentication = value - case "Channel": - if intValue, err := strconv.Atoi(value); err == nil { - w.Channel = intValue - } - case "Receive rate (Mbps)": - if intValue, err := strconv.Atoi(strings.Split(value, ".")[0]); err == nil { - w.ReceiveRate = intValue - } - case "Transmit rate (Mbps)": - if intValue, err := strconv.Atoi(strings.Split(value, ".")[0]); err == nil { - w.TransmitRate = intValue - } - case "Signal": - if intValue, err := strconv.Atoi(strings.TrimRight(value, "%")); err == nil { - w.Signal = intValue - } - } - } -} diff --git a/src/segment_wifi_test.go b/src/segment_wifi_test.go index b7f474cd..086409de 100644 --- a/src/segment_wifi_test.go +++ b/src/segment_wifi_test.go @@ -2,114 +2,41 @@ package main import ( "errors" - "fmt" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" ) -type netshStringArgs struct { - state string - ssid string - radioType string - authentication string - channel int - receiveRate int - transmitRate int - signal int -} - -func getNetshString(args *netshStringArgs) string { - const netshString string = ` - There is 1 interface on the system: - - Name : Wi-Fi - Description : Intel(R) Wireless-AC 9560 160MHz - GUID : 6bb8def2-9af2-4bd4-8be2-6bd54e46bdc9 - Physical address : d4:3b:04:e6:10:40 - State : %s - SSID : %s - BSSID : 5c:7d:7d:82:c5:73 - Network type : Infrastructure - Radio type : %s - Authentication : %s - Cipher : CCMP - Connection mode : Profile - Channel : %d - Receive rate (Mbps) : %d - Transmit rate (Mbps) : %d - Signal : %d%% - Profile : ohsiggy - - Hosted network status : Not available` - - return fmt.Sprintf(netshString, args.state, args.ssid, args.radioType, args.authentication, args.channel, args.receiveRate, args.transmitRate, args.signal) -} - func TestWiFiSegment(t *testing.T) { cases := []struct { Case string ExpectedString string ExpectedEnabled bool - CommandNotFound bool - CommandOutput string - CommandError error + Network *wifiInfo + WifiError error DisplayError bool - Template string - ExpectedState string }{ { - Case: "not enabled on windows when netsh command not found", - ExpectedEnabled: false, - ExpectedString: "", - CommandNotFound: true, + Case: "No error and nil network", }, { - Case: "not enabled on windows when netsh command fails", - ExpectedEnabled: false, - ExpectedString: "", - CommandError: errors.New("intentional testing failure"), + Case: "Error and nil network", + WifiError: errors.New("oh noes"), }, { - Case: "enabled on windows with DisplayError=true", - ExpectedEnabled: true, - ExpectedString: "WIFI ERR: intentional testing failure", - CommandError: errors.New("intentional testing failure"), + Case: "Display error and nil network", + WifiError: errors.New("oh noes"), + ExpectedString: "oh noes", DisplayError: true, - Template: "{{.State}}", + ExpectedEnabled: true, }, { - Case: "enabled on windows with every property in template", + Case: "Display wifi state", + ExpectedString: "pretty fly for a wifi", ExpectedEnabled: true, - ExpectedString: "connected testing 802.11ac WPA2-Personal 99 500 400 80", - CommandOutput: getNetshString(&netshStringArgs{ - state: "connected", - ssid: "testing", - radioType: "802.11ac", - authentication: "WPA2-Personal", - channel: 99, - receiveRate: 500.0, - transmitRate: 400.0, - signal: 80, - }), - Template: "{{.State}} {{.SSID}} {{.RadioType}} {{.Authentication}} {{.Channel}} {{.ReceiveRate}} {{.TransmitRate}} {{.Signal}}", - }, - { - Case: "enabled on windows but wifi not connected", - ExpectedEnabled: true, - ExpectedString: "disconnected", - CommandOutput: getNetshString(&netshStringArgs{ - state: "disconnected", - }), - Template: "{{if not .Connected}}{{.State}}{{end}}", - }, - { - Case: "enabled on windows but template is invalid", - ExpectedEnabled: true, - ExpectedString: "unable to create text based on template", - CommandOutput: getNetshString(&netshStringArgs{}), - Template: "{{.DoesNotExist}}", + Network: &wifiInfo{ + SSID: "pretty fly for a wifi", + }, }, } @@ -117,18 +44,19 @@ func TestWiFiSegment(t *testing.T) { env := new(MockedEnvironment) env.On("getPlatform", nil).Return(windowsPlatform) env.On("isWsl", nil).Return(false) - env.On("hasCommand", "netsh.exe").Return(!tc.CommandNotFound) - env.On("runCommand", mock.Anything, mock.Anything).Return(tc.CommandOutput, tc.CommandError) + env.On("getWifiNetwork", nil).Return(tc.Network, tc.WifiError) w := &wifi{ env: env, props: map[Property]interface{}{ DisplayError: tc.DisplayError, - SegmentTemplate: tc.Template, + SegmentTemplate: "{{ if .Error }}{{ .Error }}{{ else }}{{ .SSID }}{{ end }}", }, } assert.Equal(t, tc.ExpectedEnabled, w.enabled(), tc.Case) - assert.Equal(t, tc.ExpectedString, w.string(), tc.Case) + if tc.Network != nil || tc.DisplayError { + assert.Equal(t, tc.ExpectedString, w.string(), tc.Case) + } } }