refactor(winreg): make windows registry query function more general

simplify config definition, requires one prop to define path and key
add variant return type
add template-accessible query functions on the segment for advanced use
This commit is contained in:
will 2021-12-04 21:11:25 +00:00 committed by Jan De Dobbeleer
parent 6b0b7c2b1b
commit 2ccaf90cbf
8 changed files with 187 additions and 72 deletions

View file

@ -12,6 +12,7 @@ Supported registry key types:
- String - String
- DWORD (displayed in upper-case 0x hex) - DWORD (displayed in upper-case 0x hex)
- QWORD (displayed in upper-case 0x hex)
## Sample Configuration ## Sample Configuration
@ -23,9 +24,9 @@ Supported registry key types:
"foreground": "#ffffff", "foreground": "#ffffff",
"background": "#444444", "background": "#444444",
"properties": { "properties": {
"path": "HKLM\\software\\microsoft\\windows nt\\currentversion", "path": "HKLM\\software\\microsoft\\windows nt\\currentversion\\buildlab",
"key":"buildlab", "fallback":"unknown",
"template":"{{ if .Value }}{{ .Value }}{{ else }}unknown{{ end }}", "template":"{{ .Value }}",
"prefix": " \uE62A " "prefix": " \uE62A "
} }
}, },
@ -34,14 +35,14 @@ Supported registry key types:
## Properties ## Properties
- path: `string` - registry path to the desired key using backslashes and with a valid root HKEY name. - path: `string` - registry path to the desired key using backslashes and with a valid root HKEY name.
- key: `string` - the key to read from the `path` location. Ending path with \ will get the (Default) key from that path.
- fallback: `string` - the value to fall back to if no entry is found - fallback: `string` - the value to fall back to if no entry is found
- template: `string` - a go [text/template][go-text-template] template extended - template: `string` - a go [text/template][go-text-template] template extended
with [sprig][sprig] utilizing the properties below. with [sprig][sprig] utilizing the properties below.
## Template Properties ## Template Properties
- .Value: `string` - The result of your query - .Value: `string` - The result of your query, or fallback if not found.
[go-text-template]: https://golang.org/pkg/text/template/ [go-text-template]: https://golang.org/pkg/text/template/
[sprig]: https://masterminds.github.io/sprig/ [sprig]: https://masterminds.github.io/sprig/

View file

@ -56,6 +56,21 @@ type cache interface {
set(key, value string, ttl int) set(key, value string, ttl int)
} }
type windowsRegistryValueType int
const (
regQword windowsRegistryValueType = iota
regDword
regString
)
type windowsRegistryValue struct {
valueType windowsRegistryValueType
qword uint64
dword uint32
str string
}
type environmentInfo interface { type environmentInfo interface {
getenv(key string) string getenv(key string) string
getcwd() string getcwd() string
@ -80,7 +95,7 @@ type environmentInfo interface {
getBatteryInfo() ([]*battery.Battery, error) getBatteryInfo() ([]*battery.Battery, error)
getShellName() string getShellName() string
getWindowTitle(imageName, windowTitleRegex string) (string, error) getWindowTitle(imageName, windowTitleRegex string) (string, error)
getWindowsRegistryKeyValue(regPath, regKey string) (string, error) getWindowsRegistryKeyValue(path string) (*windowsRegistryValue, error)
doGet(url string, timeout int) ([]byte, error) doGet(url string, timeout int) ([]byte, error)
hasParentFilePath(path string) (fileInfo *fileInfo, err error) hasParentFilePath(path string) (fileInfo *fileInfo, err error)
isWsl() bool isWsl() bool

View file

@ -60,6 +60,6 @@ func (env *environment) getCachePath() string {
return env.homeDir() return env.homeDir()
} }
func (env *environment) getWindowsRegistryKeyValue(regPath, regKey string) (string, error) { func (env *environment) getWindowsRegistryKeyValue(path string) (*windowsRegistryValue, error) {
return "", errors.New("not implemented") return nil, errors.New("not implemented")
} }

View file

@ -105,31 +105,70 @@ func (env *environment) getCachePath() string {
} }
// //
// Takes a registry path like "HKLM\Software\Microsoft\Windows NT\CurrentVersion" and a key under that path like "CurrentVersion" (or "" if the (Default) key is required). // Takes a registry path to a key like
// Returns a bool and string: // "HKLM\Software\Microsoft\Windows NT\CurrentVersion\EditionID"
// //
// true and the retrieved value formatted into a string if successful. // The last part of the path is the key to retrieve.
// false and the string will be the error
// //
func (env *environment) getWindowsRegistryKeyValue(regPath, regKey string) (string, error) { // If the path ends in "\", the "(Default)" key in that path is retrieved.
env.trace(time.Now(), "getWindowsRegistryKeyValue", regPath, regKey) //
// Returns a variant type if successful; nil and an error if not.
//
func (env *environment) getWindowsRegistryKeyValue(path string) (*windowsRegistryValue, error) {
env.trace(time.Now(), "getWindowsRegistryKeyValue", path)
// Extract root HK value and turn it into a windows.Handle to open the key. // Format:
regPathParts := strings.SplitN(regPath, "\\", 2) // "HKLM\Software\Microsoft\Windows NT\CurrentVersion\EditionID"
// 1 | 2 | 3
//
// Split into:
//
// 1. Root key - extract the root HKEY string and turn this into a handle to get started
// 2. Path - open this path
// 3. Key - get this key value
//
// If 3 is "" (i.e. the path ends with "\"), then get (Default) key.
//
regPathParts := strings.SplitN(path, "\\", 2)
if len(regPathParts) < 2 {
errorLogMsg := fmt.Sprintf("Error, malformed registry path: '%s'", path)
env.log(Error, "getWindowsRegistryKeyValue", errorLogMsg)
return nil, errors.New(errorLogMsg)
}
regRootHKeyHandle := getHKEYHandleFromAbbrString(regPathParts[0]) regRootHKeyHandle := getHKEYHandleFromAbbrString(regPathParts[0])
if regRootHKeyHandle == 0 { if regRootHKeyHandle == 0 {
errorLogMsg := fmt.Sprintf("Error, Supplied root HKEY value not valid: '%s'", regPathParts[0]) errorLogMsg := fmt.Sprintf("Error, Supplied root HKEY value not valid: '%s'", regPathParts[0])
env.log(Error, "getWindowsRegistryKeyValue", errorLogMsg) env.log(Error, "getWindowsRegistryKeyValue", errorLogMsg)
return "", errors.New(errorLogMsg) return nil, errors.New(errorLogMsg)
} }
// Second part of split is registry path after HK part - needs to be UTF16 to pass to the windows. API // Strip key off the end.
regPathUTF16, regPathUTF16ConversionErr := windows.UTF16FromString(regPathParts[1]) lastSlash := strings.LastIndex(regPathParts[1], "\\")
if regPathUTF16ConversionErr != nil {
errorLogMsg := fmt.Sprintf("Error, Could not convert supplied path '%s' to UTF16, error: '%s'", regPathParts[1], regPathUTF16ConversionErr) if lastSlash < 0 {
errorLogMsg := fmt.Sprintf("Error, malformed registry path: '%s'", path)
env.log(Error, "getWindowsRegistryKeyValue", errorLogMsg) env.log(Error, "getWindowsRegistryKeyValue", errorLogMsg)
return "", errors.New(errorLogMsg) return nil, errors.New(errorLogMsg)
}
regKey := regPathParts[1][lastSlash+1:]
regPath := regPathParts[1][0:lastSlash]
// Just for debug log display.
regKeyLogged := regKey
if len(regKeyLogged) == 0 {
regKeyLogged = "(Default)"
}
env.log(Debug, "getWindowsRegistryKeyValue", fmt.Sprintf("getWindowsRegistryKeyValue: root:\"%s\", path:\"%s\", key:\"%s\"", regPathParts[0], regPath, regKeyLogged))
// Second part of split is registry path after HK part - needs to be UTF16 to pass to the windows. API
regPathUTF16, err := windows.UTF16FromString(regPath)
if err != nil {
errorLogMsg := fmt.Sprintf("Error, Could not convert supplied path '%s' to UTF16, error: '%s'", regPath, err)
env.log(Error, "getWindowsRegistryKeyValue", errorLogMsg)
return nil, errors.New(errorLogMsg)
} }
// Ok - open it.. // Ok - open it..
@ -138,7 +177,7 @@ func (env *environment) getWindowsRegistryKeyValue(regPath, regKey string) (stri
if regOpenErr != nil { if regOpenErr != nil {
errorLogMsg := fmt.Sprintf("Error RegOpenKeyEx opening registry path to '%s', error: '%s'", regPath, regOpenErr) errorLogMsg := fmt.Sprintf("Error RegOpenKeyEx opening registry path to '%s', error: '%s'", regPath, regOpenErr)
env.log(Error, "getWindowsRegistryKeyValue", errorLogMsg) env.log(Error, "getWindowsRegistryKeyValue", errorLogMsg)
return "", errors.New(errorLogMsg) return nil, errors.New(errorLogMsg)
} }
// Success - from here on out, when returning make sure to close that reg key with a deferred call to close: // Success - from here on out, when returning make sure to close that reg key with a deferred call to close:
defer func() { defer func() {
@ -149,11 +188,11 @@ func (env *environment) getWindowsRegistryKeyValue(regPath, regKey string) (stri
}() }()
// Again - need UTF16 of the key for the API: // Again - need UTF16 of the key for the API:
regKeyUTF16, regKeyUTF16ConversionErr := windows.UTF16FromString(regKey) regKeyUTF16, err := windows.UTF16FromString(regKey)
if regKeyUTF16ConversionErr != nil { if err != nil {
errorLogMsg := fmt.Sprintf("Error, could not convert supplied key '%s' to UTF16, error: '%s'", regKey, regKeyUTF16ConversionErr) errorLogMsg := fmt.Sprintf("Error, could not convert supplied key '%s' to UTF16, error: '%s'", regKey, err)
env.log(Error, "getWindowsRegistryKeyValue", errorLogMsg) env.log(Error, "getWindowsRegistryKeyValue", errorLogMsg)
return "", errors.New(errorLogMsg) return nil, errors.New(errorLogMsg)
} }
// Two stage way to get the key value - query once to get size - then allocate and query again to fill it. // Two stage way to get the key value - query once to get size - then allocate and query again to fill it.
@ -164,7 +203,7 @@ func (env *environment) getWindowsRegistryKeyValue(regPath, regKey string) (stri
if regQueryErr != nil { if regQueryErr != nil {
errorLogMsg := fmt.Sprintf("Error calling RegQueryValueEx to retrieve key data size with error '%s'", regQueryErr) errorLogMsg := fmt.Sprintf("Error calling RegQueryValueEx to retrieve key data size with error '%s'", regQueryErr)
env.log(Error, "getWindowsRegistryKeyValue", errorLogMsg) env.log(Error, "getWindowsRegistryKeyValue", errorLogMsg)
return "", errors.New(errorLogMsg) return nil, errors.New(errorLogMsg)
} }
// Alloc and fill... // Alloc and fill...
@ -174,7 +213,7 @@ func (env *environment) getWindowsRegistryKeyValue(regPath, regKey string) (stri
if regQueryErr != nil { if regQueryErr != nil {
errorLogMsg := fmt.Sprintf("Error calling RegQueryValueEx to retrieve key data with error '%s'", regQueryErr) errorLogMsg := fmt.Sprintf("Error calling RegQueryValueEx to retrieve key data with error '%s'", regQueryErr)
env.log(Error, "getWindowsRegistryKeyValue", errorLogMsg) env.log(Error, "getWindowsRegistryKeyValue", errorLogMsg)
return "", errors.New(errorLogMsg) return nil, errors.New(errorLogMsg)
} }
// Format result into a string, depending on type. (future refactor - move this out into it's own function) // Format result into a string, depending on type. (future refactor - move this out into it's own function)
@ -185,16 +224,22 @@ func (env *environment) getWindowsRegistryKeyValue(regPath, regKey string) (stri
valueString := windows.UTF16PtrToString(uint16p) valueString := windows.UTF16PtrToString(uint16p)
env.log(Debug, "getWindowsRegistryKeyValue", fmt.Sprintf("success, string: %s", valueString)) env.log(Debug, "getWindowsRegistryKeyValue", fmt.Sprintf("success, string: %s", valueString))
return valueString, nil
return &windowsRegistryValue{valueType: regString, str: valueString}, nil
case windows.REG_DWORD: case windows.REG_DWORD:
var uint32p *uint32 var uint32p *uint32
uint32p = (*uint32)(unsafe.Pointer(&keyBuf[0])) // more casting goodness uint32p = (*uint32)(unsafe.Pointer(&keyBuf[0])) // more casting goodness
valueString := fmt.Sprintf("0x%08X", *uint32p) env.log(Debug, "getWindowsRegistryKeyValue", fmt.Sprintf("success, DWORD, 0x%08X", *uint32p))
env.log(Debug, "getWindowsRegistryKeyValue", fmt.Sprintf("success, DWORD, formatted as string: %s", valueString)) return &windowsRegistryValue{valueType: regDword, dword: *uint32p}, nil
return valueString, nil case windows.REG_QWORD:
var uint64p *uint64
uint64p = (*uint64)(unsafe.Pointer(&keyBuf[0])) // more casting goodness
env.log(Debug, "getWindowsRegistryKeyValue", fmt.Sprintf("success, QWORD, 0x%016X", *uint64p))
return &windowsRegistryValue{valueType: regQword, qword: *uint64p}, nil
default: default:
errorLogMsg := fmt.Sprintf("Error, no formatter for REG_? type:%d, data size:%d bytes", keyBufType, keyBufSize) errorLogMsg := fmt.Sprintf("Error, no formatter for REG_? type:%d, data size:%d bytes", keyBufType, keyBufSize)
return "", errors.New(errorLogMsg) return nil, errors.New(errorLogMsg)
} }
} }

View file

@ -129,9 +129,9 @@ func (env *MockedEnvironment) getWindowTitle(imageName, windowTitleRegex string)
return args.String(0), args.Error(1) return args.String(0), args.Error(1)
} }
func (env *MockedEnvironment) getWindowsRegistryKeyValue(regPath, regKey string) (string, error) { func (env *MockedEnvironment) getWindowsRegistryKeyValue(path string) (*windowsRegistryValue, error) {
args := env.Called(regPath, regKey) args := env.Called(path)
return args.String(0), args.Error(1) return args.Get(0).(*windowsRegistryValue), args.Error(1)
} }
func (env *MockedEnvironment) doGet(url string, timeout int) ([]byte, error) { func (env *MockedEnvironment) doGet(url string, timeout int) ([]byte, error) {

View file

@ -1,5 +1,10 @@
package main package main
import (
"errors"
"fmt"
)
type winreg struct { type winreg struct {
props properties props properties
env environmentInfo env environmentInfo
@ -8,10 +13,8 @@ type winreg struct {
} }
const ( const (
// path from the supplied root under which the key exists // full path to the key; if ends in \, gets "(Default)" key in that path
RegistryPath Property = "path" RegistryPath Property = "path"
// key within full reg path formed from two above
RegistryKey Property = "key"
// Fallback is the text to display if the key is not found // Fallback is the text to display if the key is not found
Fallback Property = "fallback" Fallback Property = "fallback"
) )
@ -27,14 +30,23 @@ func (wr *winreg) enabled() bool {
} }
registryPath := wr.props.getString(RegistryPath, "") registryPath := wr.props.getString(RegistryPath, "")
registryKey := wr.props.getString(RegistryKey, "")
fallback := wr.props.getString(Fallback, "") fallback := wr.props.getString(Fallback, "")
var err error var regValue *windowsRegistryValue
wr.Value, err = wr.env.getWindowsRegistryKeyValue(registryPath, registryKey) regValue, _ = wr.env.getWindowsRegistryKeyValue(registryPath)
if err == nil { if regValue != nil {
switch regValue.valueType {
case regString:
wr.Value = regValue.str
return true return true
case regDword:
wr.Value = fmt.Sprintf("0x%08X", regValue.dword)
return true
case regQword:
wr.Value = fmt.Sprintf("0x%016X", regValue.qword)
return true
}
} }
if len(fallback) > 0 { if len(fallback) > 0 {
@ -62,3 +74,45 @@ func (wr *winreg) templateString(segmentTemplate string) string {
} }
return text return text
} }
func (wr winreg) GetRegistryString(path string) (string, error) {
regValue, err := wr.env.getWindowsRegistryKeyValue(path)
if regValue == nil {
return "", err
}
if regValue.valueType != regString {
return "", errors.New("type mismatch, registry value is not a string")
}
return regValue.str, nil
}
func (wr winreg) GetRegistryDword(path string) (uint32, error) {
regValue, err := wr.env.getWindowsRegistryKeyValue(path)
if regValue == nil {
return 0, err
}
if regValue.valueType != regDword {
return 0, errors.New("type mismatch, registry value is not a dword")
}
return regValue.dword, nil
}
func (wr winreg) GetRegistryQword(path string) (uint64, error) {
regValue, err := wr.env.getWindowsRegistryKeyValue(path)
if regValue == nil {
return 0, err
}
if regValue.valueType != regQword {
return 0, errors.New("type mismatch, registry value is not a qword")
}
return regValue.qword, nil
}

View file

@ -7,37 +7,32 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestRegQueryEnabled(t *testing.T) { func TestWinReg(t *testing.T) {
cases := []struct { cases := []struct {
CaseDescription string CaseDescription string
Path string Path string
Key string
Fallback string Fallback string
ExpectedSuccess bool ExpectedSuccess bool
ExpectedValue string ExpectedValue string
Output string getWRKVOutput *windowsRegistryValue
Err error Err error
}{ }{
{ {
CaseDescription: "Error", CaseDescription: "Error",
Path: "HKLLM\\Software\\Microsoft\\Windows NT\\CurrentVersion", Path: "HKLLM\\Software\\Microsoft\\Windows NT\\CurrentVersion\\ProductName",
Key: "ProductName",
Err: errors.New("No match"), Err: errors.New("No match"),
ExpectedSuccess: false, ExpectedSuccess: false,
}, },
{ {
CaseDescription: "Value", CaseDescription: "Value",
Path: "HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion", Path: "HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion\\InstallTime",
Key: "InstallTime", getWRKVOutput: &windowsRegistryValue{valueType: regString, str: "xbox"},
Output: "xbox",
ExpectedSuccess: true, ExpectedSuccess: true,
ExpectedValue: "xbox", ExpectedValue: "xbox",
}, },
{ {
CaseDescription: "Fallback value", CaseDescription: "Fallback value",
Path: "HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion", Path: "HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion\\InstallTime",
Key: "InstallTime",
Output: "no formatter",
Fallback: "cortana", Fallback: "cortana",
Err: errors.New("No match"), Err: errors.New("No match"),
ExpectedSuccess: true, ExpectedSuccess: true,
@ -45,32 +40,43 @@ func TestRegQueryEnabled(t *testing.T) {
}, },
{ {
CaseDescription: "Empty string value (no error) should display empty string even in presence of fallback", CaseDescription: "Empty string value (no error) should display empty string even in presence of fallback",
Path: "HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion", Path: "HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion\\InstallTime",
Key: "InstallTime", getWRKVOutput: &windowsRegistryValue{valueType: regString, str: ""},
Output: "",
Fallback: "anaconda", Fallback: "anaconda",
ExpectedSuccess: true, ExpectedSuccess: true,
ExpectedValue: "", ExpectedValue: "",
}, },
{ {
CaseDescription: "Empty string value (no error) should display empty string", CaseDescription: "Empty string value (no error) should display empty string",
Path: "HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion", Path: "HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion\\InstallTime",
Key: "InstallTime", getWRKVOutput: &windowsRegistryValue{valueType: regString, str: ""},
Output: "",
ExpectedSuccess: true, ExpectedSuccess: true,
ExpectedValue: "", ExpectedValue: "",
}, },
{
CaseDescription: "DWORD value",
Path: "HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion\\InstallTime",
getWRKVOutput: &windowsRegistryValue{valueType: regDword, dword: 0xdeadbeef},
ExpectedSuccess: true,
ExpectedValue: "0xDEADBEEF",
},
{
CaseDescription: "QWORD value",
Path: "HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion\\InstallTime",
getWRKVOutput: &windowsRegistryValue{valueType: regQword, qword: 0x7eb199e57fa1afe1},
ExpectedSuccess: true,
ExpectedValue: "0x7EB199E57FA1AFE1",
},
} }
for _, tc := range cases { for _, tc := range cases {
env := new(MockedEnvironment) env := new(MockedEnvironment)
env.On("getRuntimeGOOS", nil).Return(windowsPlatform) env.On("getRuntimeGOOS", nil).Return(windowsPlatform)
env.On("getWindowsRegistryKeyValue", tc.Path, tc.Key).Return(tc.Output, tc.Err) env.On("getWindowsRegistryKeyValue", tc.Path).Return(tc.getWRKVOutput, tc.Err)
r := &winreg{ r := &winreg{
env: env, env: env,
props: map[Property]interface{}{ props: map[Property]interface{}{
RegistryPath: tc.Path, RegistryPath: tc.Path,
RegistryKey: tc.Key,
Fallback: tc.Fallback, Fallback: tc.Fallback,
}, },
} }

View file

@ -1687,13 +1687,7 @@
"path": { "path": {
"type": "string", "type": "string",
"title": "Registry Path", "title": "Registry Path",
"description": "The path under which the registy key lives (case insensitive, must use backslashes), e.g. HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion", "description": "The path to the registry key (case insensitive, must use backslashes). Ending with \\ will retrieve \"(Default)\" key in that path.",
"default": ""
},
"key": {
"type": "string",
"title": "Registry Key",
"description": "The key under he registry path to get (case insensitive). If left blank, will get the value of the (Default) key in the registry_path",
"default": "" "default": ""
}, },
"fallback": { "fallback": {