From cccb50298964eec06cb89b9c051dc12b646843e8 Mon Sep 17 00:00:00 2001 From: Jan De Dobbeleer Date: Fri, 5 Aug 2022 21:07:35 +0200 Subject: [PATCH] refactor: do generic http request --- src/http/oauth.go | 99 ++++++++++++---------------------------- src/http/oauth_test.go | 5 +- src/http/request.go | 75 ++++++++++++++++++++++++++++++ src/http/request_test.go | 91 ++++++++++++++++++++++++++++++++++++ src/segments/strava.go | 19 ++++---- src/segments/withings.go | 19 ++++---- 6 files changed, 216 insertions(+), 92 deletions(-) create mode 100644 src/http/request.go create mode 100644 src/http/request_test.go diff --git a/src/http/oauth.go b/src/http/oauth.go index f362b6a6..163b04e4 100644 --- a/src/http/oauth.go +++ b/src/http/oauth.go @@ -2,7 +2,6 @@ package http import ( "encoding/json" - "errors" "fmt" "io" "net/http" @@ -31,32 +30,27 @@ func (a *OAuthError) Error() string { return a.message } -type OAuth struct { - Props properties.Properties - Env environment.Environment +type OAuthRequest struct { + Request AccessTokenKey string RefreshTokenKey string SegmentName string } -func (o *OAuth) error(err error) { - o.Env.Log(environment.Error, "OAuth", err.Error()) -} - -func (o *OAuth) getAccessToken() (string, error) { +func (o *OAuthRequest) getAccessToken() (string, error) { // get directly from cache - if acccessToken, OK := o.Env.Cache().Get(o.AccessTokenKey); OK && len(acccessToken) != 0 { + if acccessToken, OK := o.env.Cache().Get(o.AccessTokenKey); OK && len(acccessToken) != 0 { return acccessToken, nil } // use cached refresh token to get new access token - if refreshToken, OK := o.Env.Cache().Get(o.RefreshTokenKey); OK && len(refreshToken) != 0 { + if refreshToken, OK := o.env.Cache().Get(o.RefreshTokenKey); OK && len(refreshToken) != 0 { if acccessToken, err := o.refreshToken(refreshToken); err == nil { return acccessToken, nil } } // use initial refresh token from property - refreshToken := o.Props.GetString(properties.RefreshToken, "") + refreshToken := o.props.GetString(properties.RefreshToken, "") // ignore an empty or default refresh token if len(refreshToken) == 0 || refreshToken == DefaultRefreshToken { return "", &OAuthError{ @@ -68,10 +62,10 @@ func (o *OAuth) getAccessToken() (string, error) { return acccessToken, err } -func (o *OAuth) refreshToken(refreshToken string) (string, error) { - httpTimeout := o.Props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout) +func (o *OAuthRequest) refreshToken(refreshToken string) (string, error) { + httpTimeout := o.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout) url := fmt.Sprintf("https://ohmyposh.dev/api/refresh?segment=%s&token=%s", o.SegmentName, refreshToken) - body, err := o.Env.HTTPRequest(url, nil, httpTimeout) + body, err := o.env.HTTPRequest(url, nil, httpTimeout) if err != nil { return "", &OAuthError{ // This might happen if /api was asleep. Assume the user will just retry @@ -86,68 +80,31 @@ func (o *OAuth) refreshToken(refreshToken string) (string, error) { } } // add tokens to cache - o.Env.Cache().Set(o.AccessTokenKey, tokens.AccessToken, tokens.ExpiresIn/60) - o.Env.Cache().Set(o.RefreshTokenKey, tokens.RefreshToken, 2*525960) // it should never expire unless revoked, default to 2 year + o.env.Cache().Set(o.AccessTokenKey, tokens.AccessToken, tokens.ExpiresIn/60) + o.env.Cache().Set(o.RefreshTokenKey, tokens.RefreshToken, 2*525960) // it should never expire unless revoked, default to 2 year return tokens.AccessToken, nil } -func OauthResult[a any](o *OAuth, url string, body io.Reader, requestModifiers ...environment.HTTPRequestModifier) (a, error) { - var data a - - getCacheValue := func(key string) (a, error) { - if val, found := o.Env.Cache().Get(key); found { - err := json.Unmarshal([]byte(val), &data) - if err != nil { - o.error(err) - return data, err - } - return data, nil +func OauthResult[a any](o *OAuthRequest, url string, body io.Reader, requestModifiers ...environment.HTTPRequestModifier) (a, error) { + addToken := func() error { + accessToken, err := o.getAccessToken() + if err != nil { + return err } - err := errors.New("no data in cache") - o.error(err) - return data, err - } - httpTimeout := o.Props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout) - - // No need to check more than every 30 minutes by default - cacheTimeout := o.Props.GetInt(properties.CacheTimeout, 30) - if cacheTimeout > 0 { - if data, err := getCacheValue(url); err == nil { - return data, nil + // add token to header for authentication + addAuthHeader := func(request *http.Request) { + request.Header.Add("Authorization", "Bearer "+accessToken) } - } - accessToken, err := o.getAccessToken() - if err != nil { - return data, err + + if requestModifiers == nil { + requestModifiers = []environment.HTTPRequestModifier{} + } + + requestModifiers = append(requestModifiers, addAuthHeader) + + return nil } - // add token to header for authentication - addAuthHeader := func(request *http.Request) { - request.Header.Add("Authorization", "Bearer "+accessToken) - } - - if requestModifiers == nil { - requestModifiers = []environment.HTTPRequestModifier{} - } - - requestModifiers = append(requestModifiers, addAuthHeader) - - responseBody, err := o.Env.HTTPRequest(url, body, httpTimeout, requestModifiers...) - if err != nil { - o.error(err) - return data, err - } - - err = json.Unmarshal(responseBody, &data) - if err != nil { - o.error(err) - return data, err - } - - if cacheTimeout > 0 { - o.Env.Cache().Set(url, string(responseBody), cacheTimeout) - } - - return data, nil + return do[a](&o.Request, url, body, addToken, requestModifiers...) } diff --git a/src/http/oauth_test.go b/src/http/oauth_test.go index c8c5eef6..da076b52 100644 --- a/src/http/oauth_test.go +++ b/src/http/oauth_test.go @@ -163,13 +163,12 @@ func TestOauthResult(t *testing.T) { env.On("HTTPRequest", tokenURL).Return([]byte(tc.TokenResponse), tc.Error) env.On("Log", environment.Error, "OAuth", mock2.Anything).Return() - oauth := &OAuth{ - Props: props, - Env: env, + oauth := &OAuthRequest{ AccessTokenKey: accessTokenKey, RefreshTokenKey: refreshTokenKey, SegmentName: "test", } + oauth.Init(env, props) got, err := OauthResult[*data](oauth, url, nil) assert.Equal(t, tc.ExpectedData, got, tc.Case) diff --git a/src/http/request.go b/src/http/request.go new file mode 100644 index 00000000..7963ba3c --- /dev/null +++ b/src/http/request.go @@ -0,0 +1,75 @@ +package http + +import ( + "encoding/json" + "errors" + "io" + "oh-my-posh/environment" + "oh-my-posh/properties" +) + +type Request struct { + props properties.Properties + env environment.Environment +} + +func (r *Request) Init(env environment.Environment, props properties.Properties) { + r.env = env + r.props = props +} + +func Do[a any](r *Request, url string, requestModifiers ...environment.HTTPRequestModifier) (a, error) { + return do[a](r, url, nil, nil, requestModifiers...) +} + +func do[a any](r *Request, url string, body io.Reader, preRequestFunc func() error, requestModifiers ...environment.HTTPRequestModifier) (a, error) { + var data a + + getCacheValue := func(key string) (a, error) { + if val, found := r.env.Cache().Get(key); found { + err := json.Unmarshal([]byte(val), &data) + if err != nil { + r.env.Log(environment.Error, "OAuth", err.Error()) + return data, err + } + return data, nil + } + err := errors.New("no data in cache") + r.env.Log(environment.Error, "OAuth", err.Error()) + return data, err + } + + httpTimeout := r.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout) + + // No need to check more than every 30 minutes by default + cacheTimeout := r.props.GetInt(properties.CacheTimeout, 30) + if cacheTimeout > 0 { + if data, err := getCacheValue(url); err == nil { + return data, nil + } + } + + if preRequestFunc != nil { + if err := preRequestFunc(); err != nil { + return data, err + } + } + + responseBody, err := r.env.HTTPRequest(url, body, httpTimeout, requestModifiers...) + if err != nil { + r.env.Log(environment.Error, "OAuth", err.Error()) + return data, err + } + + err = json.Unmarshal(responseBody, &data) + if err != nil { + r.env.Log(environment.Error, "OAuth", err.Error()) + return data, err + } + + if cacheTimeout > 0 { + r.env.Cache().Set(url, string(responseBody), cacheTimeout) + } + + return data, nil +} diff --git a/src/http/request_test.go b/src/http/request_test.go new file mode 100644 index 00000000..717616be --- /dev/null +++ b/src/http/request_test.go @@ -0,0 +1,91 @@ +package http + +import ( + "net" + "oh-my-posh/environment" + "oh-my-posh/mock" + "oh-my-posh/properties" + "testing" + + "github.com/stretchr/testify/assert" + mock2 "github.com/stretchr/testify/mock" +) + +func TestRequestResult(t *testing.T) { + successData := &data{Hello: "world"} + jsonResponse := `{ "hello":"world" }` + url := "https://google.com?q=hello" + + cases := []struct { + Case string + // API response + JSONResponse string + // Cache + CacheJSONResponse string + CacheTimeout int + ResponseCacheMiss bool + // Errors + Error error + // Validations + ExpectedErrorMessage string + ExpectedData *data + }{ + { + Case: "No cache", + JSONResponse: jsonResponse, + ExpectedData: successData, + }, + { + Case: "Cache", + CacheJSONResponse: `{ "hello":"mom" }`, + ExpectedData: &data{Hello: "mom"}, + CacheTimeout: 10, + }, + { + Case: "Cache miss", + ResponseCacheMiss: true, + JSONResponse: jsonResponse, + CacheJSONResponse: `{ "hello":"mom" }`, + ExpectedData: successData, + CacheTimeout: 10, + }, + { + Case: "DNS error", + Error: &net.DNSError{IsNotFound: true}, + ExpectedErrorMessage: "lookup : ", + }, + { + Case: "Response incorrect", + JSONResponse: `[`, + ExpectedErrorMessage: "unexpected end of JSON input", + }, + } + + for _, tc := range cases { + var props properties.Map = map[properties.Property]interface{}{ + properties.CacheTimeout: tc.CacheTimeout, + } + + cache := &mock.MockedCache{} + + cache.On("Get", url).Return(tc.CacheJSONResponse, !tc.ResponseCacheMiss) + cache.On("Set", mock2.Anything, mock2.Anything, mock2.Anything) + + env := &mock.MockedEnvironment{} + + env.On("Cache").Return(cache) + env.On("HTTPRequest", url).Return([]byte(tc.JSONResponse), tc.Error) + env.On("Log", environment.Error, "OAuth", mock2.Anything).Return() + + request := &Request{} + request.Init(env, props) + + got, err := Do[*data](request, url, nil) + assert.Equal(t, tc.ExpectedData, got, tc.Case) + if len(tc.ExpectedErrorMessage) == 0 { + assert.Nil(t, err, tc.Case) + } else { + assert.Equal(t, tc.ExpectedErrorMessage, err.Error(), tc.Case) + } + } +} diff --git a/src/segments/strava.go b/src/segments/strava.go index 66fd8e82..7f839208 100644 --- a/src/segments/strava.go +++ b/src/segments/strava.go @@ -15,12 +15,12 @@ type StravaAPI interface { } type stravaAPI struct { - http.OAuth + http.OAuthRequest } func (s *stravaAPI) GetActivities() ([]*StravaData, error) { url := "https://www.strava.com/api/v3/athlete/activities?page=1&per_page=1" - return http.OauthResult[[]*StravaData](&s.OAuth, url, nil) + return http.OauthResult[[]*StravaData](&s.OAuthRequest, url, nil) } // segment struct, makes templating easier @@ -128,13 +128,14 @@ func (s *Strava) getActivityIcon() string { func (s *Strava) Init(props properties.Properties, env environment.Environment) { s.props = props + oauth := &http.OAuthRequest{ + AccessTokenKey: StravaAccessTokenKey, + RefreshTokenKey: StravaRefreshTokenKey, + SegmentName: "strava", + } + oauth.Init(env, props) + s.api = &stravaAPI{ - OAuth: http.OAuth{ - Props: props, - Env: env, - AccessTokenKey: StravaAccessTokenKey, - RefreshTokenKey: StravaRefreshTokenKey, - SegmentName: "strava", - }, + OAuthRequest: *oauth, } } diff --git a/src/segments/withings.go b/src/segments/withings.go index b156153d..97a1df91 100644 --- a/src/segments/withings.go +++ b/src/segments/withings.go @@ -76,7 +76,7 @@ type WithingsAPI interface { } type withingsAPI struct { - *http.OAuth + *http.OAuthRequest } func (w *withingsAPI) GetMeasures(meastypes string) (*WithingsData, error) { @@ -124,7 +124,7 @@ func (w *withingsAPI) getWithingsData(endpoint string, formData url.Values) (*Wi request.Header.Add("Content-Type", "application/x-www-form-urlencoded") } body := strings.NewReader(formData.Encode()) - data, err := http.OauthResult[*WithingsData](w.OAuth, endpoint, body, modifiers) + data, err := http.OauthResult[*WithingsData](w.OAuthRequest, endpoint, body, modifiers) if data != nil && data.Status != 0 { return nil, errors.New("Withings API error: " + strconv.Itoa(data.Status)) } @@ -219,13 +219,14 @@ func (w *Withings) getSleep() bool { func (w *Withings) Init(props properties.Properties, env environment.Environment) { w.props = props + oauth := &http.OAuthRequest{ + AccessTokenKey: WithingsAccessTokenKey, + RefreshTokenKey: WithingsRefreshTokenKey, + SegmentName: "withings", + } + oauth.Init(env, props) + w.api = &withingsAPI{ - OAuth: &http.OAuth{ - Props: props, - Env: env, - AccessTokenKey: WithingsAccessTokenKey, - RefreshTokenKey: WithingsRefreshTokenKey, - SegmentName: "withings", - }, + OAuthRequest: oauth, } }