mirror of
https://github.com/JanDeDobbeleer/oh-my-posh.git
synced 2025-03-05 20:49:04 -08:00
refactor: do generic http request
This commit is contained in:
parent
64231a790f
commit
cccb502989
|
@ -2,7 +2,6 @@ package http
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -31,32 +30,27 @@ func (a *OAuthError) Error() string {
|
||||||
return a.message
|
return a.message
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuth struct {
|
type OAuthRequest struct {
|
||||||
Props properties.Properties
|
Request
|
||||||
Env environment.Environment
|
|
||||||
|
|
||||||
AccessTokenKey string
|
AccessTokenKey string
|
||||||
RefreshTokenKey string
|
RefreshTokenKey string
|
||||||
SegmentName string
|
SegmentName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *OAuth) error(err error) {
|
func (o *OAuthRequest) getAccessToken() (string, error) {
|
||||||
o.Env.Log(environment.Error, "OAuth", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *OAuth) getAccessToken() (string, error) {
|
|
||||||
// get directly from cache
|
// get directly from cache
|
||||||
if acccessToken, OK := o.Env.Cache().Get(o.AccessTokenKey); OK && len(acccessToken) != 0 {
|
if acccessToken, OK := o.env.Cache().Get(o.AccessTokenKey); OK && len(acccessToken) != 0 {
|
||||||
return acccessToken, nil
|
return acccessToken, nil
|
||||||
}
|
}
|
||||||
// use cached refresh token to get new access token
|
// use cached refresh token to get new access token
|
||||||
if refreshToken, OK := o.Env.Cache().Get(o.RefreshTokenKey); OK && len(refreshToken) != 0 {
|
if refreshToken, OK := o.env.Cache().Get(o.RefreshTokenKey); OK && len(refreshToken) != 0 {
|
||||||
if acccessToken, err := o.refreshToken(refreshToken); err == nil {
|
if acccessToken, err := o.refreshToken(refreshToken); err == nil {
|
||||||
return acccessToken, nil
|
return acccessToken, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// use initial refresh token from property
|
// use initial refresh token from property
|
||||||
refreshToken := o.Props.GetString(properties.RefreshToken, "")
|
refreshToken := o.props.GetString(properties.RefreshToken, "")
|
||||||
// ignore an empty or default refresh token
|
// ignore an empty or default refresh token
|
||||||
if len(refreshToken) == 0 || refreshToken == DefaultRefreshToken {
|
if len(refreshToken) == 0 || refreshToken == DefaultRefreshToken {
|
||||||
return "", &OAuthError{
|
return "", &OAuthError{
|
||||||
|
@ -68,10 +62,10 @@ func (o *OAuth) getAccessToken() (string, error) {
|
||||||
return acccessToken, err
|
return acccessToken, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *OAuth) refreshToken(refreshToken string) (string, error) {
|
func (o *OAuthRequest) refreshToken(refreshToken string) (string, error) {
|
||||||
httpTimeout := o.Props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
|
httpTimeout := o.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
|
||||||
url := fmt.Sprintf("https://ohmyposh.dev/api/refresh?segment=%s&token=%s", o.SegmentName, refreshToken)
|
url := fmt.Sprintf("https://ohmyposh.dev/api/refresh?segment=%s&token=%s", o.SegmentName, refreshToken)
|
||||||
body, err := o.Env.HTTPRequest(url, nil, httpTimeout)
|
body, err := o.env.HTTPRequest(url, nil, httpTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", &OAuthError{
|
return "", &OAuthError{
|
||||||
// This might happen if /api was asleep. Assume the user will just retry
|
// This might happen if /api was asleep. Assume the user will just retry
|
||||||
|
@ -86,68 +80,31 @@ func (o *OAuth) refreshToken(refreshToken string) (string, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// add tokens to cache
|
// add tokens to cache
|
||||||
o.Env.Cache().Set(o.AccessTokenKey, tokens.AccessToken, tokens.ExpiresIn/60)
|
o.env.Cache().Set(o.AccessTokenKey, tokens.AccessToken, tokens.ExpiresIn/60)
|
||||||
o.Env.Cache().Set(o.RefreshTokenKey, tokens.RefreshToken, 2*525960) // it should never expire unless revoked, default to 2 year
|
o.env.Cache().Set(o.RefreshTokenKey, tokens.RefreshToken, 2*525960) // it should never expire unless revoked, default to 2 year
|
||||||
return tokens.AccessToken, nil
|
return tokens.AccessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func OauthResult[a any](o *OAuth, url string, body io.Reader, requestModifiers ...environment.HTTPRequestModifier) (a, error) {
|
func OauthResult[a any](o *OAuthRequest, url string, body io.Reader, requestModifiers ...environment.HTTPRequestModifier) (a, error) {
|
||||||
var data a
|
addToken := func() error {
|
||||||
|
accessToken, err := o.getAccessToken()
|
||||||
getCacheValue := func(key string) (a, error) {
|
if err != nil {
|
||||||
if val, found := o.Env.Cache().Get(key); found {
|
return err
|
||||||
err := json.Unmarshal([]byte(val), &data)
|
|
||||||
if err != nil {
|
|
||||||
o.error(err)
|
|
||||||
return data, err
|
|
||||||
}
|
|
||||||
return data, nil
|
|
||||||
}
|
}
|
||||||
err := errors.New("no data in cache")
|
|
||||||
o.error(err)
|
|
||||||
return data, err
|
|
||||||
}
|
|
||||||
|
|
||||||
httpTimeout := o.Props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
|
// add token to header for authentication
|
||||||
|
addAuthHeader := func(request *http.Request) {
|
||||||
// No need to check more than every 30 minutes by default
|
request.Header.Add("Authorization", "Bearer "+accessToken)
|
||||||
cacheTimeout := o.Props.GetInt(properties.CacheTimeout, 30)
|
|
||||||
if cacheTimeout > 0 {
|
|
||||||
if data, err := getCacheValue(url); err == nil {
|
|
||||||
return data, nil
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
accessToken, err := o.getAccessToken()
|
if requestModifiers == nil {
|
||||||
if err != nil {
|
requestModifiers = []environment.HTTPRequestModifier{}
|
||||||
return data, err
|
}
|
||||||
|
|
||||||
|
requestModifiers = append(requestModifiers, addAuthHeader)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// add token to header for authentication
|
return do[a](&o.Request, url, body, addToken, requestModifiers...)
|
||||||
addAuthHeader := func(request *http.Request) {
|
|
||||||
request.Header.Add("Authorization", "Bearer "+accessToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
if requestModifiers == nil {
|
|
||||||
requestModifiers = []environment.HTTPRequestModifier{}
|
|
||||||
}
|
|
||||||
|
|
||||||
requestModifiers = append(requestModifiers, addAuthHeader)
|
|
||||||
|
|
||||||
responseBody, err := o.Env.HTTPRequest(url, body, httpTimeout, requestModifiers...)
|
|
||||||
if err != nil {
|
|
||||||
o.error(err)
|
|
||||||
return data, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = json.Unmarshal(responseBody, &data)
|
|
||||||
if err != nil {
|
|
||||||
o.error(err)
|
|
||||||
return data, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if cacheTimeout > 0 {
|
|
||||||
o.Env.Cache().Set(url, string(responseBody), cacheTimeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
return data, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -163,13 +163,12 @@ func TestOauthResult(t *testing.T) {
|
||||||
env.On("HTTPRequest", tokenURL).Return([]byte(tc.TokenResponse), tc.Error)
|
env.On("HTTPRequest", tokenURL).Return([]byte(tc.TokenResponse), tc.Error)
|
||||||
env.On("Log", environment.Error, "OAuth", mock2.Anything).Return()
|
env.On("Log", environment.Error, "OAuth", mock2.Anything).Return()
|
||||||
|
|
||||||
oauth := &OAuth{
|
oauth := &OAuthRequest{
|
||||||
Props: props,
|
|
||||||
Env: env,
|
|
||||||
AccessTokenKey: accessTokenKey,
|
AccessTokenKey: accessTokenKey,
|
||||||
RefreshTokenKey: refreshTokenKey,
|
RefreshTokenKey: refreshTokenKey,
|
||||||
SegmentName: "test",
|
SegmentName: "test",
|
||||||
}
|
}
|
||||||
|
oauth.Init(env, props)
|
||||||
|
|
||||||
got, err := OauthResult[*data](oauth, url, nil)
|
got, err := OauthResult[*data](oauth, url, nil)
|
||||||
assert.Equal(t, tc.ExpectedData, got, tc.Case)
|
assert.Equal(t, tc.ExpectedData, got, tc.Case)
|
||||||
|
|
75
src/http/request.go
Normal file
75
src/http/request.go
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"oh-my-posh/environment"
|
||||||
|
"oh-my-posh/properties"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Request struct {
|
||||||
|
props properties.Properties
|
||||||
|
env environment.Environment
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) Init(env environment.Environment, props properties.Properties) {
|
||||||
|
r.env = env
|
||||||
|
r.props = props
|
||||||
|
}
|
||||||
|
|
||||||
|
func Do[a any](r *Request, url string, requestModifiers ...environment.HTTPRequestModifier) (a, error) {
|
||||||
|
return do[a](r, url, nil, nil, requestModifiers...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func do[a any](r *Request, url string, body io.Reader, preRequestFunc func() error, requestModifiers ...environment.HTTPRequestModifier) (a, error) {
|
||||||
|
var data a
|
||||||
|
|
||||||
|
getCacheValue := func(key string) (a, error) {
|
||||||
|
if val, found := r.env.Cache().Get(key); found {
|
||||||
|
err := json.Unmarshal([]byte(val), &data)
|
||||||
|
if err != nil {
|
||||||
|
r.env.Log(environment.Error, "OAuth", err.Error())
|
||||||
|
return data, err
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
err := errors.New("no data in cache")
|
||||||
|
r.env.Log(environment.Error, "OAuth", err.Error())
|
||||||
|
return data, err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpTimeout := r.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
|
||||||
|
|
||||||
|
// No need to check more than every 30 minutes by default
|
||||||
|
cacheTimeout := r.props.GetInt(properties.CacheTimeout, 30)
|
||||||
|
if cacheTimeout > 0 {
|
||||||
|
if data, err := getCacheValue(url); err == nil {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if preRequestFunc != nil {
|
||||||
|
if err := preRequestFunc(); err != nil {
|
||||||
|
return data, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
responseBody, err := r.env.HTTPRequest(url, body, httpTimeout, requestModifiers...)
|
||||||
|
if err != nil {
|
||||||
|
r.env.Log(environment.Error, "OAuth", err.Error())
|
||||||
|
return data, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(responseBody, &data)
|
||||||
|
if err != nil {
|
||||||
|
r.env.Log(environment.Error, "OAuth", err.Error())
|
||||||
|
return data, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if cacheTimeout > 0 {
|
||||||
|
r.env.Cache().Set(url, string(responseBody), cacheTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
91
src/http/request_test.go
Normal file
91
src/http/request_test.go
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"oh-my-posh/environment"
|
||||||
|
"oh-my-posh/mock"
|
||||||
|
"oh-my-posh/properties"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
mock2 "github.com/stretchr/testify/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestResult(t *testing.T) {
|
||||||
|
successData := &data{Hello: "world"}
|
||||||
|
jsonResponse := `{ "hello":"world" }`
|
||||||
|
url := "https://google.com?q=hello"
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
Case string
|
||||||
|
// API response
|
||||||
|
JSONResponse string
|
||||||
|
// Cache
|
||||||
|
CacheJSONResponse string
|
||||||
|
CacheTimeout int
|
||||||
|
ResponseCacheMiss bool
|
||||||
|
// Errors
|
||||||
|
Error error
|
||||||
|
// Validations
|
||||||
|
ExpectedErrorMessage string
|
||||||
|
ExpectedData *data
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Case: "No cache",
|
||||||
|
JSONResponse: jsonResponse,
|
||||||
|
ExpectedData: successData,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Case: "Cache",
|
||||||
|
CacheJSONResponse: `{ "hello":"mom" }`,
|
||||||
|
ExpectedData: &data{Hello: "mom"},
|
||||||
|
CacheTimeout: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Case: "Cache miss",
|
||||||
|
ResponseCacheMiss: true,
|
||||||
|
JSONResponse: jsonResponse,
|
||||||
|
CacheJSONResponse: `{ "hello":"mom" }`,
|
||||||
|
ExpectedData: successData,
|
||||||
|
CacheTimeout: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Case: "DNS error",
|
||||||
|
Error: &net.DNSError{IsNotFound: true},
|
||||||
|
ExpectedErrorMessage: "lookup : ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Case: "Response incorrect",
|
||||||
|
JSONResponse: `[`,
|
||||||
|
ExpectedErrorMessage: "unexpected end of JSON input",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
var props properties.Map = map[properties.Property]interface{}{
|
||||||
|
properties.CacheTimeout: tc.CacheTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mock.MockedCache{}
|
||||||
|
|
||||||
|
cache.On("Get", url).Return(tc.CacheJSONResponse, !tc.ResponseCacheMiss)
|
||||||
|
cache.On("Set", mock2.Anything, mock2.Anything, mock2.Anything)
|
||||||
|
|
||||||
|
env := &mock.MockedEnvironment{}
|
||||||
|
|
||||||
|
env.On("Cache").Return(cache)
|
||||||
|
env.On("HTTPRequest", url).Return([]byte(tc.JSONResponse), tc.Error)
|
||||||
|
env.On("Log", environment.Error, "OAuth", mock2.Anything).Return()
|
||||||
|
|
||||||
|
request := &Request{}
|
||||||
|
request.Init(env, props)
|
||||||
|
|
||||||
|
got, err := Do[*data](request, url, nil)
|
||||||
|
assert.Equal(t, tc.ExpectedData, got, tc.Case)
|
||||||
|
if len(tc.ExpectedErrorMessage) == 0 {
|
||||||
|
assert.Nil(t, err, tc.Case)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tc.ExpectedErrorMessage, err.Error(), tc.Case)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -15,12 +15,12 @@ type StravaAPI interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type stravaAPI struct {
|
type stravaAPI struct {
|
||||||
http.OAuth
|
http.OAuthRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stravaAPI) GetActivities() ([]*StravaData, error) {
|
func (s *stravaAPI) GetActivities() ([]*StravaData, error) {
|
||||||
url := "https://www.strava.com/api/v3/athlete/activities?page=1&per_page=1"
|
url := "https://www.strava.com/api/v3/athlete/activities?page=1&per_page=1"
|
||||||
return http.OauthResult[[]*StravaData](&s.OAuth, url, nil)
|
return http.OauthResult[[]*StravaData](&s.OAuthRequest, url, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// segment struct, makes templating easier
|
// segment struct, makes templating easier
|
||||||
|
@ -128,13 +128,14 @@ func (s *Strava) getActivityIcon() string {
|
||||||
func (s *Strava) Init(props properties.Properties, env environment.Environment) {
|
func (s *Strava) Init(props properties.Properties, env environment.Environment) {
|
||||||
s.props = props
|
s.props = props
|
||||||
|
|
||||||
|
oauth := &http.OAuthRequest{
|
||||||
|
AccessTokenKey: StravaAccessTokenKey,
|
||||||
|
RefreshTokenKey: StravaRefreshTokenKey,
|
||||||
|
SegmentName: "strava",
|
||||||
|
}
|
||||||
|
oauth.Init(env, props)
|
||||||
|
|
||||||
s.api = &stravaAPI{
|
s.api = &stravaAPI{
|
||||||
OAuth: http.OAuth{
|
OAuthRequest: *oauth,
|
||||||
Props: props,
|
|
||||||
Env: env,
|
|
||||||
AccessTokenKey: StravaAccessTokenKey,
|
|
||||||
RefreshTokenKey: StravaRefreshTokenKey,
|
|
||||||
SegmentName: "strava",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,7 +76,7 @@ type WithingsAPI interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type withingsAPI struct {
|
type withingsAPI struct {
|
||||||
*http.OAuth
|
*http.OAuthRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *withingsAPI) GetMeasures(meastypes string) (*WithingsData, error) {
|
func (w *withingsAPI) GetMeasures(meastypes string) (*WithingsData, error) {
|
||||||
|
@ -124,7 +124,7 @@ func (w *withingsAPI) getWithingsData(endpoint string, formData url.Values) (*Wi
|
||||||
request.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
request.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
}
|
}
|
||||||
body := strings.NewReader(formData.Encode())
|
body := strings.NewReader(formData.Encode())
|
||||||
data, err := http.OauthResult[*WithingsData](w.OAuth, endpoint, body, modifiers)
|
data, err := http.OauthResult[*WithingsData](w.OAuthRequest, endpoint, body, modifiers)
|
||||||
if data != nil && data.Status != 0 {
|
if data != nil && data.Status != 0 {
|
||||||
return nil, errors.New("Withings API error: " + strconv.Itoa(data.Status))
|
return nil, errors.New("Withings API error: " + strconv.Itoa(data.Status))
|
||||||
}
|
}
|
||||||
|
@ -219,13 +219,14 @@ func (w *Withings) getSleep() bool {
|
||||||
func (w *Withings) Init(props properties.Properties, env environment.Environment) {
|
func (w *Withings) Init(props properties.Properties, env environment.Environment) {
|
||||||
w.props = props
|
w.props = props
|
||||||
|
|
||||||
|
oauth := &http.OAuthRequest{
|
||||||
|
AccessTokenKey: WithingsAccessTokenKey,
|
||||||
|
RefreshTokenKey: WithingsRefreshTokenKey,
|
||||||
|
SegmentName: "withings",
|
||||||
|
}
|
||||||
|
oauth.Init(env, props)
|
||||||
|
|
||||||
w.api = &withingsAPI{
|
w.api = &withingsAPI{
|
||||||
OAuth: &http.OAuth{
|
OAuthRequest: oauth,
|
||||||
Props: props,
|
|
||||||
Env: env,
|
|
||||||
AccessTokenKey: WithingsAccessTokenKey,
|
|
||||||
RefreshTokenKey: WithingsRefreshTokenKey,
|
|
||||||
SegmentName: "withings",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue