Add Azure AD package for remote write (#11944)

* Add Azure AD package for remote write
* Made AzurePublic default and updated configuration.md
* Updated config structure and removed getToken at initialization
* Changed passing context from request

Signed-off-by: Rakshith Padmanabha <rapadman@microsoft.com>
Signed-off-by: rakshith210 <rakshith.me@gmail.com>
This commit is contained in:
rakshith210 2023-06-01 14:20:10 -07:00 committed by GitHub
parent a8772a4178
commit b1675e23af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 568 additions and 7 deletions

View file

@ -34,6 +34,7 @@ import (
"github.com/prometheus/prometheus/discovery"
"github.com/prometheus/prometheus/model/labels"
"github.com/prometheus/prometheus/model/relabel"
"github.com/prometheus/prometheus/storage/remote/azuread"
)
var (
@ -907,6 +908,7 @@ type RemoteWriteConfig struct {
QueueConfig QueueConfig `yaml:"queue_config,omitempty"`
MetadataConfig MetadataConfig `yaml:"metadata_config,omitempty"`
SigV4Config *sigv4.SigV4Config `yaml:"sigv4,omitempty"`
AzureADConfig *azuread.AzureADConfig `yaml:"azuread,omitempty"`
}
// SetDirectory joins any relative file paths with dir.
@ -943,8 +945,12 @@ func (c *RemoteWriteConfig) UnmarshalYAML(unmarshal func(interface{}) error) err
httpClientConfigAuthEnabled := c.HTTPClientConfig.BasicAuth != nil ||
c.HTTPClientConfig.Authorization != nil || c.HTTPClientConfig.OAuth2 != nil
if httpClientConfigAuthEnabled && c.SigV4Config != nil {
return fmt.Errorf("at most one of basic_auth, authorization, oauth2, & sigv4 must be configured")
if httpClientConfigAuthEnabled && (c.SigV4Config != nil || c.AzureADConfig != nil) {
return fmt.Errorf("at most one of basic_auth, authorization, oauth2, sigv4, & azuread must be configured")
}
if c.SigV4Config != nil && c.AzureADConfig != nil {
return fmt.Errorf("at most one of basic_auth, authorization, oauth2, sigv4, & azuread must be configured")
}
return nil
@ -965,7 +971,7 @@ func validateHeadersForTracing(headers map[string]string) error {
func validateHeaders(headers map[string]string) error {
for header := range headers {
if strings.ToLower(header) == "authorization" {
return errors.New("authorization header must be changed via the basic_auth, authorization, oauth2, or sigv4 parameter")
return errors.New("authorization header must be changed via the basic_auth, authorization, oauth2, sigv4, or azuread parameter")
}
if _, ok := reservedHeaders[strings.ToLower(header)]; ok {
return fmt.Errorf("%s is a reserved header. It must not be changed", header)

View file

@ -1727,7 +1727,7 @@ var expectedErrors = []struct {
},
{
filename: "remote_write_authorization_header.bad.yml",
errMsg: `authorization header must be changed via the basic_auth, authorization, oauth2, or sigv4 parameter`,
errMsg: `authorization header must be changed via the basic_auth, authorization, oauth2, sigv4, or azuread parameter`,
},
{
filename: "remote_write_url_missing.bad.yml",

View file

@ -3466,7 +3466,7 @@ authorization:
[ credentials_file: <filename> ]
# Optionally configures AWS's Signature Verification 4 signing process to
# sign requests. Cannot be set at the same time as basic_auth, authorization, or oauth2.
# sign requests. Cannot be set at the same time as basic_auth, authorization, oauth2, or azuread.
# To use the default credentials from the AWS SDK, use `sigv4: {}`.
sigv4:
# The AWS region. If blank, the region from the default credentials chain
@ -3485,10 +3485,20 @@ sigv4:
[ role_arn: <string> ]
# Optional OAuth 2.0 configuration.
# Cannot be used at the same time as basic_auth, authorization, or sigv4.
# Cannot be used at the same time as basic_auth, authorization, sigv4, or azuread.
oauth2:
[ <oauth2> ]
# Optional AzureAD configuration.
# Cannot be used at the same time as basic_auth, authorization, oauth2, or sigv4.
azuread:
# The Azure Cloud. Options are 'AzurePublic', 'AzureChina', or 'AzureGovernment'.
[ cloud: <string> | default = AzurePublic ]
# Azure User-assigned Managed identity.
[ managed_identity:
[ client_id: <string> ]
# Configures the remote write request's TLS settings.
tls_config:
[ <tls_config> ]

9
go.mod
View file

@ -4,6 +4,8 @@ go 1.19
require (
github.com/Azure/azure-sdk-for-go v65.0.0+incompatible
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1
github.com/Azure/go-autorest/autorest v0.11.28
github.com/Azure/go-autorest/autorest/adal v0.9.23
github.com/alecthomas/kingpin/v2 v2.3.2
@ -83,10 +85,15 @@ require (
require (
cloud.google.com/go/compute/metadata v0.2.3 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v0.8.1 // indirect
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 // indirect
github.com/rogpeppe/go-internal v1.10.0 // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/xhit/go-str2duration/v2 v2.1.0 // indirect
google.golang.org/genproto v0.0.0-20230526203410-71b5a4ffd15e // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230526203410-71b5a4ffd15e // indirect
@ -135,7 +142,7 @@ require (
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/google/uuid v1.3.0
github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect
github.com/googleapis/gax-go/v2 v2.7.1 // indirect
github.com/gorilla/websocket v1.5.0 // indirect

12
go.sum
View file

@ -38,6 +38,12 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/Azure/azure-sdk-for-go v65.0.0+incompatible h1:HzKLt3kIwMm4KeJYTdx9EbjRYTySD/t8i1Ee/W5EGXw=
github.com/Azure/azure-sdk-for-go v65.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1 h1:gVXuXcWd1i4C2Ruxe321aU+IKGaStvGB/S90PUPB/W8=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1/go.mod h1:DffdKW9RFqa5VgmsjUOsS7UE7eiA5iAvYUs63bhKQ0M=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1 h1:T8quHYlUGyb/oqtSTwqlCr1ilJHrDv+ZtpSfo+hm1BU=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1/go.mod h1:gLa1CL2RNE4s7M3yopJ/p0iq5DdY6Yv5ZUt9MTRZOQM=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 h1:+5VZ72z0Qan5Bog5C+ZkgSqUbeVUd9wgtHOrIKuc5b8=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs=
@ -60,6 +66,8 @@ github.com/Azure/go-autorest/logger v0.2.1 h1:IG7i4p/mDa2Ce4TRyAO8IHnVhAVF3RFU+Z
github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8=
github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo=
github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU=
github.com/AzureAD/microsoft-authentication-library-for-go v0.8.1 h1:oPdPEZFSbl7oSPEAIPMPBMUmiL+mqgzBJwM/9qYcwNg=
github.com/AzureAD/microsoft-authentication-library-for-go v0.8.1/go.mod h1:4qFor3D/HDsvBME35Xy9rwW9DecL+M2sNw1ybjPtwA0=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
@ -515,6 +523,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM=
github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4=
github.com/linode/linodego v1.16.1 h1:5otq57M4PdHycPERRfSFZ0s1yz1ETVWGjCp3hh7+F9w=
@ -630,6 +640,8 @@ github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAv
github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac=
github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc=
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 h1:Qj1ukM4GlMWXNdMBuXcXfz/Kw9s1qm0CLY32QxuSImI=
github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4/go.mod h1:N6UoU20jOqggOuDwUaBQpluzLNDqif3kq9z2wpdYEfQ=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=

View file

@ -0,0 +1,8 @@
azuread package
=========================================
azuread provides an http.RoundTripper that attaches an Azure AD accessToken
to remote write requests.
This module is considered internal to Prometheus, without any stability
guarantees for external usage.

View file

@ -0,0 +1,247 @@
// Copyright 2023 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package azuread
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/google/uuid"
)
const (
// Clouds.
AzureChina = "AzureChina"
AzureGovernment = "AzureGovernment"
AzurePublic = "AzurePublic"
// Audiences.
IngestionChinaAudience = "https://monitor.azure.cn//.default"
IngestionGovernmentAudience = "https://monitor.azure.us//.default"
IngestionPublicAudience = "https://monitor.azure.com//.default"
)
// ManagedIdentityConfig is used to store managed identity config values
type ManagedIdentityConfig struct {
// ClientID is the clientId of the managed identity that is being used to authenticate.
ClientID string `yaml:"client_id,omitempty"`
}
// AzureADConfig is used to store the config values.
type AzureADConfig struct { // nolint:revive
// ManagedIdentity is the managed identity that is being used to authenticate.
ManagedIdentity *ManagedIdentityConfig `yaml:"managed_identity,omitempty"`
// Cloud is the Azure cloud in which the service is running. Example: AzurePublic/AzureGovernment/AzureChina.
Cloud string `yaml:"cloud,omitempty"`
}
// azureADRoundTripper is used to store the roundtripper and the tokenprovider.
type azureADRoundTripper struct {
next http.RoundTripper
tokenProvider *tokenProvider
}
// tokenProvider is used to store and retrieve Azure AD accessToken.
type tokenProvider struct {
// token is member used to store the current valid accessToken.
token string
// mu guards access to token.
mu sync.Mutex
// refreshTime is used to store the refresh time of the current valid accessToken.
refreshTime time.Time
// credentialClient is the Azure AD credential client that is being used to retrieve accessToken.
credentialClient azcore.TokenCredential
options *policy.TokenRequestOptions
}
// Validate validates config values provided.
func (c *AzureADConfig) Validate() error {
if c.Cloud == "" {
c.Cloud = AzurePublic
}
if c.Cloud != AzureChina && c.Cloud != AzureGovernment && c.Cloud != AzurePublic {
return fmt.Errorf("must provide a cloud in the Azure AD config")
}
if c.ManagedIdentity == nil {
return fmt.Errorf("must provide an Azure Managed Identity in the Azure AD config")
}
if c.ManagedIdentity.ClientID == "" {
return fmt.Errorf("must provide an Azure Managed Identity client_id in the Azure AD config")
}
_, err := uuid.Parse(c.ManagedIdentity.ClientID)
if err != nil {
return fmt.Errorf("the provided Azure Managed Identity client_id provided is invalid")
}
return nil
}
// UnmarshalYAML unmarshal the Azure AD config yaml.
func (c *AzureADConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain AzureADConfig
*c = AzureADConfig{}
if err := unmarshal((*plain)(c)); err != nil {
return err
}
return c.Validate()
}
// NewAzureADRoundTripper creates round tripper adding Azure AD authorization to calls.
func NewAzureADRoundTripper(cfg *AzureADConfig, next http.RoundTripper) (http.RoundTripper, error) {
if next == nil {
next = http.DefaultTransport
}
cred, err := newTokenCredential(cfg)
if err != nil {
return nil, err
}
tokenProvider, err := newTokenProvider(cfg, cred)
if err != nil {
return nil, err
}
rt := &azureADRoundTripper{
next: next,
tokenProvider: tokenProvider,
}
return rt, nil
}
// RoundTrip sets Authorization header for requests.
func (rt *azureADRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
accessToken, err := rt.tokenProvider.getAccessToken(req.Context())
if err != nil {
return nil, err
}
bearerAccessToken := "Bearer " + accessToken
req.Header.Set("Authorization", bearerAccessToken)
return rt.next.RoundTrip(req)
}
// newTokenCredential returns a TokenCredential of different kinds like Azure Managed Identity and Azure AD application.
func newTokenCredential(cfg *AzureADConfig) (azcore.TokenCredential, error) {
cred, err := newManagedIdentityTokenCredential(cfg.ManagedIdentity.ClientID)
if err != nil {
return nil, err
}
return cred, nil
}
// newManagedIdentityTokenCredential returns new Managed Identity token credential.
func newManagedIdentityTokenCredential(managedIdentityClientID string) (azcore.TokenCredential, error) {
clientID := azidentity.ClientID(managedIdentityClientID)
opts := &azidentity.ManagedIdentityCredentialOptions{ID: clientID}
return azidentity.NewManagedIdentityCredential(opts)
}
// newTokenProvider helps to fetch accessToken for different types of credential. This also takes care of
// refreshing the accessToken before expiry. This accessToken is attached to the Authorization header while making requests.
func newTokenProvider(cfg *AzureADConfig, cred azcore.TokenCredential) (*tokenProvider, error) {
audience, err := getAudience(cfg.Cloud)
if err != nil {
return nil, err
}
tokenProvider := &tokenProvider{
credentialClient: cred,
options: &policy.TokenRequestOptions{Scopes: []string{audience}},
}
return tokenProvider, nil
}
// getAccessToken returns the current valid accessToken.
func (tokenProvider *tokenProvider) getAccessToken(ctx context.Context) (string, error) {
tokenProvider.mu.Lock()
defer tokenProvider.mu.Unlock()
if tokenProvider.valid() {
return tokenProvider.token, nil
}
err := tokenProvider.getToken(ctx)
if err != nil {
return "", errors.New("Failed to get access token: " + err.Error())
}
return tokenProvider.token, nil
}
// valid checks if the token in the token provider is valid and not expired.
func (tokenProvider *tokenProvider) valid() bool {
if len(tokenProvider.token) == 0 {
return false
}
if tokenProvider.refreshTime.After(time.Now().UTC()) {
return true
}
return false
}
// getToken retrieves a new accessToken and stores the newly retrieved token in the tokenProvider.
func (tokenProvider *tokenProvider) getToken(ctx context.Context) error {
accessToken, err := tokenProvider.credentialClient.GetToken(ctx, *tokenProvider.options)
if err != nil {
return err
}
if len(accessToken.Token) == 0 {
return errors.New("access token is empty")
}
tokenProvider.token = accessToken.Token
err = tokenProvider.updateRefreshTime(accessToken)
if err != nil {
return err
}
return nil
}
// updateRefreshTime handles logic to set refreshTime. The refreshTime is set at half the duration of the actual token expiry.
func (tokenProvider *tokenProvider) updateRefreshTime(accessToken azcore.AccessToken) error {
tokenExpiryTimestamp := accessToken.ExpiresOn.UTC()
deltaExpirytime := time.Now().Add(time.Until(tokenExpiryTimestamp) / 2)
if deltaExpirytime.After(time.Now().UTC()) {
tokenProvider.refreshTime = deltaExpirytime
} else {
return errors.New("access token expiry is less than the current time")
}
return nil
}
// getAudience returns audiences for different clouds.
func getAudience(cloud string) (string, error) {
switch strings.ToLower(cloud) {
case strings.ToLower(AzureChina):
return IngestionChinaAudience, nil
case strings.ToLower(AzureGovernment):
return IngestionGovernmentAudience, nil
case strings.ToLower(AzurePublic):
return IngestionPublicAudience, nil
default:
return "", errors.New("Cloud is not specified or is incorrect: " + cloud)
}
}

View file

@ -0,0 +1,252 @@
// Copyright 2023 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package azuread
import (
"context"
"net/http"
"os"
"strings"
"testing"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"gopkg.in/yaml.v2"
)
const (
dummyAudience = "dummyAudience"
dummyClientID = "00000000-0000-0000-0000-000000000000"
testTokenString = "testTokenString"
)
var testTokenExpiry = time.Now().Add(10 * time.Second)
type AzureAdTestSuite struct {
suite.Suite
mockCredential *mockCredential
}
type TokenProviderTestSuite struct {
suite.Suite
mockCredential *mockCredential
}
// mockCredential mocks azidentity TokenCredential interface.
type mockCredential struct {
mock.Mock
}
func (ad *AzureAdTestSuite) BeforeTest(_, _ string) {
ad.mockCredential = new(mockCredential)
}
func TestAzureAd(t *testing.T) {
suite.Run(t, new(AzureAdTestSuite))
}
func (ad *AzureAdTestSuite) TestAzureAdRoundTripper() {
var gotReq *http.Request
testToken := &azcore.AccessToken{
Token: testTokenString,
ExpiresOn: testTokenExpiry,
}
managedIdentityConfig := &ManagedIdentityConfig{
ClientID: dummyClientID,
}
azureAdConfig := &AzureADConfig{
Cloud: "AzurePublic",
ManagedIdentity: managedIdentityConfig,
}
ad.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil)
tokenProvider, err := newTokenProvider(azureAdConfig, ad.mockCredential)
ad.Assert().NoError(err)
rt := &azureADRoundTripper{
next: promhttp.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
gotReq = req
return &http.Response{StatusCode: http.StatusOK}, nil
}),
tokenProvider: tokenProvider,
}
cli := &http.Client{Transport: rt}
req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!"))
ad.Assert().NoError(err)
_, err = cli.Do(req)
ad.Assert().NoError(err)
ad.Assert().NotNil(gotReq)
origReq := gotReq
ad.Assert().NotEmpty(origReq.Header.Get("Authorization"))
ad.Assert().Equal("Bearer "+testTokenString, origReq.Header.Get("Authorization"))
}
func loadAzureAdConfig(filename string) (*AzureADConfig, error) {
content, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
cfg := AzureADConfig{}
if err = yaml.UnmarshalStrict(content, &cfg); err != nil {
return nil, err
}
return &cfg, nil
}
func testGoodConfig(t *testing.T, filename string) {
_, err := loadAzureAdConfig(filename)
if err != nil {
t.Fatalf("Unexpected error parsing %s: %s", filename, err)
}
}
func TestGoodAzureAdConfig(t *testing.T) {
filename := "testdata/azuread_good.yaml"
testGoodConfig(t, filename)
}
func TestGoodCloudMissingAzureAdConfig(t *testing.T) {
filename := "testdata/azuread_good_cloudmissing.yaml"
testGoodConfig(t, filename)
}
func TestBadClientIdMissingAzureAdConfig(t *testing.T) {
filename := "testdata/azuread_bad_clientidmissing.yaml"
_, err := loadAzureAdConfig(filename)
if err == nil {
t.Fatalf("Did not receive expected error unmarshaling bad azuread config")
}
if !strings.Contains(err.Error(), "must provide an Azure Managed Identity in the Azure AD config") {
t.Errorf("Received unexpected error from unmarshal of %s: %s", filename, err.Error())
}
}
func TestBadInvalidClientIdAzureAdConfig(t *testing.T) {
filename := "testdata/azuread_bad_invalidclientid.yaml"
_, err := loadAzureAdConfig(filename)
if err == nil {
t.Fatalf("Did not receive expected error unmarshaling bad azuread config")
}
if !strings.Contains(err.Error(), "the provided Azure Managed Identity client_id provided is invalid") {
t.Errorf("Received unexpected error from unmarshal of %s: %s", filename, err.Error())
}
}
func (m *mockCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) {
args := m.MethodCalled("GetToken", ctx, options)
if args.Get(0) == nil {
return azcore.AccessToken{}, args.Error(1)
}
return args.Get(0).(azcore.AccessToken), nil
}
func (s *TokenProviderTestSuite) BeforeTest(_, _ string) {
s.mockCredential = new(mockCredential)
}
func TestTokenProvider(t *testing.T) {
suite.Run(t, new(TokenProviderTestSuite))
}
func (s *TokenProviderTestSuite) TestNewTokenProvider_NilAudience_Fail() {
managedIdentityConfig := &ManagedIdentityConfig{
ClientID: dummyClientID,
}
azureAdConfig := &AzureADConfig{
Cloud: "PublicAzure",
ManagedIdentity: managedIdentityConfig,
}
actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential)
s.Assert().Nil(actualTokenProvider)
s.Assert().NotNil(actualErr)
s.Assert().Equal("Cloud is not specified or is incorrect: "+azureAdConfig.Cloud, actualErr.Error())
}
func (s *TokenProviderTestSuite) TestNewTokenProvider_Success() {
managedIdentityConfig := &ManagedIdentityConfig{
ClientID: dummyClientID,
}
azureAdConfig := &AzureADConfig{
Cloud: "AzurePublic",
ManagedIdentity: managedIdentityConfig,
}
s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil)
actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential)
s.Assert().NotNil(actualTokenProvider)
s.Assert().Nil(actualErr)
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
}
func (s *TokenProviderTestSuite) TestPeriodicTokenRefresh_Success() {
// setup
managedIdentityConfig := &ManagedIdentityConfig{
ClientID: dummyClientID,
}
azureAdConfig := &AzureADConfig{
Cloud: "AzurePublic",
ManagedIdentity: managedIdentityConfig,
}
testToken := &azcore.AccessToken{
Token: testTokenString,
ExpiresOn: testTokenExpiry,
}
s.mockCredential.On("GetToken", mock.Anything, mock.Anything).Return(*testToken, nil).Once().
On("GetToken", mock.Anything, mock.Anything).Return(getToken(), nil)
actualTokenProvider, actualErr := newTokenProvider(azureAdConfig, s.mockCredential)
s.Assert().NotNil(actualTokenProvider)
s.Assert().Nil(actualErr)
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
// Token set to refresh at half of the expiry time. The test tokens are set to expiry in 10s.
// Hence, the 6 seconds wait to check if the token is refreshed.
time.Sleep(6 * time.Second)
s.Assert().NotNil(actualTokenProvider.getAccessToken(context.Background()))
s.mockCredential.AssertNumberOfCalls(s.T(), "GetToken", 2)
accessToken, err := actualTokenProvider.getAccessToken(context.Background())
s.Assert().Nil(err)
s.Assert().NotEqual(accessToken, testTokenString)
}
func getToken() azcore.AccessToken {
return azcore.AccessToken{
Token: uuid.New().String(),
ExpiresOn: time.Now().Add(10 * time.Second),
}
}

View file

@ -0,0 +1 @@
cloud: AzurePublic

View file

@ -0,0 +1,3 @@
cloud: AzurePublic
managed_identity:
client_id: foo-foobar-bar-foo-00000000

View file

@ -0,0 +1,3 @@
cloud: AzurePublic
managed_identity:
client_id: 00000000-0000-0000-0000-000000000000

View file

@ -0,0 +1,2 @@
managed_identity:
client_id: 00000000-0000-0000-0000-000000000000

View file

@ -36,6 +36,7 @@ import (
"go.opentelemetry.io/otel/trace"
"github.com/prometheus/prometheus/prompb"
"github.com/prometheus/prometheus/storage/remote/azuread"
)
const maxErrMsgLen = 1024
@ -97,6 +98,7 @@ type ClientConfig struct {
Timeout model.Duration
HTTPClientConfig config_util.HTTPClientConfig
SigV4Config *sigv4.SigV4Config
AzureADConfig *azuread.AzureADConfig
Headers map[string]string
RetryOnRateLimit bool
}
@ -146,6 +148,13 @@ func NewWriteClient(name string, conf *ClientConfig) (WriteClient, error) {
}
}
if conf.AzureADConfig != nil {
t, err = azuread.NewAzureADRoundTripper(conf.AzureADConfig, httpClient.Transport)
if err != nil {
return nil, err
}
}
if len(conf.Headers) > 0 {
t = newInjectHeadersRoundTripper(conf.Headers, t)
}

View file

@ -158,6 +158,7 @@ func (rws *WriteStorage) ApplyConfig(conf *config.Config) error {
Timeout: rwConf.RemoteTimeout,
HTTPClientConfig: rwConf.HTTPClientConfig,
SigV4Config: rwConf.SigV4Config,
AzureADConfig: rwConf.AzureADConfig,
Headers: rwConf.Headers,
RetryOnRateLimit: rwConf.QueueConfig.RetryOnRateLimit,
})