refactor: do generic http request

This commit is contained in:
Jan De Dobbeleer 2022-08-05 21:07:35 +02:00 committed by Jan De Dobbeleer
parent 64231a790f
commit cccb502989
6 changed files with 216 additions and 92 deletions

View file

@ -2,7 +2,6 @@ package http
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@ -31,32 +30,27 @@ func (a *OAuthError) Error() string {
return a.message
}
type OAuth struct {
Props properties.Properties
Env environment.Environment
type OAuthRequest struct {
Request
AccessTokenKey string
RefreshTokenKey string
SegmentName string
}
func (o *OAuth) error(err error) {
o.Env.Log(environment.Error, "OAuth", err.Error())
}
func (o *OAuth) getAccessToken() (string, error) {
func (o *OAuthRequest) getAccessToken() (string, error) {
// get directly from cache
if acccessToken, OK := o.Env.Cache().Get(o.AccessTokenKey); OK && len(acccessToken) != 0 {
if acccessToken, OK := o.env.Cache().Get(o.AccessTokenKey); OK && len(acccessToken) != 0 {
return acccessToken, nil
}
// use cached refresh token to get new access token
if refreshToken, OK := o.Env.Cache().Get(o.RefreshTokenKey); OK && len(refreshToken) != 0 {
if refreshToken, OK := o.env.Cache().Get(o.RefreshTokenKey); OK && len(refreshToken) != 0 {
if acccessToken, err := o.refreshToken(refreshToken); err == nil {
return acccessToken, nil
}
}
// use initial refresh token from property
refreshToken := o.Props.GetString(properties.RefreshToken, "")
refreshToken := o.props.GetString(properties.RefreshToken, "")
// ignore an empty or default refresh token
if len(refreshToken) == 0 || refreshToken == DefaultRefreshToken {
return "", &OAuthError{
@ -68,10 +62,10 @@ func (o *OAuth) getAccessToken() (string, error) {
return acccessToken, err
}
func (o *OAuth) refreshToken(refreshToken string) (string, error) {
httpTimeout := o.Props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
func (o *OAuthRequest) refreshToken(refreshToken string) (string, error) {
httpTimeout := o.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
url := fmt.Sprintf("https://ohmyposh.dev/api/refresh?segment=%s&token=%s", o.SegmentName, refreshToken)
body, err := o.Env.HTTPRequest(url, nil, httpTimeout)
body, err := o.env.HTTPRequest(url, nil, httpTimeout)
if err != nil {
return "", &OAuthError{
// This might happen if /api was asleep. Assume the user will just retry
@ -86,40 +80,16 @@ func (o *OAuth) refreshToken(refreshToken string) (string, error) {
}
}
// add tokens to cache
o.Env.Cache().Set(o.AccessTokenKey, tokens.AccessToken, tokens.ExpiresIn/60)
o.Env.Cache().Set(o.RefreshTokenKey, tokens.RefreshToken, 2*525960) // it should never expire unless revoked, default to 2 year
o.env.Cache().Set(o.AccessTokenKey, tokens.AccessToken, tokens.ExpiresIn/60)
o.env.Cache().Set(o.RefreshTokenKey, tokens.RefreshToken, 2*525960) // it should never expire unless revoked, default to 2 year
return tokens.AccessToken, nil
}
func OauthResult[a any](o *OAuth, url string, body io.Reader, requestModifiers ...environment.HTTPRequestModifier) (a, error) {
var data a
getCacheValue := func(key string) (a, error) {
if val, found := o.Env.Cache().Get(key); found {
err := json.Unmarshal([]byte(val), &data)
if err != nil {
o.error(err)
return data, err
}
return data, nil
}
err := errors.New("no data in cache")
o.error(err)
return data, err
}
httpTimeout := o.Props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
// No need to check more than every 30 minutes by default
cacheTimeout := o.Props.GetInt(properties.CacheTimeout, 30)
if cacheTimeout > 0 {
if data, err := getCacheValue(url); err == nil {
return data, nil
}
}
func OauthResult[a any](o *OAuthRequest, url string, body io.Reader, requestModifiers ...environment.HTTPRequestModifier) (a, error) {
addToken := func() error {
accessToken, err := o.getAccessToken()
if err != nil {
return data, err
return err
}
// add token to header for authentication
@ -133,21 +103,8 @@ func OauthResult[a any](o *OAuth, url string, body io.Reader, requestModifiers .
requestModifiers = append(requestModifiers, addAuthHeader)
responseBody, err := o.Env.HTTPRequest(url, body, httpTimeout, requestModifiers...)
if err != nil {
o.error(err)
return data, err
return nil
}
err = json.Unmarshal(responseBody, &data)
if err != nil {
o.error(err)
return data, err
}
if cacheTimeout > 0 {
o.Env.Cache().Set(url, string(responseBody), cacheTimeout)
}
return data, nil
return do[a](&o.Request, url, body, addToken, requestModifiers...)
}

View file

@ -163,13 +163,12 @@ func TestOauthResult(t *testing.T) {
env.On("HTTPRequest", tokenURL).Return([]byte(tc.TokenResponse), tc.Error)
env.On("Log", environment.Error, "OAuth", mock2.Anything).Return()
oauth := &OAuth{
Props: props,
Env: env,
oauth := &OAuthRequest{
AccessTokenKey: accessTokenKey,
RefreshTokenKey: refreshTokenKey,
SegmentName: "test",
}
oauth.Init(env, props)
got, err := OauthResult[*data](oauth, url, nil)
assert.Equal(t, tc.ExpectedData, got, tc.Case)

75
src/http/request.go Normal file
View file

@ -0,0 +1,75 @@
package http
import (
"encoding/json"
"errors"
"io"
"oh-my-posh/environment"
"oh-my-posh/properties"
)
type Request struct {
props properties.Properties
env environment.Environment
}
func (r *Request) Init(env environment.Environment, props properties.Properties) {
r.env = env
r.props = props
}
func Do[a any](r *Request, url string, requestModifiers ...environment.HTTPRequestModifier) (a, error) {
return do[a](r, url, nil, nil, requestModifiers...)
}
func do[a any](r *Request, url string, body io.Reader, preRequestFunc func() error, requestModifiers ...environment.HTTPRequestModifier) (a, error) {
var data a
getCacheValue := func(key string) (a, error) {
if val, found := r.env.Cache().Get(key); found {
err := json.Unmarshal([]byte(val), &data)
if err != nil {
r.env.Log(environment.Error, "OAuth", err.Error())
return data, err
}
return data, nil
}
err := errors.New("no data in cache")
r.env.Log(environment.Error, "OAuth", err.Error())
return data, err
}
httpTimeout := r.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
// No need to check more than every 30 minutes by default
cacheTimeout := r.props.GetInt(properties.CacheTimeout, 30)
if cacheTimeout > 0 {
if data, err := getCacheValue(url); err == nil {
return data, nil
}
}
if preRequestFunc != nil {
if err := preRequestFunc(); err != nil {
return data, err
}
}
responseBody, err := r.env.HTTPRequest(url, body, httpTimeout, requestModifiers...)
if err != nil {
r.env.Log(environment.Error, "OAuth", err.Error())
return data, err
}
err = json.Unmarshal(responseBody, &data)
if err != nil {
r.env.Log(environment.Error, "OAuth", err.Error())
return data, err
}
if cacheTimeout > 0 {
r.env.Cache().Set(url, string(responseBody), cacheTimeout)
}
return data, nil
}

91
src/http/request_test.go Normal file
View file

@ -0,0 +1,91 @@
package http
import (
"net"
"oh-my-posh/environment"
"oh-my-posh/mock"
"oh-my-posh/properties"
"testing"
"github.com/stretchr/testify/assert"
mock2 "github.com/stretchr/testify/mock"
)
func TestRequestResult(t *testing.T) {
successData := &data{Hello: "world"}
jsonResponse := `{ "hello":"world" }`
url := "https://google.com?q=hello"
cases := []struct {
Case string
// API response
JSONResponse string
// Cache
CacheJSONResponse string
CacheTimeout int
ResponseCacheMiss bool
// Errors
Error error
// Validations
ExpectedErrorMessage string
ExpectedData *data
}{
{
Case: "No cache",
JSONResponse: jsonResponse,
ExpectedData: successData,
},
{
Case: "Cache",
CacheJSONResponse: `{ "hello":"mom" }`,
ExpectedData: &data{Hello: "mom"},
CacheTimeout: 10,
},
{
Case: "Cache miss",
ResponseCacheMiss: true,
JSONResponse: jsonResponse,
CacheJSONResponse: `{ "hello":"mom" }`,
ExpectedData: successData,
CacheTimeout: 10,
},
{
Case: "DNS error",
Error: &net.DNSError{IsNotFound: true},
ExpectedErrorMessage: "lookup : ",
},
{
Case: "Response incorrect",
JSONResponse: `[`,
ExpectedErrorMessage: "unexpected end of JSON input",
},
}
for _, tc := range cases {
var props properties.Map = map[properties.Property]interface{}{
properties.CacheTimeout: tc.CacheTimeout,
}
cache := &mock.MockedCache{}
cache.On("Get", url).Return(tc.CacheJSONResponse, !tc.ResponseCacheMiss)
cache.On("Set", mock2.Anything, mock2.Anything, mock2.Anything)
env := &mock.MockedEnvironment{}
env.On("Cache").Return(cache)
env.On("HTTPRequest", url).Return([]byte(tc.JSONResponse), tc.Error)
env.On("Log", environment.Error, "OAuth", mock2.Anything).Return()
request := &Request{}
request.Init(env, props)
got, err := Do[*data](request, url, nil)
assert.Equal(t, tc.ExpectedData, got, tc.Case)
if len(tc.ExpectedErrorMessage) == 0 {
assert.Nil(t, err, tc.Case)
} else {
assert.Equal(t, tc.ExpectedErrorMessage, err.Error(), tc.Case)
}
}
}

View file

@ -15,12 +15,12 @@ type StravaAPI interface {
}
type stravaAPI struct {
http.OAuth
http.OAuthRequest
}
func (s *stravaAPI) GetActivities() ([]*StravaData, error) {
url := "https://www.strava.com/api/v3/athlete/activities?page=1&per_page=1"
return http.OauthResult[[]*StravaData](&s.OAuth, url, nil)
return http.OauthResult[[]*StravaData](&s.OAuthRequest, url, nil)
}
// segment struct, makes templating easier
@ -128,13 +128,14 @@ func (s *Strava) getActivityIcon() string {
func (s *Strava) Init(props properties.Properties, env environment.Environment) {
s.props = props
s.api = &stravaAPI{
OAuth: http.OAuth{
Props: props,
Env: env,
oauth := &http.OAuthRequest{
AccessTokenKey: StravaAccessTokenKey,
RefreshTokenKey: StravaRefreshTokenKey,
SegmentName: "strava",
},
}
oauth.Init(env, props)
s.api = &stravaAPI{
OAuthRequest: *oauth,
}
}

View file

@ -76,7 +76,7 @@ type WithingsAPI interface {
}
type withingsAPI struct {
*http.OAuth
*http.OAuthRequest
}
func (w *withingsAPI) GetMeasures(meastypes string) (*WithingsData, error) {
@ -124,7 +124,7 @@ func (w *withingsAPI) getWithingsData(endpoint string, formData url.Values) (*Wi
request.Header.Add("Content-Type", "application/x-www-form-urlencoded")
}
body := strings.NewReader(formData.Encode())
data, err := http.OauthResult[*WithingsData](w.OAuth, endpoint, body, modifiers)
data, err := http.OauthResult[*WithingsData](w.OAuthRequest, endpoint, body, modifiers)
if data != nil && data.Status != 0 {
return nil, errors.New("Withings API error: " + strconv.Itoa(data.Status))
}
@ -219,13 +219,14 @@ func (w *Withings) getSleep() bool {
func (w *Withings) Init(props properties.Properties, env environment.Environment) {
w.props = props
w.api = &withingsAPI{
OAuth: &http.OAuth{
Props: props,
Env: env,
oauth := &http.OAuthRequest{
AccessTokenKey: WithingsAccessTokenKey,
RefreshTokenKey: WithingsRefreshTokenKey,
SegmentName: "withings",
},
}
oauth.Init(env, props)
w.api = &withingsAPI{
OAuthRequest: oauth,
}
}