diff --git a/config/config.go b/config/config.go index 9f81bbfd57..d32fcc33c9 100644 --- a/config/config.go +++ b/config/config.go @@ -34,6 +34,7 @@ import ( "github.com/prometheus/prometheus/discovery" "github.com/prometheus/prometheus/model/labels" "github.com/prometheus/prometheus/model/relabel" + "github.com/prometheus/prometheus/storage/remote/azuread" ) var ( @@ -907,6 +908,7 @@ type RemoteWriteConfig struct { QueueConfig QueueConfig `yaml:"queue_config,omitempty"` MetadataConfig MetadataConfig `yaml:"metadata_config,omitempty"` SigV4Config *sigv4.SigV4Config `yaml:"sigv4,omitempty"` + AzureADConfig *azuread.AzureADConfig `yaml:"azuread,omitempty"` } // SetDirectory joins any relative file paths with dir. @@ -943,8 +945,12 @@ func (c *RemoteWriteConfig) UnmarshalYAML(unmarshal func(interface{}) error) err httpClientConfigAuthEnabled := c.HTTPClientConfig.BasicAuth != nil || c.HTTPClientConfig.Authorization != nil || c.HTTPClientConfig.OAuth2 != nil - if httpClientConfigAuthEnabled && c.SigV4Config != nil { - return fmt.Errorf("at most one of basic_auth, authorization, oauth2, & sigv4 must be configured") + if httpClientConfigAuthEnabled && (c.SigV4Config != nil || c.AzureADConfig != nil) { + return fmt.Errorf("at most one of basic_auth, authorization, oauth2, sigv4, & azuread must be configured") + } + + if c.SigV4Config != nil && c.AzureADConfig != nil { + return fmt.Errorf("at most one of basic_auth, authorization, oauth2, sigv4, & azuread must be configured") } return nil @@ -965,7 +971,7 @@ func validateHeadersForTracing(headers map[string]string) error { func validateHeaders(headers map[string]string) error { for header := range headers { if strings.ToLower(header) == "authorization" { - return errors.New("authorization header must be changed via the basic_auth, authorization, oauth2, or sigv4 parameter") + return errors.New("authorization header must be changed via the basic_auth, authorization, oauth2, sigv4, or azuread parameter") } if _, ok := reservedHeaders[strings.ToLower(header)]; ok { return fmt.Errorf("%s is a reserved header. It must not be changed", header) diff --git a/config/config_test.go b/config/config_test.go index d243d687c4..d3288cc90d 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1727,7 +1727,7 @@ var expectedErrors = []struct { }, { filename: "remote_write_authorization_header.bad.yml", - errMsg: `authorization header must be changed via the basic_auth, authorization, oauth2, or sigv4 parameter`, + errMsg: `authorization header must be changed via the basic_auth, authorization, oauth2, sigv4, or azuread parameter`, }, { filename: "remote_write_url_missing.bad.yml", diff --git a/docs/configuration/configuration.md b/docs/configuration/configuration.md index ff1449e34a..3a9ace2b6c 100644 --- a/docs/configuration/configuration.md +++ b/docs/configuration/configuration.md @@ -3466,7 +3466,7 @@ authorization: [ credentials_file: ] # Optionally configures AWS's Signature Verification 4 signing process to -# sign requests. Cannot be set at the same time as basic_auth, authorization, or oauth2. +# sign requests. Cannot be set at the same time as basic_auth, authorization, oauth2, or azuread. # To use the default credentials from the AWS SDK, use `sigv4: {}`. sigv4: # The AWS region. If blank, the region from the default credentials chain @@ -3485,10 +3485,20 @@ sigv4: [ role_arn: ] # Optional OAuth 2.0 configuration. -# Cannot be used at the same time as basic_auth, authorization, or sigv4. +# Cannot be used at the same time as basic_auth, authorization, sigv4, or azuread. oauth2: [ ] +# Optional AzureAD configuration. +# Cannot be used at the same time as basic_auth, authorization, oauth2, or sigv4. +azuread: + # The Azure Cloud. Options are 'AzurePublic', 'AzureChina', or 'AzureGovernment'. + [ cloud: | default = AzurePublic ] + + # Azure User-assigned Managed identity. + [ managed_identity: + [ client_id: ] + # Configures the remote write request's TLS settings. tls_config: [ ] diff --git a/go.mod b/go.mod index 83e4b96774..9e3f77c4af 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,8 @@ go 1.19 require ( github.com/Azure/azure-sdk-for-go v65.0.0+incompatible + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1 github.com/Azure/go-autorest/autorest v0.11.28 github.com/Azure/go-autorest/autorest/adal v0.9.23 github.com/DmitriyVTitov/size v1.5.0 @@ -85,11 +87,16 @@ require ( require ( cloud.google.com/go/compute/metadata v0.2.3 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v0.8.1 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/dustin/go-humanize v1.0.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/xhit/go-str2duration/v2 v2.1.0 // indirect google.golang.org/genproto v0.0.0-20230526203410-71b5a4ffd15e // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230526203410-71b5a4ffd15e // indirect @@ -138,7 +145,7 @@ require ( github.com/google/go-cmp v0.5.9 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/gofuzz v1.2.0 // indirect - github.com/google/uuid v1.3.0 // indirect + github.com/google/uuid v1.3.0 github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.7.1 // indirect github.com/gorilla/websocket v1.5.0 // indirect diff --git a/go.sum b/go.sum index 6332e74709..92702a64df 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,12 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9 dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/Azure/azure-sdk-for-go v65.0.0+incompatible h1:HzKLt3kIwMm4KeJYTdx9EbjRYTySD/t8i1Ee/W5EGXw= github.com/Azure/azure-sdk-for-go v65.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1 h1:gVXuXcWd1i4C2Ruxe321aU+IKGaStvGB/S90PUPB/W8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1/go.mod h1:DffdKW9RFqa5VgmsjUOsS7UE7eiA5iAvYUs63bhKQ0M= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1 h1:T8quHYlUGyb/oqtSTwqlCr1ilJHrDv+ZtpSfo+hm1BU= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1/go.mod h1:gLa1CL2RNE4s7M3yopJ/p0iq5DdY6Yv5ZUt9MTRZOQM= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 h1:+5VZ72z0Qan5Bog5C+ZkgSqUbeVUd9wgtHOrIKuc5b8= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= @@ -60,6 +66,8 @@ github.com/Azure/go-autorest/logger v0.2.1 h1:IG7i4p/mDa2Ce4TRyAO8IHnVhAVF3RFU+Z github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8= github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= +github.com/AzureAD/microsoft-authentication-library-for-go v0.8.1 h1:oPdPEZFSbl7oSPEAIPMPBMUmiL+mqgzBJwM/9qYcwNg= +github.com/AzureAD/microsoft-authentication-library-for-go v0.8.1/go.mod h1:4qFor3D/HDsvBME35Xy9rwW9DecL+M2sNw1ybjPtwA0= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= @@ -523,6 +531,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= github.com/linode/linodego v1.16.1 h1:5otq57M4PdHycPERRfSFZ0s1yz1ETVWGjCp3hh7+F9w= @@ -638,6 +648,8 @@ github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAv github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac= github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= +github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 h1:Qj1ukM4GlMWXNdMBuXcXfz/Kw9s1qm0CLY32QxuSImI= +github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4/go.mod h1:N6UoU20jOqggOuDwUaBQpluzLNDqif3kq9z2wpdYEfQ= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/storage/remote/azuread/README.md b/storage/remote/azuread/README.md new file mode 100644 index 0000000000..b3b4457a6e --- /dev/null +++ b/storage/remote/azuread/README.md @@ -0,0 +1,8 @@ +azuread package +========================================= + +azuread provides an http.RoundTripper that attaches an Azure AD accessToken +to remote write requests. + +This module is considered internal to Prometheus, without any stability +guarantees for external usage. diff --git a/storage/remote/azuread/azuread.go b/storage/remote/azuread/azuread.go new file mode 100644 index 0000000000..94b6144da1 --- /dev/null +++ b/storage/remote/azuread/azuread.go @@ -0,0 +1,247 @@ +// Copyright 2023 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package azuread + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/google/uuid" +) + +const ( + // Clouds. + AzureChina = "AzureChina" + AzureGovernment = "AzureGovernment" + AzurePublic = "AzurePublic" + + // Audiences. + IngestionChinaAudience = "https://monitor.azure.cn//.default" + IngestionGovernmentAudience = "https://monitor.azure.us//.default" + IngestionPublicAudience = "https://monitor.azure.com//.default" +) + +// ManagedIdentityConfig is used to store managed identity config values +type ManagedIdentityConfig struct { + // ClientID is the clientId of the managed identity that is being used to authenticate. + ClientID string `yaml:"client_id,omitempty"` +} + +// AzureADConfig is used to store the config values. +type AzureADConfig struct { // nolint:revive + // ManagedIdentity is the managed identity that is being used to authenticate. + ManagedIdentity *ManagedIdentityConfig `yaml:"managed_identity,omitempty"` + + // Cloud is the Azure cloud in which the service is running. Example: AzurePublic/AzureGovernment/AzureChina. + Cloud string `yaml:"cloud,omitempty"` +} + +// azureADRoundTripper is used to store the roundtripper and the tokenprovider. +type azureADRoundTripper struct { + next http.RoundTripper + tokenProvider *tokenProvider +} + +// tokenProvider is used to store and retrieve Azure AD accessToken. +type tokenProvider struct { + // token is member used to store the current valid accessToken. + token string + // mu guards access to token. + mu sync.Mutex + // refreshTime is used to store the refresh time of the current valid accessToken. + refreshTime time.Time + // credentialClient is the Azure AD credential client that is being used to retrieve accessToken. + credentialClient azcore.TokenCredential + options *policy.TokenRequestOptions +} + +// Validate validates config values provided. +func (c *AzureADConfig) Validate() error { + if c.Cloud == "" { + c.Cloud = AzurePublic + } + + if c.Cloud != AzureChina && c.Cloud != AzureGovernment && c.Cloud != AzurePublic { + return fmt.Errorf("must provide a cloud in the Azure AD config") + } + + if c.ManagedIdentity == nil { + return fmt.Errorf("must provide an Azure Managed Identity in the Azure AD config") + } + + if c.ManagedIdentity.ClientID == "" { + return fmt.Errorf("must provide an Azure Managed Identity client_id in the Azure AD config") + } + + _, err := uuid.Parse(c.ManagedIdentity.ClientID) + if err != nil { + return fmt.Errorf("the provided Azure Managed Identity client_id provided is invalid") + } + return nil +} + +// UnmarshalYAML unmarshal the Azure AD config yaml. +func (c *AzureADConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + type plain AzureADConfig + *c = AzureADConfig{} + if err := unmarshal((*plain)(c)); err != nil { + return err + } + return c.Validate() +} + +// NewAzureADRoundTripper creates round tripper adding Azure AD authorization to calls. +func NewAzureADRoundTripper(cfg *AzureADConfig, next http.RoundTripper) (http.RoundTripper, error) { + if next == nil { + next = http.DefaultTransport + } + + cred, err := newTokenCredential(cfg) + if err != nil { + return nil, err + } + + tokenProvider, err := newTokenProvider(cfg, cred) + if err != nil { + return nil, err + } + + rt := &azureADRoundTripper{ + next: next, + tokenProvider: tokenProvider, + } + return rt, nil +} + +// RoundTrip sets Authorization header for requests. +func (rt *azureADRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + accessToken, err := rt.tokenProvider.getAccessToken(req.Context()) + if err != nil { + return nil, err + } + bearerAccessToken := "Bearer " + accessToken + req.Header.Set("Authorization", bearerAccessToken) + + return rt.next.RoundTrip(req) +} + +// newTokenCredential returns a TokenCredential of different kinds like Azure Managed Identity and Azure AD application. +func newTokenCredential(cfg *AzureADConfig) (azcore.TokenCredential, error) { + cred, err := newManagedIdentityTokenCredential(cfg.ManagedIdentity.ClientID) + if err != nil { + return nil, err + } + + return cred, nil +} + +// newManagedIdentityTokenCredential returns new Managed Identity token credential. +func newManagedIdentityTokenCredential(managedIdentityClientID string) (azcore.TokenCredential, error) { + clientID := azidentity.ClientID(managedIdentityClientID) + opts := &azidentity.ManagedIdentityCredentialOptions{ID: clientID} + return azidentity.NewManagedIdentityCredential(opts) +} + +// newTokenProvider helps to fetch accessToken for different types of credential. This also takes care of +// refreshing the accessToken before expiry. This accessToken is attached to the Authorization header while making requests. +func newTokenProvider(cfg *AzureADConfig, cred azcore.TokenCredential) (*tokenProvider, error) { + audience, err := getAudience(cfg.Cloud) + if err != nil { + return nil, err + } + + tokenProvider := &tokenProvider{ + credentialClient: cred, + options: &policy.TokenRequestOptions{Scopes: []string{audience}}, + } + + return tokenProvider, nil +} + +// getAccessToken returns the current valid accessToken. +func (tokenProvider *tokenProvider) getAccessToken(ctx context.Context) (string, error) { + tokenProvider.mu.Lock() + defer tokenProvider.mu.Unlock() + if tokenProvider.valid() { + return tokenProvider.token, nil + } + err := tokenProvider.getToken(ctx) + if err != nil { + return "", errors.New("Failed to get access token: " + err.Error()) + } + return tokenProvider.token, nil +} + +// valid checks if the token in the token provider is valid and not expired. +func (tokenProvider *tokenProvider) valid() bool { + if len(tokenProvider.token) == 0 { + return false + } + if tokenProvider.refreshTime.After(time.Now().UTC()) { + return true + } + return false +} + +// getToken retrieves a new accessToken and stores the newly retrieved token in the tokenProvider. +func (tokenProvider *tokenProvider) getToken(ctx context.Context) error { + accessToken, err := tokenProvider.credentialClient.GetToken(ctx, *tokenProvider.options) + if err != nil { + return err + } + if len(accessToken.Token) == 0 { + return errors.New("access token is empty") + } + + tokenProvider.token = accessToken.Token + err = tokenProvider.updateRefreshTime(accessToken) + if err != nil { + return err + } + return nil +} + +// updateRefreshTime handles logic to set refreshTime. The refreshTime is set at half the duration of the actual token expiry. +func (tokenProvider *tokenProvider) updateRefreshTime(accessToken azcore.AccessToken) error { + tokenExpiryTimestamp := accessToken.ExpiresOn.UTC() + deltaExpirytime := time.Now().Add(time.Until(tokenExpiryTimestamp) / 2) + if deltaExpirytime.After(time.Now().UTC()) { + tokenProvider.refreshTime = deltaExpirytime + } else { + return errors.New("access token expiry is less than the current time") + } + return nil +} + +// getAudience returns audiences for different clouds. +func getAudience(cloud string) (string, error) { + switch strings.ToLower(cloud) { + case strings.ToLower(AzureChina): + return IngestionChinaAudience, nil + case strings.ToLower(AzureGovernment): + return IngestionGovernmentAudience, nil + case strings.ToLower(AzurePublic): + return IngestionPublicAudience, nil + default: + return "", errors.New("Cloud is not specified or is incorrect: " + cloud) + } +} diff --git a/storage/remote/azuread/azuread_test.go b/storage/remote/azuread/azuread_test.go new file mode 100644 index 0000000000..ebbd98958c --- /dev/null +++ b/storage/remote/azuread/azuread_test.go @@ -0,0 +1,252 @@ +// Copyright 2023 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package azuread + +import ( + "context" + "net/http" + "os" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "gopkg.in/yaml.v2" +) + +const ( + dummyAudience = "dummyAudience" + dummyClientID = "00000000-0000-0000-0000-000000000000" + testTokenString = "testTokenString" +) + +var testTokenExpiry = time.Now().Add(10 * time.Second) + +type AzureAdTestSuite struct { + suite.Suite + mockCredential *mockCredential +} + +type TokenProviderTestSuite struct { + suite.Suite + mockCredential *mockCredential +} + +// mockCredential mocks azidentity TokenCredential interface. +type mockCredential struct { + mock.Mock +} + +func (ad *AzureAdTestSuite) BeforeTest(_, _ string) { + ad.mockCredential = new(mockCredential) +} + +func TestAzureAd(t *testing.T) { + suite.Run(t, new(AzureAdTestSuite)) +} + +func (ad *AzureAdTestSuite) TestAzureAdRoundTripper() { + var gotReq *http.Request + + testToken := &azcore.AccessToken{ + Token: testTokenString, + ExpiresOn: testTokenExpiry, + } + + managedIdentityConfig := &ManagedIdentityConfig{ + ClientID: dummyClientID, + } + + azureAdConfig := &AzureADConfig{ + Cloud: "AzurePublic", + ManagedIdentity: managedIdentityConfig, + } + + ad.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil) + + tokenProvider, err := newTokenProvider(azureAdConfig, ad.mockCredential) + ad.Assert().NoError(err) + + rt := &azureADRoundTripper{ + next: promhttp.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + gotReq = req + return &http.Response{StatusCode: http.StatusOK}, nil + }), + tokenProvider: tokenProvider, + } + + cli := &http.Client{Transport: rt} + + req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!")) + ad.Assert().NoError(err) + + _, err = cli.Do(req) + ad.Assert().NoError(err) + ad.Assert().NotNil(gotReq) + + origReq := gotReq + ad.Assert().NotEmpty(origReq.Header.Get("Authorization")) + ad.Assert().Equal("Bearer "+testTokenString, origReq.Header.Get("Authorization")) +} + +func loadAzureAdConfig(filename string) (*AzureADConfig, error) { + content, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + cfg := AzureADConfig{} + if err = yaml.UnmarshalStrict(content, &cfg); err != nil { + return nil, err + } + return &cfg, nil +} + +func testGoodConfig(t *testing.T, filename string) { + _, err := loadAzureAdConfig(filename) + if err != nil { + t.Fatalf("Unexpected error parsing %s: %s", filename, err) + } +} + +func TestGoodAzureAdConfig(t *testing.T) { + filename := "testdata/azuread_good.yaml" + testGoodConfig(t, filename) +} + +func TestGoodCloudMissingAzureAdConfig(t *testing.T) { + filename := "testdata/azuread_good_cloudmissing.yaml" + testGoodConfig(t, filename) +} + +func TestBadClientIdMissingAzureAdConfig(t *testing.T) { + filename := "testdata/azuread_bad_clientidmissing.yaml" + _, err := loadAzureAdConfig(filename) + if err == nil { + t.Fatalf("Did not receive expected error unmarshaling bad azuread config") + } + if !strings.Contains(err.Error(), "must provide an Azure Managed Identity in the Azure AD config") { + t.Errorf("Received unexpected error from unmarshal of %s: %s", filename, err.Error()) + } +} + +func TestBadInvalidClientIdAzureAdConfig(t *testing.T) { + filename := "testdata/azuread_bad_invalidclientid.yaml" + _, err := loadAzureAdConfig(filename) + if err == nil { + t.Fatalf("Did not receive expected error unmarshaling bad azuread config") + } + if !strings.Contains(err.Error(), "the provided Azure Managed Identity client_id provided is invalid") { + t.Errorf("Received unexpected error from unmarshal of %s: %s", filename, err.Error()) + } +} + +func (m *mockCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { + args := m.MethodCalled("GetToken", ctx, options) + if args.Get(0) == nil { + return azcore.AccessToken{}, args.Error(1) + } + + return args.Get(0).(azcore.AccessToken), nil +} + +func (s *TokenProviderTestSuite) BeforeTest(_, _ string) { + s.mockCredential = new(mockCredential) +} + +func TestTokenProvider(t *testing.T) { + suite.Run(t, new(TokenProviderTestSuite)) +} + +func (s *TokenProviderTestSuite) TestNewTokenProvider_NilAudience_Fail() { + managedIdentityConfig := &ManagedIdentityConfig{ + ClientID: dummyClientID, + } + + azureAdConfig := &AzureADConfig{ + Cloud: "PublicAzure", + ManagedIdentity: managedIdentityConfig, + } + + actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential) + + s.Assert().Nil(actualTokenProvider) + s.Assert().NotNil(actualErr) + s.Assert().Equal("Cloud is not specified or is incorrect: "+azureAdConfig.Cloud, actualErr.Error()) +} + +func (s *TokenProviderTestSuite) TestNewTokenProvider_Success() { + managedIdentityConfig := &ManagedIdentityConfig{ + ClientID: dummyClientID, + } + + azureAdConfig := &AzureADConfig{ + Cloud: "AzurePublic", + ManagedIdentity: managedIdentityConfig, + } + s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil) + + actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential) + + s.Assert().NotNil(actualTokenProvider) + s.Assert().Nil(actualErr) + s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background())) +} + +func (s *TokenProviderTestSuite) TestPeriodicTokenRefresh_Success() { + // setup + managedIdentityConfig := &ManagedIdentityConfig{ + ClientID: dummyClientID, + } + + azureAdConfig := &AzureADConfig{ + Cloud: "AzurePublic", + ManagedIdentity: managedIdentityConfig, + } + testToken := &azcore.AccessToken{ + Token: testTokenString, + ExpiresOn: testTokenExpiry, + } + + s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil).Once(). + On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil) + + actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential) + + s.Assert().NotNil(actualTokenProvider) + s.Assert().Nil(actualErr) + s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background())) + + // Token set to refresh at half of the expiry time. The test tokens are set to expiry in 10s. + // Hence, the 6 seconds wait to check if the token is refreshed. + time.Sleep(6 * time.Second) + + s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background())) + + s.mockCredential.AssertNumberOfCalls(s.T(), "GetToken", 2) + accessToken, err := actualTokenProvider.getAccessToken(context.Background()) + s.Assert().Nil(err) + s.Assert().NotEqual(accessToken, testTokenString) +} + +func getToken() azcore.AccessToken { + return azcore.AccessToken{ + Token: uuid.New().String(), + ExpiresOn: time.Now().Add(10 * time.Second), + } +} diff --git a/storage/remote/azuread/testdata/azuread_bad_clientidmissing.yaml b/storage/remote/azuread/testdata/azuread_bad_clientidmissing.yaml new file mode 100644 index 0000000000..68b119cd42 --- /dev/null +++ b/storage/remote/azuread/testdata/azuread_bad_clientidmissing.yaml @@ -0,0 +1 @@ +cloud: AzurePublic diff --git a/storage/remote/azuread/testdata/azuread_bad_invalidclientid.yaml b/storage/remote/azuread/testdata/azuread_bad_invalidclientid.yaml new file mode 100644 index 0000000000..1f72fbb71f --- /dev/null +++ b/storage/remote/azuread/testdata/azuread_bad_invalidclientid.yaml @@ -0,0 +1,3 @@ +cloud: AzurePublic +managed_identity: + client_id: foo-foobar-bar-foo-00000000 diff --git a/storage/remote/azuread/testdata/azuread_good.yaml b/storage/remote/azuread/testdata/azuread_good.yaml new file mode 100644 index 0000000000..de39f0a060 --- /dev/null +++ b/storage/remote/azuread/testdata/azuread_good.yaml @@ -0,0 +1,3 @@ +cloud: AzurePublic +managed_identity: + client_id: 00000000-0000-0000-0000-000000000000 diff --git a/storage/remote/azuread/testdata/azuread_good_cloudmissing.yaml b/storage/remote/azuread/testdata/azuread_good_cloudmissing.yaml new file mode 100644 index 0000000000..bef6318743 --- /dev/null +++ b/storage/remote/azuread/testdata/azuread_good_cloudmissing.yaml @@ -0,0 +1,2 @@ +managed_identity: + client_id: 00000000-0000-0000-0000-000000000000 diff --git a/storage/remote/client.go b/storage/remote/client.go index 1625c9918d..33774203c5 100644 --- a/storage/remote/client.go +++ b/storage/remote/client.go @@ -36,6 +36,7 @@ import ( "go.opentelemetry.io/otel/trace" "github.com/prometheus/prometheus/prompb" + "github.com/prometheus/prometheus/storage/remote/azuread" ) const maxErrMsgLen = 1024 @@ -97,6 +98,7 @@ type ClientConfig struct { Timeout model.Duration HTTPClientConfig config_util.HTTPClientConfig SigV4Config *sigv4.SigV4Config + AzureADConfig *azuread.AzureADConfig Headers map[string]string RetryOnRateLimit bool } @@ -146,6 +148,13 @@ func NewWriteClient(name string, conf *ClientConfig) (WriteClient, error) { } } + if conf.AzureADConfig != nil { + t, err = azuread.NewAzureADRoundTripper(conf.AzureADConfig, httpClient.Transport) + if err != nil { + return nil, err + } + } + if len(conf.Headers) > 0 { t = newInjectHeadersRoundTripper(conf.Headers, t) } diff --git a/storage/remote/write.go b/storage/remote/write.go index 6a33c0adf9..4b0a249014 100644 --- a/storage/remote/write.go +++ b/storage/remote/write.go @@ -158,6 +158,7 @@ func (rws *WriteStorage) ApplyConfig(conf *config.Config) error { Timeout: rwConf.RemoteTimeout, HTTPClientConfig: rwConf.HTTPClientConfig, SigV4Config: rwConf.SigV4Config, + AzureADConfig: rwConf.AzureADConfig, Headers: rwConf.Headers, RetryOnRateLimit: rwConf.QueueConfig.RetryOnRateLimit, })