diff --git a/src/segments/az.go b/src/segments/az.go index 8950c7e4..94ca6e6c 100644 --- a/src/segments/az.go +++ b/src/segments/az.go @@ -101,7 +101,7 @@ func (a *Az) FileContentWithoutBom(file string) string { func (a *Az) getAzureProfile() bool { var content string - profile := filepath.Join(a.env.Home(), ".azure", "azureProfile.json") + profile := filepath.Join(a.ConfigHome(), "azureProfile.json") if content = a.FileContentWithoutBom(profile); len(content) == 0 { return false } @@ -121,9 +121,9 @@ func (a *Az) getAzureProfile() bool { func (a *Az) getAzureRmContext() bool { var content string + cfgHome := a.ConfigHome() profiles := []string{ - filepath.Join(a.env.Home(), ".azure", "AzureRmContext.json"), - filepath.Join(a.env.Home(), ".Azure", "AzureRmContext.json"), + filepath.Join(cfgHome, "AzureRmContext.json"), } for _, profile := range profiles { if content = a.FileContentWithoutBom(profile); len(content) != 0 { @@ -153,3 +153,11 @@ func (a *Az) getAzureRmContext() bool { a.Origin = "PWSH" return true } + +func (a *Az) ConfigHome() string { + cfgHome := a.env.Getenv("AZURE_CONFIG_DIR") + if len(cfgHome) != 0 { + return cfgHome + } + return filepath.Join(a.env.Home(), ".azure") +} diff --git a/src/segments/az_test.go b/src/segments/az_test.go index 64b0838f..20a09c82 100644 --- a/src/segments/az_test.go +++ b/src/segments/az_test.go @@ -14,13 +14,12 @@ import ( func TestAzSegment(t *testing.T) { cases := []struct { - Case string - ExpectedEnabled bool - ExpectedString string - HasCLI bool - HasPowerShell bool - HasPowerShellUnix bool - Template string + Case string + ExpectedEnabled bool + ExpectedString string + HasCLI bool + HasPowerShell bool + Template string }{ { Case: "no config files found", @@ -41,11 +40,11 @@ func TestAzSegment(t *testing.T) { HasPowerShell: true, }, { - Case: "Az Pwsh Profile", - ExpectedEnabled: true, - ExpectedString: "AzurePoshCloud", - Template: "{{ .EnvironmentName }}", - HasPowerShellUnix: true, + Case: "Az Pwsh Profile", + ExpectedEnabled: true, + ExpectedString: "AzurePoshCloud", + Template: "{{ .EnvironmentName }}", + HasPowerShell: true, }, { Case: "Faulty template", @@ -74,7 +73,7 @@ func TestAzSegment(t *testing.T) { env := new(mock.MockedEnvironment) home := "/Users/posh" env.On("Home").Return(home) - var azureProfile, azureRmContext, azureRMContext string + var azureProfile, azureRmContext string if tc.HasCLI { content, _ := ioutil.ReadFile("../test/azureProfile.json") azureProfile = string(content) @@ -83,14 +82,10 @@ func TestAzSegment(t *testing.T) { content, _ := ioutil.ReadFile("../test/AzureRmContext.json") azureRmContext = string(content) } - if tc.HasPowerShellUnix { - content, _ := ioutil.ReadFile("../test/AzureRmContext.json") - azureRMContext = string(content) - } env.On("GOOS").Return(environment.LinuxPlatform) env.On("FileContent", filepath.Join(home, ".azure", "azureProfile.json")).Return(azureProfile) - env.On("FileContent", filepath.Join(home, ".Azure", "AzureRmContext.json")).Return(azureRmContext) - env.On("FileContent", filepath.Join(home, ".azure", "AzureRmContext.json")).Return(azureRMContext) + env.On("FileContent", filepath.Join(home, ".azure", "AzureRmContext.json")).Return(azureRmContext) + env.On("Getenv", "AZURE_CONFIG_DIR").Return("") az := &Az{ env: env, props: properties.Map{},