mirror of
https://github.com/JanDeDobbeleer/oh-my-posh.git
synced 2025-03-05 20:49:04 -08:00
refactor: do generic http request
This commit is contained in:
parent
64231a790f
commit
cccb502989
|
@ -2,7 +2,6 @@ package http
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
@ -31,32 +30,27 @@ func (a *OAuthError) Error() string {
|
|||
return a.message
|
||||
}
|
||||
|
||||
type OAuth struct {
|
||||
Props properties.Properties
|
||||
Env environment.Environment
|
||||
type OAuthRequest struct {
|
||||
Request
|
||||
|
||||
AccessTokenKey string
|
||||
RefreshTokenKey string
|
||||
SegmentName string
|
||||
}
|
||||
|
||||
func (o *OAuth) error(err error) {
|
||||
o.Env.Log(environment.Error, "OAuth", err.Error())
|
||||
}
|
||||
|
||||
func (o *OAuth) getAccessToken() (string, error) {
|
||||
func (o *OAuthRequest) getAccessToken() (string, error) {
|
||||
// get directly from cache
|
||||
if acccessToken, OK := o.Env.Cache().Get(o.AccessTokenKey); OK && len(acccessToken) != 0 {
|
||||
if acccessToken, OK := o.env.Cache().Get(o.AccessTokenKey); OK && len(acccessToken) != 0 {
|
||||
return acccessToken, nil
|
||||
}
|
||||
// use cached refresh token to get new access token
|
||||
if refreshToken, OK := o.Env.Cache().Get(o.RefreshTokenKey); OK && len(refreshToken) != 0 {
|
||||
if refreshToken, OK := o.env.Cache().Get(o.RefreshTokenKey); OK && len(refreshToken) != 0 {
|
||||
if acccessToken, err := o.refreshToken(refreshToken); err == nil {
|
||||
return acccessToken, nil
|
||||
}
|
||||
}
|
||||
// use initial refresh token from property
|
||||
refreshToken := o.Props.GetString(properties.RefreshToken, "")
|
||||
refreshToken := o.props.GetString(properties.RefreshToken, "")
|
||||
// ignore an empty or default refresh token
|
||||
if len(refreshToken) == 0 || refreshToken == DefaultRefreshToken {
|
||||
return "", &OAuthError{
|
||||
|
@ -68,10 +62,10 @@ func (o *OAuth) getAccessToken() (string, error) {
|
|||
return acccessToken, err
|
||||
}
|
||||
|
||||
func (o *OAuth) refreshToken(refreshToken string) (string, error) {
|
||||
httpTimeout := o.Props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
|
||||
func (o *OAuthRequest) refreshToken(refreshToken string) (string, error) {
|
||||
httpTimeout := o.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
|
||||
url := fmt.Sprintf("https://ohmyposh.dev/api/refresh?segment=%s&token=%s", o.SegmentName, refreshToken)
|
||||
body, err := o.Env.HTTPRequest(url, nil, httpTimeout)
|
||||
body, err := o.env.HTTPRequest(url, nil, httpTimeout)
|
||||
if err != nil {
|
||||
return "", &OAuthError{
|
||||
// This might happen if /api was asleep. Assume the user will just retry
|
||||
|
@ -86,68 +80,31 @@ func (o *OAuth) refreshToken(refreshToken string) (string, error) {
|
|||
}
|
||||
}
|
||||
// add tokens to cache
|
||||
o.Env.Cache().Set(o.AccessTokenKey, tokens.AccessToken, tokens.ExpiresIn/60)
|
||||
o.Env.Cache().Set(o.RefreshTokenKey, tokens.RefreshToken, 2*525960) // it should never expire unless revoked, default to 2 year
|
||||
o.env.Cache().Set(o.AccessTokenKey, tokens.AccessToken, tokens.ExpiresIn/60)
|
||||
o.env.Cache().Set(o.RefreshTokenKey, tokens.RefreshToken, 2*525960) // it should never expire unless revoked, default to 2 year
|
||||
return tokens.AccessToken, nil
|
||||
}
|
||||
|
||||
func OauthResult[a any](o *OAuth, url string, body io.Reader, requestModifiers ...environment.HTTPRequestModifier) (a, error) {
|
||||
var data a
|
||||
|
||||
getCacheValue := func(key string) (a, error) {
|
||||
if val, found := o.Env.Cache().Get(key); found {
|
||||
err := json.Unmarshal([]byte(val), &data)
|
||||
if err != nil {
|
||||
o.error(err)
|
||||
return data, err
|
||||
}
|
||||
return data, nil
|
||||
func OauthResult[a any](o *OAuthRequest, url string, body io.Reader, requestModifiers ...environment.HTTPRequestModifier) (a, error) {
|
||||
addToken := func() error {
|
||||
accessToken, err := o.getAccessToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err := errors.New("no data in cache")
|
||||
o.error(err)
|
||||
return data, err
|
||||
}
|
||||
|
||||
httpTimeout := o.Props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
|
||||
|
||||
// No need to check more than every 30 minutes by default
|
||||
cacheTimeout := o.Props.GetInt(properties.CacheTimeout, 30)
|
||||
if cacheTimeout > 0 {
|
||||
if data, err := getCacheValue(url); err == nil {
|
||||
return data, nil
|
||||
// add token to header for authentication
|
||||
addAuthHeader := func(request *http.Request) {
|
||||
request.Header.Add("Authorization", "Bearer "+accessToken)
|
||||
}
|
||||
}
|
||||
accessToken, err := o.getAccessToken()
|
||||
if err != nil {
|
||||
return data, err
|
||||
|
||||
if requestModifiers == nil {
|
||||
requestModifiers = []environment.HTTPRequestModifier{}
|
||||
}
|
||||
|
||||
requestModifiers = append(requestModifiers, addAuthHeader)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// add token to header for authentication
|
||||
addAuthHeader := func(request *http.Request) {
|
||||
request.Header.Add("Authorization", "Bearer "+accessToken)
|
||||
}
|
||||
|
||||
if requestModifiers == nil {
|
||||
requestModifiers = []environment.HTTPRequestModifier{}
|
||||
}
|
||||
|
||||
requestModifiers = append(requestModifiers, addAuthHeader)
|
||||
|
||||
responseBody, err := o.Env.HTTPRequest(url, body, httpTimeout, requestModifiers...)
|
||||
if err != nil {
|
||||
o.error(err)
|
||||
return data, err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(responseBody, &data)
|
||||
if err != nil {
|
||||
o.error(err)
|
||||
return data, err
|
||||
}
|
||||
|
||||
if cacheTimeout > 0 {
|
||||
o.Env.Cache().Set(url, string(responseBody), cacheTimeout)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
return do[a](&o.Request, url, body, addToken, requestModifiers...)
|
||||
}
|
||||
|
|
|
@ -163,13 +163,12 @@ func TestOauthResult(t *testing.T) {
|
|||
env.On("HTTPRequest", tokenURL).Return([]byte(tc.TokenResponse), tc.Error)
|
||||
env.On("Log", environment.Error, "OAuth", mock2.Anything).Return()
|
||||
|
||||
oauth := &OAuth{
|
||||
Props: props,
|
||||
Env: env,
|
||||
oauth := &OAuthRequest{
|
||||
AccessTokenKey: accessTokenKey,
|
||||
RefreshTokenKey: refreshTokenKey,
|
||||
SegmentName: "test",
|
||||
}
|
||||
oauth.Init(env, props)
|
||||
|
||||
got, err := OauthResult[*data](oauth, url, nil)
|
||||
assert.Equal(t, tc.ExpectedData, got, tc.Case)
|
||||
|
|
75
src/http/request.go
Normal file
75
src/http/request.go
Normal file
|
@ -0,0 +1,75 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"oh-my-posh/environment"
|
||||
"oh-my-posh/properties"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
props properties.Properties
|
||||
env environment.Environment
|
||||
}
|
||||
|
||||
func (r *Request) Init(env environment.Environment, props properties.Properties) {
|
||||
r.env = env
|
||||
r.props = props
|
||||
}
|
||||
|
||||
func Do[a any](r *Request, url string, requestModifiers ...environment.HTTPRequestModifier) (a, error) {
|
||||
return do[a](r, url, nil, nil, requestModifiers...)
|
||||
}
|
||||
|
||||
func do[a any](r *Request, url string, body io.Reader, preRequestFunc func() error, requestModifiers ...environment.HTTPRequestModifier) (a, error) {
|
||||
var data a
|
||||
|
||||
getCacheValue := func(key string) (a, error) {
|
||||
if val, found := r.env.Cache().Get(key); found {
|
||||
err := json.Unmarshal([]byte(val), &data)
|
||||
if err != nil {
|
||||
r.env.Log(environment.Error, "OAuth", err.Error())
|
||||
return data, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
err := errors.New("no data in cache")
|
||||
r.env.Log(environment.Error, "OAuth", err.Error())
|
||||
return data, err
|
||||
}
|
||||
|
||||
httpTimeout := r.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
|
||||
|
||||
// No need to check more than every 30 minutes by default
|
||||
cacheTimeout := r.props.GetInt(properties.CacheTimeout, 30)
|
||||
if cacheTimeout > 0 {
|
||||
if data, err := getCacheValue(url); err == nil {
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
|
||||
if preRequestFunc != nil {
|
||||
if err := preRequestFunc(); err != nil {
|
||||
return data, err
|
||||
}
|
||||
}
|
||||
|
||||
responseBody, err := r.env.HTTPRequest(url, body, httpTimeout, requestModifiers...)
|
||||
if err != nil {
|
||||
r.env.Log(environment.Error, "OAuth", err.Error())
|
||||
return data, err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(responseBody, &data)
|
||||
if err != nil {
|
||||
r.env.Log(environment.Error, "OAuth", err.Error())
|
||||
return data, err
|
||||
}
|
||||
|
||||
if cacheTimeout > 0 {
|
||||
r.env.Cache().Set(url, string(responseBody), cacheTimeout)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
91
src/http/request_test.go
Normal file
91
src/http/request_test.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"net"
|
||||
"oh-my-posh/environment"
|
||||
"oh-my-posh/mock"
|
||||
"oh-my-posh/properties"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
mock2 "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestRequestResult(t *testing.T) {
|
||||
successData := &data{Hello: "world"}
|
||||
jsonResponse := `{ "hello":"world" }`
|
||||
url := "https://google.com?q=hello"
|
||||
|
||||
cases := []struct {
|
||||
Case string
|
||||
// API response
|
||||
JSONResponse string
|
||||
// Cache
|
||||
CacheJSONResponse string
|
||||
CacheTimeout int
|
||||
ResponseCacheMiss bool
|
||||
// Errors
|
||||
Error error
|
||||
// Validations
|
||||
ExpectedErrorMessage string
|
||||
ExpectedData *data
|
||||
}{
|
||||
{
|
||||
Case: "No cache",
|
||||
JSONResponse: jsonResponse,
|
||||
ExpectedData: successData,
|
||||
},
|
||||
{
|
||||
Case: "Cache",
|
||||
CacheJSONResponse: `{ "hello":"mom" }`,
|
||||
ExpectedData: &data{Hello: "mom"},
|
||||
CacheTimeout: 10,
|
||||
},
|
||||
{
|
||||
Case: "Cache miss",
|
||||
ResponseCacheMiss: true,
|
||||
JSONResponse: jsonResponse,
|
||||
CacheJSONResponse: `{ "hello":"mom" }`,
|
||||
ExpectedData: successData,
|
||||
CacheTimeout: 10,
|
||||
},
|
||||
{
|
||||
Case: "DNS error",
|
||||
Error: &net.DNSError{IsNotFound: true},
|
||||
ExpectedErrorMessage: "lookup : ",
|
||||
},
|
||||
{
|
||||
Case: "Response incorrect",
|
||||
JSONResponse: `[`,
|
||||
ExpectedErrorMessage: "unexpected end of JSON input",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
var props properties.Map = map[properties.Property]interface{}{
|
||||
properties.CacheTimeout: tc.CacheTimeout,
|
||||
}
|
||||
|
||||
cache := &mock.MockedCache{}
|
||||
|
||||
cache.On("Get", url).Return(tc.CacheJSONResponse, !tc.ResponseCacheMiss)
|
||||
cache.On("Set", mock2.Anything, mock2.Anything, mock2.Anything)
|
||||
|
||||
env := &mock.MockedEnvironment{}
|
||||
|
||||
env.On("Cache").Return(cache)
|
||||
env.On("HTTPRequest", url).Return([]byte(tc.JSONResponse), tc.Error)
|
||||
env.On("Log", environment.Error, "OAuth", mock2.Anything).Return()
|
||||
|
||||
request := &Request{}
|
||||
request.Init(env, props)
|
||||
|
||||
got, err := Do[*data](request, url, nil)
|
||||
assert.Equal(t, tc.ExpectedData, got, tc.Case)
|
||||
if len(tc.ExpectedErrorMessage) == 0 {
|
||||
assert.Nil(t, err, tc.Case)
|
||||
} else {
|
||||
assert.Equal(t, tc.ExpectedErrorMessage, err.Error(), tc.Case)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -15,12 +15,12 @@ type StravaAPI interface {
|
|||
}
|
||||
|
||||
type stravaAPI struct {
|
||||
http.OAuth
|
||||
http.OAuthRequest
|
||||
}
|
||||
|
||||
func (s *stravaAPI) GetActivities() ([]*StravaData, error) {
|
||||
url := "https://www.strava.com/api/v3/athlete/activities?page=1&per_page=1"
|
||||
return http.OauthResult[[]*StravaData](&s.OAuth, url, nil)
|
||||
return http.OauthResult[[]*StravaData](&s.OAuthRequest, url, nil)
|
||||
}
|
||||
|
||||
// segment struct, makes templating easier
|
||||
|
@ -128,13 +128,14 @@ func (s *Strava) getActivityIcon() string {
|
|||
func (s *Strava) Init(props properties.Properties, env environment.Environment) {
|
||||
s.props = props
|
||||
|
||||
oauth := &http.OAuthRequest{
|
||||
AccessTokenKey: StravaAccessTokenKey,
|
||||
RefreshTokenKey: StravaRefreshTokenKey,
|
||||
SegmentName: "strava",
|
||||
}
|
||||
oauth.Init(env, props)
|
||||
|
||||
s.api = &stravaAPI{
|
||||
OAuth: http.OAuth{
|
||||
Props: props,
|
||||
Env: env,
|
||||
AccessTokenKey: StravaAccessTokenKey,
|
||||
RefreshTokenKey: StravaRefreshTokenKey,
|
||||
SegmentName: "strava",
|
||||
},
|
||||
OAuthRequest: *oauth,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -76,7 +76,7 @@ type WithingsAPI interface {
|
|||
}
|
||||
|
||||
type withingsAPI struct {
|
||||
*http.OAuth
|
||||
*http.OAuthRequest
|
||||
}
|
||||
|
||||
func (w *withingsAPI) GetMeasures(meastypes string) (*WithingsData, error) {
|
||||
|
@ -124,7 +124,7 @@ func (w *withingsAPI) getWithingsData(endpoint string, formData url.Values) (*Wi
|
|||
request.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
}
|
||||
body := strings.NewReader(formData.Encode())
|
||||
data, err := http.OauthResult[*WithingsData](w.OAuth, endpoint, body, modifiers)
|
||||
data, err := http.OauthResult[*WithingsData](w.OAuthRequest, endpoint, body, modifiers)
|
||||
if data != nil && data.Status != 0 {
|
||||
return nil, errors.New("Withings API error: " + strconv.Itoa(data.Status))
|
||||
}
|
||||
|
@ -219,13 +219,14 @@ func (w *Withings) getSleep() bool {
|
|||
func (w *Withings) Init(props properties.Properties, env environment.Environment) {
|
||||
w.props = props
|
||||
|
||||
oauth := &http.OAuthRequest{
|
||||
AccessTokenKey: WithingsAccessTokenKey,
|
||||
RefreshTokenKey: WithingsRefreshTokenKey,
|
||||
SegmentName: "withings",
|
||||
}
|
||||
oauth.Init(env, props)
|
||||
|
||||
w.api = &withingsAPI{
|
||||
OAuth: &http.OAuth{
|
||||
Props: props,
|
||||
Env: env,
|
||||
AccessTokenKey: WithingsAccessTokenKey,
|
||||
RefreshTokenKey: WithingsRefreshTokenKey,
|
||||
SegmentName: "withings",
|
||||
},
|
||||
OAuthRequest: oauth,
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue