Update github.com/mdlayher/netlink with code simplifications

This commit is contained in:
Matt Layher 2017-03-10 12:32:29 -05:00
parent d92dc8cabe
commit db15c0e365
No known key found for this signature in database
GPG key ID: 77BFE531397EDE94
3 changed files with 31 additions and 38 deletions

View file

@ -3,7 +3,6 @@ package netlink
import (
"errors"
"math/rand"
"sync"
"sync/atomic"
"golang.org/x/net/bpf"
@ -27,11 +26,8 @@ type Conn struct {
// numbers when Conn.Send is called.
seq *uint32
// pid is an atomically set/loaded integer which is set to the PID assigned
// by netlink, when netlink sends its first response message. pidOnce performs
// the assignment exactl once.
pid *uint32
pidOnce sync.Once
// pid is the PID assigned by netlink.
pid uint32
}
// An osConn is an operating-system specific implementation of netlink
@ -50,22 +46,22 @@ type osConn interface {
// configuration will be used.
func Dial(proto int, config *Config) (*Conn, error) {
// Use OS-specific dial() to create osConn
c, err := dial(proto, config)
c, pid, err := dial(proto, config)
if err != nil {
return nil, err
}
return newConn(c), nil
return newConn(c, pid), nil
}
// newConn is the internal constructor for Conn, used in tests.
func newConn(c osConn) *Conn {
func newConn(c osConn, pid uint32) *Conn {
seq := rand.Uint32()
return &Conn{
c: c,
seq: &seq,
pid: new(uint32),
pid: pid,
}
}
@ -128,7 +124,7 @@ func (c *Conn) Send(m Message) (Message, error) {
}
if m.Header.PID == 0 {
m.Header.PID = atomic.LoadUint32(c.pid)
m.Header.PID = c.pid
}
if err := c.c.Send(m); err != nil {
@ -142,10 +138,6 @@ func (c *Conn) Send(m Message) (Message, error) {
// handled transparently and returned as a single slice of Messages, with the
// final empty "multi-part done" message removed.
//
// If a PID has not yet been assigned to this Conn by netlink, the PID will
// be set from the first received message. This PID will be used in all
// subsequent communications with netlink.
//
// If any of the messages indicate a netlink error, that error will be returned.
func (c *Conn) Receive() ([]Message, error) {
msgs, err := c.receive()
@ -153,16 +145,6 @@ func (c *Conn) Receive() ([]Message, error) {
return nil, err
}
if len(msgs) > 0 {
// netlink multicast messages from kernel have PID of 0, so don't
// assign 0 as the expected PID for next messages
if pid := msgs[0].Header.PID; pid != 0 {
c.pidOnce.Do(func() {
atomic.StoreUint32(c.pid, pid)
})
}
}
// Trim the final message with multi-part done indicator if
// present
if m := msgs[len(msgs)-1]; m.Header.Flags&HeaderFlagsMulti != 0 && m.Header.Type == HeaderTypeDone {

View file

@ -29,21 +29,22 @@ type conn struct {
type socket interface {
Bind(sa unix.Sockaddr) error
Close() error
Getsockname() (unix.Sockaddr, error)
Recvmsg(p, oob []byte, flags int) (n int, oobn int, recvflags int, from unix.Sockaddr, err error)
Sendmsg(p, oob []byte, to unix.Sockaddr, flags int) error
SetSockopt(level, name int, v unsafe.Pointer, l uint32) error
}
// dial is the entry point for Dial. dial opens a netlink socket using
// system calls.
func dial(family int, config *Config) (*conn, error) {
// system calls, and returns its PID.
func dial(family int, config *Config) (*conn, uint32, error) {
fd, err := unix.Socket(
unix.AF_NETLINK,
unix.SOCK_RAW,
family,
)
if err != nil {
return nil, err
return nil, 0, err
}
return bind(&sysSocket{fd: fd}, config)
@ -51,7 +52,7 @@ func dial(family int, config *Config) (*conn, error) {
// bind binds a connection to netlink using the input socket, which may be
// a system call implementation or a mocked one for tests.
func bind(s socket, config *Config) (*conn, error) {
func bind(s socket, config *Config) (*conn, uint32, error) {
if config == nil {
config = &Config{}
}
@ -61,17 +62,26 @@ func bind(s socket, config *Config) (*conn, error) {
Groups: config.Groups,
}
// Socket must be closed in the event of any system call errors, to avoid
// leaking file descriptors.
if err := s.Bind(addr); err != nil {
// Since this never returns conn (and as such, the caller cannot close it),
// close the socket here in the event of a failure to bind.
_ = s.Close()
return nil, err
return nil, 0, err
}
sa, err := s.Getsockname()
if err != nil {
_ = s.Close()
return nil, 0, err
}
pid := sa.(*unix.SockaddrNetlink).Pid
return &conn{
s: s,
sa: addr,
}, nil
}, pid, nil
}
// Send sends a single Message to netlink.
@ -199,8 +209,9 @@ type sysSocket struct {
fd int
}
func (s *sysSocket) Bind(sa unix.Sockaddr) error { return unix.Bind(s.fd, sa) }
func (s *sysSocket) Close() error { return unix.Close(s.fd) }
func (s *sysSocket) Bind(sa unix.Sockaddr) error { return unix.Bind(s.fd, sa) }
func (s *sysSocket) Close() error { return unix.Close(s.fd) }
func (s *sysSocket) Getsockname() (unix.Sockaddr, error) { return unix.Getsockname(s.fd) }
func (s *sysSocket) Recvmsg(p, oob []byte, flags int) (int, int, int, unix.Sockaddr, error) {
return unix.Recvmsg(s.fd, p, oob, flags)
}

6
vendor/vendor.json vendored
View file

@ -51,10 +51,10 @@
"revisionTime": "2016-04-24T11:30:07Z"
},
{
"checksumSHA1": "yDvo49XwrEOOzk4g5eMtz7aZ1RY=",
"checksumSHA1": "r3t+HDvOEQVCaLjuMT8rl6opbNQ=",
"path": "github.com/mdlayher/netlink",
"revision": "11047e3e3daa32f7b757bc9ab59c413cadeccfa1",
"revisionTime": "2017-03-02T15:49:27Z"
"revision": "343c07bd16ebbc714f19c528a6deb6723ace06f3",
"revisionTime": "2017-03-10T17:31:27Z"
},
{
"checksumSHA1": "+2roeIWCAjCC58tZcs12Vqgf1Io=",