diff --git a/promql/parser/functions.go b/promql/parser/functions.go index 450021328b..479c7f635d 100644 --- a/promql/parser/functions.go +++ b/promql/parser/functions.go @@ -387,7 +387,7 @@ var Functions = map[string]*Function{ } // getFunction returns a predefined Function object for the given name. -func getFunction(name string) (*Function, bool) { - function, ok := Functions[name] +func getFunction(name string, functions map[string]*Function) (*Function, bool) { + function, ok := functions[name] return function, ok } diff --git a/promql/parser/generated_parser.y b/promql/parser/generated_parser.y index b1c604eeca..b28e9d544c 100644 --- a/promql/parser/generated_parser.y +++ b/promql/parser/generated_parser.y @@ -339,7 +339,7 @@ grouping_label : maybe_label function_call : IDENTIFIER function_call_body { - fn, exist := getFunction($1.Val) + fn, exist := getFunction($1.Val, yylex.(*parser).functions) if !exist{ yylex.(*parser).addParseErrf($1.PositionRange(),"unknown function with name %q", $1.Val) } diff --git a/promql/parser/generated_parser.y.go b/promql/parser/generated_parser.y.go index 2cf3e06b9a..1274137988 100644 --- a/promql/parser/generated_parser.y.go +++ b/promql/parser/generated_parser.y.go @@ -1210,7 +1210,7 @@ yydefault: yyDollar = yyS[yypt-2 : yypt+1] //line promql/parser/generated_parser.y:341 { - fn, exist := getFunction(yyDollar[1].item.Val) + fn, exist := getFunction(yyDollar[1].item.Val, yylex.(*parser).functions) if !exist { yylex.(*parser).addParseErrf(yyDollar[1].item.PositionRange(), "unknown function with name %q", yyDollar[1].item.Val) } diff --git a/promql/parser/parse.go b/promql/parser/parse.go index e69ed4595c..b87dfaece3 100644 --- a/promql/parser/parse.go +++ b/promql/parser/parse.go @@ -37,12 +37,20 @@ var parserPool = sync.Pool{ }, } +type Parser interface { + ParseExpr() (Expr, error) + Close() +} + type parser struct { lex Lexer inject ItemType injecting bool + // functions contains all functions supported by the parser instance. + functions map[string]*Function + // Everytime an Item is lexed that could be the end // of certain expressions its end position is stored here. lastClosing Pos @@ -53,6 +61,62 @@ type parser struct { parseErrors ParseErrors } +type Opt func(p *parser) + +func WithFunctions(functions map[string]*Function) Opt { + return func(p *parser) { + p.functions = functions + } +} + +// NewParser returns a new parser. +func NewParser(input string, opts ...Opt) *parser { + p := parserPool.Get().(*parser) + + p.functions = Functions + p.injecting = false + p.parseErrors = nil + p.generatedParserResult = nil + + // Clear lexer struct before reusing. + p.lex = Lexer{ + input: input, + state: lexStatements, + } + + // Apply user define options. + for _, opt := range opts { + opt(p) + } + + return p +} + +func (p *parser) ParseExpr() (expr Expr, err error) { + defer p.recover(&err) + + parseResult := p.parseGenerated(START_EXPRESSION) + + if parseResult != nil { + expr = parseResult.(Expr) + } + + // Only typecheck when there are no syntax errors. + if len(p.parseErrors) == 0 { + p.checkAST(expr) + } + + if len(p.parseErrors) != 0 { + err = p.parseErrors + } + + return expr, err +} + +func (p *parser) Close() { + defer parserPool.Put(p) +} + // ParseErr wraps a parsing error with line and position context. type ParseErr struct { PositionRange PositionRange @@ -105,32 +169,15 @@ func (errs ParseErrors) Error() string { // ParseExpr returns the expression parsed from the input. func ParseExpr(input string) (expr Expr, err error) { - p := newParser(input) - defer parserPool.Put(p) - defer p.recover(&err) - - parseResult := p.parseGenerated(START_EXPRESSION) - - if parseResult != nil { - expr = parseResult.(Expr) - } - - // Only typecheck when there are no syntax errors. - if len(p.parseErrors) == 0 { - p.checkAST(expr) - } - - if len(p.parseErrors) != 0 { - err = p.parseErrors - } - - return expr, err + p := NewParser(input) + defer p.Close() + return p.ParseExpr() } // ParseMetric parses the input into a metric func ParseMetric(input string) (m labels.Labels, err error) { - p := newParser(input) - defer parserPool.Put(p) + p := NewParser(input) + defer p.Close() defer p.recover(&err) parseResult := p.parseGenerated(START_METRIC) @@ -148,8 +195,8 @@ func ParseMetric(input string) (m labels.Labels, err error) { // ParseMetricSelector parses the provided textual metric selector into a list of // label matchers. func ParseMetricSelector(input string) (m []*labels.Matcher, err error) { - p := newParser(input) - defer parserPool.Put(p) + p := NewParser(input) + defer p.Close() defer p.recover(&err) parseResult := p.parseGenerated(START_METRIC_SELECTOR) @@ -164,22 +211,6 @@ func ParseMetricSelector(input string) (m []*labels.Matcher, err error) { return m, err } -// newParser returns a new parser. -func newParser(input string) *parser { - p := parserPool.Get().(*parser) - - p.injecting = false - p.parseErrors = nil - p.generatedParserResult = nil - - // Clear lexer struct before reusing. - p.lex = Lexer{ - input: input, - state: lexStatements, - } - return p -} - // SequenceValue is an omittable value in a sequence of time series values. type SequenceValue struct { Value float64 @@ -200,10 +231,10 @@ type seriesDescription struct { // ParseSeriesDesc parses the description of a time series. func ParseSeriesDesc(input string) (labels labels.Labels, values []SequenceValue, err error) { - p := newParser(input) + p := NewParser(input) p.lex.seriesDesc = true - defer parserPool.Put(p) + defer p.Close() defer p.recover(&err) parseResult := p.parseGenerated(START_SERIES_DESCRIPTION) @@ -799,7 +830,7 @@ func MustLabelMatcher(mt labels.MatchType, name, val string) *labels.Matcher { } func MustGetFunction(name string) *Function { - f, ok := getFunction(name) + f, ok := getFunction(name, Functions) if !ok { panic(fmt.Errorf("function %q does not exist", name)) } diff --git a/promql/parser/parse_test.go b/promql/parser/parse_test.go index a336e96e8a..b16aed71d6 100644 --- a/promql/parser/parse_test.go +++ b/promql/parser/parse_test.go @@ -3739,7 +3739,7 @@ func TestParseSeries(t *testing.T) { } func TestRecoverParserRuntime(t *testing.T) { - p := newParser("foo bar") + p := NewParser("foo bar") var err error defer func() { @@ -3753,7 +3753,7 @@ func TestRecoverParserRuntime(t *testing.T) { } func TestRecoverParserError(t *testing.T) { - p := newParser("foo bar") + p := NewParser("foo bar") var err error e := errors.New("custom error") @@ -3801,3 +3801,20 @@ func TestExtractSelectors(t *testing.T) { require.Equal(t, expected, ExtractSelectors(expr)) } } + +func TestParseCustomFunctions(t *testing.T) { + funcs := Functions + funcs["custom_func"] = &Function{ + Name: "custom_func", + ArgTypes: []ValueType{ValueTypeMatrix}, + ReturnType: ValueTypeVector, + } + input := "custom_func(metric[1m])" + p := NewParser(input, WithFunctions(funcs)) + expr, err := p.ParseExpr() + require.NoError(t, err) + + call, ok := expr.(*Call) + require.True(t, ok) + require.Equal(t, "custom_func", call.Func.Name) +}