feat(config): only download when changed

resolves #5176
This commit is contained in:
Jan De Dobbeleer 2024-06-27 21:34:42 +02:00 committed by Jan De Dobbeleer
parent 12a732d63a
commit 0449aa8a2d
8 changed files with 164 additions and 52 deletions

View file

@ -8,7 +8,7 @@ import (
"strings"
"time"
"github.com/jandedobbeleer/oh-my-posh/src/platform"
"github.com/jandedobbeleer/oh-my-posh/src/platform/net"
)
type codePoints map[uint64]uint64
@ -24,7 +24,7 @@ func getGlyphCodePoints() (codePoints, error) {
return codePoints, &ConnectionError{reason: err.Error()}
}
response, err := platform.Client.Do(request)
response, err := net.HTTPClient.Do(request)
if err != nil {
return codePoints, err
}

View file

@ -11,7 +11,7 @@ import (
"net/http"
"net/url"
"github.com/jandedobbeleer/oh-my-posh/src/platform"
"github.com/jandedobbeleer/oh-my-posh/src/platform/net"
)
func Download(fontPath string) ([]byte, error) {
@ -42,7 +42,7 @@ func getRemoteFile(location string) (data []byte, err error) {
if err != nil {
return nil, err
}
resp, err := platform.Client.Do(req)
resp, err := net.HTTPClient.Do(req)
if err != nil {
return
}

View file

@ -10,7 +10,7 @@ import (
"strings"
"time"
"github.com/jandedobbeleer/oh-my-posh/src/platform"
"github.com/jandedobbeleer/oh-my-posh/src/platform/net"
)
type release struct {
@ -57,7 +57,7 @@ func fetchFontAssets(repo string) ([]*Asset, error) {
}
req.Header.Add("Accept", "application/vnd.github.v3+json")
response, err := platform.Client.Do(req)
response, err := net.HTTPClient.Do(req)
if err != nil || response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get %s release", repo)
}

View file

@ -22,30 +22,27 @@ func Plain() {
plain = true
}
func Info(message string) {
if !enabled {
return
}
log.WriteString(message)
}
func Trace(start time.Time, args ...string) {
if !enabled {
return
}
elapsed := time.Since(start)
fn, _ := funcSpec()
header := fmt.Sprintf("%s(%s) - %s", fn, strings.Join(args, " "), Text(elapsed.String()).Yellow().Plain())
printLn(trace, header)
}
func Debug(message string) {
func Debug(message ...string) {
if !enabled {
return
}
fn, line := funcSpec()
header := fmt.Sprintf("%s:%d", fn, line)
printLn(debug, header, message)
printLn(debug, header, strings.Join(message, " "))
}
func Error(err error) {
@ -54,6 +51,7 @@ func Error(err error) {
}
fn, line := funcSpec()
header := fmt.Sprintf("%s:%d", fn, line)
printLn(bug, header, err.Error())
}
@ -66,11 +64,14 @@ func funcSpec() (string, int) {
if !OK {
return "", 0
}
fn := runtime.FuncForPC(pc).Name()
fn = fn[strings.LastIndex(fn, ".")+1:]
file = filepath.Base(file)
if strings.HasPrefix(fn, "func") {
return file, line
}
return fmt.Sprintf("%s:%s", file, fn), line
}

View file

@ -0,0 +1,118 @@
package config
import (
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/jandedobbeleer/oh-my-posh/src/log"
"github.com/jandedobbeleer/oh-my-posh/src/platform/net"
)
func Download(cachePath, url string) (string, error) {
defer log.Trace(time.Now(), cachePath, url)
configPath, shouldUpdate := shouldUpdate(cachePath, url)
if !shouldUpdate {
return configPath, nil
}
log.Debug("downloading config from ", url, " to ", configPath)
ctx, cncl := context.WithTimeout(context.Background(), time.Second*time.Duration(5))
defer cncl()
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
log.Error(err)
return "", err
}
response, err := net.HTTPClient.Do(request)
if err != nil {
log.Error(err)
return "", err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
err := fmt.Errorf("unexpected status code: %d", response.StatusCode)
log.Error(err)
return "", err
}
if len(configPath) == 0 {
configPath = formatConfigPath(url, response.Header.Get("Etag"), cachePath)
log.Debug("config path not set yet, using ", configPath)
}
out, err := os.Create(configPath)
if err != nil {
log.Error(err)
return "", err
}
defer out.Close()
_, err = io.Copy(out, response.Body)
if err != nil {
log.Error(err)
return "", err
}
log.Debug("config updated to ", configPath)
return configPath, nil
}
func shouldUpdate(cachePath, url string) (string, bool) {
defer log.Trace(time.Now(), cachePath, url)
ctx, cncl := context.WithTimeout(context.Background(), time.Second*time.Duration(5))
defer cncl()
request, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err != nil {
log.Error(err)
return "", true
}
response, err := net.HTTPClient.Do(request)
if err != nil {
log.Error(err)
return "", true
}
defer response.Body.Close()
etag := response.Header.Get("Etag")
if len(etag) == 0 {
log.Debug("no etag found, updating config")
return "", true
}
configPath := formatConfigPath(url, etag, cachePath)
_, err = os.Stat(configPath)
if err != nil {
log.Debug("configfile ", configPath, " doest not exist, updating config")
return configPath, true
}
log.Debug("config found at", configPath, " skipping update")
return configPath, false
}
func formatConfigPath(url, etag, cachePath string) string {
ext := filepath.Ext(url)
etag = strings.TrimLeft(etag, `W/`)
etag = strings.Trim(etag, `"`)
filename := fmt.Sprintf("config.%s.omp%s", etag, ext)
return filepath.Join(cachePath, filename)
}

View file

@ -1,4 +1,4 @@
package platform
package net
import (
"net"
@ -31,5 +31,6 @@ var (
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
}
Client httpClient = &http.Client{Transport: defaultTransport}
HTTPClient httpClient = &http.Client{Transport: defaultTransport}
)

View file

@ -1,9 +1,7 @@
package platform
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
@ -22,6 +20,8 @@ import (
"github.com/jandedobbeleer/oh-my-posh/src/log"
"github.com/jandedobbeleer/oh-my-posh/src/platform/battery"
"github.com/jandedobbeleer/oh-my-posh/src/platform/cmd"
"github.com/jandedobbeleer/oh-my-posh/src/platform/config"
"github.com/jandedobbeleer/oh-my-posh/src/platform/net"
"github.com/jandedobbeleer/oh-my-posh/src/regex"
disk "github.com/shirou/gopsutil/v3/disk"
@ -239,67 +239,48 @@ func (env *Shell) Init() {
func (env *Shell) resolveConfigPath() {
defer env.Trace(time.Now())
if len(env.CmdFlags.Config) == 0 {
env.CmdFlags.Config = env.Getenv("POSH_THEME")
}
if len(env.CmdFlags.Config) == 0 {
env.Debug("No config set, fallback to default config")
return
}
if strings.HasPrefix(env.CmdFlags.Config, "https://") {
if err := env.downloadConfig(env.CmdFlags.Config); err != nil {
// make it use default config when download fails
filePath, err := config.Download(env.CachePath(), env.CmdFlags.Config)
if err != nil {
env.Error(err)
env.CmdFlags.Config = ""
return
}
env.CmdFlags.Config = filePath
return
}
// Cygwin path always needs the full path as we're on Windows but not really.
// Doing filepath actions will convert it to a Windows path and break the init script.
if env.Platform() == WINDOWS && env.Shell() == "bash" {
env.Debug("Cygwin detected, using full path for config")
return
}
configFile := env.CmdFlags.Config
if strings.HasPrefix(configFile, "~") {
configFile = strings.TrimPrefix(configFile, "~")
configFile = filepath.Join(env.Home(), configFile)
}
if !filepath.IsAbs(configFile) {
configFile = filepath.Join(env.Pwd(), configFile)
}
env.CmdFlags.Config = filepath.Clean(configFile)
}
func (env *Shell) downloadConfig(location string) error {
defer env.Trace(time.Now(), location)
ext := filepath.Ext(location)
fileHash := base64.StdEncoding.EncodeToString([]byte(location))
filename := fmt.Sprintf("config.%s.omp%s", fileHash, ext)
configPath := filepath.Join(env.CachePath(), filename)
cfg, err := env.HTTPRequest(location, nil, 5000)
if err != nil {
if _, osErr := os.Stat(configPath); !os.IsNotExist(osErr) {
// use the already cached config
env.CmdFlags.Config = configPath
return nil
}
return err
}
out, err := os.Create(configPath)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, bytes.NewReader(cfg))
if err != nil {
return err
}
env.CmdFlags.Config = configPath
return nil
}
func (env *Shell) Trace(start time.Time, args ...string) {
log.Trace(start, args...)
}
@ -639,24 +620,30 @@ func (env *Shell) unWrapError(err error) error {
func (env *Shell) HTTPRequest(targetURL string, body io.Reader, timeout int, requestModifiers ...HTTPRequestModifier) ([]byte, error) {
defer env.Trace(time.Now(), targetURL)
ctx, cncl := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(timeout))
defer cncl()
request, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, body)
if err != nil {
return nil, err
}
for _, modifier := range requestModifiers {
modifier(request)
}
if env.CmdFlags.Debug {
dump, _ := httputil.DumpRequestOut(request, true)
env.Debug(string(dump))
}
response, err := Client.Do(request)
response, err := net.HTTPClient.Do(request)
if err != nil {
env.Error(err)
return nil, env.unWrapError(err)
}
// anything inside the range [200, 299] is considered a success
if response.StatusCode < 200 || response.StatusCode >= 300 {
message := "HTTP status code " + strconv.Itoa(response.StatusCode)
@ -664,13 +651,17 @@ func (env *Shell) HTTPRequest(targetURL string, body io.Reader, timeout int, req
env.Error(err)
return nil, err
}
defer response.Body.Close()
responseBody, err := io.ReadAll(response.Body)
if err != nil {
env.Error(err)
return nil, err
}
env.Debug(string(responseBody))
return responseBody, nil
}

View file

@ -14,6 +14,7 @@ import (
"github.com/charmbracelet/lipgloss"
"github.com/jandedobbeleer/oh-my-posh/src/build"
"github.com/jandedobbeleer/oh-my-posh/src/platform"
"github.com/jandedobbeleer/oh-my-posh/src/platform/net"
)
var (
@ -148,7 +149,7 @@ func downloadAsset(asset string) (io.ReadCloser, error) {
return nil, err
}
resp, err := platform.Client.Do(req)
resp, err := net.HTTPClient.Do(req)
if err != nil {
return nil, err
}