Update github.com/mdlayher/netlink with code simplifications

This commit is contained in:
Matt Layher 2017-03-10 12:32:29 -05:00
parent d92dc8cabe
commit db15c0e365
No known key found for this signature in database
GPG key ID: 77BFE531397EDE94
3 changed files with 31 additions and 38 deletions

View file

@ -3,7 +3,6 @@ package netlink
import ( import (
"errors" "errors"
"math/rand" "math/rand"
"sync"
"sync/atomic" "sync/atomic"
"golang.org/x/net/bpf" "golang.org/x/net/bpf"
@ -27,11 +26,8 @@ type Conn struct {
// numbers when Conn.Send is called. // numbers when Conn.Send is called.
seq *uint32 seq *uint32
// pid is an atomically set/loaded integer which is set to the PID assigned // pid is the PID assigned by netlink.
// by netlink, when netlink sends its first response message. pidOnce performs pid uint32
// the assignment exactl once.
pid *uint32
pidOnce sync.Once
} }
// An osConn is an operating-system specific implementation of netlink // An osConn is an operating-system specific implementation of netlink
@ -50,22 +46,22 @@ type osConn interface {
// configuration will be used. // configuration will be used.
func Dial(proto int, config *Config) (*Conn, error) { func Dial(proto int, config *Config) (*Conn, error) {
// Use OS-specific dial() to create osConn // Use OS-specific dial() to create osConn
c, err := dial(proto, config) c, pid, err := dial(proto, config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newConn(c), nil return newConn(c, pid), nil
} }
// newConn is the internal constructor for Conn, used in tests. // newConn is the internal constructor for Conn, used in tests.
func newConn(c osConn) *Conn { func newConn(c osConn, pid uint32) *Conn {
seq := rand.Uint32() seq := rand.Uint32()
return &Conn{ return &Conn{
c: c, c: c,
seq: &seq, seq: &seq,
pid: new(uint32), pid: pid,
} }
} }
@ -128,7 +124,7 @@ func (c *Conn) Send(m Message) (Message, error) {
} }
if m.Header.PID == 0 { if m.Header.PID == 0 {
m.Header.PID = atomic.LoadUint32(c.pid) m.Header.PID = c.pid
} }
if err := c.c.Send(m); err != nil { if err := c.c.Send(m); err != nil {
@ -142,10 +138,6 @@ func (c *Conn) Send(m Message) (Message, error) {
// handled transparently and returned as a single slice of Messages, with the // handled transparently and returned as a single slice of Messages, with the
// final empty "multi-part done" message removed. // final empty "multi-part done" message removed.
// //
// If a PID has not yet been assigned to this Conn by netlink, the PID will
// be set from the first received message. This PID will be used in all
// subsequent communications with netlink.
//
// If any of the messages indicate a netlink error, that error will be returned. // If any of the messages indicate a netlink error, that error will be returned.
func (c *Conn) Receive() ([]Message, error) { func (c *Conn) Receive() ([]Message, error) {
msgs, err := c.receive() msgs, err := c.receive()
@ -153,16 +145,6 @@ func (c *Conn) Receive() ([]Message, error) {
return nil, err return nil, err
} }
if len(msgs) > 0 {
// netlink multicast messages from kernel have PID of 0, so don't
// assign 0 as the expected PID for next messages
if pid := msgs[0].Header.PID; pid != 0 {
c.pidOnce.Do(func() {
atomic.StoreUint32(c.pid, pid)
})
}
}
// Trim the final message with multi-part done indicator if // Trim the final message with multi-part done indicator if
// present // present
if m := msgs[len(msgs)-1]; m.Header.Flags&HeaderFlagsMulti != 0 && m.Header.Type == HeaderTypeDone { if m := msgs[len(msgs)-1]; m.Header.Flags&HeaderFlagsMulti != 0 && m.Header.Type == HeaderTypeDone {

View file

@ -29,21 +29,22 @@ type conn struct {
type socket interface { type socket interface {
Bind(sa unix.Sockaddr) error Bind(sa unix.Sockaddr) error
Close() error Close() error
Getsockname() (unix.Sockaddr, error)
Recvmsg(p, oob []byte, flags int) (n int, oobn int, recvflags int, from unix.Sockaddr, err error) Recvmsg(p, oob []byte, flags int) (n int, oobn int, recvflags int, from unix.Sockaddr, err error)
Sendmsg(p, oob []byte, to unix.Sockaddr, flags int) error Sendmsg(p, oob []byte, to unix.Sockaddr, flags int) error
SetSockopt(level, name int, v unsafe.Pointer, l uint32) error SetSockopt(level, name int, v unsafe.Pointer, l uint32) error
} }
// dial is the entry point for Dial. dial opens a netlink socket using // dial is the entry point for Dial. dial opens a netlink socket using
// system calls. // system calls, and returns its PID.
func dial(family int, config *Config) (*conn, error) { func dial(family int, config *Config) (*conn, uint32, error) {
fd, err := unix.Socket( fd, err := unix.Socket(
unix.AF_NETLINK, unix.AF_NETLINK,
unix.SOCK_RAW, unix.SOCK_RAW,
family, family,
) )
if err != nil { if err != nil {
return nil, err return nil, 0, err
} }
return bind(&sysSocket{fd: fd}, config) return bind(&sysSocket{fd: fd}, config)
@ -51,7 +52,7 @@ func dial(family int, config *Config) (*conn, error) {
// bind binds a connection to netlink using the input socket, which may be // bind binds a connection to netlink using the input socket, which may be
// a system call implementation or a mocked one for tests. // a system call implementation or a mocked one for tests.
func bind(s socket, config *Config) (*conn, error) { func bind(s socket, config *Config) (*conn, uint32, error) {
if config == nil { if config == nil {
config = &Config{} config = &Config{}
} }
@ -61,17 +62,26 @@ func bind(s socket, config *Config) (*conn, error) {
Groups: config.Groups, Groups: config.Groups,
} }
// Socket must be closed in the event of any system call errors, to avoid
// leaking file descriptors.
if err := s.Bind(addr); err != nil { if err := s.Bind(addr); err != nil {
// Since this never returns conn (and as such, the caller cannot close it),
// close the socket here in the event of a failure to bind.
_ = s.Close() _ = s.Close()
return nil, err return nil, 0, err
} }
sa, err := s.Getsockname()
if err != nil {
_ = s.Close()
return nil, 0, err
}
pid := sa.(*unix.SockaddrNetlink).Pid
return &conn{ return &conn{
s: s, s: s,
sa: addr, sa: addr,
}, nil }, pid, nil
} }
// Send sends a single Message to netlink. // Send sends a single Message to netlink.
@ -199,8 +209,9 @@ type sysSocket struct {
fd int fd int
} }
func (s *sysSocket) Bind(sa unix.Sockaddr) error { return unix.Bind(s.fd, sa) } func (s *sysSocket) Bind(sa unix.Sockaddr) error { return unix.Bind(s.fd, sa) }
func (s *sysSocket) Close() error { return unix.Close(s.fd) } func (s *sysSocket) Close() error { return unix.Close(s.fd) }
func (s *sysSocket) Getsockname() (unix.Sockaddr, error) { return unix.Getsockname(s.fd) }
func (s *sysSocket) Recvmsg(p, oob []byte, flags int) (int, int, int, unix.Sockaddr, error) { func (s *sysSocket) Recvmsg(p, oob []byte, flags int) (int, int, int, unix.Sockaddr, error) {
return unix.Recvmsg(s.fd, p, oob, flags) return unix.Recvmsg(s.fd, p, oob, flags)
} }

6
vendor/vendor.json vendored
View file

@ -51,10 +51,10 @@
"revisionTime": "2016-04-24T11:30:07Z" "revisionTime": "2016-04-24T11:30:07Z"
}, },
{ {
"checksumSHA1": "yDvo49XwrEOOzk4g5eMtz7aZ1RY=", "checksumSHA1": "r3t+HDvOEQVCaLjuMT8rl6opbNQ=",
"path": "github.com/mdlayher/netlink", "path": "github.com/mdlayher/netlink",
"revision": "11047e3e3daa32f7b757bc9ab59c413cadeccfa1", "revision": "343c07bd16ebbc714f19c528a6deb6723ace06f3",
"revisionTime": "2017-03-02T15:49:27Z" "revisionTime": "2017-03-10T17:31:27Z"
}, },
{ {
"checksumSHA1": "+2roeIWCAjCC58tZcs12Vqgf1Io=", "checksumSHA1": "+2roeIWCAjCC58tZcs12Vqgf1Io=",