package netlink import ( "errors" "io" "math/rand" "os" "sync/atomic" "golang.org/x/net/bpf" ) // Error messages which can be returned by Validate. var ( errMismatchedSequence = errors.New("mismatched sequence in netlink reply") errMismatchedPID = errors.New("mismatched PID in netlink reply") errShortErrorMessage = errors.New("not enough data for netlink error code") ) // Errors which can be returned by a Socket that does not implement // all exposed methods of Conn. var ( errReadWriteCloserNotSupported = errors.New("raw read/write/closer not supported") errMulticastGroupsNotSupported = errors.New("multicast groups not supported") errBPFFiltersNotSupported = errors.New("BPF filters not supported") ) // A Conn is a connection to netlink. A Conn can be used to send and // receives messages to and from netlink. // // A Conn is safe for concurrent use, but to avoid contention in // high-throughput applications, the caller should almost certainly create a // pool of Conns and distribute them among workers. type Conn struct { // sock is the operating system-specific implementation of // a netlink sockets connection. sock Socket // seq is an atomically incremented integer used to provide sequence // numbers when Conn.Send is called. seq *uint32 // pid is the PID assigned by netlink. pid uint32 } // A Socket is an operating-system specific implementation of netlink // sockets used by Conn. type Socket interface { Close() error Send(m Message) error Receive() ([]Message, error) } // Dial dials a connection to netlink, using the specified netlink family. // Config specifies optional configuration for Conn. If config is nil, a default // configuration will be used. func Dial(family int, config *Config) (*Conn, error) { // Use OS-specific dial() to create Socket c, pid, err := dial(family, config) if err != nil { return nil, err } return NewConn(c, pid), nil } // NewConn creates a Conn using the specified Socket and PID for netlink // communications. // // NewConn is primarily useful for tests. Most applications should use // Dial instead. func NewConn(c Socket, pid uint32) *Conn { seq := rand.Uint32() return &Conn{ sock: c, seq: &seq, pid: pid, } } // Close closes the connection. func (c *Conn) Close() error { return c.sock.Close() } // Execute sends a single Message to netlink using Conn.Send, receives one or more // replies using Conn.Receive, and then checks the validity of the replies against // the request using Validate. // // See the documentation of Conn.Send, Conn.Receive, and Validate for details about // each function. func (c *Conn) Execute(m Message) ([]Message, error) { req, err := c.Send(m) if err != nil { return nil, err } replies, err := c.Receive() if err != nil { return nil, err } if err := Validate(req, replies); err != nil { return nil, err } return replies, nil } // Send sends a single Message to netlink. In most cases, m.Header's Length, // Sequence, and PID fields should be set to 0, so they can be populated // automatically before the Message is sent. On success, Send returns a copy // of the Message with all parameters populated, for later validation. // // If m.Header.Length is 0, it will be automatically populated using the // correct length for the Message, including its payload. // // If m.Header.Sequence is 0, it will be automatically populated using the // next sequence number for this connection. // // If m.Header.PID is 0, it will be automatically populated using a PID // assigned by netlink. func (c *Conn) Send(m Message) (Message, error) { ml := nlmsgLength(len(m.Data)) // TODO(mdlayher): fine-tune this limit. if ml > (1024 * 32) { return Message{}, errors.New("netlink message data too large") } if m.Header.Length == 0 { m.Header.Length = uint32(nlmsgAlign(ml)) } if m.Header.Sequence == 0 { m.Header.Sequence = c.nextSequence() } if m.Header.PID == 0 { m.Header.PID = c.pid } if err := c.sock.Send(m); err != nil { return Message{}, err } return m, nil } // Receive receives one or more messages from netlink. Multi-part messages are // handled transparently and returned as a single slice of Messages, with the // final empty "multi-part done" message removed. // // If any of the messages indicate a netlink error, that error will be returned. func (c *Conn) Receive() ([]Message, error) { msgs, err := c.receive() if err != nil { return nil, err } // When using nltest, it's possible for zero messages to be returned by receive. if len(msgs) == 0 { return msgs, nil } // Trim the final message with multi-part done indicator if // present. if m := msgs[len(msgs)-1]; m.Header.Flags&HeaderFlagsMulti != 0 && m.Header.Type == HeaderTypeDone { return msgs[:len(msgs)-1], nil } return msgs, nil } // receive is the internal implementation of Conn.Receive, which can be called // recursively to handle multi-part messages. func (c *Conn) receive() ([]Message, error) { msgs, err := c.sock.Receive() if err != nil { return nil, err } // If this message is multi-part, we will need to perform an recursive call // to continue draining the socket var multi bool for _, m := range msgs { // Is this a multi-part message and is it not done yet? if m.Header.Flags&HeaderFlagsMulti != 0 && m.Header.Type != HeaderTypeDone { multi = true } if err := checkMessage(m); err != nil { return nil, err } } if !multi { return msgs, nil } // More messages waiting mmsgs, err := c.receive() if err != nil { return nil, err } return append(msgs, mmsgs...), nil } // An fder is a Socket that supports retrieving its raw file descriptor. type fder interface { Socket FD() int } var _ io.ReadWriteCloser = &fileReadWriteCloser{} // A fileReadWriteCloser is a limited *os.File which only allows access to its // Read and Write methods. type fileReadWriteCloser struct { f *os.File } // Read implements io.ReadWriteCloser. func (rwc *fileReadWriteCloser) Read(b []byte) (int, error) { return rwc.f.Read(b) } // Write implements io.ReadWriteCloser. func (rwc *fileReadWriteCloser) Write(b []byte) (int, error) { return rwc.f.Write(b) } // Close implements io.ReadWriteCloser. func (rwc *fileReadWriteCloser) Close() error { return rwc.f.Close() } // ReadWriteCloser returns a raw io.ReadWriteCloser backed by the connection // of the Conn. // // ReadWriteCloser is intended for advanced use cases, such as those that do // not involve standard netlink message passing. // // Once invoked, it is the caller's responsibility to ensure that operations // performed using Conn and the raw io.ReadWriteCloser do not conflict with // each other. In almost all scenarios, only one of the two should be used. func (c *Conn) ReadWriteCloser() (io.ReadWriteCloser, error) { fc, ok := c.sock.(fder) if !ok { return nil, errReadWriteCloserNotSupported } return &fileReadWriteCloser{ // Backing the io.ReadWriteCloser with an *os.File enables easy reading // and writing without more system call boilerplate. f: os.NewFile(uintptr(fc.FD()), "netlink"), }, nil } // A groupJoinLeaver is a Socket that supports joining and leaving // netlink multicast groups. type groupJoinLeaver interface { Socket JoinGroup(group uint32) error LeaveGroup(group uint32) error } // JoinGroup joins a netlink multicast group by its ID. func (c *Conn) JoinGroup(group uint32) error { gc, ok := c.sock.(groupJoinLeaver) if !ok { return errMulticastGroupsNotSupported } return gc.JoinGroup(group) } // LeaveGroup leaves a netlink multicast group by its ID. func (c *Conn) LeaveGroup(group uint32) error { gc, ok := c.sock.(groupJoinLeaver) if !ok { return errMulticastGroupsNotSupported } return gc.LeaveGroup(group) } // A bpfSetter is a Socket that supports setting BPF filters. type bpfSetter interface { Socket bpf.Setter } // SetBPF attaches an assembled BPF program to a Conn. func (c *Conn) SetBPF(filter []bpf.RawInstruction) error { bc, ok := c.sock.(bpfSetter) if !ok { return errBPFFiltersNotSupported } return bc.SetBPF(filter) } // nextSequence atomically increments Conn's sequence number and returns // the incremented value. func (c *Conn) nextSequence() uint32 { return atomic.AddUint32(c.seq, 1) } // Validate validates one or more reply Messages against a request Message, // ensuring that they contain matching sequence numbers and PIDs. func Validate(request Message, replies []Message) error { for _, m := range replies { // Check for mismatched sequence, unless: // - request had no sequence, meaning we are probably validating // a multicast reply if m.Header.Sequence != request.Header.Sequence && request.Header.Sequence != 0 { return errMismatchedSequence } // Check for mismatched PID, unless: // - request had no PID, meaning we are either: // - validating a multicast reply // - netlink has not yet assigned us a PID // - response had no PID, meaning it's from the kernel as a multicast reply if m.Header.PID != request.Header.PID && request.Header.PID != 0 && m.Header.PID != 0 { return errMismatchedPID } } return nil } // Config contains options for a Conn. type Config struct { // Groups is a bitmask which specifies multicast groups. If set to 0, // no multicast group subscriptions will be made. Groups uint32 // Experimental: do not lock the internal system call handling goroutine // to its OS thread. This may result in a speed-up of system call handling, // but may cause unexpected behavior when sending and receiving a large number // of messages. // // This should almost certainly be set to false, but if you come up with a // valid reason for using this, please file an issue at // https://github.com/mdlayher/netlink to discuss your thoughts. NoLockThread bool }