refactor: extract OAuth logic

This commit is contained in:
Jan De Dobbeleer 2022-07-15 13:24:56 +02:00 committed by Jan De Dobbeleer
parent a6e9a3561b
commit e5bf5db9c2
18 changed files with 460 additions and 322 deletions

146
src/http/oauth.go Normal file
View file

@ -0,0 +1,146 @@
package http
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"oh-my-posh/environment"
"oh-my-posh/properties"
)
const (
Timeout = "timeout"
InvalidRefreshToken = "invalid refresh token"
TokenRefreshFailed = "token refresh error"
DefaultRefreshToken = "111111111111111111111111111111"
)
type tokenExchange struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
}
type OAuthError struct {
message string
}
func (a *OAuthError) Error() string {
return a.message
}
type OAuth struct {
Props properties.Properties
Env environment.Environment
AccessTokenKey string
RefreshTokenKey string
SegmentName string
}
func (o *OAuth) error(err error) {
o.Env.Log(environment.Error, "OAuth", err.Error())
}
func (o *OAuth) getAccessToken() (string, error) {
// get directly from cache
if acccessToken, OK := o.Env.Cache().Get(o.AccessTokenKey); OK {
return acccessToken, nil
}
// use cached refresh token to get new access token
if refreshToken, OK := o.Env.Cache().Get(o.RefreshTokenKey); OK {
if acccessToken, err := o.refreshToken(refreshToken); err == nil {
return acccessToken, nil
}
}
// use initial refresh token from property
refreshToken := o.Props.GetString(properties.RefreshToken, "")
// ignore an empty or default refresh token
if len(refreshToken) == 0 || refreshToken == DefaultRefreshToken {
return "", &OAuthError{
message: InvalidRefreshToken,
}
}
// no need to let the user provide access token, we'll always verify the refresh token
acccessToken, err := o.refreshToken(refreshToken)
return acccessToken, err
}
func (o *OAuth) refreshToken(refreshToken string) (string, error) {
httpTimeout := o.Props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
url := fmt.Sprintf("https://ohmyposh.dev/api/refresh?segment=%s&token=%s", o.SegmentName, refreshToken)
body, err := o.Env.HTTPRequest(url, httpTimeout)
if err != nil {
return "", &OAuthError{
// This might happen if /api was asleep. Assume the user will just retry
message: Timeout,
}
}
tokens := &tokenExchange{}
err = json.Unmarshal(body, &tokens)
if err != nil {
return "", &OAuthError{
message: TokenRefreshFailed,
}
}
// add tokens to cache
o.Env.Cache().Set(o.AccessTokenKey, tokens.AccessToken, tokens.ExpiresIn/60)
o.Env.Cache().Set(o.RefreshTokenKey, tokens.RefreshToken, 2*525960) // it should never expire unless revoked, default to 2 year
return tokens.AccessToken, nil
}
func OauthResult[a any](o *OAuth, url string) (a, error) {
var data a
getCacheValue := func(key string) (a, error) {
if val, found := o.Env.Cache().Get(key); found {
err := json.Unmarshal([]byte(val), &data)
if err != nil {
o.error(err)
return data, err
}
return data, nil
}
err := errors.New("no data in cache")
o.error(err)
return data, err
}
httpTimeout := o.Props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
// No need to check more than every 30 minutes by default
cacheTimeout := o.Props.GetInt(properties.CacheTimeout, 30)
if cacheTimeout > 0 {
if data, err := getCacheValue(url); err == nil {
return data, nil
}
}
accessToken, err := o.getAccessToken()
if err != nil {
return data, err
}
// add token to header for authentication
addAuthHeader := func(request *http.Request) {
request.Header.Add("Authorization", "Bearer "+accessToken)
}
body, err := o.Env.HTTPRequest(url, httpTimeout, addAuthHeader)
if err != nil {
o.error(err)
return data, err
}
err = json.Unmarshal(body, &data)
if err != nil {
o.error(err)
return data, err
}
if cacheTimeout > 0 {
o.Env.Cache().Set(url, string(body), cacheTimeout)
}
return data, nil
}

182
src/http/oauth_test.go Normal file
View file

@ -0,0 +1,182 @@
package http
import (
"fmt"
"oh-my-posh/environment"
"oh-my-posh/mock"
"oh-my-posh/properties"
"testing"
"github.com/stretchr/testify/assert"
mock2 "github.com/stretchr/testify/mock"
)
type data struct {
Hello string `json:"hello"`
}
func TestOauthResult(t *testing.T) {
accessTokenKey := "test_access_token"
refreshTokenKey := "test_refresh_token"
tokenResponse := `{ "access_token":"NEW_ACCESSTOKEN","refresh_token":"NEW_REFRESHTOKEN", "expires_in":1234 }`
jsonResponse := `{ "hello":"world" }`
successData := &data{Hello: "world"}
cases := []struct {
Case string
// tokens
AccessToken string
RefreshToken string
TokenResponse string
// API response
JSONResponse string
// Cache
CacheJSONResponse string
CacheTimeout int
RefreshTokenFromCache bool
AccessTokenFromCache bool
ResponseCacheMiss bool
// Errors
Error error
// Validations
ExpectedErrorMessage string
ExpectedData *data
}{
{
Case: "No initial tokens",
ExpectedErrorMessage: InvalidRefreshToken,
},
{
Case: "Use config tokens",
AccessToken: "INITIAL_ACCESSTOKEN",
RefreshToken: "INITIAL_REFRESHTOKEN",
TokenResponse: tokenResponse,
JSONResponse: jsonResponse,
ExpectedData: successData,
},
{
Case: "Access token from cache",
AccessToken: "ACCESSTOKEN",
AccessTokenFromCache: true,
JSONResponse: jsonResponse,
ExpectedData: successData,
},
{
Case: "Refresh token from cache",
RefreshToken: "REFRESH_TOKEN",
RefreshTokenFromCache: true,
JSONResponse: jsonResponse,
TokenResponse: tokenResponse,
ExpectedData: successData,
},
{
Case: "Refresh token from cache, success",
RefreshToken: "REFRESH_TOKEN",
RefreshTokenFromCache: true,
JSONResponse: jsonResponse,
TokenResponse: tokenResponse,
ExpectedData: successData,
},
{
Case: "Refresh API error",
RefreshToken: "REFRESH_TOKEN",
RefreshTokenFromCache: true,
Error: fmt.Errorf("API error"),
ExpectedErrorMessage: Timeout,
},
{
Case: "Refresh API parse error",
RefreshToken: "REFRESH_TOKEN",
RefreshTokenFromCache: true,
TokenResponse: "INVALID_JSON",
ExpectedErrorMessage: TokenRefreshFailed,
},
{
Case: "Default config token",
RefreshToken: DefaultRefreshToken,
ExpectedErrorMessage: InvalidRefreshToken,
},
{
Case: "Cache data",
CacheTimeout: 60,
CacheJSONResponse: jsonResponse,
ExpectedData: successData,
},
{
Case: "Cache data, invalid data",
CacheTimeout: 60,
RefreshToken: "REFRESH_TOKEN",
CacheJSONResponse: "ERR",
TokenResponse: tokenResponse,
JSONResponse: jsonResponse,
ExpectedData: successData,
},
{
Case: "Cache data, no cache",
CacheTimeout: 60,
RefreshToken: "REFRESH_TOKEN",
ResponseCacheMiss: true,
TokenResponse: tokenResponse,
JSONResponse: jsonResponse,
ExpectedData: successData,
},
{
Case: "API body failure",
AccessToken: "ACCESSTOKEN",
AccessTokenFromCache: true,
ResponseCacheMiss: true,
JSONResponse: "ERR",
ExpectedErrorMessage: "invalid character 'E' looking for beginning of value",
},
{
Case: "API request failure",
AccessToken: "ACCESSTOKEN",
AccessTokenFromCache: true,
ResponseCacheMiss: true,
JSONResponse: "ERR",
Error: fmt.Errorf("no response"),
ExpectedErrorMessage: "no response",
},
}
for _, tc := range cases {
url := "https://www.strava.com/api/v3/athlete/activities?page=1&per_page=1"
tokenURL := fmt.Sprintf("https://ohmyposh.dev/api/refresh?segment=test&token=%s", tc.RefreshToken)
var props properties.Map = map[properties.Property]interface{}{
properties.CacheTimeout: tc.CacheTimeout,
properties.AccessToken: tc.AccessToken,
properties.RefreshToken: tc.RefreshToken,
}
cache := &mock.MockedCache{}
cache.On("Get", url).Return(tc.CacheJSONResponse, !tc.ResponseCacheMiss)
cache.On("Get", accessTokenKey).Return(tc.AccessToken, tc.AccessTokenFromCache)
cache.On("Get", refreshTokenKey).Return(tc.RefreshToken, tc.RefreshTokenFromCache)
cache.On("Set", mock2.Anything, mock2.Anything, mock2.Anything)
env := &mock.MockedEnvironment{}
env.On("Cache").Return(cache)
env.On("HTTPRequest", url).Return([]byte(tc.JSONResponse), tc.Error)
env.On("HTTPRequest", tokenURL).Return([]byte(tc.TokenResponse), tc.Error)
env.On("Log", environment.Error, "OAuth", mock2.Anything).Return()
oauth := &OAuth{
Props: props,
Env: env,
AccessTokenKey: accessTokenKey,
RefreshTokenKey: refreshTokenKey,
SegmentName: "test",
}
got, err := OauthResult[*data](oauth, url)
assert.Equal(t, tc.ExpectedData, got, tc.Case)
if len(tc.ExpectedErrorMessage) == 0 {
assert.Nil(t, err, tc.Case)
} else {
assert.Equal(t, tc.ExpectedErrorMessage, err.Error(), tc.Case)
}
}
}

View file

@ -43,6 +43,14 @@ const (
AccessToken Property = "access_token" AccessToken Property = "access_token"
// RefreshToken is the refresh token to use for an API // RefreshToken is the refresh token to use for an API
RefreshToken Property = "refresh_token" RefreshToken Property = "refresh_token"
// HTTPTimeout timeout used when executing http request
HTTPTimeout Property = "http_timeout"
// DefaultHTTPTimeout default timeout used when executing http request
DefaultHTTPTimeout = 20
// DefaultCacheTimeout default timeout used when caching data
DefaultCacheTimeout = 10
// CacheTimeout cache timeout
CacheTimeout Property = "cache_timeout"
) )
type Map map[Property]interface{} type Map map[Property]interface{}

View file

@ -245,7 +245,7 @@ func (bf *Brewfather) getResult() (*Batch, error) {
batchURL := fmt.Sprintf("https://api.brewfather.app/v1/batches/%s", batchID) batchURL := fmt.Sprintf("https://api.brewfather.app/v1/batches/%s", batchID)
batchReadingsURL := fmt.Sprintf("https://api.brewfather.app/v1/batches/%s/readings", batchID) batchReadingsURL := fmt.Sprintf("https://api.brewfather.app/v1/batches/%s/readings", batchID)
httpTimeout := bf.props.GetInt(HTTPTimeout, DefaultHTTPTimeout) httpTimeout := bf.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
cacheTimeout := bf.props.GetInt(BFCacheTimeout, 5) cacheTimeout := bf.props.GetInt(BFCacheTimeout, 5)
if cacheTimeout > 0 { if cacheTimeout > 0 {

View file

@ -140,10 +140,10 @@ func TestBrewfatherSegment(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
env := &mock.MockedEnvironment{} env := &mock.MockedEnvironment{}
props := properties.Map{ props := properties.Map{
CacheTimeout: tc.CacheTimeout, properties.CacheTimeout: tc.CacheTimeout,
BFBatchID: BFFakeBatchID, BFBatchID: BFFakeBatchID,
BFAPIKey: "FAKE", BFAPIKey: "FAKE",
BFUserID: "FAKE", BFUserID: "FAKE",
} }
cache := &mock.MockedCache{} cache := &mock.MockedCache{}

View file

@ -30,7 +30,7 @@ func (i *IPify) Enabled() bool {
} }
func (i *IPify) getResult() (string, error) { func (i *IPify) getResult() (string, error) {
cacheTimeout := i.props.GetInt(CacheTimeout, DefaultCacheTimeout) cacheTimeout := i.props.GetInt(properties.CacheTimeout, properties.DefaultCacheTimeout)
url := i.props.GetString(IpifyURL, "https://api.ipify.org") url := i.props.GetString(IpifyURL, "https://api.ipify.org")
@ -43,7 +43,7 @@ func (i *IPify) getResult() (string, error) {
} }
} }
httpTimeout := i.props.GetInt(HTTPTimeout, DefaultHTTPTimeout) httpTimeout := i.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
body, err := i.env.HTTPRequest(url, httpTimeout) body, err := i.env.HTTPRequest(url, httpTimeout)
if err != nil { if err != nil {

View file

@ -46,7 +46,7 @@ func TestIpifySegment(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
env := &mock.MockedEnvironment{} env := &mock.MockedEnvironment{}
props := properties.Map{ props := properties.Map{
CacheTimeout: 0, properties.CacheTimeout: 0,
} }
env.On("HTTPRequest", IPIFYAPIURL).Return([]byte(tc.Response), tc.Error) env.On("HTTPRequest", IPIFYAPIURL).Return([]byte(tc.Response), tc.Error)

View file

@ -107,7 +107,7 @@ func (ns *Nightscout) getResult() (*NightscoutData, error) {
} }
url := ns.props.GetString(URL, "") url := ns.props.GetString(URL, "")
httpTimeout := ns.props.GetInt(HTTPTimeout, DefaultHTTPTimeout) httpTimeout := ns.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
// natural and understood NS timeout is 5, anything else is unusual // natural and understood NS timeout is 5, anything else is unusual
cacheTimeout := ns.props.GetInt(NSCacheTimeout, 5) cacheTimeout := ns.props.GetInt(NSCacheTimeout, 5)

View file

@ -134,8 +134,8 @@ func TestNSSegment(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
env := &mock.MockedEnvironment{} env := &mock.MockedEnvironment{}
props := properties.Map{ props := properties.Map{
CacheTimeout: tc.CacheTimeout, properties.CacheTimeout: tc.CacheTimeout,
URL: "FAKE", URL: "FAKE",
} }
cache := &mock.MockedCache{} cache := &mock.MockedCache{}

View file

@ -26,8 +26,6 @@ const (
Location properties.Property = "location" Location properties.Property = "location"
// Units openweathermap units // Units openweathermap units
Units properties.Property = "units" Units properties.Property = "units"
// CacheTimeout cache timeout
CacheTimeout properties.Property = "cache_timeout"
// CacheKeyResponse key used when caching the response // CacheKeyResponse key used when caching the response
CacheKeyResponse string = "owm_response" CacheKeyResponse string = "owm_response"
// CacheKeyURL key used when caching the url responsible for the response // CacheKeyURL key used when caching the url responsible for the response
@ -58,7 +56,7 @@ func (d *Owm) Template() string {
} }
func (d *Owm) getResult() (*owmDataResponse, error) { func (d *Owm) getResult() (*owmDataResponse, error) {
cacheTimeout := d.props.GetInt(CacheTimeout, DefaultCacheTimeout) cacheTimeout := d.props.GetInt(properties.CacheTimeout, properties.DefaultCacheTimeout)
response := new(owmDataResponse) response := new(owmDataResponse)
if cacheTimeout > 0 { if cacheTimeout > 0 {
// check if data stored in cache // check if data stored in cache
@ -77,7 +75,7 @@ func (d *Owm) getResult() (*owmDataResponse, error) {
apikey := d.props.GetString(APIKey, ".") apikey := d.props.GetString(APIKey, ".")
location := d.props.GetString(Location, "De Bilt,NL") location := d.props.GetString(Location, "De Bilt,NL")
units := d.props.GetString(Units, "standard") units := d.props.GetString(Units, "standard")
httpTimeout := d.props.GetInt(HTTPTimeout, DefaultHTTPTimeout) httpTimeout := d.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
d.URL = fmt.Sprintf("http://api.openweathermap.org/data/2.5/weather?q=%s&units=%s&appid=%s", location, units, apikey) d.URL = fmt.Sprintf("http://api.openweathermap.org/data/2.5/weather?q=%s&units=%s&appid=%s", location, units, apikey)
body, err := d.env.HTTPRequest(d.URL, httpTimeout) body, err := d.env.HTTPRequest(d.URL, httpTimeout)

View file

@ -54,10 +54,10 @@ func TestOWMSegmentSingle(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
env := &mock.MockedEnvironment{} env := &mock.MockedEnvironment{}
props := properties.Map{ props := properties.Map{
APIKey: "key", APIKey: "key",
Location: "AMSTERDAM,NL", Location: "AMSTERDAM,NL",
Units: "metric", Units: "metric",
CacheTimeout: 0, properties.CacheTimeout: 0,
} }
env.On("HTTPRequest", OWMAPIURL).Return([]byte(tc.JSONResponse), tc.Error) env.On("HTTPRequest", OWMAPIURL).Return([]byte(tc.JSONResponse), tc.Error)
@ -189,10 +189,10 @@ func TestOWMSegmentIcons(t *testing.T) {
o := &Owm{ o := &Owm{
props: properties.Map{ props: properties.Map{
APIKey: "key", APIKey: "key",
Location: "AMSTERDAM,NL", Location: "AMSTERDAM,NL",
Units: "metric", Units: "metric",
CacheTimeout: 0, properties.CacheTimeout: 0,
}, },
env: env, env: env,
} }
@ -212,10 +212,10 @@ func TestOWMSegmentIcons(t *testing.T) {
o := &Owm{ o := &Owm{
props: properties.Map{ props: properties.Map{
APIKey: "key", APIKey: "key",
Location: "AMSTERDAM,NL", Location: "AMSTERDAM,NL",
Units: "metric", Units: "metric",
CacheTimeout: 0, properties.CacheTimeout: 0,
}, },
env: env, env: env,
} }

View file

@ -1,12 +0,0 @@
package segments
import "oh-my-posh/properties"
const (
// HTTPTimeout timeout used when executing http request
HTTPTimeout properties.Property = "http_timeout"
// DefaultHTTPTimeout default timeout used when executing http request
DefaultHTTPTimeout = 20
// DefaultCacheTimeout default timeout used when caching data
DefaultCacheTimeout = 10
)

View file

@ -1,20 +1,31 @@
package segments package segments
import ( import (
"encoding/json"
"errors"
"fmt" "fmt"
"math" "math"
"net/http"
"oh-my-posh/environment" "oh-my-posh/environment"
"oh-my-posh/http"
"oh-my-posh/properties" "oh-my-posh/properties"
"time" "time"
) )
// StravaAPI is a wrapper around http.Oauth
type StravaAPI interface {
GetActivities() ([]*StravaData, error)
}
type stravaAPI struct {
http.OAuth
}
func (s *stravaAPI) GetActivities() ([]*StravaData, error) {
url := "https://www.strava.com/api/v3/athlete/activities?page=1&per_page=1"
return http.OauthResult[[]*StravaData](&s.OAuth, url)
}
// segment struct, makes templating easier // segment struct, makes templating easier
type Strava struct { type Strava struct {
props properties.Properties props properties.Properties
env environment.Environment
StravaData StravaData
Icon string Icon string
@ -23,6 +34,8 @@ type Strava struct {
Authenticate bool Authenticate bool
Error string Error string
URL string URL string
api StravaAPI
} }
const ( const (
@ -32,12 +45,10 @@ const (
WorkOutIcon properties.Property = "workout_icon" WorkOutIcon properties.Property = "workout_icon"
UnknownActivityIcon properties.Property = "unknown_activity_icon" UnknownActivityIcon properties.Property = "unknown_activity_icon"
StravaAccessToken = "strava_access_token" StravaAccessTokenKey = "strava_access_token"
StravaRefreshToken = "strava_refresh_token" StravaRefreshTokenKey = "strava_refresh_token"
Timeout = "timeout" noActivitiesFound = "No activities found"
InvalidRefreshToken = "invalid refresh token"
TokenRefreshFailed = "token refresh error"
) )
// StravaData struct contains the API data // StravaData struct contains the API data
@ -56,36 +67,26 @@ type StravaData struct {
KudosCount int `json:"kudos_count"` KudosCount int `json:"kudos_count"`
} }
type TokenExchange struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
}
type AuthError struct {
message string
}
func (a *AuthError) Error() string {
return a.message
}
func (s *Strava) Template() string { func (s *Strava) Template() string {
return " {{ if .Error }}{{ .Error }}{{ else }}{{ .Ago }}{{ end }} " return " {{ if .Error }}{{ .Error }}{{ else }}{{ .Ago }}{{ end }} "
} }
func (s *Strava) Enabled() bool { func (s *Strava) Enabled() bool {
data, err := s.getResult() data, err := s.api.GetActivities()
if err == nil { if err == nil && len(data) > 0 {
s.StravaData = *data s.StravaData = *data[0]
s.Icon = s.getActivityIcon() s.Icon = s.getActivityIcon()
s.Hours = s.getHours() s.Hours = s.getHours()
s.Ago = s.getAgo() s.Ago = s.getAgo()
s.URL = fmt.Sprintf("https://www.strava.com/activities/%d", s.ID) s.URL = fmt.Sprintf("https://www.strava.com/activities/%d", s.ID)
return true return true
} }
if _, s.Authenticate = err.(*AuthError); s.Authenticate { if err == nil && len(data) == 0 {
s.Error = err.(*AuthError).Error() s.Error = noActivitiesFound
return true
}
if _, s.Authenticate = err.(*http.OAuthError); s.Authenticate {
s.Error = err.(*http.OAuthError).Error()
return true return true
} }
return false return false
@ -124,115 +125,16 @@ func (s *Strava) getActivityIcon() string {
return s.props.GetString(UnknownActivityIcon, "\ue213") return s.props.GetString(UnknownActivityIcon, "\ue213")
} }
func (s *Strava) getAccessToken() (string, error) {
// get directly from cache
if acccessToken, OK := s.env.Cache().Get(StravaAccessToken); OK {
return acccessToken, nil
}
// use cached refresh token to get new access token
if refreshToken, OK := s.env.Cache().Get(StravaRefreshToken); OK {
if acccessToken, err := s.refreshToken(refreshToken); err == nil {
return acccessToken, nil
}
}
// use initial refresh token from property
refreshToken := s.props.GetString(properties.RefreshToken, "")
// ignore an empty or default refresh token
if len(refreshToken) == 0 || refreshToken == "111111111111111111111111111111" {
return "", &AuthError{
message: InvalidRefreshToken,
}
}
// no need to let the user provide access token, we'll always verify the refresh token
acccessToken, err := s.refreshToken(refreshToken)
return acccessToken, err
}
func (s *Strava) refreshToken(refreshToken string) (string, error) {
httpTimeout := s.props.GetInt(HTTPTimeout, DefaultHTTPTimeout)
url := fmt.Sprintf("https://ohmyposh.dev/api/refresh?segment=strava&token=%s", refreshToken)
body, err := s.env.HTTPRequest(url, httpTimeout)
if err != nil {
return "", &AuthError{
// This might happen if /api was asleep. Assume the user will just retry
message: Timeout,
}
}
tokens := &TokenExchange{}
err = json.Unmarshal(body, &tokens)
if err != nil {
return "", &AuthError{
message: TokenRefreshFailed,
}
}
// add tokens to cache
s.env.Cache().Set(StravaAccessToken, tokens.AccessToken, tokens.ExpiresIn/60)
s.env.Cache().Set(StravaRefreshToken, tokens.RefreshToken, 2*525960) // it should never expire unless revoked, default to 2 year
return tokens.AccessToken, nil
}
func (s *Strava) getResult() (*StravaData, error) {
parseSingleElement := func(data []byte) (*StravaData, error) {
var result []*StravaData
err := json.Unmarshal(data, &result)
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, errors.New("no elements in the array")
}
return result[0], nil
}
getCacheValue := func(key string) (*StravaData, error) {
val, found := s.env.Cache().Get(key)
// we got something from the cache
if found {
if data, err := parseSingleElement([]byte(val)); err == nil {
return data, nil
}
}
return nil, errors.New("no data in cache")
}
// We only want the last activity
url := "https://www.strava.com/api/v3/athlete/activities?page=1&per_page=1"
httpTimeout := s.props.GetInt(HTTPTimeout, DefaultHTTPTimeout)
// No need to check more the every 30 min
cacheTimeout := s.props.GetInt(CacheTimeout, 30)
if cacheTimeout > 0 {
if data, err := getCacheValue(url); err == nil {
return data, nil
}
}
accessToken, err := s.getAccessToken()
if err != nil {
return nil, err
}
addAuthHeader := func(request *http.Request) {
request.Header.Add("Authorization", "Bearer "+accessToken)
}
body, err := s.env.HTTPRequest(url, httpTimeout, addAuthHeader)
if err != nil {
return nil, err
}
var arr []*StravaData
err = json.Unmarshal(body, &arr)
if err != nil {
return nil, err
}
data, err := parseSingleElement(body)
if err != nil {
return nil, err
}
if cacheTimeout > 0 {
// persist new sugars in cache
s.env.Cache().Set(url, string(body), cacheTimeout)
}
return data, nil
}
func (s *Strava) Init(props properties.Properties, env environment.Environment) { func (s *Strava) Init(props properties.Properties, env environment.Environment) {
s.props = props s.props = props
s.env = env
s.api = &stravaAPI{
OAuth: http.OAuth{
Props: props,
Env: env,
AccessTokenKey: StravaAccessTokenKey,
RefreshTokenKey: StravaRefreshTokenKey,
SegmentName: "strava",
},
}
} }

View file

@ -2,7 +2,6 @@ package segments
import ( import (
"errors" "errors"
"fmt"
"oh-my-posh/mock" "oh-my-posh/mock"
"oh-my-posh/properties" "oh-my-posh/properties"
"oh-my-posh/template" "oh-my-posh/template"
@ -10,192 +9,107 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
mock2 "github.com/stretchr/testify/mock"
) )
type mockedStravaAPI struct {
mock2.Mock
}
func (s *mockedStravaAPI) GetActivities() ([]*StravaData, error) {
args := s.Called()
return args.Get(0).([]*StravaData), args.Error(1)
}
func TestStravaSegment(t *testing.T) { func TestStravaSegment(t *testing.T) {
h, _ := time.ParseDuration("6h") h, _ := time.ParseDuration("6h")
sixHoursAgo := time.Now().Add(-h).Format(time.RFC3339) sixHoursAgo := time.Now().Add(-h)
h, _ = time.ParseDuration("100h") h, _ = time.ParseDuration("100h")
fourDaysAgo := time.Now().Add(-h).Format(time.RFC3339) fourDaysAgo := time.Now().Add(-h)
cases := []struct { cases := []struct {
Case string Case string
JSONResponse string ExpectedString string
AccessToken string ExpectedEnabled bool
RefreshToken string Template string
AccessTokenCacheFoundFail bool APIError error
RefreshTokenCacheFoundFail bool StravaData []*StravaData
InitialAccessToken string
InitialRefreshToken string
TokenRefreshToken string
TokenResponse string
TokenTest bool
ExpectedString string
ExpectedEnabled bool
CacheTimeout int
CacheFoundFail bool
Template string
Error error
AuthDebugMsg string
}{ }{
{
Case: "No initial tokens",
InitialAccessToken: "",
AccessTokenCacheFoundFail: true,
RefreshTokenCacheFoundFail: true,
TokenTest: true,
AuthDebugMsg: "invalid refresh token",
},
{
Case: "Use initial tokens",
AccessToken: "NEW_ACCESSTOKEN",
InitialAccessToken: "INITIAL ACCESSTOKEN",
InitialRefreshToken: "INITIAL REFRESHTOKEN",
TokenRefreshToken: "INITIAL REFRESHTOKEN",
TokenResponse: `{ "access_token":"NEW_ACCESSTOKEN","refresh_token":"NEW_REFRESHTOKEN", "expires_in":1234 }`,
AccessTokenCacheFoundFail: true,
RefreshTokenCacheFoundFail: true,
TokenTest: true,
},
{
Case: "Access token from cache",
AccessToken: "ACCESSTOKEN",
TokenTest: true,
},
{
Case: "Refresh token from cache",
AccessTokenCacheFoundFail: true,
RefreshTokenCacheFoundFail: false,
RefreshToken: "REFRESHTOKEN",
TokenRefreshToken: "REFRESHTOKEN",
TokenTest: true,
AuthDebugMsg: "invalid refresh token",
},
{ {
Case: "Ride 6", Case: "Ride 6",
JSONResponse: ` StravaData: []*StravaData{
[{"type":"Ride","start_date":"` + sixHoursAgo + `","name":"Sesongens første på tjukkas","distance":16144.0}]`, {
Type: "Ride",
StartDate: sixHoursAgo,
Name: "Sesongens første på tjukkas",
Distance: 16144.0,
},
},
Template: "{{.Ago}} {{.Icon}}", Template: "{{.Ago}} {{.Icon}}",
ExpectedString: "6h \uf5a2", ExpectedString: "6h \uf5a2",
ExpectedEnabled: true, ExpectedEnabled: true,
}, },
{ {
Case: "Run 100", Case: "Run 100",
JSONResponse: ` StravaData: []*StravaData{
[{"type":"Run","start_date":"` + fourDaysAgo + `","name":"Sesongens første på tjukkas","distance":16144.0,"moving_time":7665}]`, {
Type: "Run",
StartDate: fourDaysAgo,
Name: "Sesongens første på tjukkas",
Distance: 16144.0,
},
},
Template: "{{.Ago}} {{.Icon}}", Template: "{{.Ago}} {{.Icon}}",
ExpectedString: "4d \ufc0c", ExpectedString: "4d \ufc0c",
ExpectedEnabled: true, ExpectedEnabled: true,
}, },
{ {
Case: "Error in retrieving data", Case: "Error in retrieving data",
JSONResponse: "nonsense", APIError: errors.New("Something went wrong"),
Error: errors.New("Something went wrong"),
ExpectedEnabled: false, ExpectedEnabled: false,
}, },
{ {
Case: "Empty array", Case: "Empty array",
JSONResponse: "[]", StravaData: []*StravaData{},
ExpectedEnabled: false, ExpectedString: noActivitiesFound,
},
{
Case: "Run from cache",
JSONResponse: `
[{"type":"Run","start_date":"` + fourDaysAgo + `","name":"Sesongens første på tjukkas","distance":16144.0,"moving_time":7665}]`,
Template: "{{.Ago}} {{.Icon}}",
ExpectedString: "4d \ufc0c",
ExpectedEnabled: true, ExpectedEnabled: true,
CacheTimeout: 10,
},
{
Case: "Run from not found cache",
JSONResponse: `
[{"type":"Run","start_date":"` + fourDaysAgo + `","name":"Morning ride","distance":16144.0,"moving_time":7665}]`,
Template: "{{.Ago}} {{.Icon}} {{.Name}} {{.Hours}}h ago",
ExpectedString: "4d \ufc0c Morning ride 100h ago",
ExpectedEnabled: true,
CacheTimeout: 10,
CacheFoundFail: true,
},
{
Case: "Error parsing response",
JSONResponse: `
4tffgt4e4567`,
Template: "{{.Ago}}{{.Icon}}",
ExpectedString: "50",
ExpectedEnabled: false,
CacheTimeout: 10,
}, },
{ {
Case: "Faulty template", Case: "Faulty template",
JSONResponse: ` StravaData: []*StravaData{
[{"sgv":50,"direction":"DoubleDown"}]`, {
Type: "Run",
StartDate: fourDaysAgo,
Name: "Sesongens første på tjukkas",
Distance: 16144.0,
},
},
Template: "{{.Ago}}{{.Burp}}", Template: "{{.Ago}}{{.Burp}}",
ExpectedString: template.IncorrectTemplate, ExpectedString: template.IncorrectTemplate,
ExpectedEnabled: true, ExpectedEnabled: true,
CacheTimeout: 10,
}, },
} }
for _, tc := range cases { for _, tc := range cases {
env := &mock.MockedEnvironment{} api := &mockedStravaAPI{}
url := "https://www.strava.com/api/v3/athlete/activities?page=1&per_page=1" api.On("GetActivities").Return(tc.StravaData, tc.APIError)
tokenURL := fmt.Sprintf("https://ohmyposh.dev/api/refresh?segment=strava&token=%s", tc.TokenRefreshToken)
var props properties.Map = map[properties.Property]interface{}{
CacheTimeout: tc.CacheTimeout,
}
cache := &mock.MockedCache{}
cache.On("Get", url).Return(tc.JSONResponse, !tc.CacheFoundFail)
cache.On("Set", url, tc.JSONResponse, tc.CacheTimeout).Return()
cache.On("Get", StravaAccessToken).Return(tc.AccessToken, !tc.AccessTokenCacheFoundFail) strava := &Strava{
cache.On("Get", StravaRefreshToken).Return(tc.RefreshToken, !tc.RefreshTokenCacheFoundFail) api: api,
props: &properties.Map{},
cache.On("Set", StravaRefreshToken, "NEW_REFRESHTOKEN", 2*525960)
cache.On("Set", StravaAccessToken, "NEW_ACCESSTOKEN", 20)
env.On("HTTPRequest", url).Return([]byte(tc.JSONResponse), tc.Error)
env.On("HTTPRequest", tokenURL).Return([]byte(tc.TokenResponse), tc.Error)
env.On("Cache").Return(cache)
if tc.InitialAccessToken != "" {
props[properties.AccessToken] = tc.InitialAccessToken
}
if tc.InitialRefreshToken != "" {
props[properties.RefreshToken] = tc.InitialRefreshToken
} }
ns := &Strava{ enabled := strava.Enabled()
props: props,
env: env,
}
if tc.TokenTest {
// continue
at, err := ns.getAccessToken()
if err != nil {
if authErr, ok := err.(*AuthError); ok {
assert.Equal(t, tc.AuthDebugMsg, authErr.Error(), tc.Case)
} else {
assert.Equal(t, tc.Error, err, tc.Case)
}
} else {
assert.Equal(t, tc.AccessToken, at, tc.Case)
}
continue
}
enabled := ns.Enabled()
assert.Equal(t, tc.ExpectedEnabled, enabled, tc.Case) assert.Equal(t, tc.ExpectedEnabled, enabled, tc.Case)
if !enabled { if !enabled {
continue continue
} }
if tc.Template == "" { if tc.Template == "" {
tc.Template = ns.Template() tc.Template = strava.Template()
} }
var got = renderTemplate(env, tc.Template, ns)
var got = renderTemplate(&mock.MockedEnvironment{}, tc.Template, strava)
assert.Equal(t, tc.ExpectedString, got, tc.Case) assert.Equal(t, tc.ExpectedString, got, tc.Case)
} }
} }

View file

@ -35,7 +35,7 @@ func (w *Wakatime) Enabled() bool {
func (w *Wakatime) setAPIData() error { func (w *Wakatime) setAPIData() error {
url := w.props.GetString(URL, "") url := w.props.GetString(URL, "")
cacheTimeout := w.props.GetInt(CacheTimeout, DefaultCacheTimeout) cacheTimeout := w.props.GetInt(properties.CacheTimeout, properties.DefaultCacheTimeout)
if cacheTimeout > 0 { if cacheTimeout > 0 {
// check if data stored in cache // check if data stored in cache
if val, found := w.env.Cache().Get(url); found { if val, found := w.env.Cache().Get(url); found {
@ -47,7 +47,7 @@ func (w *Wakatime) setAPIData() error {
} }
} }
httpTimeout := w.props.GetInt(HTTPTimeout, DefaultHTTPTimeout) httpTimeout := w.props.GetInt(properties.HTTPTimeout, properties.DefaultHTTPTimeout)
body, err := w.env.HTTPRequest(url, httpTimeout) body, err := w.env.HTTPRequest(url, httpTimeout)
if err != nil { if err != nil {

View file

@ -82,9 +82,9 @@ func TestWTTrackedTime(t *testing.T) {
w := &Wakatime{ w := &Wakatime{
props: properties.Map{ props: properties.Map{
APIKey: "key", APIKey: "key",
CacheTimeout: tc.CacheTimeout, properties.CacheTimeout: tc.CacheTimeout,
URL: FAKEAPIURL, URL: FAKEAPIURL,
}, },
env: env, env: env,
} }

View file

@ -67,7 +67,7 @@ type track struct {
func (y *Ytm) setStatus() error { func (y *Ytm) setStatus() error {
// https://github.com/ytmdesktop/ytmdesktop/wiki/Remote-Control-API // https://github.com/ytmdesktop/ytmdesktop/wiki/Remote-Control-API
url := y.props.GetString(APIURL, "http://127.0.0.1:9863") url := y.props.GetString(APIURL, "http://127.0.0.1:9863")
httpTimeout := y.props.GetInt(APIURL, DefaultHTTPTimeout) httpTimeout := y.props.GetInt(APIURL, properties.DefaultHTTPTimeout)
body, err := y.env.HTTPRequest(url+"/query", httpTimeout) body, err := y.env.HTTPRequest(url+"/query", httpTimeout)
if err != nil { if err != nil {
return err return err

View file

@ -47,7 +47,7 @@ if that color is visible against any of your backgrounds.
"{{ if and (lt .Hours 100) (gt .Hours 50) }}#343a40{{ end }}", "{{ if and (lt .Hours 100) (gt .Hours 50) }}#343a40{{ end }}",
"{{ if lt .Hours 50 }}#FFFFFF{{ end }}" "{{ if lt .Hours 50 }}#FFFFFF{{ end }}"
], ],
"template": "{{.Name}} {{.Ago}} {{.Icon}}", "template": " {{.Name}} {{.Ago}} {{.Icon}} ",
"properties": { "properties": {
"access_token":"11111111111111111", "access_token":"11111111111111111",
"refresh_token":"1111111111111111", "refresh_token":"1111111111111111",