diff --git a/src/environment.go b/src/environment.go index b7df4e8a..72765b45 100644 --- a/src/environment.go +++ b/src/environment.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "io/ioutil" "net/http" "os" @@ -30,6 +31,12 @@ func (e *commandError) Error() string { return e.err } +type fileInfo struct { + parentFolder string + path string + isDir bool +} + type environmentInfo interface { getenv(key string) string getcwd() string @@ -54,6 +61,7 @@ type environmentInfo interface { getShellName() string getWindowTitle(imageName, windowTitleRegex string) (string, error) doGet(url string) ([]byte, error) + hasParentFilePath(path string) (fileInfo *fileInfo, err error) } type environment struct { @@ -239,6 +247,29 @@ func (env *environment) doGet(url string) ([]byte, error) { return body, nil } +func (env *environment) hasParentFilePath(path string) (*fileInfo, error) { + currentFolder := env.getcwd() + for { + searchPath := filepath.Join(currentFolder, path) + info, err := os.Stat(searchPath) + if err == nil { + return &fileInfo{ + parentFolder: currentFolder, + path: searchPath, + isDir: info.IsDir(), + }, nil + } + if !os.IsNotExist(err) { + return nil, err + } + if dir := filepath.Dir(currentFolder); dir != currentFolder { + currentFolder = dir + continue + } + return nil, errors.New("no match at root level") + } +} + func cleanHostName(hostName string) string { garbage := []string{ ".lan", diff --git a/src/segment_git.go b/src/segment_git.go index cc08f0ce..9724bef7 100644 --- a/src/segment_git.go +++ b/src/segment_git.go @@ -15,7 +15,7 @@ type gitRepo struct { HEAD string upstream string stashCount string - root string + gitFolder string } type gitStatus struct { @@ -121,8 +121,24 @@ func (g *git) enabled() bool { if !g.env.hasCommand("git") { return false } - output, _ := g.env.runCommand(gitCommand, "rev-parse", "--is-inside-work-tree") - return output == "true" + gitdir, err := g.env.hasParentFilePath(".git") + if err != nil { + return false + } + g.repo = &gitRepo{} + if gitdir.isDir { + g.repo.gitFolder = gitdir.path + return true + } + // handle worktree + dirPointer := g.env.getFileContent(gitdir.path) + dirPointer = strings.Trim(dirPointer, " \r\n") + matches := findNamedRegexMatch(`^gitdir: (?P.*)$`, dirPointer) + if matches != nil && matches["dir"] != "" { + g.repo.gitFolder = matches["dir"] + return true + } + return false } func (g *git) string() string { @@ -198,8 +214,6 @@ func (g *git) getUpstreamSymbol() string { } func (g *git) setGitStatus() { - g.repo = &gitRepo{} - g.repo.root = g.getGitCommandOutput("rev-parse", "--show-toplevel") output := g.getGitCommandOutput("status", "-unormal", "--short", "--branch") splittedOutput := strings.Split(output, "\n") g.repo.working = g.parseGitStats(splittedOutput, true) @@ -285,17 +299,17 @@ func (g *git) getGitHEADContext(ref string) string { } func (g *git) hasGitFile(file string) bool { - files := fmt.Sprintf(".git/%s", file) - return g.env.hasFilesInDir(g.repo.root, files) + return g.env.hasFilesInDir(g.repo.gitFolder, file) } func (g *git) hasGitFolder(folder string) bool { - path := fmt.Sprintf("%s/.git/%s", g.repo.root, folder) + path := g.repo.gitFolder + "/" + folder return g.env.hasFolder(path) } func (g *git) getGitFileContents(file string) string { - content := g.env.getFileContent(fmt.Sprintf("%s/.git/%s", g.repo.root, file)) + path := g.repo.gitFolder + "/" + file + content := g.env.getFileContent(path) return strings.Trim(content, " \r\n") } diff --git a/src/segment_git_test.go b/src/segment_git_test.go index dd4159ad..140aa2ea 100644 --- a/src/segment_git_test.go +++ b/src/segment_git_test.go @@ -22,11 +22,34 @@ func TestEnabledGitNotFound(t *testing.T) { func TestEnabledInWorkingDirectory(t *testing.T) { env := new(MockedEnvironment) env.On("hasCommand", "git").Return(true) - env.On("runCommand", "git", []string{"rev-parse", "--is-inside-work-tree"}).Return("true", nil) + fileInfo := &fileInfo{ + path: "/dir/hello", + parentFolder: "/dir", + isDir: true, + } + env.On("hasParentFilePath", ".git").Return(fileInfo, nil) g := &git{ env: env, } assert.True(t, g.enabled()) + assert.Equal(t, fileInfo.path, g.repo.gitFolder) +} + +func TestEnabledInWorkingTree(t *testing.T) { + env := new(MockedEnvironment) + env.On("hasCommand", "git").Return(true) + fileInfo := &fileInfo{ + path: "/dir/hello", + parentFolder: "/dir", + isDir: false, + } + env.On("hasParentFilePath", ".git").Return(fileInfo, nil) + env.On("getFileContent", "/dir/hello").Return("gitdir: /dir/hello/burp/burp") + g := &git{ + env: env, + } + assert.True(t, g.enabled()) + assert.Equal(t, "/dir/hello/burp/burp", g.repo.gitFolder) } func TestGetGitOutputForCommand(t *testing.T) { @@ -61,19 +84,19 @@ type detachedContext struct { func setupHEADContextEnv(context *detachedContext) *git { env := new(MockedEnvironment) - env.On("hasFolder", "/.git/rebase-merge").Return(context.rebaseMerge) - env.On("hasFolder", "/.git/rebase-apply").Return(context.rebaseApply) - env.On("getFileContent", "/.git/rebase-merge/orig-head").Return(context.origin) - env.On("getFileContent", "/.git/rebase-merge/onto").Return(context.onto) - env.On("getFileContent", "/.git/rebase-merge/msgnum").Return(context.step) - env.On("getFileContent", "/.git/rebase-apply/next").Return(context.step) - env.On("getFileContent", "/.git/rebase-merge/end").Return(context.total) - env.On("getFileContent", "/.git/rebase-apply/last").Return(context.total) - env.On("getFileContent", "/.git/rebase-apply/head-name").Return(context.origin) - env.On("getFileContent", "/.git/CHERRY_PICK_HEAD").Return(context.cherryPickSHA) - env.On("getFileContent", "/.git/MERGE_HEAD").Return(context.mergeHEAD) - env.On("hasFilesInDir", "", ".git/CHERRY_PICK_HEAD").Return(context.cherryPick) - env.On("hasFilesInDir", "", ".git/MERGE_HEAD").Return(context.merge) + env.On("hasFolder", "/rebase-merge").Return(context.rebaseMerge) + env.On("hasFolder", "/rebase-apply").Return(context.rebaseApply) + env.On("getFileContent", "/rebase-merge/orig-head").Return(context.origin) + env.On("getFileContent", "/rebase-merge/onto").Return(context.onto) + env.On("getFileContent", "/rebase-merge/msgnum").Return(context.step) + env.On("getFileContent", "/rebase-apply/next").Return(context.step) + env.On("getFileContent", "/rebase-merge/end").Return(context.total) + env.On("getFileContent", "/rebase-apply/last").Return(context.total) + env.On("getFileContent", "/rebase-apply/head-name").Return(context.origin) + env.On("getFileContent", "/CHERRY_PICK_HEAD").Return(context.cherryPickSHA) + env.On("getFileContent", "/MERGE_HEAD").Return(context.mergeHEAD) + env.On("hasFilesInDir", "", "CHERRY_PICK_HEAD").Return(context.cherryPick) + env.On("hasFilesInDir", "", "MERGE_HEAD").Return(context.merge) env.mockGitCommand(context.currentCommit, "rev-parse", "--short", "HEAD") env.mockGitCommand(context.tagName, "describe", "--tags", "--exact-match") env.mockGitCommand(context.origin, "name-rev", "--name-only", "--exclude=tags/*", context.origin) @@ -83,7 +106,7 @@ func setupHEADContextEnv(context *detachedContext) *git { g := &git{ env: env, repo: &gitRepo{ - root: "", + gitFolder: "", }, } return g diff --git a/src/segment_path_test.go b/src/segment_path_test.go index 21479259..ca546f8e 100644 --- a/src/segment_path_test.go +++ b/src/segment_path_test.go @@ -127,6 +127,11 @@ func (env *MockedEnvironment) doGet(url string) ([]byte, error) { return args.Get(0).([]byte), args.Error(1) } +func (env *MockedEnvironment) hasParentFilePath(path string) (*fileInfo, error) { + args := env.Called(path) + return args.Get(0).(*fileInfo), args.Error(1) +} + const ( homeBill = "/home/bill" homeJan = "/usr/home/jan"