diff --git a/src/segments/az.go b/src/segments/az.go index bb4657e5..b093eb97 100644 --- a/src/segments/az.go +++ b/src/segments/az.go @@ -2,6 +2,7 @@ package segments import ( "encoding/json" + "errors" "oh-my-posh/environment" "oh-my-posh/properties" "path/filepath" @@ -14,6 +15,8 @@ type Az struct { AzureSubscription Origin string + + configDir string } const ( @@ -100,6 +103,11 @@ func (a *Az) Init(props properties.Properties, env environment.Environment) { func (a *Az) Enabled() bool { source := a.props.GetString(Source, firstMatch) + var err error + a.configDir, err = a.ConfigDir() + if err != nil { + return false + } switch source { case firstMatch: return a.getCLISubscription() || a.getModuleSubscription() @@ -119,7 +127,7 @@ func (a *Az) FileContentWithoutBom(file string) string { func (a *Az) getCLISubscription() bool { var content string - profile := filepath.Join(a.ConfigHome(), "azureProfile.json") + profile := filepath.Join(a.configDir, "azureProfile.json") if content = a.FileContentWithoutBom(profile); len(content) == 0 { return false } @@ -139,9 +147,8 @@ func (a *Az) getCLISubscription() bool { func (a *Az) getModuleSubscription() bool { var content string - cfgHome := a.ConfigHome() profiles := []string{ - filepath.Join(cfgHome, "AzureRmContext.json"), + filepath.Join(a.configDir, "AzureRmContext.json"), } for _, profile := range profiles { if content = a.FileContentWithoutBom(profile); len(content) != 0 { @@ -173,10 +180,16 @@ func (a *Az) getModuleSubscription() bool { return true } -func (a *Az) ConfigHome() string { - cfgHome := a.env.Getenv("AZURE_CONFIG_DIR") - if len(cfgHome) != 0 { - return cfgHome +func (a *Az) ConfigDir() (string, error) { + configDirs := []string{ + a.env.Getenv("AZURE_CONFIG_DIR"), + filepath.Join(a.env.Home(), ".azure"), + filepath.Join(a.env.Home(), ".Azure"), } - return filepath.Join(a.env.Home(), ".azure") + for _, dir := range configDirs { + if len(dir) != 0 && a.env.HasFolder(dir) { + return dir, nil + } + } + return "", errors.New("azure config dir not found") } diff --git a/src/segments/az_test.go b/src/segments/az_test.go index 5e90848d..c396eedc 100644 --- a/src/segments/az_test.go +++ b/src/segments/az_test.go @@ -124,6 +124,7 @@ func TestAzSegment(t *testing.T) { env.On("FileContent", filepath.Join(home, ".azure", "azureProfile.json")).Return(azureProfile) env.On("FileContent", filepath.Join(home, ".azure", "AzureRmContext.json")).Return(azureRmContext) env.On("Getenv", "AZURE_CONFIG_DIR").Return("") + env.On("HasFolder", filepath.Clean("/Users/posh/.azure")).Return(true) if tc.Source == "" { tc.Source = firstMatch }