feat(winreg): additional value support

This commit is contained in:
Jan De Dobbeleer 2022-06-01 09:50:20 +02:00 committed by Jan De Dobbeleer
parent 1668949037
commit b22dd21fa5
13 changed files with 102 additions and 239 deletions

View file

@ -13,13 +13,13 @@ func GetAccentColor(env environment.Environment) (*RGB, error) {
}
// see https://stackoverflow.com/questions/3560890/vista-7-how-to-get-glass-color
value, err := env.WindowsRegistryKeyValue(`HKEY_CURRENT_USER\Software\Microsoft\Windows\DWM\ColorizationColor`)
if err != nil {
if err != nil || value.ValueType != environment.DWORD {
return nil, err
}
return &RGB{
R: byte(value.Dword >> 16),
G: byte(value.Dword >> 8),
B: byte(value.Dword),
R: byte(value.DWord >> 16),
G: byte(value.DWord >> 8),
B: byte(value.DWord),
}, nil
}

View file

@ -2,7 +2,7 @@ package environment
import (
"encoding/json"
"io/ioutil"
"os"
"path/filepath"
"time"
)
@ -27,7 +27,7 @@ func (fc *fileCache) Init(cachePath string) {
fc.cache = newConcurrentMap()
fc.cachePath = cachePath
cacheFilePath := filepath.Join(fc.cachePath, CacheFile)
content, err := ioutil.ReadFile(cacheFilePath)
content, err := os.ReadFile(cacheFilePath)
if err != nil {
return
}
@ -49,7 +49,7 @@ func (fc *fileCache) Close() {
cache := fc.cache.list()
if dump, err := json.MarshalIndent(cache, "", " "); err == nil {
cacheFilePath := filepath.Join(fc.cachePath, CacheFile)
_ = ioutil.WriteFile(cacheFilePath, dump, 0644)
_ = os.WriteFile(cacheFilePath, dump, 0644)
}
}

View file

@ -7,7 +7,6 @@ import (
"fmt"
"io"
"io/fs"
"io/ioutil"
"log"
"net/http"
"oh-my-posh/regex"
@ -84,19 +83,20 @@ type Cache interface {
type HTTPRequestModifier func(request *http.Request)
type WindowsRegistryValueType int
type WindowsRegistryValueType string
const (
RegQword WindowsRegistryValueType = iota
RegDword
RegString
DWORD = "DWORD"
QWORD = "QWORD"
BINARY = "BINARY"
STRING = "STRING"
)
type WindowsRegistryValue struct {
ValueType WindowsRegistryValueType
Qword uint64
Dword uint32
Str string
DWord uint64
QWord uint64
String string
}
type WifiType string
@ -433,7 +433,7 @@ func (env *ShellEnvironment) FileContent(file string) string {
if !filepath.IsAbs(file) {
file = filepath.Join(env.Pwd(), file)
}
content, err := ioutil.ReadFile(file)
content, err := os.ReadFile(file)
if err != nil {
env.Log(Error, "FileContent", err.Error())
return ""
@ -620,7 +620,7 @@ func (env *ShellEnvironment) HTTPRequest(targetURL string, timeout int, requestM
return nil, err
}
defer response.Body.Close()
body, err := ioutil.ReadAll(response.Body)
body, err := io.ReadAll(response.Body)
if err != nil {
env.Log(Error, "HTTPRequest", err.Error())
return nil, err

View file

@ -15,6 +15,7 @@ import (
"github.com/Azure/go-ansiterm/winterm"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
func (env *ShellEnvironment) Root() bool {
@ -134,16 +135,15 @@ func (env *ShellEnvironment) LookWinAppPath(file string) (string, error) {
return "", errors.New("no Windows Store App")
}
//
// Takes a registry path to a key like
// "HKLM\Software\Microsoft\Windows NT\CurrentVersion\EditionID"
//
// "HKLM\Software\Microsoft\Windows NT\CurrentVersion\EditionID"
//
// The last part of the path is the key to retrieve.
//
// If the path ends in "\", the "(Default)" key in that path is retrieved.
//
// Returns a variant type if successful; nil and an error if not.
//
func (env *ShellEnvironment) WindowsRegistryKeyValue(path string) (*WindowsRegistryValue, error) {
env.Trace(time.Now(), "WindowsRegistryKeyValue", path)
@ -159,119 +159,70 @@ func (env *ShellEnvironment) WindowsRegistryKeyValue(path string) (*WindowsRegis
//
// If 3 is "" (i.e. the path ends with "\"), then get (Default) key.
//
regPathParts := strings.SplitN(path, "\\", 2)
if len(regPathParts) < 2 {
rootKey, regPath, found := strings.Cut(path, `\`)
if !found {
errorLogMsg := fmt.Sprintf("Error, malformed registry path: '%s'", path)
env.Log(Error, "WindowsRegistryKeyValue", errorLogMsg)
return nil, errors.New(errorLogMsg)
}
regRootHKeyHandle := getHKEYHandleFromAbbrString(regPathParts[0])
if regRootHKeyHandle == 0 {
errorLogMsg := fmt.Sprintf("Error, Supplied root HKEY value not valid: '%s'", regPathParts[0])
env.Log(Error, "WindowsRegistryKeyValue", errorLogMsg)
return nil, errors.New(errorLogMsg)
regKey := Base(env, regPath)
if len(regKey) != 0 {
regPath = strings.TrimSuffix(regPath, `\`+regKey)
}
// Strip key off the end.
lastSlash := strings.LastIndex(regPathParts[1], "\\")
if lastSlash < 0 {
errorLogMsg := fmt.Sprintf("Error, malformed registry path: '%s'", path)
env.Log(Error, "WindowsRegistryKeyValue", errorLogMsg)
return nil, errors.New(errorLogMsg)
}
regKey := regPathParts[1][lastSlash+1:]
regPath := regPathParts[1][0:lastSlash]
// Just for debug log display.
regKeyLogged := regKey
if len(regKeyLogged) == 0 {
regKeyLogged = "(Default)"
}
env.Log(Debug, "WindowsRegistryKeyValue", fmt.Sprintf("WindowsRegistryKeyValue: root:\"%s\", path:\"%s\", key:\"%s\"", regPathParts[0], regPath, regKeyLogged))
// Second part of split is registry path after HK part - needs to be UTF16 to pass to the windows. API
regPathUTF16, err := windows.UTF16FromString(regPath)
if err != nil {
errorLogMsg := fmt.Sprintf("Error, Could not convert supplied path '%s' to UTF16, error: '%s'", regPath, err)
env.Log(Error, "WindowsRegistryKeyValue", errorLogMsg)
return nil, errors.New(errorLogMsg)
}
// Ok - open it..
var hKeyHandle windows.Handle
regOpenErr := windows.RegOpenKeyEx(regRootHKeyHandle, &regPathUTF16[0], 0, windows.KEY_READ, &hKeyHandle)
if regOpenErr != nil {
errorLogMsg := fmt.Sprintf("Error RegOpenKeyEx opening registry path to '%s', error: '%s'", regPath, regOpenErr)
env.Log(Error, "WindowsRegistryKeyValue", errorLogMsg)
return nil, errors.New(errorLogMsg)
}
// Success - from here on out, when returning make sure to close that reg key with a deferred call to close:
defer func() {
err := windows.RegCloseKey(hKeyHandle)
if err != nil {
env.Log(Error, "WindowsRegistryKeyValue", fmt.Sprintf("Error closing registry key: %s", err))
}
}()
// Again - need UTF16 of the key for the API:
regKeyUTF16, err := windows.UTF16FromString(regKey)
if err != nil {
errorLogMsg := fmt.Sprintf("Error, could not convert supplied key '%s' to UTF16, error: '%s'", regKey, err)
env.Log(Error, "WindowsRegistryKeyValue", errorLogMsg)
return nil, errors.New(errorLogMsg)
}
// Two stage way to get the key value - query once to get size - then allocate and query again to fill it.
var keyBufType uint32
var keyBufSize uint32
regQueryErr := windows.RegQueryValueEx(hKeyHandle, &regKeyUTF16[0], nil, &keyBufType, nil, &keyBufSize)
if regQueryErr != nil {
errorLogMsg := fmt.Sprintf("Error calling RegQueryValueEx to retrieve key data size with error '%s'", regQueryErr)
env.Log(Error, "WindowsRegistryKeyValue", errorLogMsg)
return nil, errors.New(errorLogMsg)
}
// Alloc and fill...
var keyBuf = make([]byte, keyBufSize)
regQueryErr = windows.RegQueryValueEx(hKeyHandle, &regKeyUTF16[0], nil, &keyBufType, &keyBuf[0], &keyBufSize)
if regQueryErr != nil {
errorLogMsg := fmt.Sprintf("Error calling RegQueryValueEx to retrieve key data with error '%s'", regQueryErr)
env.Log(Error, "WindowsRegistryKeyValue", errorLogMsg)
return nil, errors.New(errorLogMsg)
}
// Format result into a string, depending on type. (future refactor - move this out into it's own function)
switch keyBufType {
case windows.REG_SZ:
var uint16p *uint16
uint16p = (*uint16)(unsafe.Pointer(&keyBuf[0])) // nasty casty
valueString := windows.UTF16PtrToString(uint16p)
env.Log(Debug, "WindowsRegistryKeyValue", fmt.Sprintf("success, string: %s", valueString))
return &WindowsRegistryValue{ValueType: RegString, Str: valueString}, nil
case windows.REG_DWORD:
var uint32p *uint32
uint32p = (*uint32)(unsafe.Pointer(&keyBuf[0])) // more casting goodness
env.Log(Debug, "WindowsRegistryKeyValue", fmt.Sprintf("success, DWORD, 0x%08X", *uint32p))
return &WindowsRegistryValue{ValueType: RegDword, Dword: *uint32p}, nil
case windows.REG_QWORD:
var uint64p *uint64
uint64p = (*uint64)(unsafe.Pointer(&keyBuf[0])) // more casting goodness
env.Log(Debug, "WindowsRegistryKeyValue", fmt.Sprintf("success, QWORD, 0x%016X", *uint64p))
return &WindowsRegistryValue{ValueType: RegQword, Qword: *uint64p}, nil
var key registry.Key
switch rootKey {
case "HKCR", "HKEY_CLASSES_ROOT":
key = windows.HKEY_CLASSES_ROOT
case "HKCC", "HKEY_CURRENT_CONFIG":
key = windows.HKEY_CURRENT_CONFIG
case "HKCU", "HKEY_CURRENT_USER":
key = windows.HKEY_CURRENT_USER
case "HKLM", "HKEY_LOCAL_MACHINE":
key = windows.HKEY_LOCAL_MACHINE
case "HKU", "HKEY_USERS":
key = windows.HKEY_USERS
default:
errorLogMsg := fmt.Sprintf("Error, no formatter for REG_? type:%d, data size:%d bytes", keyBufType, keyBufSize)
errorLogMsg := fmt.Sprintf("Error, unknown registry key: '%s'", rootKey)
env.Log(Error, "WindowsRegistryKeyValue", errorLogMsg)
return nil, errors.New(errorLogMsg)
}
k, err := registry.OpenKey(key, regPath, registry.READ)
if err != nil {
env.Log(Error, "WindowsRegistryKeyValue", err.Error())
return nil, err
}
_, valType, err := k.GetValue(regKey, nil)
if err != nil {
env.Log(Error, "WindowsRegistryKeyValue", err.Error())
return nil, err
}
var regValue *WindowsRegistryValue
switch valType {
case windows.REG_SZ, windows.REG_EXPAND_SZ:
value, _, _ := k.GetStringValue(regKey)
regValue = &WindowsRegistryValue{ValueType: STRING, String: value}
case windows.REG_DWORD:
value, _, _ := k.GetIntegerValue(regKey)
regValue = &WindowsRegistryValue{ValueType: DWORD, DWord: value, String: fmt.Sprintf("0x%08X", value)}
case windows.REG_QWORD:
value, _, _ := k.GetIntegerValue(regKey)
regValue = &WindowsRegistryValue{ValueType: QWORD, QWord: value, String: fmt.Sprintf("0x%016X", value)}
case windows.REG_BINARY:
value, _, _ := k.GetBinaryValue(regKey)
regValue = &WindowsRegistryValue{ValueType: BINARY, String: string(value)}
}
if regValue == nil {
errorLogMsg := fmt.Sprintf("Error, no formatter for type: %d", valType)
return nil, errors.New(errorLogMsg)
}
env.Log(Debug, "WindowsRegistryKeyValue", fmt.Sprintf("%s(%s): %s", regKey, regValue.ValueType, regValue.String))
return regValue, nil
}
func (env *ShellEnvironment) InWSLSharedDrive() bool {

View file

@ -112,25 +112,6 @@ func queryWindowTitles(processName, windowTitleRegex string) (string, error) {
return title, nil
}
// Return the windows handles corresponding to the names of the root registry keys.
// A returned value of 0 means there was no match.
func getHKEYHandleFromAbbrString(abbr string) windows.Handle {
switch abbr {
case "HKCR", "HKEY_CLASSES_ROOT":
return windows.HKEY_CLASSES_ROOT
case "HKCC", "HKEY_CURRENT_CONFIG":
return windows.HKEY_CURRENT_CONFIG
case "HKCU", "HKEY_CURRENT_USER":
return windows.HKEY_CURRENT_USER
case "HKLM", "HKEY_LOCAL_MACHINE":
return windows.HKEY_LOCAL_MACHINE
case "HKU", "HKEY_USERS":
return windows.HKEY_USERS
}
return 0
}
type REPARSE_DATA_BUFFER struct { // nolint: revive
ReparseTag uint32
ReparseDataLength uint16

View file

@ -1,11 +1,11 @@
package segments
import (
"io/ioutil"
"oh-my-posh/environment"
"oh-my-posh/mock"
"oh-my-posh/properties"
"oh-my-posh/template"
"os"
"path/filepath"
"testing"
@ -113,11 +113,11 @@ func TestAzSegment(t *testing.T) {
env.On("Home").Return(home)
var azureProfile, azureRmContext string
if tc.HasCLI {
content, _ := ioutil.ReadFile("../test/azureProfile.json")
content, _ := os.ReadFile("../test/azureProfile.json")
azureProfile = string(content)
}
if tc.HasPowerShell {
content, _ := ioutil.ReadFile("../test/AzureRmContext.json")
content, _ := os.ReadFile("../test/AzureRmContext.json")
azureRmContext = string(content)
}
env.On("GOOS").Return(environment.LinuxPlatform)

View file

@ -3,10 +3,10 @@ package segments
import (
"errors"
"fmt"
"io/ioutil"
"oh-my-posh/environment"
"oh-my-posh/mock"
"oh-my-posh/properties"
"os"
"testing"
"github.com/stretchr/testify/assert"
@ -81,7 +81,7 @@ func TestGolang(t *testing.T) {
if tc.InvalidModfile {
content = "invalid go.mod file"
} else {
tmp, _ := ioutil.ReadFile(fileInfo.Path)
tmp, _ := os.ReadFile(fileInfo.Path)
content = string(tmp)
}
env.On("FileContent", fileInfo.Path).Return(content)

View file

@ -2,10 +2,10 @@ package segments
import (
"fmt"
"io/ioutil"
"oh-my-posh/environment"
"oh-my-posh/mock"
"oh-my-posh/properties"
"os"
"path/filepath"
"testing"
@ -111,7 +111,7 @@ func TestKubectlSegment(t *testing.T) {
env := new(mock.MockedEnvironment)
env.On("HasCommand", "kubectl").Return(tc.KubectlExists)
var kubeconfig string
content, err := ioutil.ReadFile("../test/kubectl.yml")
content, err := os.ReadFile("../test/kubectl.yml")
if err == nil {
kubeconfig = fmt.Sprintf(string(content), tc.Cluster, tc.UserName, tc.Namespace, tc.Context)
}

View file

@ -2,9 +2,9 @@ package segments
import (
"io/fs"
"io/ioutil"
"oh-my-posh/mock"
"oh-my-posh/properties"
"os"
"testing"
"github.com/alecthomas/assert"
@ -256,7 +256,7 @@ func TestNuspecPackage(t *testing.T) {
name: tc.FileName,
},
})
content, _ := ioutil.ReadFile(tc.FileName)
content, _ := os.ReadFile(tc.FileName)
env.On("FileContent", tc.FileName).Return(string(content))
pkg := &Project{}
pkg.Init(properties.Map{}, env)

View file

@ -1,9 +1,9 @@
package segments
import (
"io/ioutil"
"oh-my-posh/mock"
"oh-my-posh/properties"
"os"
"testing"
"github.com/stretchr/testify/assert"
@ -82,11 +82,11 @@ func TestTerraform(t *testing.T) {
env.On("HasFiles", "main.tf").Return(tc.HasTfFiles)
env.On("HasFiles", "terraform.tfstate").Return(tc.HasTfStateFile)
if tc.HasTfFiles {
content, _ := ioutil.ReadFile("../test/versions.tf")
content, _ := os.ReadFile("../test/versions.tf")
env.On("FileContent", "versions.tf").Return(string(content))
}
if tc.HasTfStateFile {
content, _ := ioutil.ReadFile("../test/terraform.tfstate")
content, _ := os.ReadFile("../test/terraform.tfstate")
env.On("FileContent", "terraform.tfstate").Return(string(content))
}
tf := &Terraform{

View file

@ -1,8 +1,6 @@
package segments
import (
"errors"
"fmt"
"oh-my-posh/environment"
"oh-my-posh/properties"
)
@ -36,71 +34,16 @@ func (wr *WindowsRegistry) Enabled() bool {
}
registryPath := wr.props.GetString(RegistryPath, "")
fallback := wr.props.GetString(Fallback, "")
wr.Value = wr.props.GetString(Fallback, "")
var regValue *environment.WindowsRegistryValue
regValue, _ = wr.env.WindowsRegistryKeyValue(registryPath)
if regValue != nil {
switch regValue.ValueType {
case environment.RegString:
wr.Value = regValue.Str
return true
case environment.RegDword:
wr.Value = fmt.Sprintf("0x%08X", regValue.Dword)
return true
case environment.RegQword:
wr.Value = fmt.Sprintf("0x%016X", regValue.Qword)
return true
}
}
if len(fallback) > 0 {
wr.Value = fallback
regValue, err := wr.env.WindowsRegistryKeyValue(registryPath)
if err == nil {
wr.Value = regValue.String
return true
}
if len(wr.Value) > 0 {
// we have fallback value
return true
}
return false
}
func (wr WindowsRegistry) GetRegistryString(path string) (string, error) {
regValue, err := wr.env.WindowsRegistryKeyValue(path)
if regValue == nil {
return "", err
}
if regValue.ValueType != environment.RegString {
return "", errors.New("type mismatch, registry value is not a string")
}
return regValue.Str, nil
}
func (wr WindowsRegistry) GetRegistryDword(path string) (uint32, error) {
regValue, err := wr.env.WindowsRegistryKeyValue(path)
if regValue == nil {
return 0, err
}
if regValue.ValueType != environment.RegDword {
return 0, errors.New("type mismatch, registry value is not a dword")
}
return regValue.Dword, nil
}
func (wr WindowsRegistry) GetRegistryQword(path string) (uint64, error) {
regValue, err := wr.env.WindowsRegistryKeyValue(path)
if regValue == nil {
return 0, err
}
if regValue.ValueType != environment.RegQword {
return 0, errors.New("type mismatch, registry value is not a qword")
}
return regValue.Qword, nil
}

View file

@ -29,7 +29,7 @@ func TestWinReg(t *testing.T) {
{
CaseDescription: "Value",
Path: "HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion\\InstallTime",
getWRKVOutput: &environment.WindowsRegistryValue{ValueType: environment.RegString, Str: "xbox"},
getWRKVOutput: &environment.WindowsRegistryValue{ValueType: environment.STRING, String: "xbox"},
ExpectedSuccess: true,
ExpectedValue: "xbox",
},
@ -44,7 +44,7 @@ func TestWinReg(t *testing.T) {
{
CaseDescription: "Empty string value (no error) should display empty string even in presence of fallback",
Path: "HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion\\InstallTime",
getWRKVOutput: &environment.WindowsRegistryValue{ValueType: environment.RegString, Str: ""},
getWRKVOutput: &environment.WindowsRegistryValue{ValueType: environment.STRING, String: ""},
Fallback: "anaconda",
ExpectedSuccess: true,
ExpectedValue: "",
@ -52,24 +52,10 @@ func TestWinReg(t *testing.T) {
{
CaseDescription: "Empty string value (no error) should display empty string",
Path: "HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion\\InstallTime",
getWRKVOutput: &environment.WindowsRegistryValue{ValueType: environment.RegString, Str: ""},
getWRKVOutput: &environment.WindowsRegistryValue{ValueType: environment.STRING, String: ""},
ExpectedSuccess: true,
ExpectedValue: "",
},
{
CaseDescription: "DWORD value",
Path: "HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion\\InstallTime",
getWRKVOutput: &environment.WindowsRegistryValue{ValueType: environment.RegDword, Dword: 0xdeadbeef},
ExpectedSuccess: true,
ExpectedValue: "0xDEADBEEF",
},
{
CaseDescription: "QWORD value",
Path: "HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion\\InstallTime",
getWRKVOutput: &environment.WindowsRegistryValue{ValueType: environment.RegQword, Qword: 0x7eb199e57fa1afe1},
ExpectedSuccess: true,
ExpectedValue: "0x7EB199E57FA1AFE1",
},
}
for _, tc := range cases {

View file

@ -10,9 +10,11 @@ Display the content of the requested Windows registry key.
Supported registry key types:
- String
- DWORD (displayed in upper-case 0x hex)
- QWORD (displayed in upper-case 0x hex)
- `SZ` (displayed as string value)
- `EXPAND_SZ` (displayed as string value)
- `BINARY` (displayed as string value)
- `DWORD` (displayed in upper-case 0x hex)
- `QWORD` (displayed in upper-case 0x hex)
## Sample Configuration