From a200ee73ff0c1459f4fe00edd14e13e7fdb1bd44 Mon Sep 17 00:00:00 2001 From: Jan De Dobbeleer Date: Sat, 15 Jan 2022 20:38:32 +0100 Subject: [PATCH] feat: add OS to template functionality --- src/environment.go | 33 +++++++++++++++++++++++++++------ src/segment_os.go | 13 +++++++------ src/segment_os_test.go | 6 +++--- src/segment_path_test.go | 20 ++++++++++++++++++++ src/template.go | 16 ++++++++++------ 5 files changed, 67 insertions(+), 21 deletions(-) diff --git a/src/environment.go b/src/environment.go index aaf5e941..ae5e26a7 100644 --- a/src/environment.go +++ b/src/environment.go @@ -90,6 +90,7 @@ type wifiInfo struct { type Environment interface { getenv(key string) string + environ() map[string]string getcwd() string homeDir() string hasFiles(pattern string) bool @@ -154,12 +155,13 @@ const ( ) type environment struct { - args *args - cwd string - cmdCache *commandCache - fileCache *fileCache - logBuilder strings.Builder - debug bool + args *args + cwd string + cmdCache *commandCache + fileCache *fileCache + environCache map[string]string + logBuilder strings.Builder + debug bool } func (env *environment) init(args *args) { @@ -222,6 +224,25 @@ func (env *environment) getenv(key string) string { return val } +func (env *environment) environ() map[string]string { + defer env.trace(time.Now(), "environ") + if env.environCache != nil { + return env.environCache + } + const separator = "=" + values := os.Environ() + for value := range values { + splitted := strings.Split(values[value], separator) + if len(splitted) != 2 { + continue + } + key := splitted[0] + val := splitted[1:] + env.environCache[key] = strings.Join(val, separator) + } + return env.environCache +} + func (env *environment) getcwd() string { defer env.trace(time.Now(), "getcwd") if env.cwd != "" { diff --git a/src/segment_os.go b/src/segment_os.go index 94aa3a0a..fcffea97 100644 --- a/src/segment_os.go +++ b/src/segment_os.go @@ -7,7 +7,8 @@ import ( type osInfo struct { props Properties env Environment - OS string + + os string } const ( @@ -71,25 +72,25 @@ func (n *osInfo) string() string { goos := n.env.getRuntimeGOOS() switch goos { case windowsPlatform: - n.OS = windowsPlatform + n.os = windowsPlatform return n.props.getString(Windows, "\uE62A") case darwinPlatform: - n.OS = darwinPlatform + n.os = darwinPlatform return n.props.getString(MacOS, "\uF179") case linuxPlatform: wsl := n.env.getenv("WSL_DISTRO_NAME") p := n.env.getPlatform() if len(wsl) == 0 { - n.OS = p + n.os = p return n.getDistroName(p, "") } - n.OS = wsl + n.os = wsl return fmt.Sprintf("%s%s%s", n.props.getString(WSL, "WSL"), n.props.getString(WSLSeparator, " - "), n.getDistroName(p, wsl)) default: - n.OS = goos + n.os = goos return goos } } diff --git a/src/segment_os_test.go b/src/segment_os_test.go index 6ecad947..d074e946 100644 --- a/src/segment_os_test.go +++ b/src/segment_os_test.go @@ -76,11 +76,11 @@ func TestOSInfo(t *testing.T) { } assert.Equal(t, tc.ExpectedString, osInfo.string(), tc.Case) if tc.WSLDistro != "" { - assert.Equal(t, tc.WSLDistro, osInfo.OS, tc.Case) + assert.Equal(t, tc.WSLDistro, osInfo.os, tc.Case) } else if tc.Platform != "" { - assert.Equal(t, tc.Platform, osInfo.OS, tc.Case) + assert.Equal(t, tc.Platform, osInfo.os, tc.Case) } else { - assert.Equal(t, tc.GOOS, osInfo.OS, tc.Case) + assert.Equal(t, tc.GOOS, osInfo.os, tc.Case) } } } diff --git a/src/segment_path_test.go b/src/segment_path_test.go index 1cf0e224..0ede23d9 100644 --- a/src/segment_path_test.go +++ b/src/segment_path_test.go @@ -19,6 +19,11 @@ func (env *MockedEnvironment) getenv(key string) string { return args.String(0) } +func (env *MockedEnvironment) environ() map[string]string { + args := env.Called() + return args.Get(0).(map[string]string) +} + func (env *MockedEnvironment) getcwd() string { args := env.Called() return args.String(0) @@ -212,6 +217,15 @@ func (env *MockedEnvironment) onTemplate() { } env.On(method).Return(returnArguments...) } + patchEnvVars := func() map[string]string { + keyValueArray := make(map[string]string) + for _, call := range env.Mock.ExpectedCalls { + if call.Method == "getenv" { + keyValueArray[call.Arguments.String(0)] = call.ReturnArguments.String(0) + } + } + return keyValueArray + } patchMethodIfNotSpecified("isRunningAsRoot", false) patchMethodIfNotSpecified("getcwd", "/usr/home/dev/my-app") patchMethodIfNotSpecified("homeDir", "/usr/home/dev") @@ -220,6 +234,12 @@ func (env *MockedEnvironment) onTemplate() { patchMethodIfNotSpecified("getCurrentUser", "dev") patchMethodIfNotSpecified("getHostName", "laptop", nil) patchMethodIfNotSpecified("lastErrorCode", 0) + patchMethodIfNotSpecified("getRuntimeGOOS", darwinPlatform) + if env.getRuntimeGOOS() == linuxPlatform { + env.On("getenv", "WSL_DISTRO_NAME").Return("ubuntu") + env.On("getPlatform").Return("ubuntu") + } + patchMethodIfNotSpecified("environ", patchEnvVars()) } const ( diff --git a/src/template.go b/src/template.go index fba6a1c2..cc24e81a 100644 --- a/src/template.go +++ b/src/template.go @@ -34,6 +34,7 @@ type Context struct { Host string Code int Env map[string]string + OS string // Simple container to hold ANY object Data @@ -55,13 +56,16 @@ func (c *Context) init(t *textTemplate) { c.Host = host } c.Code = t.Env.lastErrorCode() - if strings.Contains(t.Template, ".Env.") { - c.Env = map[string]string{} - matches := findAllNamedRegexMatch(templateEnvRegex, t.Template) - for _, match := range matches { - c.Env[match["ENV"]] = t.Env.getenv(match["ENV"]) + c.Env = t.Env.environ() + goos := t.Env.getRuntimeGOOS() + if goos == linuxPlatform { + wsl := t.Env.getenv("WSL_DISTRO_NAME") + goos = t.Env.getPlatform() + if len(wsl) != 0 { + goos = wsl } } + c.OS = goos } func (t *textTemplate) render() (string, error) { @@ -100,7 +104,7 @@ func (t *textTemplate) cleanTemplate() { *knownVariables = append(*knownVariables, splitted[0]) return splitted[0], true } - knownVariables := []string{"Root", "PWD", "Folder", "Shell", "User", "Host", "Env", "Data", "Code"} + knownVariables := []string{"Root", "PWD", "Folder", "Shell", "User", "Host", "Env", "Data", "Code", "OS"} matches := findAllNamedRegexMatch(`(?: |{)(?P(\.[a-zA-Z_][a-zA-Z0-9]*)+)`, t.Template) for _, match := range matches { if variable, OK := unknownVariable(match["var"], &knownVariables); OK {