Update vendored go-zookeeper client (#2778)

It is likely this will fix #2758.
This commit is contained in:
Stephan Erb 2017-05-29 15:59:30 +02:00 committed by Julius Volz
parent 51626f2573
commit 14eee34da3
9 changed files with 574 additions and 420 deletions

View file

@ -44,9 +44,9 @@ const (
type watchType int
const (
watchTypeData = iota
watchTypeExist = iota
watchTypeChild = iota
watchTypeData = iota
watchTypeExist
watchTypeChild
)
type watchPathType struct {
@ -61,37 +61,52 @@ type Logger interface {
Printf(string, ...interface{})
}
type Conn struct {
lastZxid int64
sessionID int64
state State // must be 32-bit aligned
xid uint32
timeout int32 // session timeout in milliseconds
passwd []byte
type authCreds struct {
scheme string
auth []byte
}
dialer Dialer
servers []string
serverIndex int // remember last server that was tried during connect to round-robin attempts to servers
lastServerIndex int // index of the last server that was successfully connected to and authenticated with
conn net.Conn
eventChan chan Event
shouldQuit chan struct{}
pingInterval time.Duration
recvTimeout time.Duration
connectTimeout time.Duration
type Conn struct {
lastZxid int64
sessionID int64
state State // must be 32-bit aligned
xid uint32
sessionTimeoutMs int32 // session timeout in milliseconds
passwd []byte
dialer Dialer
hostProvider HostProvider
serverMu sync.Mutex // protects server
server string // remember the address/port of the current server
conn net.Conn
eventChan chan Event
eventCallback EventCallback // may be nil
shouldQuit chan struct{}
pingInterval time.Duration
recvTimeout time.Duration
connectTimeout time.Duration
creds []authCreds
credsMu sync.Mutex // protects server
sendChan chan *request
requests map[int32]*request // Xid -> pending request
requestsLock sync.Mutex
watchers map[watchPathType][]chan Event
watchersLock sync.Mutex
closeChan chan struct{} // channel to tell send loop stop
// Debug (used by unit tests)
reconnectDelay time.Duration
logger Logger
buf []byte
}
// connOption represents a connection option.
type connOption func(c *Conn)
type request struct {
xid int32
opcode int32
@ -122,26 +137,39 @@ type Event struct {
Server string // For connection events
}
// Connect establishes a new connection to a pool of zookeeper servers
// using the default net.Dialer. See ConnectWithDialer for further
// information about session timeout.
func Connect(servers []string, sessionTimeout time.Duration) (*Conn, <-chan Event, error) {
return ConnectWithDialer(servers, sessionTimeout, nil)
// HostProvider is used to represent a set of hosts a ZooKeeper client should connect to.
// It is an analog of the Java equivalent:
// http://svn.apache.org/viewvc/zookeeper/trunk/src/java/main/org/apache/zookeeper/client/HostProvider.java?view=markup
type HostProvider interface {
// Init is called first, with the servers specified in the connection string.
Init(servers []string) error
// Len returns the number of servers.
Len() int
// Next returns the next server to connect to. retryStart will be true if we've looped through
// all known servers without Connected() being called.
Next() (server string, retryStart bool)
// Notify the HostProvider of a successful connection.
Connected()
}
// ConnectWithDialer establishes a new connection to a pool of zookeeper
// ConnectWithDialer establishes a new connection to a pool of zookeeper servers
// using a custom Dialer. See Connect for further information about session timeout.
// This method is deprecated and provided for compatibility: use the WithDialer option instead.
func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Dialer) (*Conn, <-chan Event, error) {
return Connect(servers, sessionTimeout, WithDialer(dialer))
}
// Connect establishes a new connection to a pool of zookeeper
// servers. The provided session timeout sets the amount of time for which
// a session is considered valid after losing connection to a server. Within
// the session timeout it's possible to reestablish a connection to a different
// server and keep the same session. This is means any ephemeral nodes and
// watches are maintained.
func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Dialer) (*Conn, <-chan Event, error) {
func Connect(servers []string, sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) {
if len(servers) == 0 {
return nil, nil, errors.New("zk: server list must not be empty")
}
recvTimeout := sessionTimeout * 2 / 3
srvs := make([]string, len(servers))
for i, addr := range servers {
@ -156,38 +184,69 @@ func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Di
stringShuffle(srvs)
ec := make(chan Event, eventChanSize)
if dialer == nil {
dialer = net.DialTimeout
}
conn := Conn{
dialer: dialer,
servers: srvs,
serverIndex: 0,
lastServerIndex: -1,
conn: nil,
state: StateDisconnected,
eventChan: ec,
shouldQuit: make(chan struct{}),
recvTimeout: recvTimeout,
pingInterval: recvTimeout / 2,
connectTimeout: 1 * time.Second,
sendChan: make(chan *request, sendChanSize),
requests: make(map[int32]*request),
watchers: make(map[watchPathType][]chan Event),
passwd: emptyPassword,
timeout: int32(sessionTimeout.Nanoseconds() / 1e6),
logger: DefaultLogger,
conn := &Conn{
dialer: net.DialTimeout,
hostProvider: &DNSHostProvider{},
conn: nil,
state: StateDisconnected,
eventChan: ec,
shouldQuit: make(chan struct{}),
connectTimeout: 1 * time.Second,
sendChan: make(chan *request, sendChanSize),
requests: make(map[int32]*request),
watchers: make(map[watchPathType][]chan Event),
passwd: emptyPassword,
logger: DefaultLogger,
buf: make([]byte, bufferSize),
// Debug
reconnectDelay: 0,
}
// Set provided options.
for _, option := range options {
option(conn)
}
if err := conn.hostProvider.Init(srvs); err != nil {
return nil, nil, err
}
conn.setTimeouts(int32(sessionTimeout / time.Millisecond))
go func() {
conn.loop()
conn.flushRequests(ErrClosing)
conn.invalidateWatches(ErrClosing)
close(conn.eventChan)
}()
return &conn, ec, nil
return conn, ec, nil
}
// WithDialer returns a connection option specifying a non-default Dialer.
func WithDialer(dialer Dialer) connOption {
return func(c *Conn) {
c.dialer = dialer
}
}
// WithHostProvider returns a connection option specifying a non-default HostProvider.
func WithHostProvider(hostProvider HostProvider) connOption {
return func(c *Conn) {
c.hostProvider = hostProvider
}
}
// EventCallback is a function that is called when an Event occurs.
type EventCallback func(Event)
// WithEventCallback returns a connection option that specifies an event
// callback.
// The callback must not block - doing so would delay the ZK go routines.
func WithEventCallback(cb EventCallback) connOption {
return func(c *Conn) {
c.eventCallback = cb
}
}
func (c *Conn) Close() {
@ -199,31 +258,54 @@ func (c *Conn) Close() {
}
}
// States returns the current state of the connection.
// State returns the current state of the connection.
func (c *Conn) State() State {
return State(atomic.LoadInt32((*int32)(&c.state)))
}
// SessionID returns the current session id of the connection.
func (c *Conn) SessionID() int64 {
return atomic.LoadInt64(&c.sessionID)
}
// SetLogger sets the logger to be used for printing errors.
// Logger is an interface provided by this package.
func (c *Conn) SetLogger(l Logger) {
c.logger = l
}
func (c *Conn) setTimeouts(sessionTimeoutMs int32) {
c.sessionTimeoutMs = sessionTimeoutMs
sessionTimeout := time.Duration(sessionTimeoutMs) * time.Millisecond
c.recvTimeout = sessionTimeout * 2 / 3
c.pingInterval = c.recvTimeout / 2
}
func (c *Conn) setState(state State) {
atomic.StoreInt32((*int32)(&c.state), int32(state))
c.sendEvent(Event{Type: EventSession, State: state, Server: c.Server()})
}
func (c *Conn) sendEvent(evt Event) {
if c.eventCallback != nil {
c.eventCallback(evt)
}
select {
case c.eventChan <- Event{Type: EventSession, State: state, Server: c.servers[c.serverIndex]}:
case c.eventChan <- evt:
default:
// panic("zk: event channel full - it must be monitored and never allowed to be full")
}
}
func (c *Conn) connect() error {
c.setState(StateConnecting)
var retryStart bool
for {
c.serverIndex = (c.serverIndex + 1) % len(c.servers)
if c.serverIndex == c.lastServerIndex {
c.serverMu.Lock()
c.server, retryStart = c.hostProvider.Next()
c.serverMu.Unlock()
c.setState(StateConnecting)
if retryStart {
c.flushUnsentRequests(ErrNoServer)
select {
case <-time.After(time.Second):
@ -233,22 +315,79 @@ func (c *Conn) connect() error {
c.flushUnsentRequests(ErrClosing)
return ErrClosing
}
} else if c.lastServerIndex < 0 {
// lastServerIndex defaults to -1 to avoid a delay on the initial connect
c.lastServerIndex = 0
}
zkConn, err := c.dialer("tcp", c.servers[c.serverIndex], c.connectTimeout)
zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout)
if err == nil {
c.conn = zkConn
c.setState(StateConnected)
c.logger.Printf("Connected to %s", c.Server())
return nil
}
c.logger.Printf("Failed to connect to %s: %+v", c.servers[c.serverIndex], err)
c.logger.Printf("Failed to connect to %s: %+v", c.Server(), err)
}
}
func (c *Conn) resendZkAuth(reauthReadyChan chan struct{}) {
c.credsMu.Lock()
defer c.credsMu.Unlock()
defer close(reauthReadyChan)
c.logger.Printf("Re-submitting `%d` credentials after reconnect",
len(c.creds))
for _, cred := range c.creds {
resChan, err := c.sendRequest(
opSetAuth,
&setAuthRequest{Type: 0,
Scheme: cred.scheme,
Auth: cred.auth,
},
&setAuthResponse{},
nil)
if err != nil {
c.logger.Printf("Call to sendRequest failed during credential resubmit: %s", err)
// FIXME(prozlach): lets ignore errors for now
continue
}
res := <-resChan
if res.err != nil {
c.logger.Printf("Credential re-submit failed: %s", res.err)
// FIXME(prozlach): lets ignore errors for now
continue
}
}
}
func (c *Conn) sendRequest(
opcode int32,
req interface{},
res interface{},
recvFunc func(*request, *responseHeader, error),
) (
<-chan response,
error,
) {
rq := &request{
xid: c.nextXid(),
opcode: opcode,
pkt: req,
recvStruct: res,
recvChan: make(chan response, 1),
recvFunc: recvFunc,
}
if err := c.sendData(rq); err != nil {
return nil, err
}
return rq.recvChan, nil
}
func (c *Conn) loop() {
for {
if err := c.connect(); err != nil {
@ -259,41 +398,46 @@ func (c *Conn) loop() {
err := c.authenticate()
switch {
case err == ErrSessionExpired:
c.logger.Printf("Authentication failed: %s", err)
c.invalidateWatches(err)
case err != nil && c.conn != nil:
c.logger.Printf("Authentication failed: %s", err)
c.conn.Close()
case err == nil:
c.lastServerIndex = c.serverIndex
closeChan := make(chan struct{}) // channel to tell send loop stop
var wg sync.WaitGroup
c.logger.Printf("Authenticated: id=%d, timeout=%d", c.SessionID(), c.sessionTimeoutMs)
c.hostProvider.Connected() // mark success
c.closeChan = make(chan struct{}) // channel to tell send loop stop
reauthChan := make(chan struct{}) // channel to tell send loop that authdata has been resubmitted
var wg sync.WaitGroup
wg.Add(1)
go func() {
c.sendLoop(c.conn, closeChan)
<-reauthChan
err := c.sendLoop()
c.logger.Printf("Send loop terminated: err=%v", err)
c.conn.Close() // causes recv loop to EOF/exit
wg.Done()
}()
wg.Add(1)
go func() {
err = c.recvLoop(c.conn)
err := c.recvLoop(c.conn)
c.logger.Printf("Recv loop terminated: err=%v", err)
if err == nil {
panic("zk: recvLoop should never return nil error")
}
close(closeChan) // tell send loop to exit
close(c.closeChan) // tell send loop to exit
wg.Done()
}()
c.resendZkAuth(reauthChan)
c.sendSetWatches()
wg.Wait()
}
c.setState(StateDisconnected)
// Yeesh
if err != io.EOF && err != ErrSessionExpired && !strings.Contains(err.Error(), "use of closed network connection") {
c.logger.Printf(err.Error())
}
select {
case <-c.shouldQuit:
c.flushRequests(ErrClosing)
@ -399,13 +543,12 @@ func (c *Conn) sendSetWatches() {
func (c *Conn) authenticate() error {
buf := make([]byte, 256)
// connect request
// Encode and send a connect request.
n, err := encodePacket(buf[4:], &connectRequest{
ProtocolVersion: protocolVersion,
LastZxidSeen: c.lastZxid,
TimeOut: c.timeout,
SessionID: c.sessionID,
TimeOut: c.sessionTimeoutMs,
SessionID: c.SessionID(),
Passwd: c.passwd,
})
if err != nil {
@ -421,23 +564,12 @@ func (c *Conn) authenticate() error {
return err
}
c.sendSetWatches()
// connect response
// package length
// Receive and decode a connect response.
c.conn.SetReadDeadline(time.Now().Add(c.recvTimeout * 10))
_, err = io.ReadFull(c.conn, buf[:4])
c.conn.SetReadDeadline(time.Time{})
if err != nil {
// Sometimes zookeeper just drops connection on invalid session data,
// we prefer to drop session and start from scratch when that event
// occurs instead of dropping into loop of connect/disconnect attempts
c.sessionID = 0
c.passwd = emptyPassword
c.lastZxid = 0
c.setState(StateExpired)
return ErrSessionExpired
return err
}
blen := int(binary.BigEndian.Uint32(buf[:4]))
@ -456,81 +588,88 @@ func (c *Conn) authenticate() error {
return err
}
if r.SessionID == 0 {
c.sessionID = 0
atomic.StoreInt64(&c.sessionID, int64(0))
c.passwd = emptyPassword
c.lastZxid = 0
c.setState(StateExpired)
return ErrSessionExpired
}
c.timeout = r.TimeOut
c.sessionID = r.SessionID
atomic.StoreInt64(&c.sessionID, r.SessionID)
c.setTimeouts(r.TimeOut)
c.passwd = r.Passwd
c.setState(StateHasSession)
return nil
}
func (c *Conn) sendLoop(conn net.Conn, closeChan <-chan struct{}) error {
func (c *Conn) sendData(req *request) error {
header := &requestHeader{req.xid, req.opcode}
n, err := encodePacket(c.buf[4:], header)
if err != nil {
req.recvChan <- response{-1, err}
return nil
}
n2, err := encodePacket(c.buf[4+n:], req.pkt)
if err != nil {
req.recvChan <- response{-1, err}
return nil
}
n += n2
binary.BigEndian.PutUint32(c.buf[:4], uint32(n))
c.requestsLock.Lock()
select {
case <-c.closeChan:
req.recvChan <- response{-1, ErrConnectionClosed}
c.requestsLock.Unlock()
return ErrConnectionClosed
default:
}
c.requests[req.xid] = req
c.requestsLock.Unlock()
c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout))
_, err = c.conn.Write(c.buf[:n+4])
c.conn.SetWriteDeadline(time.Time{})
if err != nil {
req.recvChan <- response{-1, err}
c.conn.Close()
return err
}
return nil
}
func (c *Conn) sendLoop() error {
pingTicker := time.NewTicker(c.pingInterval)
defer pingTicker.Stop()
buf := make([]byte, bufferSize)
for {
select {
case req := <-c.sendChan:
header := &requestHeader{req.xid, req.opcode}
n, err := encodePacket(buf[4:], header)
if err != nil {
req.recvChan <- response{-1, err}
continue
}
n2, err := encodePacket(buf[4+n:], req.pkt)
if err != nil {
req.recvChan <- response{-1, err}
continue
}
n += n2
binary.BigEndian.PutUint32(buf[:4], uint32(n))
c.requestsLock.Lock()
select {
case <-closeChan:
req.recvChan <- response{-1, ErrConnectionClosed}
c.requestsLock.Unlock()
return ErrConnectionClosed
default:
}
c.requests[req.xid] = req
c.requestsLock.Unlock()
conn.SetWriteDeadline(time.Now().Add(c.recvTimeout))
_, err = conn.Write(buf[:n+4])
conn.SetWriteDeadline(time.Time{})
if err != nil {
req.recvChan <- response{-1, err}
conn.Close()
if err := c.sendData(req); err != nil {
return err
}
case <-pingTicker.C:
n, err := encodePacket(buf[4:], &requestHeader{Xid: -2, Opcode: opPing})
n, err := encodePacket(c.buf[4:], &requestHeader{Xid: -2, Opcode: opPing})
if err != nil {
panic("zk: opPing should never fail to serialize")
}
binary.BigEndian.PutUint32(buf[:4], uint32(n))
binary.BigEndian.PutUint32(c.buf[:4], uint32(n))
conn.SetWriteDeadline(time.Now().Add(c.recvTimeout))
_, err = conn.Write(buf[:n+4])
conn.SetWriteDeadline(time.Time{})
c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout))
_, err = c.conn.Write(c.buf[:n+4])
c.conn.SetWriteDeadline(time.Time{})
if err != nil {
conn.Close()
c.conn.Close()
return err
}
case <-closeChan:
case <-c.closeChan:
return nil
}
}
@ -565,7 +704,7 @@ func (c *Conn) recvLoop(conn net.Conn) error {
if res.Xid == -1 {
res := &watcherEvent{}
_, err := decodePacket(buf[16:16+blen], res)
_, err := decodePacket(buf[16:blen], res)
if err != nil {
return err
}
@ -575,10 +714,7 @@ func (c *Conn) recvLoop(conn net.Conn) error {
Path: res.Path,
Err: nil,
}
select {
case c.eventChan <- ev:
default:
}
c.sendEvent(ev)
wTypes := make([]watchType, 0, 2)
switch res.Type {
case EventNodeCreated:
@ -622,7 +758,7 @@ func (c *Conn) recvLoop(conn net.Conn) error {
if res.Err != 0 {
err = res.Err.toError()
} else {
_, err = decodePacket(buf[16:16+blen], req.recvStruct)
_, err = decodePacket(buf[16:blen], req.recvStruct)
}
if req.recvFunc != nil {
req.recvFunc(req, &res, err)
@ -670,7 +806,28 @@ func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc
func (c *Conn) AddAuth(scheme string, auth []byte) error {
_, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil)
return err
if err != nil {
return err
}
// Remember authdata so that it can be re-submitted on reconnect
//
// FIXME(prozlach): For now we treat "userfoo:passbar" and "userfoo:passbar2"
// as two different entries, which will be re-submitted on reconnet. Some
// research is needed on how ZK treats these cases and
// then maybe switch to something like "map[username] = password" to allow
// only single password for given user with users being unique.
obj := authCreds{
scheme: scheme,
auth: auth,
}
c.credsMu.Lock()
c.creds = append(c.creds, obj)
c.credsMu.Unlock()
return nil
}
func (c *Conn) Children(path string) ([]string, *Stat, error) {
@ -816,7 +973,6 @@ func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) {
_, err := c.request(opGetAcl, &getAclRequest{Path: path}, res, nil)
return res.Acl, &res.Stat, err
}
func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) {
res := &setAclResponse{}
_, err := c.request(opSetAcl, &setAclRequest{Path: path, Acl: acl, Version: version}, res, nil)
@ -832,6 +988,7 @@ func (c *Conn) Sync(path string) (string, error) {
type MultiResponse struct {
Stat *Stat
String string
Error error
}
// Multi executes multiple ZooKeeper operations or none of them. The provided
@ -854,7 +1011,7 @@ func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) {
case *CheckVersionRequest:
opCode = opCheck
default:
return nil, fmt.Errorf("uknown operation type %T", op)
return nil, fmt.Errorf("unknown operation type %T", op)
}
req.Ops = append(req.Ops, multiRequestOp{multiHeader{opCode, false, -1}, op})
}
@ -862,7 +1019,14 @@ func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) {
_, err := c.request(opMulti, req, res, nil)
mr := make([]MultiResponse, len(res.Ops))
for i, op := range res.Ops {
mr[i] = MultiResponse{Stat: op.Stat, String: op.String}
mr[i] = MultiResponse{Stat: op.Stat, String: op.String, Error: op.Err.toError()}
}
return mr, err
}
// Server returns the current or last-connected server name.
func (c *Conn) Server() string {
c.serverMu.Lock()
defer c.serverMu.Unlock()
return c.server
}

View file

@ -28,18 +28,19 @@ const (
opClose = -11
opSetAuth = 100
opSetWatches = 101
opError = -1
// Not in protocol, used internally
opWatcherEvent = -2
)
const (
EventNodeCreated = EventType(1)
EventNodeDeleted = EventType(2)
EventNodeDataChanged = EventType(3)
EventNodeChildrenChanged = EventType(4)
EventNodeCreated EventType = 1
EventNodeDeleted EventType = 2
EventNodeDataChanged EventType = 3
EventNodeChildrenChanged EventType = 4
EventSession = EventType(-1)
EventNotWatching = EventType(-2)
EventSession EventType = -1
EventNotWatching EventType = -2
)
var (
@ -54,14 +55,13 @@ var (
)
const (
StateUnknown = State(-1)
StateDisconnected = State(0)
StateConnecting = State(1)
StateAuthFailed = State(4)
StateConnectedReadOnly = State(5)
StateSaslAuthenticated = State(6)
StateExpired = State(-112)
// StateAuthFailed = State(-113)
StateUnknown State = -1
StateDisconnected State = 0
StateConnecting State = 1
StateAuthFailed State = 4
StateConnectedReadOnly State = 5
StateSaslAuthenticated State = 6
StateExpired State = -112
StateConnected = State(100)
StateHasSession = State(101)
@ -154,20 +154,20 @@ const (
errBadArguments = -8
errInvalidState = -9
// API errors
errAPIError = ErrCode(-100)
errNoNode = ErrCode(-101) // *
errNoAuth = ErrCode(-102)
errBadVersion = ErrCode(-103) // *
errNoChildrenForEphemerals = ErrCode(-108)
errNodeExists = ErrCode(-110) // *
errNotEmpty = ErrCode(-111)
errSessionExpired = ErrCode(-112)
errInvalidCallback = ErrCode(-113)
errInvalidAcl = ErrCode(-114)
errAuthFailed = ErrCode(-115)
errClosing = ErrCode(-116)
errNothing = ErrCode(-117)
errSessionMoved = ErrCode(-118)
errAPIError ErrCode = -100
errNoNode ErrCode = -101 // *
errNoAuth ErrCode = -102
errBadVersion ErrCode = -103 // *
errNoChildrenForEphemerals ErrCode = -108
errNodeExists ErrCode = -110 // *
errNotEmpty ErrCode = -111
errSessionExpired ErrCode = -112
errInvalidCallback ErrCode = -113
errInvalidAcl ErrCode = -114
errAuthFailed ErrCode = -115
errClosing ErrCode = -116
errNothing ErrCode = -117
errSessionMoved ErrCode = -118
)
// Constants for ACL permissions

View file

@ -0,0 +1,88 @@
package zk
import (
"fmt"
"net"
"sync"
)
// DNSHostProvider is the default HostProvider. It currently matches
// the Java StaticHostProvider, resolving hosts from DNS once during
// the call to Init. It could be easily extended to re-query DNS
// periodically or if there is trouble connecting.
type DNSHostProvider struct {
mu sync.Mutex // Protects everything, so we can add asynchronous updates later.
servers []string
curr int
last int
lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing.
}
// Init is called first, with the servers specified in the connection
// string. It uses DNS to look up addresses for each server, then
// shuffles them all together.
func (hp *DNSHostProvider) Init(servers []string) error {
hp.mu.Lock()
defer hp.mu.Unlock()
lookupHost := hp.lookupHost
if lookupHost == nil {
lookupHost = net.LookupHost
}
found := []string{}
for _, server := range servers {
host, port, err := net.SplitHostPort(server)
if err != nil {
return err
}
addrs, err := lookupHost(host)
if err != nil {
return err
}
for _, addr := range addrs {
found = append(found, net.JoinHostPort(addr, port))
}
}
if len(found) == 0 {
return fmt.Errorf("No hosts found for addresses %q", servers)
}
// Randomize the order of the servers to avoid creating hotspots
stringShuffle(found)
hp.servers = found
hp.curr = -1
hp.last = -1
return nil
}
// Len returns the number of servers available
func (hp *DNSHostProvider) Len() int {
hp.mu.Lock()
defer hp.mu.Unlock()
return len(hp.servers)
}
// Next returns the next server to connect to. retryStart will be true
// if we've looped through all known servers without Connected() being
// called.
func (hp *DNSHostProvider) Next() (server string, retryStart bool) {
hp.mu.Lock()
defer hp.mu.Unlock()
hp.curr = (hp.curr + 1) % len(hp.servers)
retryStart = hp.curr == hp.last
if hp.last == -1 {
hp.last = 0
}
return hp.servers[hp.curr], retryStart
}
// Connected notifies the HostProvider of a successful connection.
func (hp *DNSHostProvider) Connected() {
hp.mu.Lock()
defer hp.mu.Unlock()
hp.last = hp.curr
}

View file

@ -5,10 +5,10 @@ import (
"bytes"
"fmt"
"io/ioutil"
"math/big"
"net"
"regexp"
"strconv"
"strings"
"time"
)
@ -22,7 +22,7 @@ import (
// which server had the issue.
func FLWSrvr(servers []string, timeout time.Duration) ([]*ServerStats, bool) {
// different parts of the regular expression that are required to parse the srvr output
var (
const (
zrVer = `^Zookeeper version: ([A-Za-z0-9\.\-]+), built on (\d\d/\d\d/\d\d\d\d \d\d:\d\d [A-Za-z0-9:\+\-]+)`
zrLat = `^Latency min/avg/max: (\d+)/(\d+)/(\d+)`
zrNet = `^Received: (\d+).*\n^Sent: (\d+).*\n^Connections: (\d+).*\n^Outstanding: (\d+)`
@ -31,7 +31,6 @@ func FLWSrvr(servers []string, timeout time.Duration) ([]*ServerStats, bool) {
// build the regex from the pieces above
re, err := regexp.Compile(fmt.Sprintf(`(?m:\A%v.*\n%v.*\n%v.*\n%v)`, zrVer, zrLat, zrNet, zrState))
if err != nil {
return nil, false
}
@ -152,14 +151,13 @@ func FLWRuok(servers []string, timeout time.Duration) []bool {
// As with FLWSrvr, the boolean value indicates whether one of the requests had
// an issue. The Clients struct has an Error value that can be checked.
func FLWCons(servers []string, timeout time.Duration) ([]*ServerClients, bool) {
var (
const (
zrAddr = `^ /((?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?):(?:\d+))\[\d+\]`
zrPac = `\(queued=(\d+),recved=(\d+),sent=(\d+),sid=(0x[A-Za-z0-9]+),lop=(\w+),est=(\d+),to=(\d+),`
zrSesh = `lcxid=(0x[A-Za-z0-9]+),lzxid=(0x[A-Za-z0-9]+),lresp=(\d+),llat=(\d+),minlat=(\d+),avglat=(\d+),maxlat=(\d+)\)`
)
re, err := regexp.Compile(fmt.Sprintf("%v%v%v", zrAddr, zrPac, zrSesh))
if err != nil {
return nil, false
}
@ -205,41 +203,21 @@ func FLWCons(servers []string, timeout time.Duration) ([]*ServerClients, bool) {
sid, _ := strconv.ParseInt(match[4], 0, 64)
est, _ := strconv.ParseInt(match[6], 0, 64)
timeout, _ := strconv.ParseInt(match[7], 0, 32)
lcxid, _ := parseInt64(match[8])
lzxid, _ := parseInt64(match[9])
lresp, _ := strconv.ParseInt(match[10], 0, 64)
llat, _ := strconv.ParseInt(match[11], 0, 32)
minlat, _ := strconv.ParseInt(match[12], 0, 32)
avglat, _ := strconv.ParseInt(match[13], 0, 32)
maxlat, _ := strconv.ParseInt(match[14], 0, 32)
// zookeeper returns a value, '0xffffffffffffffff', as the
// Lzxid for PING requests in the 'cons' output.
// unfortunately, in Go that is an invalid int64 and is not represented
// as -1.
// However, converting the string value to a big.Int and then back to
// and int64 properly sets the value to -1
lzxid, ok := new(big.Int).SetString(match[9], 0)
var errVal error
if !ok {
errVal = fmt.Errorf("failed to convert lzxid value to big.Int")
imOk = false
}
lcxid, ok := new(big.Int).SetString(match[8], 0)
if !ok && errVal == nil {
errVal = fmt.Errorf("failed to convert lcxid value to big.Int")
imOk = false
}
clients = append(clients, &ServerClient{
Queued: queued,
Received: recvd,
Sent: sent,
SessionID: sid,
Lcxid: lcxid.Int64(),
Lzxid: lzxid.Int64(),
Lcxid: int64(lcxid),
Lzxid: int64(lzxid),
Timeout: int32(timeout),
LastLatency: int32(llat),
MinLatency: int32(minlat),
@ -249,7 +227,6 @@ func FLWCons(servers []string, timeout time.Duration) ([]*ServerClients, bool) {
LastResponse: time.Unix(lresp, 0),
Addr: match[0],
LastOperation: match[5],
Error: errVal,
})
}
@ -259,9 +236,17 @@ func FLWCons(servers []string, timeout time.Duration) ([]*ServerClients, bool) {
return sc, imOk
}
// parseInt64 is similar to strconv.ParseInt, but it also handles hex values that represent negative numbers
func parseInt64(s string) (int64, error) {
if strings.HasPrefix(s, "0x") {
i, err := strconv.ParseUint(s, 0, 64)
return int64(i), err
}
return strconv.ParseInt(s, 0, 64)
}
func fourLetterWord(server, command string, timeout time.Duration) ([]byte, error) {
conn, err := net.DialTimeout("tcp", server, timeout)
if err != nil {
return nil, err
}
@ -271,20 +256,11 @@ func fourLetterWord(server, command string, timeout time.Duration) ([]byte, erro
defer conn.Close()
conn.SetWriteDeadline(time.Now().Add(timeout))
_, err = conn.Write([]byte(command))
if err != nil {
return nil, err
}
conn.SetReadDeadline(time.Now().Add(timeout))
resp, err := ioutil.ReadAll(conn)
if err != nil {
return nil, err
}
return resp, nil
return ioutil.ReadAll(conn)
}

View file

@ -58,8 +58,16 @@ func (l *Lock) Lock() error {
parts := strings.Split(l.path, "/")
pth := ""
for _, p := range parts[1:] {
var exists bool
pth += "/" + p
_, err := l.c.Create(pth, []byte{}, 0, l.acl)
exists, _, err = l.c.Exists(pth)
if err != nil {
return err
}
if exists == true {
continue
}
_, err = l.c.Create(pth, []byte{}, 0, l.acl)
if err != nil && err != ErrNodeExists {
return err
}
@ -86,7 +94,7 @@ func (l *Lock) Lock() error {
}
lowestSeq := seq
prevSeq := 0
prevSeq := -1
prevSeqPath := ""
for _, p := range children {
s, err := parseSeq(p)

View file

@ -7,9 +7,14 @@ import (
"math/rand"
"os"
"path/filepath"
"strings"
"time"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
type TestServer struct {
Port int
Path string
@ -87,33 +92,125 @@ func StartTestCluster(size int, stdout, stderr io.Writer) (*TestCluster, error)
Srv: srv,
})
}
if err := cluster.waitForStart(10, time.Second); err != nil {
return nil, err
}
success = true
time.Sleep(time.Second) // Give the server time to become active. Should probably actually attempt to connect to verify.
return cluster, nil
}
func (ts *TestCluster) Connect(idx int) (*Conn, error) {
zk, _, err := Connect([]string{fmt.Sprintf("127.0.0.1:%d", ts.Servers[idx].Port)}, time.Second*15)
func (tc *TestCluster) Connect(idx int) (*Conn, error) {
zk, _, err := Connect([]string{fmt.Sprintf("127.0.0.1:%d", tc.Servers[idx].Port)}, time.Second*15)
return zk, err
}
func (ts *TestCluster) ConnectAll() (*Conn, <-chan Event, error) {
return ts.ConnectAllTimeout(time.Second * 15)
func (tc *TestCluster) ConnectAll() (*Conn, <-chan Event, error) {
return tc.ConnectAllTimeout(time.Second * 15)
}
func (ts *TestCluster) ConnectAllTimeout(sessionTimeout time.Duration) (*Conn, <-chan Event, error) {
hosts := make([]string, len(ts.Servers))
for i, srv := range ts.Servers {
func (tc *TestCluster) ConnectAllTimeout(sessionTimeout time.Duration) (*Conn, <-chan Event, error) {
return tc.ConnectWithOptions(sessionTimeout)
}
func (tc *TestCluster) ConnectWithOptions(sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) {
hosts := make([]string, len(tc.Servers))
for i, srv := range tc.Servers {
hosts[i] = fmt.Sprintf("127.0.0.1:%d", srv.Port)
}
zk, ch, err := Connect(hosts, sessionTimeout)
zk, ch, err := Connect(hosts, sessionTimeout, options...)
return zk, ch, err
}
func (ts *TestCluster) Stop() error {
for _, srv := range ts.Servers {
func (tc *TestCluster) Stop() error {
for _, srv := range tc.Servers {
srv.Srv.Stop()
}
defer os.RemoveAll(ts.Path)
defer os.RemoveAll(tc.Path)
return tc.waitForStop(5, time.Second)
}
// waitForStart blocks until the cluster is up
func (tc *TestCluster) waitForStart(maxRetry int, interval time.Duration) error {
// verify that the servers are up with SRVR
serverAddrs := make([]string, len(tc.Servers))
for i, s := range tc.Servers {
serverAddrs[i] = fmt.Sprintf("127.0.0.1:%d", s.Port)
}
for i := 0; i < maxRetry; i++ {
_, ok := FLWSrvr(serverAddrs, time.Second)
if ok {
return nil
}
time.Sleep(interval)
}
return fmt.Errorf("unable to verify health of servers")
}
// waitForStop blocks until the cluster is down
func (tc *TestCluster) waitForStop(maxRetry int, interval time.Duration) error {
// verify that the servers are up with RUOK
serverAddrs := make([]string, len(tc.Servers))
for i, s := range tc.Servers {
serverAddrs[i] = fmt.Sprintf("127.0.0.1:%d", s.Port)
}
var success bool
for i := 0; i < maxRetry && !success; i++ {
success = true
for _, ok := range FLWRuok(serverAddrs, time.Second) {
if ok {
success = false
}
}
if !success {
time.Sleep(interval)
}
}
if !success {
return fmt.Errorf("unable to verify servers are down")
}
return nil
}
func (tc *TestCluster) StartServer(server string) {
for _, s := range tc.Servers {
if strings.HasSuffix(server, fmt.Sprintf(":%d", s.Port)) {
s.Srv.Start()
return
}
}
panic(fmt.Sprintf("Unknown server: %s", server))
}
func (tc *TestCluster) StopServer(server string) {
for _, s := range tc.Servers {
if strings.HasSuffix(server, fmt.Sprintf(":%d", s.Port)) {
s.Srv.Stop()
return
}
}
panic(fmt.Sprintf("Unknown server: %s", server))
}
func (tc *TestCluster) StartAllServers() error {
for _, s := range tc.Servers {
if err := s.Srv.Start(); err != nil {
return fmt.Errorf(
"Failed to start server listening on port `%d` : %+v", s.Port, err)
}
}
return nil
}
func (tc *TestCluster) StopAllServers() error {
for _, s := range tc.Servers {
if err := s.Srv.Stop(); err != nil {
return fmt.Errorf(
"Failed to stop server listening on port `%d` : %+v", s.Port, err)
}
}
return nil
}

View file

@ -270,6 +270,7 @@ type multiResponseOp struct {
Header multiHeader
String string
Stat *Stat
Err ErrCode
}
type multiResponse struct {
Ops []multiResponseOp
@ -327,6 +328,8 @@ func (r *multiRequest) Decode(buf []byte) (int, error) {
}
func (r *multiResponse) Decode(buf []byte) (int, error) {
var multiErr error
r.Ops = make([]multiResponseOp, 0)
r.DoneHeader = multiHeader{-1, true, -1}
total := 0
@ -347,6 +350,8 @@ func (r *multiResponse) Decode(buf []byte) (int, error) {
switch header.Type {
default:
return total, ErrAPIError
case opError:
w = reflect.ValueOf(&res.Err)
case opCreate:
w = reflect.ValueOf(&res.String)
case opSetData:
@ -362,8 +367,12 @@ func (r *multiResponse) Decode(buf []byte) (int, error) {
total += n
}
r.Ops = append(r.Ops, res)
if multiErr == nil && res.Err != errOk {
// Use the first error as the error returned from Multi().
multiErr = res.Err.toError()
}
}
return total, nil
return total, multiErr
}
type watcherEvent struct {
@ -598,43 +607,3 @@ func requestStructForOp(op int32) interface{} {
}
return nil
}
func responseStructForOp(op int32) interface{} {
switch op {
case opClose:
return &closeResponse{}
case opCreate:
return &createResponse{}
case opDelete:
return &deleteResponse{}
case opExists:
return &existsResponse{}
case opGetAcl:
return &getAclResponse{}
case opGetChildren:
return &getChildrenResponse{}
case opGetChildren2:
return &getChildren2Response{}
case opGetData:
return &getDataResponse{}
case opPing:
return &pingResponse{}
case opSetAcl:
return &setAclResponse{}
case opSetData:
return &setDataResponse{}
case opSetWatches:
return &setWatchesResponse{}
case opSync:
return &syncResponse{}
case opWatcherEvent:
return &watcherEvent{}
case opSetAuth:
return &setAuthResponse{}
// case opCheck:
// return &checkVersionResponse{}
case opMulti:
return &multiResponse{}
}
return nil
}

View file

@ -1,148 +0,0 @@
package zk
import (
"encoding/binary"
"fmt"
"io"
"net"
"sync"
)
var (
requests = make(map[int32]int32) // Map of Xid -> Opcode
requestsLock = &sync.Mutex{}
)
func trace(conn1, conn2 net.Conn, client bool) {
defer conn1.Close()
defer conn2.Close()
buf := make([]byte, 10*1024)
init := true
for {
_, err := io.ReadFull(conn1, buf[:4])
if err != nil {
fmt.Println("1>", client, err)
return
}
blen := int(binary.BigEndian.Uint32(buf[:4]))
_, err = io.ReadFull(conn1, buf[4:4+blen])
if err != nil {
fmt.Println("2>", client, err)
return
}
var cr interface{}
opcode := int32(-1)
readHeader := true
if client {
if init {
cr = &connectRequest{}
readHeader = false
} else {
xid := int32(binary.BigEndian.Uint32(buf[4:8]))
opcode = int32(binary.BigEndian.Uint32(buf[8:12]))
requestsLock.Lock()
requests[xid] = opcode
requestsLock.Unlock()
cr = requestStructForOp(opcode)
if cr == nil {
fmt.Printf("Unknown opcode %d\n", opcode)
}
}
} else {
if init {
cr = &connectResponse{}
readHeader = false
} else {
xid := int32(binary.BigEndian.Uint32(buf[4:8]))
zxid := int64(binary.BigEndian.Uint64(buf[8:16]))
errnum := int32(binary.BigEndian.Uint32(buf[16:20]))
if xid != -1 || zxid != -1 {
requestsLock.Lock()
found := false
opcode, found = requests[xid]
if !found {
opcode = 0
}
delete(requests, xid)
requestsLock.Unlock()
} else {
opcode = opWatcherEvent
}
cr = responseStructForOp(opcode)
if cr == nil {
fmt.Printf("Unknown opcode %d\n", opcode)
}
if errnum != 0 {
cr = &struct{}{}
}
}
}
opname := "."
if opcode != -1 {
opname = opNames[opcode]
}
if cr == nil {
fmt.Printf("%+v %s %+v\n", client, opname, buf[4:4+blen])
} else {
n := 4
hdrStr := ""
if readHeader {
var hdr interface{}
if client {
hdr = &requestHeader{}
} else {
hdr = &responseHeader{}
}
if n2, err := decodePacket(buf[n:n+blen], hdr); err != nil {
fmt.Println(err)
} else {
n += n2
}
hdrStr = fmt.Sprintf(" %+v", hdr)
}
if _, err := decodePacket(buf[n:n+blen], cr); err != nil {
fmt.Println(err)
}
fmt.Printf("%+v %s%s %+v\n", client, opname, hdrStr, cr)
}
init = false
written, err := conn2.Write(buf[:4+blen])
if err != nil {
fmt.Println("3>", client, err)
return
} else if written != 4+blen {
fmt.Printf("Written != read: %d != %d\n", written, blen)
return
}
}
}
func handleConnection(addr string, conn net.Conn) {
zkConn, err := net.Dial("tcp", addr)
if err != nil {
fmt.Println(err)
return
}
go trace(conn, zkConn, true)
trace(zkConn, conn, false)
}
func StartTracer(listenAddr, serverAddr string) {
ln, err := net.Listen("tcp", listenAddr)
if err != nil {
panic(err)
}
for {
conn, err := ln.Accept()
if err != nil {
fmt.Println(err)
continue
}
go handleConnection(serverAddr, conn)
}
}

6
vendor/vendor.json vendored
View file

@ -531,10 +531,10 @@
"revisionTime": "2016-04-11T19:08:41Z"
},
{
"checksumSHA1": "+49Vr4Me28p3cR+gxX5SUQHbbas=",
"checksumSHA1": "5SYLEhADhdBVZAGPVHWggQl7H8k=",
"path": "github.com/samuel/go-zookeeper/zk",
"revision": "177002e16a0061912f02377e2dd8951a8b3551bc",
"revisionTime": "2015-08-17T10:50:50-07:00"
"revision": "1d7be4effb13d2d908342d349d71a284a7542693",
"revisionTime": "2016-10-28T23:23:40Z"
},
{
"checksumSHA1": "YuPBOVkkE3uuBh4RcRUTF0n+frs=",