// Copyright 2020 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package moby

import (
	"crypto/sha1"
	"encoding/base64"
	"net/http"
	"net/http/httptest"
	"os"
	"path/filepath"
	"strings"
	"testing"

	"github.com/stretchr/testify/require"
	"gopkg.in/yaml.v2"

	"github.com/prometheus/prometheus/util/strutil"
)

// SDMock is the interface for the DigitalOcean mock.
type SDMock struct {
	t         *testing.T
	Server    *httptest.Server
	Mux       *http.ServeMux
	directory string
	calls     map[string]int
}

// NewSDMock returns a new SDMock.
func NewSDMock(t *testing.T, directory string) *SDMock {
	return &SDMock{
		t:         t,
		directory: directory,
		calls:     make(map[string]int),
	}
}

// Endpoint returns the URI to the mock server.
func (m *SDMock) Endpoint() string {
	return m.Server.URL + "/"
}

// Setup creates the mock server.
func (m *SDMock) Setup() {
	m.Mux = http.NewServeMux()
	m.Server = httptest.NewServer(m.Mux)
	m.t.Cleanup(m.Server.Close)
	m.SetupHandlers()
}

// HandleNodesList mocks nodes list.
func (m *SDMock) SetupHandlers() {
	headers := make(map[string]string)
	rawHeaders, err := os.ReadFile(filepath.Join("testdata", m.directory, "headers.yml"))
	require.NoError(m.t, err)
	yaml.Unmarshal(rawHeaders, &headers)

	prefix := "/"
	if v, ok := headers["Api-Version"]; ok {
		prefix += "v" + v + "/"
	}

	for _, path := range []string{"_ping", "networks/", "services/", "nodes/", "nodes", "services", "tasks", "containers/"} {
		p := path
		handler := prefix + p
		if p == "_ping" {
			handler = "/" + p
		}
		m.Mux.HandleFunc(handler, func(w http.ResponseWriter, r *http.Request) {
			// The discovery should only call each API endpoint once.
			m.calls[r.RequestURI]++
			if m.calls[r.RequestURI] != 1 {
				w.WriteHeader(http.StatusTooManyRequests)
				return
			}
			for k, v := range headers {
				w.Header().Add(k, v)
			}
			parts := strings.Split(r.RequestURI, "/")
			var f string
			if strings.HasSuffix(p, "/") {
				f = filepath.Join(p[:len(p)-1], strutil.SanitizeLabelName(parts[len(parts)-1]))
			} else {
				query := strings.Split(parts[len(parts)-1], "?")
				f = query[0] + ".json"
				if len(query) == 2 {
					h := sha1.New()
					h.Write([]byte(query[1]))
					// Avoiding long filenames for Windows.
					f += "__" + base64.URLEncoding.EncodeToString(h.Sum(nil))[:10]
				}
			}
			if response, err := os.ReadFile(filepath.Join("testdata", m.directory, f+".json")); err == nil {
				w.Header().Add("content-type", "application/json")
				w.WriteHeader(http.StatusOK)
				w.Write(response)
				return
			}
			if response, err := os.ReadFile(filepath.Join("testdata", m.directory, f)); err == nil {
				w.WriteHeader(http.StatusOK)
				w.Write(response)
				return
			}
			w.WriteHeader(http.StatusInternalServerError)
		})
	}
}