diff --git a/src/platform/cache.go b/src/platform/cache.go index 42c66f58..96341bd0 100644 --- a/src/platform/cache.go +++ b/src/platform/cache.go @@ -25,13 +25,13 @@ func (c *cacheObject) expired() bool { } type fileCache struct { - cache *concurrentMap + cache *ConcurrentMap cachePath string dirty bool } func (fc *fileCache) Init(cachePath string) { - fc.cache = newConcurrentMap() + fc.cache = NewConcurrentMap() fc.cachePath = cachePath cacheFilePath := filepath.Join(fc.cachePath, CacheFile) content, err := os.ReadFile(cacheFilePath) @@ -48,7 +48,7 @@ func (fc *fileCache) Init(cachePath string) { if co.expired() { continue } - fc.cache.set(key, co) + fc.cache.Set(key, co) } } @@ -56,7 +56,7 @@ func (fc *fileCache) Close() { if !fc.dirty { return } - cache := fc.cache.list() + cache := fc.cache.List() if dump, err := json.MarshalIndent(cache, "", " "); err == nil { cacheFilePath := filepath.Join(fc.cachePath, CacheFile) _ = os.WriteFile(cacheFilePath, dump, 0644) @@ -66,7 +66,7 @@ func (fc *fileCache) Close() { // returns the value for the given key as long as // the TTL (minutes) is not expired func (fc *fileCache) Get(key string) (string, bool) { - val, found := fc.cache.get(key) + val, found := fc.cache.Get(key) if !found { return "", false } @@ -78,7 +78,7 @@ func (fc *fileCache) Get(key string) (string, bool) { // sets the value for the given key with a TTL (minutes) func (fc *fileCache) Set(key, value string, ttl int) { - fc.cache.set(key, &cacheObject{ + fc.cache.Set(key, &cacheObject{ Value: value, Timestamp: time.Now().Unix(), TTL: ttl, diff --git a/src/platform/concurrent_map.go b/src/platform/concurrent_map.go index 144def24..5a8c47d4 100644 --- a/src/platform/concurrent_map.go +++ b/src/platform/concurrent_map.go @@ -1,30 +1,33 @@ package platform -type concurrentMap struct { +import "sync" + +type ConcurrentMap struct { values map[string]interface{} + sync.RWMutex } -func newConcurrentMap() *concurrentMap { - return &concurrentMap{ +func NewConcurrentMap() *ConcurrentMap { + return &ConcurrentMap{ values: make(map[string]interface{}), } } -func (c *concurrentMap) set(key string, value interface{}) { - lock.Lock() - defer lock.Unlock() +func (c *ConcurrentMap) Set(key string, value interface{}) { + c.Lock() + defer c.Unlock() c.values[key] = value } -func (c *concurrentMap) get(key string) (interface{}, bool) { - lock.RLock() - defer lock.RUnlock() +func (c *ConcurrentMap) Get(key string) (interface{}, bool) { + c.RLock() + defer c.RUnlock() if val, ok := c.values[key]; ok { return val, true } return "", false } -func (c *concurrentMap) list() map[string]interface{} { +func (c *ConcurrentMap) List() map[string]interface{} { return c.values } diff --git a/src/platform/shell.go b/src/platform/shell.go index 3732d667..1563fa44 100644 --- a/src/platform/shell.go +++ b/src/platform/shell.go @@ -42,7 +42,6 @@ func getPID() string { } var ( - lock = sync.RWMutex{} TEMPLATECACHE = fmt.Sprintf("template_cache_%s", getPID()) TOGGLECACHE = fmt.Sprintf("toggle_cache_%s", getPID()) ) @@ -131,8 +130,8 @@ type Connection struct { type SegmentsCache map[string]interface{} -func (c *SegmentsCache) Contains(key string) bool { - _, ok := (*c)[key] +func (s *SegmentsCache) Contains(key string) bool { + _, ok := (*s)[key] return ok } @@ -149,15 +148,14 @@ type TemplateCache struct { OS string WSL bool Segments SegmentsCache + + sync.RWMutex } func (t *TemplateCache) AddSegmentData(key string, value interface{}) { - lock.Lock() - defer lock.Unlock() - if t.Segments == nil { - t.Segments = make(map[string]interface{}) - } + t.Lock() t.Segments[key] = value + t.Unlock() } type Environment interface { @@ -212,15 +210,15 @@ type Environment interface { } type commandCache struct { - commands *concurrentMap + commands *ConcurrentMap } func (c *commandCache) set(command, path string) { - c.commands.set(command, path) + c.commands.Set(command, path) } func (c *commandCache) get(command string) (string, bool) { - cacheCommand, found := c.commands.get(command) + cacheCommand, found := c.commands.Get(command) if !found { return "", false } @@ -237,6 +235,8 @@ type Shell struct { fileCache *fileCache tmplCache *TemplateCache networks []*Connection + + sync.RWMutex } func (env *Shell) Init() { @@ -251,7 +251,7 @@ func (env *Shell) Init() { env.fileCache.Init(env.CachePath()) env.resolveConfigPath() env.cmdCache = &commandCache{ - commands: newConcurrentMap(), + commands: NewConcurrentMap(), } } @@ -331,9 +331,7 @@ func (env *Shell) Getenv(key string) string { func (env *Shell) Pwd() string { defer env.Trace(time.Now(), "Pwd") - lock.Lock() defer func() { - lock.Unlock() env.Debug("Pwd", env.cwd) }() if env.cwd != "" { @@ -708,7 +706,11 @@ func (env *Shell) Logs() string { } func (env *Shell) TemplateCache() *TemplateCache { - defer env.Trace(time.Now(), "TemplateCache") + env.Lock() + defer func() { + env.Trace(time.Now(), "TemplateCache") + env.Unlock() + }() if env.tmplCache != nil { return env.tmplCache } @@ -718,6 +720,7 @@ func (env *Shell) TemplateCache() *TemplateCache { ShellVersion: env.CmdFlags.ShellVersion, Code: env.ErrorCode(), WSL: env.IsWsl(), + Segments: make(map[string]interface{}), } tmplCache.Env = make(map[string]string) const separator = "=" @@ -751,8 +754,6 @@ func (env *Shell) TemplateCache() *TemplateCache { } func (env *Shell) DirMatchesOneOf(dir string, regexes []string) (match bool) { - lock.Lock() - defer lock.Unlock() // sometimes the function panics inside golang, we want to silence that error // and assume that there's no match. Not perfect, but better than crashing // for the time being until we figure out what the actual root cause is