// Copyright 2023 The Prometheus Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package azuread import ( "context" "net/http" "os" "strings" "testing" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "gopkg.in/yaml.v2" ) const ( dummyAudience = "dummyAudience" dummyClientID = "00000000-0000-0000-0000-000000000000" dummyClientSecret = "Cl1ent$ecret!" dummyTenantID = "00000000-a12b-3cd4-e56f-000000000000" testTokenString = "testTokenString" ) var testTokenExpiry = time.Now().Add(5 * time.Second) type AzureAdTestSuite struct { suite.Suite mockCredential *mockCredential } type TokenProviderTestSuite struct { suite.Suite mockCredential *mockCredential } // mockCredential mocks azidentity TokenCredential interface. type mockCredential struct { mock.Mock } func (ad *AzureAdTestSuite) BeforeTest(_, _ string) { ad.mockCredential = new(mockCredential) } func TestAzureAd(t *testing.T) { suite.Run(t, new(AzureAdTestSuite)) } func (ad *AzureAdTestSuite) TestAzureAdRoundTripper() { cases := []struct { cfg *AzureADConfig }{ // AzureAd roundtripper with Managedidentity. { cfg: &AzureADConfig{ Cloud: "AzurePublic", ManagedIdentity: &ManagedIdentityConfig{ ClientID: dummyClientID, }, }, }, // AzureAd roundtripper with OAuth. { cfg: &AzureADConfig{ Cloud: "AzurePublic", OAuth: &OAuthConfig{ ClientID: dummyClientID, ClientSecret: dummyClientSecret, TenantID: dummyTenantID, }, }, }, } for _, c := range cases { var gotReq *http.Request testToken := &azcore.AccessToken{ Token: testTokenString, ExpiresOn: testTokenExpiry, } ad.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil) tokenProvider, err := newTokenProvider(c.cfg, ad.mockCredential) ad.Assert().NoError(err) rt := &azureADRoundTripper{ next: promhttp.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { gotReq = req return &http.Response{StatusCode: http.StatusOK}, nil }), tokenProvider: tokenProvider, } cli := &http.Client{Transport: rt} req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!")) ad.Assert().NoError(err) _, err = cli.Do(req) ad.Assert().NoError(err) ad.Assert().NotNil(gotReq) origReq := gotReq ad.Assert().NotEmpty(origReq.Header.Get("Authorization")) ad.Assert().Equal("Bearer "+testTokenString, origReq.Header.Get("Authorization")) } } func loadAzureAdConfig(filename string) (*AzureADConfig, error) { content, err := os.ReadFile(filename) if err != nil { return nil, err } cfg := AzureADConfig{} if err = yaml.UnmarshalStrict(content, &cfg); err != nil { return nil, err } return &cfg, nil } func TestAzureAdConfig(t *testing.T) { cases := []struct { filename string err string }{ // Missing managedidentiy or oauth field. { filename: "testdata/azuread_bad_configmissing.yaml", err: "must provide an Azure Managed Identity or Azure OAuth in the Azure AD config", }, // Invalid managedidentity client id. { filename: "testdata/azuread_bad_invalidclientid.yaml", err: "the provided Azure Managed Identity client_id is invalid", }, // Missing tenant id in oauth config. { filename: "testdata/azuread_bad_invalidoauthconfig.yaml", err: "must provide an Azure OAuth tenant_id in the Azure AD config", }, // Invalid config when both managedidentity and oauth is provided. { filename: "testdata/azuread_bad_twoconfig.yaml", err: "cannot provide both Azure Managed Identity and Azure OAuth in the Azure AD config", }, // Valid config with missing optionally cloud field. { filename: "testdata/azuread_good_cloudmissing.yaml", }, // Valid managed identity config. { filename: "testdata/azuread_good_managedidentity.yaml", }, // Valid Oauth config. { filename: "testdata/azuread_good_oauth.yaml", }, } for _, c := range cases { _, err := loadAzureAdConfig(c.filename) if c.err != "" { if err == nil { t.Fatalf("Did not receive expected error unmarshaling bad azuread config") } require.EqualError(t, err, c.err) } else { require.NoError(t, err) } } } func (m *mockCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { args := m.MethodCalled("GetToken", ctx, options) if args.Get(0) == nil { return azcore.AccessToken{}, args.Error(1) } return args.Get(0).(azcore.AccessToken), nil } func (s *TokenProviderTestSuite) BeforeTest(_, _ string) { s.mockCredential = new(mockCredential) } func TestTokenProvider(t *testing.T) { suite.Run(t, new(TokenProviderTestSuite)) } func (s *TokenProviderTestSuite) TestNewTokenProvider() { cases := []struct { cfg *AzureADConfig err string }{ // Invalid tokenProvider for managedidentity. { cfg: &AzureADConfig{ Cloud: "PublicAzure", ManagedIdentity: &ManagedIdentityConfig{ ClientID: dummyClientID, }, }, err: "Cloud is not specified or is incorrect: ", }, // Invalid tokenProvider for oauth. { cfg: &AzureADConfig{ Cloud: "PublicAzure", OAuth: &OAuthConfig{ ClientID: dummyClientID, ClientSecret: dummyClientSecret, TenantID: dummyTenantID, }, }, err: "Cloud is not specified or is incorrect: ", }, // Valid tokenProvider for managedidentity. { cfg: &AzureADConfig{ Cloud: "AzurePublic", ManagedIdentity: &ManagedIdentityConfig{ ClientID: dummyClientID, }, }, }, // Valid tokenProvider for oauth. { cfg: &AzureADConfig{ Cloud: "AzurePublic", OAuth: &OAuthConfig{ ClientID: dummyClientID, ClientSecret: dummyClientSecret, TenantID: dummyTenantID, }, }, }, } mockGetTokenCallCounter := 1 for _, c := range cases { if c.err != "" { actualTokenProvider, actualErr := newTokenProvider(c.cfg, s.mockCredential) s.Assert().Nil(actualTokenProvider) s.Assert().NotNil(actualErr) s.Assert().ErrorContains(actualErr, c.err) } else { testToken := &azcore.AccessToken{ Token: testTokenString, ExpiresOn: testTokenExpiry, } s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil).Once(). On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil) actualTokenProvider, actualErr := newTokenProvider(c.cfg, s.mockCredential) s.Assert().NotNil(actualTokenProvider) s.Assert().Nil(actualErr) s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background())) // Token set to refresh at half of the expiry time. The test tokens are set to expiry in 5s. // Hence, the 4 seconds wait to check if the token is refreshed. time.Sleep(4 * time.Second) s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background())) s.mockCredential.AssertNumberOfCalls(s.T(), "GetToken", 2*mockGetTokenCallCounter) mockGetTokenCallCounter += 1 accessToken, err := actualTokenProvider.getAccessToken(context.Background()) s.Assert().Nil(err) s.Assert().NotEqual(accessToken, testTokenString) } } } func getToken() azcore.AccessToken { return azcore.AccessToken{ Token: uuid.New().String(), ExpiresOn: time.Now().Add(10 * time.Second), } }