diff --git a/src/platform/shell_unix.go b/src/platform/shell_unix.go index 594ca6cc..b8fe655c 100644 --- a/src/platform/shell_unix.go +++ b/src/platform/shell_unix.go @@ -6,11 +6,11 @@ import ( "errors" "os" "strings" - "syscall" "time" "github.com/shirou/gopsutil/v3/host" terminal "github.com/wayneashleyberry/terminal-dimensions" + "golang.org/x/sys/unix" ) func (env *Shell) Root() bool { @@ -126,35 +126,7 @@ func (env *Shell) LookWinAppPath(file string) (string, error) { func (env *Shell) DirIsWritable(path string) bool { defer env.Trace(time.Now(), "DirIsWritable", path) - info, err := os.Stat(path) - if err != nil { - env.Log(Error, "DirIsWritable", err.Error()) - return false - } - - if !info.IsDir() { - env.Log(Error, "DirIsWritable", "Path isn't a directory") - return false - } - - // Check if the user bit is enabled in file permission - if info.Mode().Perm()&(1<<(uint(7))) == 0 { - env.Log(Error, "DirIsWritable", "Write permission bit is not set on this file for user") - return false - } - - var stat syscall.Stat_t - if err = syscall.Stat(path, &stat); err != nil { - env.Log(Error, "DirIsWritable", err.Error()) - return false - } - - if uint32(os.Geteuid()) != stat.Uid { - env.Log(Error, "DirIsWritable", "User doesn't have permission to write to this directory") - return false - } - - return true + return unix.Access(path, unix.W_OK) == nil } func (env *Shell) Connection(connectionType ConnectionType) (*Connection, error) { diff --git a/src/platform/shell_windows.go b/src/platform/shell_windows.go index e230a040..fc1f2fb7 100644 --- a/src/platform/shell_windows.go +++ b/src/platform/shell_windows.go @@ -235,24 +235,7 @@ func (env *Shell) ConvertToLinuxPath(path string) string { func (env *Shell) DirIsWritable(path string) bool { defer env.Trace(time.Now(), "DirIsWritable") - info, err := os.Stat(path) - if err != nil { - env.Log(Error, "DirIsWritable", err.Error()) - return false - } - - if !info.IsDir() { - env.Log(Error, "DirIsWritable", "Path isn't a directory") - return false - } - - // Check if the user bit is enabled in file permission - if info.Mode().Perm()&(1<<(uint(7))) == 0 { - env.Log(Error, "DirIsWritable", "Write permission bit is not set on this file for user") - return false - } - - return true + return isWriteable(path) } func (env *Shell) Connection(connectionType ConnectionType) (*Connection, error) { diff --git a/src/platform/win32_windows.go b/src/platform/win32_windows.go index 8db44be4..6534f6dc 100644 --- a/src/platform/win32_windows.go +++ b/src/platform/win32_windows.go @@ -3,6 +3,7 @@ package platform import ( "errors" "oh-my-posh/regex" + "reflect" "strings" "syscall" "unicode/utf16" @@ -194,3 +195,76 @@ func readWinAppLink(path string) (string, error) { } return appExecLink.Path() } + +var ( + advapi = syscall.NewLazyDLL("advapi32.dll") + procGetAce = advapi.NewProc("GetAce") +) + +const ( + ACCESS_DENIED_ACE_TYPE = 1 //nolint: revive +) + +type AccessAllowedAce struct { + AceType uint8 + AceFlags uint8 + AceSize uint16 + AccessMask uint32 + SidStart uint32 +} + +func getCurrentUser() (sid *windows.SID, err error) { + token := windows.GetCurrentProcessToken() + defer token.Close() + + tokenuser, err := token.GetTokenUser() + sid = tokenuser.User.Sid + return +} + +func isWriteable(folder string) bool { + cu, err := getCurrentUser() + if err != nil { + // unable to get current user + return false + } + + si, err := windows.GetNamedSecurityInfo(folder, windows.SE_FILE_OBJECT, windows.DACL_SECURITY_INFORMATION) + if err != nil { + return false + } + + dacl, _, err := si.DACL() + if err != nil || dacl == nil { + // no dacl implies full access + return true + } + + rs := reflect.ValueOf(dacl).Elem() + aceCount := rs.Field(3).Uint() + + for i := uint64(0); i < aceCount; i++ { + ace := &AccessAllowedAce{} + + ret, _, _ := procGetAce.Call(uintptr(unsafe.Pointer(dacl)), uintptr(i), uintptr(unsafe.Pointer(&ace))) + if ret == 0 { + return false + } + + if ace.AceType == ACCESS_DENIED_ACE_TYPE { + continue + } + + aceSid := (*windows.SID)(unsafe.Pointer(&ace.SidStart)) + + if !aceSid.Equals(cu) { + continue + } + + allowMask := ^(windows.GENERIC_WRITE | windows.GENERIC_ALL) + if ace.AccessMask&uint32(allowMask) != 0 { + return true + } + } + return false +}