mirror of
https://github.com/JanDeDobbeleer/oh-my-posh.git
synced 2025-01-30 04:21:19 -08:00
refactor: extract OAuth logic
This commit is contained in:
parent
a6e9a3561b
commit
e5bf5db9c2
146
src/http/oauth.go
Normal file
146
src/http/oauth.go
Normal file
|
@ -0,0 +1,146 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"oh-my-posh/environment"
|
||||
"oh-my-posh/properties"
|
||||
)
|
||||
|
||||
const (
|
||||
Timeout = "timeout"
|
||||
InvalidRefreshToken = "invalid refresh token"
|
||||
TokenRefreshFailed = "token refresh error"
|
||||
DefaultRefreshToken = "111111111111111111111111111111"
|
||||
)
|
||||
|
||||
type tokenExchange struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
type OAuthError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (a *OAuthError) Error() string {
|
||||
return a.message
|
||||
}
|
||||
|
||||
type OAuth struct {
|
||||
Props properties.Properties
|
||||
Env environment.Environment
|
||||
|
||||
AccessTokenKey string
|
||||
RefreshTokenKey string
|
||||
SegmentName string
|
||||
}
|
||||
|
||||
func (o *OAuth) error(err error) {
|
||||
o.Env.Log(environment.Error, "OAuth", err.Error())
|
||||
}
|
||||
|
||||
func (o *OAuth) getAccessToken() (string, error) {
|
||||
// get directly from cache
|
||||
if acccessToken, OK := o.Env.Cache().Get(o.AccessTokenKey); OK {
|
||||
return acccessToken, nil
|
||||
}
|
||||
// use cached refresh token to get new access token
|
||||
if refreshToken, OK := o.Env.Cache().Get(o.RefreshTokenKey); OK {
|
||||
if acccessToken, err := o.refreshToken(refreshToken); err == nil {
|
||||
return acccessToken, nil
|
||||
}
|
||||
}
|
||||
// use initial refresh token from property
|
||||
refreshToken := o.Props.GetString(properties.RefreshToken, "")
|
||||
// ignore an empty or default refresh token
|
||||
if len(refreshToken) == 0 || refreshToken == DefaultRefreshToken {
|
||||
return "", &OAuthError{
|
||||
message: InvalidRefreshToken,
|
||||
}
|
||||
}
|
||||
// no need to let the user provide access token, we'll always verify the refresh token
|
||||
acccessToken, err := o.refreshToken(refreshToken)
|
||||
return acccessToken, err
|
||||
}
|
||||
|
||||
func (o *OAuth) refreshToken(refreshToken string) (string, error) {
|
||||
httpTimeout := o.Props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
|
||||
url := fmt.Sprintf("https://ohmyposh.dev/api/refresh?segment=%s&token=%s", o.SegmentName, refreshToken)
|
||||
body, err := o.Env.HTTPRequest(url, httpTimeout)
|
||||
if err != nil {
|
||||
return "", &OAuthError{
|
||||
// This might happen if /api was asleep. Assume the user will just retry
|
||||
message: Timeout,
|
||||
}
|
||||
}
|
||||
tokens := &tokenExchange{}
|
||||
err = json.Unmarshal(body, &tokens)
|
||||
if err != nil {
|
||||
return "", &OAuthError{
|
||||
message: TokenRefreshFailed,
|
||||
}
|
||||
}
|
||||
// add tokens to cache
|
||||
o.Env.Cache().Set(o.AccessTokenKey, tokens.AccessToken, tokens.ExpiresIn/60)
|
||||
o.Env.Cache().Set(o.RefreshTokenKey, tokens.RefreshToken, 2*525960) // it should never expire unless revoked, default to 2 year
|
||||
return tokens.AccessToken, nil
|
||||
}
|
||||
|
||||
func OauthResult[a any](o *OAuth, url string) (a, error) {
|
||||
var data a
|
||||
|
||||
getCacheValue := func(key string) (a, error) {
|
||||
if val, found := o.Env.Cache().Get(key); found {
|
||||
err := json.Unmarshal([]byte(val), &data)
|
||||
if err != nil {
|
||||
o.error(err)
|
||||
return data, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
err := errors.New("no data in cache")
|
||||
o.error(err)
|
||||
return data, err
|
||||
}
|
||||
|
||||
httpTimeout := o.Props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
|
||||
|
||||
// No need to check more than every 30 minutes by default
|
||||
cacheTimeout := o.Props.GetInt(properties.CacheTimeout, 30)
|
||||
if cacheTimeout > 0 {
|
||||
if data, err := getCacheValue(url); err == nil {
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
accessToken, err := o.getAccessToken()
|
||||
if err != nil {
|
||||
return data, err
|
||||
}
|
||||
|
||||
// add token to header for authentication
|
||||
addAuthHeader := func(request *http.Request) {
|
||||
request.Header.Add("Authorization", "Bearer "+accessToken)
|
||||
}
|
||||
|
||||
body, err := o.Env.HTTPRequest(url, httpTimeout, addAuthHeader)
|
||||
if err != nil {
|
||||
o.error(err)
|
||||
return data, err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(body, &data)
|
||||
if err != nil {
|
||||
o.error(err)
|
||||
return data, err
|
||||
}
|
||||
|
||||
if cacheTimeout > 0 {
|
||||
o.Env.Cache().Set(url, string(body), cacheTimeout)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
182
src/http/oauth_test.go
Normal file
182
src/http/oauth_test.go
Normal file
|
@ -0,0 +1,182 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"oh-my-posh/environment"
|
||||
"oh-my-posh/mock"
|
||||
"oh-my-posh/properties"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
mock2 "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type data struct {
|
||||
Hello string `json:"hello"`
|
||||
}
|
||||
|
||||
func TestOauthResult(t *testing.T) {
|
||||
accessTokenKey := "test_access_token"
|
||||
refreshTokenKey := "test_refresh_token"
|
||||
tokenResponse := `{ "access_token":"NEW_ACCESSTOKEN","refresh_token":"NEW_REFRESHTOKEN", "expires_in":1234 }`
|
||||
jsonResponse := `{ "hello":"world" }`
|
||||
successData := &data{Hello: "world"}
|
||||
|
||||
cases := []struct {
|
||||
Case string
|
||||
// tokens
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
TokenResponse string
|
||||
// API response
|
||||
JSONResponse string
|
||||
// Cache
|
||||
CacheJSONResponse string
|
||||
CacheTimeout int
|
||||
RefreshTokenFromCache bool
|
||||
AccessTokenFromCache bool
|
||||
ResponseCacheMiss bool
|
||||
// Errors
|
||||
Error error
|
||||
// Validations
|
||||
ExpectedErrorMessage string
|
||||
ExpectedData *data
|
||||
}{
|
||||
{
|
||||
Case: "No initial tokens",
|
||||
ExpectedErrorMessage: InvalidRefreshToken,
|
||||
},
|
||||
{
|
||||
Case: "Use config tokens",
|
||||
AccessToken: "INITIAL_ACCESSTOKEN",
|
||||
RefreshToken: "INITIAL_REFRESHTOKEN",
|
||||
TokenResponse: tokenResponse,
|
||||
JSONResponse: jsonResponse,
|
||||
ExpectedData: successData,
|
||||
},
|
||||
{
|
||||
Case: "Access token from cache",
|
||||
AccessToken: "ACCESSTOKEN",
|
||||
AccessTokenFromCache: true,
|
||||
JSONResponse: jsonResponse,
|
||||
ExpectedData: successData,
|
||||
},
|
||||
{
|
||||
Case: "Refresh token from cache",
|
||||
RefreshToken: "REFRESH_TOKEN",
|
||||
RefreshTokenFromCache: true,
|
||||
JSONResponse: jsonResponse,
|
||||
TokenResponse: tokenResponse,
|
||||
ExpectedData: successData,
|
||||
},
|
||||
{
|
||||
Case: "Refresh token from cache, success",
|
||||
RefreshToken: "REFRESH_TOKEN",
|
||||
RefreshTokenFromCache: true,
|
||||
JSONResponse: jsonResponse,
|
||||
TokenResponse: tokenResponse,
|
||||
ExpectedData: successData,
|
||||
},
|
||||
{
|
||||
Case: "Refresh API error",
|
||||
RefreshToken: "REFRESH_TOKEN",
|
||||
RefreshTokenFromCache: true,
|
||||
Error: fmt.Errorf("API error"),
|
||||
ExpectedErrorMessage: Timeout,
|
||||
},
|
||||
{
|
||||
Case: "Refresh API parse error",
|
||||
RefreshToken: "REFRESH_TOKEN",
|
||||
RefreshTokenFromCache: true,
|
||||
TokenResponse: "INVALID_JSON",
|
||||
ExpectedErrorMessage: TokenRefreshFailed,
|
||||
},
|
||||
{
|
||||
Case: "Default config token",
|
||||
RefreshToken: DefaultRefreshToken,
|
||||
ExpectedErrorMessage: InvalidRefreshToken,
|
||||
},
|
||||
{
|
||||
Case: "Cache data",
|
||||
CacheTimeout: 60,
|
||||
CacheJSONResponse: jsonResponse,
|
||||
ExpectedData: successData,
|
||||
},
|
||||
{
|
||||
Case: "Cache data, invalid data",
|
||||
CacheTimeout: 60,
|
||||
RefreshToken: "REFRESH_TOKEN",
|
||||
CacheJSONResponse: "ERR",
|
||||
TokenResponse: tokenResponse,
|
||||
JSONResponse: jsonResponse,
|
||||
ExpectedData: successData,
|
||||
},
|
||||
{
|
||||
Case: "Cache data, no cache",
|
||||
CacheTimeout: 60,
|
||||
RefreshToken: "REFRESH_TOKEN",
|
||||
ResponseCacheMiss: true,
|
||||
TokenResponse: tokenResponse,
|
||||
JSONResponse: jsonResponse,
|
||||
ExpectedData: successData,
|
||||
},
|
||||
{
|
||||
Case: "API body failure",
|
||||
AccessToken: "ACCESSTOKEN",
|
||||
AccessTokenFromCache: true,
|
||||
ResponseCacheMiss: true,
|
||||
JSONResponse: "ERR",
|
||||
ExpectedErrorMessage: "invalid character 'E' looking for beginning of value",
|
||||
},
|
||||
{
|
||||
Case: "API request failure",
|
||||
AccessToken: "ACCESSTOKEN",
|
||||
AccessTokenFromCache: true,
|
||||
ResponseCacheMiss: true,
|
||||
JSONResponse: "ERR",
|
||||
Error: fmt.Errorf("no response"),
|
||||
ExpectedErrorMessage: "no response",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
url := "https://www.strava.com/api/v3/athlete/activities?page=1&per_page=1"
|
||||
tokenURL := fmt.Sprintf("https://ohmyposh.dev/api/refresh?segment=test&token=%s", tc.RefreshToken)
|
||||
|
||||
var props properties.Map = map[properties.Property]interface{}{
|
||||
properties.CacheTimeout: tc.CacheTimeout,
|
||||
properties.AccessToken: tc.AccessToken,
|
||||
properties.RefreshToken: tc.RefreshToken,
|
||||
}
|
||||
|
||||
cache := &mock.MockedCache{}
|
||||
|
||||
cache.On("Get", url).Return(tc.CacheJSONResponse, !tc.ResponseCacheMiss)
|
||||
cache.On("Get", accessTokenKey).Return(tc.AccessToken, tc.AccessTokenFromCache)
|
||||
cache.On("Get", refreshTokenKey).Return(tc.RefreshToken, tc.RefreshTokenFromCache)
|
||||
cache.On("Set", mock2.Anything, mock2.Anything, mock2.Anything)
|
||||
|
||||
env := &mock.MockedEnvironment{}
|
||||
|
||||
env.On("Cache").Return(cache)
|
||||
env.On("HTTPRequest", url).Return([]byte(tc.JSONResponse), tc.Error)
|
||||
env.On("HTTPRequest", tokenURL).Return([]byte(tc.TokenResponse), tc.Error)
|
||||
env.On("Log", environment.Error, "OAuth", mock2.Anything).Return()
|
||||
|
||||
oauth := &OAuth{
|
||||
Props: props,
|
||||
Env: env,
|
||||
AccessTokenKey: accessTokenKey,
|
||||
RefreshTokenKey: refreshTokenKey,
|
||||
SegmentName: "test",
|
||||
}
|
||||
|
||||
got, err := OauthResult[*data](oauth, url)
|
||||
assert.Equal(t, tc.ExpectedData, got, tc.Case)
|
||||
if len(tc.ExpectedErrorMessage) == 0 {
|
||||
assert.Nil(t, err, tc.Case)
|
||||
} else {
|
||||
assert.Equal(t, tc.ExpectedErrorMessage, err.Error(), tc.Case)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -43,6 +43,14 @@ const (
|
|||
AccessToken Property = "access_token"
|
||||
// RefreshToken is the refresh token to use for an API
|
||||
RefreshToken Property = "refresh_token"
|
||||
// HTTPTimeout timeout used when executing http request
|
||||
HTTPTimeout Property = "http_timeout"
|
||||
// DefaultHTTPTimeout default timeout used when executing http request
|
||||
DefaultHTTPTimeout = 20
|
||||
// DefaultCacheTimeout default timeout used when caching data
|
||||
DefaultCacheTimeout = 10
|
||||
// CacheTimeout cache timeout
|
||||
CacheTimeout Property = "cache_timeout"
|
||||
)
|
||||
|
||||
type Map map[Property]interface{}
|
||||
|
|
|
@ -245,7 +245,7 @@ func (bf *Brewfather) getResult() (*Batch, error) {
|
|||
batchURL := fmt.Sprintf("https://api.brewfather.app/v1/batches/%s", batchID)
|
||||
batchReadingsURL := fmt.Sprintf("https://api.brewfather.app/v1/batches/%s/readings", batchID)
|
||||
|
||||
httpTimeout := bf.props.GetInt(HTTPTimeout, DefaultHTTPTimeout)
|
||||
httpTimeout := bf.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
|
||||
cacheTimeout := bf.props.GetInt(BFCacheTimeout, 5)
|
||||
|
||||
if cacheTimeout > 0 {
|
||||
|
|
|
@ -140,10 +140,10 @@ func TestBrewfatherSegment(t *testing.T) {
|
|||
for _, tc := range cases {
|
||||
env := &mock.MockedEnvironment{}
|
||||
props := properties.Map{
|
||||
CacheTimeout: tc.CacheTimeout,
|
||||
BFBatchID: BFFakeBatchID,
|
||||
BFAPIKey: "FAKE",
|
||||
BFUserID: "FAKE",
|
||||
properties.CacheTimeout: tc.CacheTimeout,
|
||||
BFBatchID: BFFakeBatchID,
|
||||
BFAPIKey: "FAKE",
|
||||
BFUserID: "FAKE",
|
||||
}
|
||||
|
||||
cache := &mock.MockedCache{}
|
||||
|
|
|
@ -30,7 +30,7 @@ func (i *IPify) Enabled() bool {
|
|||
}
|
||||
|
||||
func (i *IPify) getResult() (string, error) {
|
||||
cacheTimeout := i.props.GetInt(CacheTimeout, DefaultCacheTimeout)
|
||||
cacheTimeout := i.props.GetInt(properties.CacheTimeout, properties.DefaultCacheTimeout)
|
||||
|
||||
url := i.props.GetString(IpifyURL, "https://api.ipify.org")
|
||||
|
||||
|
@ -43,7 +43,7 @@ func (i *IPify) getResult() (string, error) {
|
|||
}
|
||||
}
|
||||
|
||||
httpTimeout := i.props.GetInt(HTTPTimeout, DefaultHTTPTimeout)
|
||||
httpTimeout := i.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
|
||||
|
||||
body, err := i.env.HTTPRequest(url, httpTimeout)
|
||||
if err != nil {
|
||||
|
|
|
@ -46,7 +46,7 @@ func TestIpifySegment(t *testing.T) {
|
|||
for _, tc := range cases {
|
||||
env := &mock.MockedEnvironment{}
|
||||
props := properties.Map{
|
||||
CacheTimeout: 0,
|
||||
properties.CacheTimeout: 0,
|
||||
}
|
||||
env.On("HTTPRequest", IPIFYAPIURL).Return([]byte(tc.Response), tc.Error)
|
||||
|
||||
|
|
|
@ -107,7 +107,7 @@ func (ns *Nightscout) getResult() (*NightscoutData, error) {
|
|||
}
|
||||
|
||||
url := ns.props.GetString(URL, "")
|
||||
httpTimeout := ns.props.GetInt(HTTPTimeout, DefaultHTTPTimeout)
|
||||
httpTimeout := ns.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
|
||||
// natural and understood NS timeout is 5, anything else is unusual
|
||||
cacheTimeout := ns.props.GetInt(NSCacheTimeout, 5)
|
||||
|
||||
|
|
|
@ -134,8 +134,8 @@ func TestNSSegment(t *testing.T) {
|
|||
for _, tc := range cases {
|
||||
env := &mock.MockedEnvironment{}
|
||||
props := properties.Map{
|
||||
CacheTimeout: tc.CacheTimeout,
|
||||
URL: "FAKE",
|
||||
properties.CacheTimeout: tc.CacheTimeout,
|
||||
URL: "FAKE",
|
||||
}
|
||||
|
||||
cache := &mock.MockedCache{}
|
||||
|
|
|
@ -26,8 +26,6 @@ const (
|
|||
Location properties.Property = "location"
|
||||
// Units openweathermap units
|
||||
Units properties.Property = "units"
|
||||
// CacheTimeout cache timeout
|
||||
CacheTimeout properties.Property = "cache_timeout"
|
||||
// CacheKeyResponse key used when caching the response
|
||||
CacheKeyResponse string = "owm_response"
|
||||
// CacheKeyURL key used when caching the url responsible for the response
|
||||
|
@ -58,7 +56,7 @@ func (d *Owm) Template() string {
|
|||
}
|
||||
|
||||
func (d *Owm) getResult() (*owmDataResponse, error) {
|
||||
cacheTimeout := d.props.GetInt(CacheTimeout, DefaultCacheTimeout)
|
||||
cacheTimeout := d.props.GetInt(properties.CacheTimeout, properties.DefaultCacheTimeout)
|
||||
response := new(owmDataResponse)
|
||||
if cacheTimeout > 0 {
|
||||
// check if data stored in cache
|
||||
|
@ -77,7 +75,7 @@ func (d *Owm) getResult() (*owmDataResponse, error) {
|
|||
apikey := d.props.GetString(APIKey, ".")
|
||||
location := d.props.GetString(Location, "De Bilt,NL")
|
||||
units := d.props.GetString(Units, "standard")
|
||||
httpTimeout := d.props.GetInt(HTTPTimeout, DefaultHTTPTimeout)
|
||||
httpTimeout := d.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
|
||||
d.URL = fmt.Sprintf("http://api.openweathermap.org/data/2.5/weather?q=%s&units=%s&appid=%s", location, units, apikey)
|
||||
|
||||
body, err := d.env.HTTPRequest(d.URL, httpTimeout)
|
||||
|
|
|
@ -54,10 +54,10 @@ func TestOWMSegmentSingle(t *testing.T) {
|
|||
for _, tc := range cases {
|
||||
env := &mock.MockedEnvironment{}
|
||||
props := properties.Map{
|
||||
APIKey: "key",
|
||||
Location: "AMSTERDAM,NL",
|
||||
Units: "metric",
|
||||
CacheTimeout: 0,
|
||||
APIKey: "key",
|
||||
Location: "AMSTERDAM,NL",
|
||||
Units: "metric",
|
||||
properties.CacheTimeout: 0,
|
||||
}
|
||||
|
||||
env.On("HTTPRequest", OWMAPIURL).Return([]byte(tc.JSONResponse), tc.Error)
|
||||
|
@ -189,10 +189,10 @@ func TestOWMSegmentIcons(t *testing.T) {
|
|||
|
||||
o := &Owm{
|
||||
props: properties.Map{
|
||||
APIKey: "key",
|
||||
Location: "AMSTERDAM,NL",
|
||||
Units: "metric",
|
||||
CacheTimeout: 0,
|
||||
APIKey: "key",
|
||||
Location: "AMSTERDAM,NL",
|
||||
Units: "metric",
|
||||
properties.CacheTimeout: 0,
|
||||
},
|
||||
env: env,
|
||||
}
|
||||
|
@ -212,10 +212,10 @@ func TestOWMSegmentIcons(t *testing.T) {
|
|||
|
||||
o := &Owm{
|
||||
props: properties.Map{
|
||||
APIKey: "key",
|
||||
Location: "AMSTERDAM,NL",
|
||||
Units: "metric",
|
||||
CacheTimeout: 0,
|
||||
APIKey: "key",
|
||||
Location: "AMSTERDAM,NL",
|
||||
Units: "metric",
|
||||
properties.CacheTimeout: 0,
|
||||
},
|
||||
env: env,
|
||||
}
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
package segments
|
||||
|
||||
import "oh-my-posh/properties"
|
||||
|
||||
const (
|
||||
// HTTPTimeout timeout used when executing http request
|
||||
HTTPTimeout properties.Property = "http_timeout"
|
||||
// DefaultHTTPTimeout default timeout used when executing http request
|
||||
DefaultHTTPTimeout = 20
|
||||
// DefaultCacheTimeout default timeout used when caching data
|
||||
DefaultCacheTimeout = 10
|
||||
)
|
|
@ -1,20 +1,31 @@
|
|||
package segments
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"oh-my-posh/environment"
|
||||
"oh-my-posh/http"
|
||||
"oh-my-posh/properties"
|
||||
"time"
|
||||
)
|
||||
|
||||
// StravaAPI is a wrapper around http.Oauth
|
||||
type StravaAPI interface {
|
||||
GetActivities() ([]*StravaData, error)
|
||||
}
|
||||
|
||||
type stravaAPI struct {
|
||||
http.OAuth
|
||||
}
|
||||
|
||||
func (s *stravaAPI) GetActivities() ([]*StravaData, error) {
|
||||
url := "https://www.strava.com/api/v3/athlete/activities?page=1&per_page=1"
|
||||
return http.OauthResult[[]*StravaData](&s.OAuth, url)
|
||||
}
|
||||
|
||||
// segment struct, makes templating easier
|
||||
type Strava struct {
|
||||
props properties.Properties
|
||||
env environment.Environment
|
||||
|
||||
StravaData
|
||||
Icon string
|
||||
|
@ -23,6 +34,8 @@ type Strava struct {
|
|||
Authenticate bool
|
||||
Error string
|
||||
URL string
|
||||
|
||||
api StravaAPI
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -32,12 +45,10 @@ const (
|
|||
WorkOutIcon properties.Property = "workout_icon"
|
||||
UnknownActivityIcon properties.Property = "unknown_activity_icon"
|
||||
|
||||
StravaAccessToken = "strava_access_token"
|
||||
StravaRefreshToken = "strava_refresh_token"
|
||||
StravaAccessTokenKey = "strava_access_token"
|
||||
StravaRefreshTokenKey = "strava_refresh_token"
|
||||
|
||||
Timeout = "timeout"
|
||||
InvalidRefreshToken = "invalid refresh token"
|
||||
TokenRefreshFailed = "token refresh error"
|
||||
noActivitiesFound = "No activities found"
|
||||
)
|
||||
|
||||
// StravaData struct contains the API data
|
||||
|
@ -56,36 +67,26 @@ type StravaData struct {
|
|||
KudosCount int `json:"kudos_count"`
|
||||
}
|
||||
|
||||
type TokenExchange struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
type AuthError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (a *AuthError) Error() string {
|
||||
return a.message
|
||||
}
|
||||
|
||||
func (s *Strava) Template() string {
|
||||
return " {{ if .Error }}{{ .Error }}{{ else }}{{ .Ago }}{{ end }} "
|
||||
}
|
||||
|
||||
func (s *Strava) Enabled() bool {
|
||||
data, err := s.getResult()
|
||||
if err == nil {
|
||||
s.StravaData = *data
|
||||
data, err := s.api.GetActivities()
|
||||
if err == nil && len(data) > 0 {
|
||||
s.StravaData = *data[0]
|
||||
s.Icon = s.getActivityIcon()
|
||||
s.Hours = s.getHours()
|
||||
s.Ago = s.getAgo()
|
||||
s.URL = fmt.Sprintf("https://www.strava.com/activities/%d", s.ID)
|
||||
return true
|
||||
}
|
||||
if _, s.Authenticate = err.(*AuthError); s.Authenticate {
|
||||
s.Error = err.(*AuthError).Error()
|
||||
if err == nil && len(data) == 0 {
|
||||
s.Error = noActivitiesFound
|
||||
return true
|
||||
}
|
||||
if _, s.Authenticate = err.(*http.OAuthError); s.Authenticate {
|
||||
s.Error = err.(*http.OAuthError).Error()
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
@ -124,115 +125,16 @@ func (s *Strava) getActivityIcon() string {
|
|||
return s.props.GetString(UnknownActivityIcon, "\ue213")
|
||||
}
|
||||
|
||||
func (s *Strava) getAccessToken() (string, error) {
|
||||
// get directly from cache
|
||||
if acccessToken, OK := s.env.Cache().Get(StravaAccessToken); OK {
|
||||
return acccessToken, nil
|
||||
}
|
||||
// use cached refresh token to get new access token
|
||||
if refreshToken, OK := s.env.Cache().Get(StravaRefreshToken); OK {
|
||||
if acccessToken, err := s.refreshToken(refreshToken); err == nil {
|
||||
return acccessToken, nil
|
||||
}
|
||||
}
|
||||
// use initial refresh token from property
|
||||
refreshToken := s.props.GetString(properties.RefreshToken, "")
|
||||
// ignore an empty or default refresh token
|
||||
if len(refreshToken) == 0 || refreshToken == "111111111111111111111111111111" {
|
||||
return "", &AuthError{
|
||||
message: InvalidRefreshToken,
|
||||
}
|
||||
}
|
||||
// no need to let the user provide access token, we'll always verify the refresh token
|
||||
acccessToken, err := s.refreshToken(refreshToken)
|
||||
return acccessToken, err
|
||||
}
|
||||
|
||||
func (s *Strava) refreshToken(refreshToken string) (string, error) {
|
||||
httpTimeout := s.props.GetInt(HTTPTimeout, DefaultHTTPTimeout)
|
||||
url := fmt.Sprintf("https://ohmyposh.dev/api/refresh?segment=strava&token=%s", refreshToken)
|
||||
body, err := s.env.HTTPRequest(url, httpTimeout)
|
||||
if err != nil {
|
||||
return "", &AuthError{
|
||||
// This might happen if /api was asleep. Assume the user will just retry
|
||||
message: Timeout,
|
||||
}
|
||||
}
|
||||
tokens := &TokenExchange{}
|
||||
err = json.Unmarshal(body, &tokens)
|
||||
if err != nil {
|
||||
return "", &AuthError{
|
||||
message: TokenRefreshFailed,
|
||||
}
|
||||
}
|
||||
// add tokens to cache
|
||||
s.env.Cache().Set(StravaAccessToken, tokens.AccessToken, tokens.ExpiresIn/60)
|
||||
s.env.Cache().Set(StravaRefreshToken, tokens.RefreshToken, 2*525960) // it should never expire unless revoked, default to 2 year
|
||||
return tokens.AccessToken, nil
|
||||
}
|
||||
|
||||
func (s *Strava) getResult() (*StravaData, error) {
|
||||
parseSingleElement := func(data []byte) (*StravaData, error) {
|
||||
var result []*StravaData
|
||||
err := json.Unmarshal(data, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, errors.New("no elements in the array")
|
||||
}
|
||||
return result[0], nil
|
||||
}
|
||||
getCacheValue := func(key string) (*StravaData, error) {
|
||||
val, found := s.env.Cache().Get(key)
|
||||
// we got something from the cache
|
||||
if found {
|
||||
if data, err := parseSingleElement([]byte(val)); err == nil {
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("no data in cache")
|
||||
}
|
||||
|
||||
// We only want the last activity
|
||||
url := "https://www.strava.com/api/v3/athlete/activities?page=1&per_page=1"
|
||||
httpTimeout := s.props.GetInt(HTTPTimeout, DefaultHTTPTimeout)
|
||||
|
||||
// No need to check more the every 30 min
|
||||
cacheTimeout := s.props.GetInt(CacheTimeout, 30)
|
||||
if cacheTimeout > 0 {
|
||||
if data, err := getCacheValue(url); err == nil {
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
accessToken, err := s.getAccessToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addAuthHeader := func(request *http.Request) {
|
||||
request.Header.Add("Authorization", "Bearer "+accessToken)
|
||||
}
|
||||
body, err := s.env.HTTPRequest(url, httpTimeout, addAuthHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var arr []*StravaData
|
||||
err = json.Unmarshal(body, &arr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data, err := parseSingleElement(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cacheTimeout > 0 {
|
||||
// persist new sugars in cache
|
||||
s.env.Cache().Set(url, string(body), cacheTimeout)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *Strava) Init(props properties.Properties, env environment.Environment) {
|
||||
s.props = props
|
||||
s.env = env
|
||||
|
||||
s.api = &stravaAPI{
|
||||
OAuth: http.OAuth{
|
||||
Props: props,
|
||||
Env: env,
|
||||
AccessTokenKey: StravaAccessTokenKey,
|
||||
RefreshTokenKey: StravaRefreshTokenKey,
|
||||
SegmentName: "strava",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package segments
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"oh-my-posh/mock"
|
||||
"oh-my-posh/properties"
|
||||
"oh-my-posh/template"
|
||||
|
@ -10,192 +9,107 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
mock2 "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type mockedStravaAPI struct {
|
||||
mock2.Mock
|
||||
}
|
||||
|
||||
func (s *mockedStravaAPI) GetActivities() ([]*StravaData, error) {
|
||||
args := s.Called()
|
||||
return args.Get(0).([]*StravaData), args.Error(1)
|
||||
}
|
||||
|
||||
func TestStravaSegment(t *testing.T) {
|
||||
h, _ := time.ParseDuration("6h")
|
||||
sixHoursAgo := time.Now().Add(-h).Format(time.RFC3339)
|
||||
sixHoursAgo := time.Now().Add(-h)
|
||||
h, _ = time.ParseDuration("100h")
|
||||
fourDaysAgo := time.Now().Add(-h).Format(time.RFC3339)
|
||||
fourDaysAgo := time.Now().Add(-h)
|
||||
|
||||
cases := []struct {
|
||||
Case string
|
||||
JSONResponse string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
AccessTokenCacheFoundFail bool
|
||||
RefreshTokenCacheFoundFail bool
|
||||
InitialAccessToken string
|
||||
InitialRefreshToken string
|
||||
TokenRefreshToken string
|
||||
TokenResponse string
|
||||
TokenTest bool
|
||||
ExpectedString string
|
||||
ExpectedEnabled bool
|
||||
CacheTimeout int
|
||||
CacheFoundFail bool
|
||||
Template string
|
||||
Error error
|
||||
AuthDebugMsg string
|
||||
Case string
|
||||
ExpectedString string
|
||||
ExpectedEnabled bool
|
||||
Template string
|
||||
APIError error
|
||||
StravaData []*StravaData
|
||||
}{
|
||||
{
|
||||
Case: "No initial tokens",
|
||||
InitialAccessToken: "",
|
||||
AccessTokenCacheFoundFail: true,
|
||||
RefreshTokenCacheFoundFail: true,
|
||||
TokenTest: true,
|
||||
AuthDebugMsg: "invalid refresh token",
|
||||
},
|
||||
{
|
||||
Case: "Use initial tokens",
|
||||
AccessToken: "NEW_ACCESSTOKEN",
|
||||
InitialAccessToken: "INITIAL ACCESSTOKEN",
|
||||
InitialRefreshToken: "INITIAL REFRESHTOKEN",
|
||||
TokenRefreshToken: "INITIAL REFRESHTOKEN",
|
||||
TokenResponse: `{ "access_token":"NEW_ACCESSTOKEN","refresh_token":"NEW_REFRESHTOKEN", "expires_in":1234 }`,
|
||||
AccessTokenCacheFoundFail: true,
|
||||
RefreshTokenCacheFoundFail: true,
|
||||
TokenTest: true,
|
||||
},
|
||||
{
|
||||
Case: "Access token from cache",
|
||||
AccessToken: "ACCESSTOKEN",
|
||||
TokenTest: true,
|
||||
},
|
||||
{
|
||||
Case: "Refresh token from cache",
|
||||
AccessTokenCacheFoundFail: true,
|
||||
RefreshTokenCacheFoundFail: false,
|
||||
RefreshToken: "REFRESHTOKEN",
|
||||
TokenRefreshToken: "REFRESHTOKEN",
|
||||
TokenTest: true,
|
||||
AuthDebugMsg: "invalid refresh token",
|
||||
},
|
||||
{
|
||||
Case: "Ride 6",
|
||||
JSONResponse: `
|
||||
[{"type":"Ride","start_date":"` + sixHoursAgo + `","name":"Sesongens første på tjukkas","distance":16144.0}]`,
|
||||
StravaData: []*StravaData{
|
||||
{
|
||||
Type: "Ride",
|
||||
StartDate: sixHoursAgo,
|
||||
Name: "Sesongens første på tjukkas",
|
||||
Distance: 16144.0,
|
||||
},
|
||||
},
|
||||
Template: "{{.Ago}} {{.Icon}}",
|
||||
ExpectedString: "6h \uf5a2",
|
||||
ExpectedEnabled: true,
|
||||
},
|
||||
{
|
||||
Case: "Run 100",
|
||||
JSONResponse: `
|
||||
[{"type":"Run","start_date":"` + fourDaysAgo + `","name":"Sesongens første på tjukkas","distance":16144.0,"moving_time":7665}]`,
|
||||
StravaData: []*StravaData{
|
||||
{
|
||||
Type: "Run",
|
||||
StartDate: fourDaysAgo,
|
||||
Name: "Sesongens første på tjukkas",
|
||||
Distance: 16144.0,
|
||||
},
|
||||
},
|
||||
Template: "{{.Ago}} {{.Icon}}",
|
||||
ExpectedString: "4d \ufc0c",
|
||||
ExpectedEnabled: true,
|
||||
},
|
||||
{
|
||||
Case: "Error in retrieving data",
|
||||
JSONResponse: "nonsense",
|
||||
Error: errors.New("Something went wrong"),
|
||||
APIError: errors.New("Something went wrong"),
|
||||
ExpectedEnabled: false,
|
||||
},
|
||||
{
|
||||
Case: "Empty array",
|
||||
JSONResponse: "[]",
|
||||
ExpectedEnabled: false,
|
||||
},
|
||||
{
|
||||
Case: "Run from cache",
|
||||
JSONResponse: `
|
||||
[{"type":"Run","start_date":"` + fourDaysAgo + `","name":"Sesongens første på tjukkas","distance":16144.0,"moving_time":7665}]`,
|
||||
Template: "{{.Ago}} {{.Icon}}",
|
||||
ExpectedString: "4d \ufc0c",
|
||||
StravaData: []*StravaData{},
|
||||
ExpectedString: noActivitiesFound,
|
||||
ExpectedEnabled: true,
|
||||
CacheTimeout: 10,
|
||||
},
|
||||
{
|
||||
Case: "Run from not found cache",
|
||||
JSONResponse: `
|
||||
[{"type":"Run","start_date":"` + fourDaysAgo + `","name":"Morning ride","distance":16144.0,"moving_time":7665}]`,
|
||||
Template: "{{.Ago}} {{.Icon}} {{.Name}} {{.Hours}}h ago",
|
||||
ExpectedString: "4d \ufc0c Morning ride 100h ago",
|
||||
ExpectedEnabled: true,
|
||||
CacheTimeout: 10,
|
||||
CacheFoundFail: true,
|
||||
},
|
||||
{
|
||||
Case: "Error parsing response",
|
||||
JSONResponse: `
|
||||
4tffgt4e4567`,
|
||||
Template: "{{.Ago}}{{.Icon}}",
|
||||
ExpectedString: "50",
|
||||
ExpectedEnabled: false,
|
||||
CacheTimeout: 10,
|
||||
},
|
||||
{
|
||||
Case: "Faulty template",
|
||||
JSONResponse: `
|
||||
[{"sgv":50,"direction":"DoubleDown"}]`,
|
||||
StravaData: []*StravaData{
|
||||
{
|
||||
Type: "Run",
|
||||
StartDate: fourDaysAgo,
|
||||
Name: "Sesongens første på tjukkas",
|
||||
Distance: 16144.0,
|
||||
},
|
||||
},
|
||||
Template: "{{.Ago}}{{.Burp}}",
|
||||
ExpectedString: template.IncorrectTemplate,
|
||||
ExpectedEnabled: true,
|
||||
CacheTimeout: 10,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
env := &mock.MockedEnvironment{}
|
||||
url := "https://www.strava.com/api/v3/athlete/activities?page=1&per_page=1"
|
||||
tokenURL := fmt.Sprintf("https://ohmyposh.dev/api/refresh?segment=strava&token=%s", tc.TokenRefreshToken)
|
||||
var props properties.Map = map[properties.Property]interface{}{
|
||||
CacheTimeout: tc.CacheTimeout,
|
||||
}
|
||||
cache := &mock.MockedCache{}
|
||||
cache.On("Get", url).Return(tc.JSONResponse, !tc.CacheFoundFail)
|
||||
cache.On("Set", url, tc.JSONResponse, tc.CacheTimeout).Return()
|
||||
api := &mockedStravaAPI{}
|
||||
api.On("GetActivities").Return(tc.StravaData, tc.APIError)
|
||||
|
||||
cache.On("Get", StravaAccessToken).Return(tc.AccessToken, !tc.AccessTokenCacheFoundFail)
|
||||
cache.On("Get", StravaRefreshToken).Return(tc.RefreshToken, !tc.RefreshTokenCacheFoundFail)
|
||||
|
||||
cache.On("Set", StravaRefreshToken, "NEW_REFRESHTOKEN", 2*525960)
|
||||
cache.On("Set", StravaAccessToken, "NEW_ACCESSTOKEN", 20)
|
||||
|
||||
env.On("HTTPRequest", url).Return([]byte(tc.JSONResponse), tc.Error)
|
||||
env.On("HTTPRequest", tokenURL).Return([]byte(tc.TokenResponse), tc.Error)
|
||||
env.On("Cache").Return(cache)
|
||||
|
||||
if tc.InitialAccessToken != "" {
|
||||
props[properties.AccessToken] = tc.InitialAccessToken
|
||||
}
|
||||
if tc.InitialRefreshToken != "" {
|
||||
props[properties.RefreshToken] = tc.InitialRefreshToken
|
||||
strava := &Strava{
|
||||
api: api,
|
||||
props: &properties.Map{},
|
||||
}
|
||||
|
||||
ns := &Strava{
|
||||
props: props,
|
||||
env: env,
|
||||
}
|
||||
|
||||
if tc.TokenTest {
|
||||
// continue
|
||||
at, err := ns.getAccessToken()
|
||||
if err != nil {
|
||||
if authErr, ok := err.(*AuthError); ok {
|
||||
assert.Equal(t, tc.AuthDebugMsg, authErr.Error(), tc.Case)
|
||||
} else {
|
||||
assert.Equal(t, tc.Error, err, tc.Case)
|
||||
}
|
||||
} else {
|
||||
assert.Equal(t, tc.AccessToken, at, tc.Case)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
enabled := ns.Enabled()
|
||||
enabled := strava.Enabled()
|
||||
assert.Equal(t, tc.ExpectedEnabled, enabled, tc.Case)
|
||||
if !enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if tc.Template == "" {
|
||||
tc.Template = ns.Template()
|
||||
tc.Template = strava.Template()
|
||||
}
|
||||
var got = renderTemplate(env, tc.Template, ns)
|
||||
|
||||
var got = renderTemplate(&mock.MockedEnvironment{}, tc.Template, strava)
|
||||
assert.Equal(t, tc.ExpectedString, got, tc.Case)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ func (w *Wakatime) Enabled() bool {
|
|||
|
||||
func (w *Wakatime) setAPIData() error {
|
||||
url := w.props.GetString(URL, "")
|
||||
cacheTimeout := w.props.GetInt(CacheTimeout, DefaultCacheTimeout)
|
||||
cacheTimeout := w.props.GetInt(properties.CacheTimeout, properties.DefaultCacheTimeout)
|
||||
if cacheTimeout > 0 {
|
||||
// check if data stored in cache
|
||||
if val, found := w.env.Cache().Get(url); found {
|
||||
|
@ -47,7 +47,7 @@ func (w *Wakatime) setAPIData() error {
|
|||
}
|
||||
}
|
||||
|
||||
httpTimeout := w.props.GetInt(HTTPTimeout, DefaultHTTPTimeout)
|
||||
httpTimeout := w.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
|
||||
|
||||
body, err := w.env.HTTPRequest(url, httpTimeout)
|
||||
if err != nil {
|
||||
|
|
|
@ -82,9 +82,9 @@ func TestWTTrackedTime(t *testing.T) {
|
|||
|
||||
w := &Wakatime{
|
||||
props: properties.Map{
|
||||
APIKey: "key",
|
||||
CacheTimeout: tc.CacheTimeout,
|
||||
URL: FAKEAPIURL,
|
||||
APIKey: "key",
|
||||
properties.CacheTimeout: tc.CacheTimeout,
|
||||
URL: FAKEAPIURL,
|
||||
},
|
||||
env: env,
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ type track struct {
|
|||
func (y *Ytm) setStatus() error {
|
||||
// https://github.com/ytmdesktop/ytmdesktop/wiki/Remote-Control-API
|
||||
url := y.props.GetString(APIURL, "http://127.0.0.1:9863")
|
||||
httpTimeout := y.props.GetInt(APIURL, DefaultHTTPTimeout)
|
||||
httpTimeout := y.props.GetInt(APIURL, properties.DefaultHTTPTimeout)
|
||||
body, err := y.env.HTTPRequest(url+"/query", httpTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -47,7 +47,7 @@ if that color is visible against any of your backgrounds.
|
|||
"{{ if and (lt .Hours 100) (gt .Hours 50) }}#343a40{{ end }}",
|
||||
"{{ if lt .Hours 50 }}#FFFFFF{{ end }}"
|
||||
],
|
||||
"template": "{{.Name}} {{.Ago}} {{.Icon}}",
|
||||
"template": " {{.Name}} {{.Ago}} {{.Icon}} ",
|
||||
"properties": {
|
||||
"access_token":"11111111111111111",
|
||||
"refresh_token":"1111111111111111",
|
||||
|
|
Loading…
Reference in a new issue