Merge pull request #441 from prometheus/vendor-dependencies

Vendor external dependencies with godep.
This commit is contained in:
stuart nelson 2015-01-09 12:47:23 +01:00
commit 5201cbf3c6
180 changed files with 53689 additions and 1 deletions

31
Godeps/Godeps.json generated Normal file
View file

@ -0,0 +1,31 @@
{
"ImportPath": "github.com/prometheus/prometheus",
"GoVersion": "go1.4",
"Deps": [
{
"ImportPath": "code.google.com/p/goprotobuf/proto",
"Comment": "go.r60-152",
"Rev": "36be16571e14f67e114bb0af619e5de2c1591679"
},
{
"ImportPath": "github.com/golang/glog",
"Rev": "44145f04b68cf362d9c4df2182967c2275eaefed"
},
{
"ImportPath": "github.com/matttproud/golang_protobuf_extensions/ext",
"Rev": "7a864a042e844af638df17ebbabf8183dace556a"
},
{
"ImportPath": "github.com/miekg/dns",
"Rev": "6b75215519f9916839204d80413bb178b94ef769"
},
{
"ImportPath": "github.com/syndtr/goleveldb/leveldb",
"Rev": "63c9e642efad852f49e20a6f90194cae112fd2ac"
},
{
"ImportPath": "github.com/syndtr/gosnappy/snappy",
"Rev": "ce8acff4829e0c2458a67ead32390ac0a381c862"
}
]
}

5
Godeps/Readme generated Normal file
View file

@ -0,0 +1,5 @@
This directory tree is generated automatically by godep.
Please do not edit.
See https://github.com/tools/godep for more information.

2
Godeps/_workspace/.gitignore generated vendored Normal file
View file

@ -0,0 +1,2 @@
/pkg
/bin

View file

@ -0,0 +1,40 @@
# Go support for Protocol Buffers - Google's data interchange format
#
# Copyright 2010 The Go Authors. All rights reserved.
# http://code.google.com/p/goprotobuf/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
install:
go install
test: install generate-test-pbs
go test
generate-test-pbs:
make install && cd testdata && make

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,169 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2011 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Protocol buffer deep copy.
// TODO: MessageSet and RawMessage.
package proto
import (
"log"
"reflect"
"strings"
)
// Clone returns a deep copy of a protocol buffer.
func Clone(pb Message) Message {
in := reflect.ValueOf(pb)
if in.IsNil() {
return pb
}
out := reflect.New(in.Type().Elem())
// out is empty so a merge is a deep copy.
mergeStruct(out.Elem(), in.Elem())
return out.Interface().(Message)
}
// Merge merges src into dst.
// Required and optional fields that are set in src will be set to that value in dst.
// Elements of repeated fields will be appended.
// Merge panics if src and dst are not the same type, or if dst is nil.
func Merge(dst, src Message) {
in := reflect.ValueOf(src)
out := reflect.ValueOf(dst)
if out.IsNil() {
panic("proto: nil destination")
}
if in.Type() != out.Type() {
// Explicit test prior to mergeStruct so that mistyped nils will fail
panic("proto: type mismatch")
}
if in.IsNil() {
// Merging nil into non-nil is a quiet no-op
return
}
mergeStruct(out.Elem(), in.Elem())
}
func mergeStruct(out, in reflect.Value) {
for i := 0; i < in.NumField(); i++ {
f := in.Type().Field(i)
if strings.HasPrefix(f.Name, "XXX_") {
continue
}
mergeAny(out.Field(i), in.Field(i))
}
if emIn, ok := in.Addr().Interface().(extendableProto); ok {
emOut := out.Addr().Interface().(extendableProto)
mergeExtension(emOut.ExtensionMap(), emIn.ExtensionMap())
}
uf := in.FieldByName("XXX_unrecognized")
if !uf.IsValid() {
return
}
uin := uf.Bytes()
if len(uin) > 0 {
out.FieldByName("XXX_unrecognized").SetBytes(append([]byte(nil), uin...))
}
}
func mergeAny(out, in reflect.Value) {
if in.Type() == protoMessageType {
if !in.IsNil() {
if out.IsNil() {
out.Set(reflect.ValueOf(Clone(in.Interface().(Message))))
} else {
Merge(out.Interface().(Message), in.Interface().(Message))
}
}
return
}
switch in.Kind() {
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
reflect.String, reflect.Uint32, reflect.Uint64:
out.Set(in)
case reflect.Ptr:
if in.IsNil() {
return
}
if out.IsNil() {
out.Set(reflect.New(in.Elem().Type()))
}
mergeAny(out.Elem(), in.Elem())
case reflect.Slice:
if in.IsNil() {
return
}
n := in.Len()
if out.IsNil() {
out.Set(reflect.MakeSlice(in.Type(), 0, n))
}
switch in.Type().Elem().Kind() {
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
reflect.String, reflect.Uint32, reflect.Uint64:
out.Set(reflect.AppendSlice(out, in))
case reflect.Uint8:
// []byte is a scalar bytes field.
out.Set(in)
default:
for i := 0; i < n; i++ {
x := reflect.Indirect(reflect.New(in.Type().Elem()))
mergeAny(x, in.Index(i))
out.Set(reflect.Append(out, x))
}
}
case reflect.Struct:
mergeStruct(out, in)
default:
// unknown type, so not a protocol buffer
log.Printf("proto: don't know how to copy %v", in)
}
}
func mergeExtension(out, in map[int32]Extension) {
for extNum, eIn := range in {
eOut := Extension{desc: eIn.desc}
if eIn.value != nil {
v := reflect.New(reflect.TypeOf(eIn.value)).Elem()
mergeAny(v, reflect.ValueOf(eIn.value))
eOut.value = v.Interface()
}
if eIn.enc != nil {
eOut.enc = make([]byte, len(eIn.enc))
copy(eOut.enc, eIn.enc)
}
out[extNum] = eOut
}
}

View file

@ -0,0 +1,186 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2011 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto_test
import (
"testing"
"code.google.com/p/goprotobuf/proto"
pb "./testdata"
)
var cloneTestMessage = &pb.MyMessage{
Count: proto.Int32(42),
Name: proto.String("Dave"),
Pet: []string{"bunny", "kitty", "horsey"},
Inner: &pb.InnerMessage{
Host: proto.String("niles"),
Port: proto.Int32(9099),
Connected: proto.Bool(true),
},
Others: []*pb.OtherMessage{
{
Value: []byte("some bytes"),
},
},
Somegroup: &pb.MyMessage_SomeGroup{
GroupField: proto.Int32(6),
},
RepBytes: [][]byte{[]byte("sham"), []byte("wow")},
}
func init() {
ext := &pb.Ext{
Data: proto.String("extension"),
}
if err := proto.SetExtension(cloneTestMessage, pb.E_Ext_More, ext); err != nil {
panic("SetExtension: " + err.Error())
}
}
func TestClone(t *testing.T) {
m := proto.Clone(cloneTestMessage).(*pb.MyMessage)
if !proto.Equal(m, cloneTestMessage) {
t.Errorf("Clone(%v) = %v", cloneTestMessage, m)
}
// Verify it was a deep copy.
*m.Inner.Port++
if proto.Equal(m, cloneTestMessage) {
t.Error("Mutating clone changed the original")
}
}
func TestCloneNil(t *testing.T) {
var m *pb.MyMessage
if c := proto.Clone(m); !proto.Equal(m, c) {
t.Errorf("Clone(%v) = %v", m, c)
}
}
var mergeTests = []struct {
src, dst, want proto.Message
}{
{
src: &pb.MyMessage{
Count: proto.Int32(42),
},
dst: &pb.MyMessage{
Name: proto.String("Dave"),
},
want: &pb.MyMessage{
Count: proto.Int32(42),
Name: proto.String("Dave"),
},
},
{
src: &pb.MyMessage{
Inner: &pb.InnerMessage{
Host: proto.String("hey"),
Connected: proto.Bool(true),
},
Pet: []string{"horsey"},
Others: []*pb.OtherMessage{
{
Value: []byte("some bytes"),
},
},
},
dst: &pb.MyMessage{
Inner: &pb.InnerMessage{
Host: proto.String("niles"),
Port: proto.Int32(9099),
},
Pet: []string{"bunny", "kitty"},
Others: []*pb.OtherMessage{
{
Key: proto.Int64(31415926535),
},
{
// Explicitly test a src=nil field
Inner: nil,
},
},
},
want: &pb.MyMessage{
Inner: &pb.InnerMessage{
Host: proto.String("hey"),
Connected: proto.Bool(true),
Port: proto.Int32(9099),
},
Pet: []string{"bunny", "kitty", "horsey"},
Others: []*pb.OtherMessage{
{
Key: proto.Int64(31415926535),
},
{},
{
Value: []byte("some bytes"),
},
},
},
},
{
src: &pb.MyMessage{
RepBytes: [][]byte{[]byte("wow")},
},
dst: &pb.MyMessage{
Somegroup: &pb.MyMessage_SomeGroup{
GroupField: proto.Int32(6),
},
RepBytes: [][]byte{[]byte("sham")},
},
want: &pb.MyMessage{
Somegroup: &pb.MyMessage_SomeGroup{
GroupField: proto.Int32(6),
},
RepBytes: [][]byte{[]byte("sham"), []byte("wow")},
},
},
// Check that a scalar bytes field replaces rather than appends.
{
src: &pb.OtherMessage{Value: []byte("foo")},
dst: &pb.OtherMessage{Value: []byte("bar")},
want: &pb.OtherMessage{Value: []byte("foo")},
},
}
func TestMerge(t *testing.T) {
for _, m := range mergeTests {
got := proto.Clone(m.dst)
proto.Merge(got, m.src)
if !proto.Equal(got, m.want) {
t.Errorf("Merge(%v, %v)\n got %v\nwant %v\n", m.dst, m.src, got, m.want)
}
}
}

View file

@ -0,0 +1,721 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto
/*
* Routines for decoding protocol buffer data to construct in-memory representations.
*/
import (
"errors"
"fmt"
"io"
"os"
"reflect"
)
// errOverflow is returned when an integer is too large to be represented.
var errOverflow = errors.New("proto: integer overflow")
// The fundamental decoders that interpret bytes on the wire.
// Those that take integer types all return uint64 and are
// therefore of type valueDecoder.
// DecodeVarint reads a varint-encoded integer from the slice.
// It returns the integer and the number of bytes consumed, or
// zero if there is not enough.
// This is the format for the
// int32, int64, uint32, uint64, bool, and enum
// protocol buffer types.
func DecodeVarint(buf []byte) (x uint64, n int) {
// x, n already 0
for shift := uint(0); shift < 64; shift += 7 {
if n >= len(buf) {
return 0, 0
}
b := uint64(buf[n])
n++
x |= (b & 0x7F) << shift
if (b & 0x80) == 0 {
return x, n
}
}
// The number is too large to represent in a 64-bit value.
return 0, 0
}
// DecodeVarint reads a varint-encoded integer from the Buffer.
// This is the format for the
// int32, int64, uint32, uint64, bool, and enum
// protocol buffer types.
func (p *Buffer) DecodeVarint() (x uint64, err error) {
// x, err already 0
i := p.index
l := len(p.buf)
for shift := uint(0); shift < 64; shift += 7 {
if i >= l {
err = io.ErrUnexpectedEOF
return
}
b := p.buf[i]
i++
x |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
p.index = i
return
}
}
// The number is too large to represent in a 64-bit value.
err = errOverflow
return
}
// DecodeFixed64 reads a 64-bit integer from the Buffer.
// This is the format for the
// fixed64, sfixed64, and double protocol buffer types.
func (p *Buffer) DecodeFixed64() (x uint64, err error) {
// x, err already 0
i := p.index + 8
if i < 0 || i > len(p.buf) {
err = io.ErrUnexpectedEOF
return
}
p.index = i
x = uint64(p.buf[i-8])
x |= uint64(p.buf[i-7]) << 8
x |= uint64(p.buf[i-6]) << 16
x |= uint64(p.buf[i-5]) << 24
x |= uint64(p.buf[i-4]) << 32
x |= uint64(p.buf[i-3]) << 40
x |= uint64(p.buf[i-2]) << 48
x |= uint64(p.buf[i-1]) << 56
return
}
// DecodeFixed32 reads a 32-bit integer from the Buffer.
// This is the format for the
// fixed32, sfixed32, and float protocol buffer types.
func (p *Buffer) DecodeFixed32() (x uint64, err error) {
// x, err already 0
i := p.index + 4
if i < 0 || i > len(p.buf) {
err = io.ErrUnexpectedEOF
return
}
p.index = i
x = uint64(p.buf[i-4])
x |= uint64(p.buf[i-3]) << 8
x |= uint64(p.buf[i-2]) << 16
x |= uint64(p.buf[i-1]) << 24
return
}
// DecodeZigzag64 reads a zigzag-encoded 64-bit integer
// from the Buffer.
// This is the format used for the sint64 protocol buffer type.
func (p *Buffer) DecodeZigzag64() (x uint64, err error) {
x, err = p.DecodeVarint()
if err != nil {
return
}
x = (x >> 1) ^ uint64((int64(x&1)<<63)>>63)
return
}
// DecodeZigzag32 reads a zigzag-encoded 32-bit integer
// from the Buffer.
// This is the format used for the sint32 protocol buffer type.
func (p *Buffer) DecodeZigzag32() (x uint64, err error) {
x, err = p.DecodeVarint()
if err != nil {
return
}
x = uint64((uint32(x) >> 1) ^ uint32((int32(x&1)<<31)>>31))
return
}
// These are not ValueDecoders: they produce an array of bytes or a string.
// bytes, embedded messages
// DecodeRawBytes reads a count-delimited byte buffer from the Buffer.
// This is the format used for the bytes protocol buffer
// type and for embedded messages.
func (p *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) {
n, err := p.DecodeVarint()
if err != nil {
return
}
nb := int(n)
if nb < 0 {
return nil, fmt.Errorf("proto: bad byte length %d", nb)
}
end := p.index + nb
if end < p.index || end > len(p.buf) {
return nil, io.ErrUnexpectedEOF
}
if !alloc {
// todo: check if can get more uses of alloc=false
buf = p.buf[p.index:end]
p.index += nb
return
}
buf = make([]byte, nb)
copy(buf, p.buf[p.index:])
p.index += nb
return
}
// DecodeStringBytes reads an encoded string from the Buffer.
// This is the format used for the proto2 string type.
func (p *Buffer) DecodeStringBytes() (s string, err error) {
buf, err := p.DecodeRawBytes(false)
if err != nil {
return
}
return string(buf), nil
}
// Skip the next item in the buffer. Its wire type is decoded and presented as an argument.
// If the protocol buffer has extensions, and the field matches, add it as an extension.
// Otherwise, if the XXX_unrecognized field exists, append the skipped data there.
func (o *Buffer) skipAndSave(t reflect.Type, tag, wire int, base structPointer, unrecField field) error {
oi := o.index
err := o.skip(t, tag, wire)
if err != nil {
return err
}
if !unrecField.IsValid() {
return nil
}
ptr := structPointer_Bytes(base, unrecField)
// Add the skipped field to struct field
obuf := o.buf
o.buf = *ptr
o.EncodeVarint(uint64(tag<<3 | wire))
*ptr = append(o.buf, obuf[oi:o.index]...)
o.buf = obuf
return nil
}
// Skip the next item in the buffer. Its wire type is decoded and presented as an argument.
func (o *Buffer) skip(t reflect.Type, tag, wire int) error {
var u uint64
var err error
switch wire {
case WireVarint:
_, err = o.DecodeVarint()
case WireFixed64:
_, err = o.DecodeFixed64()
case WireBytes:
_, err = o.DecodeRawBytes(false)
case WireFixed32:
_, err = o.DecodeFixed32()
case WireStartGroup:
for {
u, err = o.DecodeVarint()
if err != nil {
break
}
fwire := int(u & 0x7)
if fwire == WireEndGroup {
break
}
ftag := int(u >> 3)
err = o.skip(t, ftag, fwire)
if err != nil {
break
}
}
default:
err = fmt.Errorf("proto: can't skip unknown wire type %d for %s", wire, t)
}
return err
}
// Unmarshaler is the interface representing objects that can
// unmarshal themselves. The method should reset the receiver before
// decoding starts. The argument points to data that may be
// overwritten, so implementations should not keep references to the
// buffer.
type Unmarshaler interface {
Unmarshal([]byte) error
}
// Unmarshal parses the protocol buffer representation in buf and places the
// decoded result in pb. If the struct underlying pb does not match
// the data in buf, the results can be unpredictable.
//
// Unmarshal resets pb before starting to unmarshal, so any
// existing data in pb is always removed. Use UnmarshalMerge
// to preserve and append to existing data.
func Unmarshal(buf []byte, pb Message) error {
pb.Reset()
return UnmarshalMerge(buf, pb)
}
// UnmarshalMerge parses the protocol buffer representation in buf and
// writes the decoded result to pb. If the struct underlying pb does not match
// the data in buf, the results can be unpredictable.
//
// UnmarshalMerge merges into existing data in pb.
// Most code should use Unmarshal instead.
func UnmarshalMerge(buf []byte, pb Message) error {
// If the object can unmarshal itself, let it.
if u, ok := pb.(Unmarshaler); ok {
return u.Unmarshal(buf)
}
return NewBuffer(buf).Unmarshal(pb)
}
// Unmarshal parses the protocol buffer representation in the
// Buffer and places the decoded result in pb. If the struct
// underlying pb does not match the data in the buffer, the results can be
// unpredictable.
func (p *Buffer) Unmarshal(pb Message) error {
// If the object can unmarshal itself, let it.
if u, ok := pb.(Unmarshaler); ok {
err := u.Unmarshal(p.buf[p.index:])
p.index = len(p.buf)
return err
}
typ, base, err := getbase(pb)
if err != nil {
return err
}
err = p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), false, base)
if collectStats {
stats.Decode++
}
return err
}
// unmarshalType does the work of unmarshaling a structure.
func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group bool, base structPointer) error {
var state errorState
required, reqFields := prop.reqCount, uint64(0)
var err error
for err == nil && o.index < len(o.buf) {
oi := o.index
var u uint64
u, err = o.DecodeVarint()
if err != nil {
break
}
wire := int(u & 0x7)
if wire == WireEndGroup {
if is_group {
return nil // input is satisfied
}
return fmt.Errorf("proto: %s: wiretype end group for non-group", st)
}
tag := int(u >> 3)
if tag <= 0 {
return fmt.Errorf("proto: %s: illegal tag %d", st, tag)
}
fieldnum, ok := prop.decoderTags.get(tag)
if !ok {
// Maybe it's an extension?
if prop.extendable {
if e := structPointer_Interface(base, st).(extendableProto); isExtensionField(e, int32(tag)) {
if err = o.skip(st, tag, wire); err == nil {
ext := e.ExtensionMap()[int32(tag)] // may be missing
ext.enc = append(ext.enc, o.buf[oi:o.index]...)
e.ExtensionMap()[int32(tag)] = ext
}
continue
}
}
err = o.skipAndSave(st, tag, wire, base, prop.unrecField)
continue
}
p := prop.Prop[fieldnum]
if p.dec == nil {
fmt.Fprintf(os.Stderr, "proto: no protobuf decoder for %s.%s\n", st, st.Field(fieldnum).Name)
continue
}
dec := p.dec
if wire != WireStartGroup && wire != p.WireType {
if wire == WireBytes && p.packedDec != nil {
// a packable field
dec = p.packedDec
} else {
err = fmt.Errorf("proto: bad wiretype for field %s.%s: got wiretype %d, want %d", st, st.Field(fieldnum).Name, wire, p.WireType)
continue
}
}
decErr := dec(o, p, base)
if decErr != nil && !state.shouldContinue(decErr, p) {
err = decErr
}
if err == nil && p.Required {
// Successfully decoded a required field.
if tag <= 64 {
// use bitmap for fields 1-64 to catch field reuse.
var mask uint64 = 1 << uint64(tag-1)
if reqFields&mask == 0 {
// new required field
reqFields |= mask
required--
}
} else {
// This is imprecise. It can be fooled by a required field
// with a tag > 64 that is encoded twice; that's very rare.
// A fully correct implementation would require allocating
// a data structure, which we would like to avoid.
required--
}
}
}
if err == nil {
if is_group {
return io.ErrUnexpectedEOF
}
if state.err != nil {
return state.err
}
if required > 0 {
// Not enough information to determine the exact field. If we use extra
// CPU, we could determine the field only if the missing required field
// has a tag <= 64 and we check reqFields.
return &RequiredNotSetError{"{Unknown}"}
}
}
return err
}
// Individual type decoders
// For each,
// u is the decoded value,
// v is a pointer to the field (pointer) in the struct
// Sizes of the pools to allocate inside the Buffer.
// The goal is modest amortization and allocation
// on at least 16-byte boundaries.
const (
boolPoolSize = 16
uint32PoolSize = 8
uint64PoolSize = 4
)
// Decode a bool.
func (o *Buffer) dec_bool(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
if len(o.bools) == 0 {
o.bools = make([]bool, boolPoolSize)
}
o.bools[0] = u != 0
*structPointer_Bool(base, p.field) = &o.bools[0]
o.bools = o.bools[1:]
return nil
}
// Decode an int32.
func (o *Buffer) dec_int32(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
word32_Set(structPointer_Word32(base, p.field), o, uint32(u))
return nil
}
// Decode an int64.
func (o *Buffer) dec_int64(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
word64_Set(structPointer_Word64(base, p.field), o, u)
return nil
}
// Decode a string.
func (o *Buffer) dec_string(p *Properties, base structPointer) error {
s, err := o.DecodeStringBytes()
if err != nil {
return err
}
sp := new(string)
*sp = s
*structPointer_String(base, p.field) = sp
return nil
}
// Decode a slice of bytes ([]byte).
func (o *Buffer) dec_slice_byte(p *Properties, base structPointer) error {
b, err := o.DecodeRawBytes(true)
if err != nil {
return err
}
*structPointer_Bytes(base, p.field) = b
return nil
}
// Decode a slice of bools ([]bool).
func (o *Buffer) dec_slice_bool(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
v := structPointer_BoolSlice(base, p.field)
*v = append(*v, u != 0)
return nil
}
// Decode a slice of bools ([]bool) in packed format.
func (o *Buffer) dec_slice_packed_bool(p *Properties, base structPointer) error {
v := structPointer_BoolSlice(base, p.field)
nn, err := o.DecodeVarint()
if err != nil {
return err
}
nb := int(nn) // number of bytes of encoded bools
y := *v
for i := 0; i < nb; i++ {
u, err := p.valDec(o)
if err != nil {
return err
}
y = append(y, u != 0)
}
*v = y
return nil
}
// Decode a slice of int32s ([]int32).
func (o *Buffer) dec_slice_int32(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
structPointer_Word32Slice(base, p.field).Append(uint32(u))
return nil
}
// Decode a slice of int32s ([]int32) in packed format.
func (o *Buffer) dec_slice_packed_int32(p *Properties, base structPointer) error {
v := structPointer_Word32Slice(base, p.field)
nn, err := o.DecodeVarint()
if err != nil {
return err
}
nb := int(nn) // number of bytes of encoded int32s
fin := o.index + nb
if fin < o.index {
return errOverflow
}
for o.index < fin {
u, err := p.valDec(o)
if err != nil {
return err
}
v.Append(uint32(u))
}
return nil
}
// Decode a slice of int64s ([]int64).
func (o *Buffer) dec_slice_int64(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
structPointer_Word64Slice(base, p.field).Append(u)
return nil
}
// Decode a slice of int64s ([]int64) in packed format.
func (o *Buffer) dec_slice_packed_int64(p *Properties, base structPointer) error {
v := structPointer_Word64Slice(base, p.field)
nn, err := o.DecodeVarint()
if err != nil {
return err
}
nb := int(nn) // number of bytes of encoded int64s
fin := o.index + nb
if fin < o.index {
return errOverflow
}
for o.index < fin {
u, err := p.valDec(o)
if err != nil {
return err
}
v.Append(u)
}
return nil
}
// Decode a slice of strings ([]string).
func (o *Buffer) dec_slice_string(p *Properties, base structPointer) error {
s, err := o.DecodeStringBytes()
if err != nil {
return err
}
v := structPointer_StringSlice(base, p.field)
*v = append(*v, s)
return nil
}
// Decode a slice of slice of bytes ([][]byte).
func (o *Buffer) dec_slice_slice_byte(p *Properties, base structPointer) error {
b, err := o.DecodeRawBytes(true)
if err != nil {
return err
}
v := structPointer_BytesSlice(base, p.field)
*v = append(*v, b)
return nil
}
// Decode a group.
func (o *Buffer) dec_struct_group(p *Properties, base structPointer) error {
bas := structPointer_GetStructPointer(base, p.field)
if structPointer_IsNil(bas) {
// allocate new nested message
bas = toStructPointer(reflect.New(p.stype))
structPointer_SetStructPointer(base, p.field, bas)
}
return o.unmarshalType(p.stype, p.sprop, true, bas)
}
// Decode an embedded message.
func (o *Buffer) dec_struct_message(p *Properties, base structPointer) (err error) {
raw, e := o.DecodeRawBytes(false)
if e != nil {
return e
}
bas := structPointer_GetStructPointer(base, p.field)
if structPointer_IsNil(bas) {
// allocate new nested message
bas = toStructPointer(reflect.New(p.stype))
structPointer_SetStructPointer(base, p.field, bas)
}
// If the object can unmarshal itself, let it.
if p.isUnmarshaler {
iv := structPointer_Interface(bas, p.stype)
return iv.(Unmarshaler).Unmarshal(raw)
}
obuf := o.buf
oi := o.index
o.buf = raw
o.index = 0
err = o.unmarshalType(p.stype, p.sprop, false, bas)
o.buf = obuf
o.index = oi
return err
}
// Decode a slice of embedded messages.
func (o *Buffer) dec_slice_struct_message(p *Properties, base structPointer) error {
return o.dec_slice_struct(p, false, base)
}
// Decode a slice of embedded groups.
func (o *Buffer) dec_slice_struct_group(p *Properties, base structPointer) error {
return o.dec_slice_struct(p, true, base)
}
// Decode a slice of structs ([]*struct).
func (o *Buffer) dec_slice_struct(p *Properties, is_group bool, base structPointer) error {
v := reflect.New(p.stype)
bas := toStructPointer(v)
structPointer_StructPointerSlice(base, p.field).Append(bas)
if is_group {
err := o.unmarshalType(p.stype, p.sprop, is_group, bas)
return err
}
raw, err := o.DecodeRawBytes(false)
if err != nil {
return err
}
// If the object can unmarshal itself, let it.
if p.isUnmarshaler {
iv := v.Interface()
return iv.(Unmarshaler).Unmarshal(raw)
}
obuf := o.buf
oi := o.index
o.buf = raw
o.index = 0
err = o.unmarshalType(p.stype, p.sprop, is_group, bas)
o.buf = obuf
o.index = oi
return err
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,241 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2011 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Protocol buffer comparison.
// TODO: MessageSet.
package proto
import (
"bytes"
"log"
"reflect"
"strings"
)
/*
Equal returns true iff protocol buffers a and b are equal.
The arguments must both be pointers to protocol buffer structs.
Equality is defined in this way:
- Two messages are equal iff they are the same type,
corresponding fields are equal, unknown field sets
are equal, and extensions sets are equal.
- Two set scalar fields are equal iff their values are equal.
If the fields are of a floating-point type, remember that
NaN != x for all x, including NaN.
- Two repeated fields are equal iff their lengths are the same,
and their corresponding elements are equal (a "bytes" field,
although represented by []byte, is not a repeated field)
- Two unset fields are equal.
- Two unknown field sets are equal if their current
encoded state is equal. (TODO)
- Two extension sets are equal iff they have corresponding
elements that are pairwise equal.
- Every other combination of things are not equal.
The return value is undefined if a and b are not protocol buffers.
*/
func Equal(a, b Message) bool {
if a == nil || b == nil {
return a == b
}
v1, v2 := reflect.ValueOf(a), reflect.ValueOf(b)
if v1.Type() != v2.Type() {
return false
}
if v1.Kind() == reflect.Ptr {
if v1.IsNil() {
return v2.IsNil()
}
if v2.IsNil() {
return false
}
v1, v2 = v1.Elem(), v2.Elem()
}
if v1.Kind() != reflect.Struct {
return false
}
return equalStruct(v1, v2)
}
// v1 and v2 are known to have the same type.
func equalStruct(v1, v2 reflect.Value) bool {
for i := 0; i < v1.NumField(); i++ {
f := v1.Type().Field(i)
if strings.HasPrefix(f.Name, "XXX_") {
continue
}
f1, f2 := v1.Field(i), v2.Field(i)
if f.Type.Kind() == reflect.Ptr {
if n1, n2 := f1.IsNil(), f2.IsNil(); n1 && n2 {
// both unset
continue
} else if n1 != n2 {
// set/unset mismatch
return false
}
b1, ok := f1.Interface().(raw)
if ok {
b2 := f2.Interface().(raw)
// RawMessage
if !bytes.Equal(b1.Bytes(), b2.Bytes()) {
return false
}
continue
}
f1, f2 = f1.Elem(), f2.Elem()
}
if !equalAny(f1, f2) {
return false
}
}
if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
em2 := v2.FieldByName("XXX_extensions")
if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
return false
}
}
uf := v1.FieldByName("XXX_unrecognized")
if !uf.IsValid() {
return true
}
u1 := uf.Bytes()
u2 := v2.FieldByName("XXX_unrecognized").Bytes()
if !bytes.Equal(u1, u2) {
return false
}
return true
}
// v1 and v2 are known to have the same type.
func equalAny(v1, v2 reflect.Value) bool {
if v1.Type() == protoMessageType {
m1, _ := v1.Interface().(Message)
m2, _ := v2.Interface().(Message)
return Equal(m1, m2)
}
switch v1.Kind() {
case reflect.Bool:
return v1.Bool() == v2.Bool()
case reflect.Float32, reflect.Float64:
return v1.Float() == v2.Float()
case reflect.Int32, reflect.Int64:
return v1.Int() == v2.Int()
case reflect.Ptr:
return equalAny(v1.Elem(), v2.Elem())
case reflect.Slice:
if v1.Type().Elem().Kind() == reflect.Uint8 {
// short circuit: []byte
if v1.IsNil() != v2.IsNil() {
return false
}
return bytes.Equal(v1.Interface().([]byte), v2.Interface().([]byte))
}
if v1.Len() != v2.Len() {
return false
}
for i := 0; i < v1.Len(); i++ {
if !equalAny(v1.Index(i), v2.Index(i)) {
return false
}
}
return true
case reflect.String:
return v1.Interface().(string) == v2.Interface().(string)
case reflect.Struct:
return equalStruct(v1, v2)
case reflect.Uint32, reflect.Uint64:
return v1.Uint() == v2.Uint()
}
// unknown type, so not a protocol buffer
log.Printf("proto: don't know how to compare %v", v1)
return false
}
// base is the struct type that the extensions are based on.
// em1 and em2 are extension maps.
func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
if len(em1) != len(em2) {
return false
}
for extNum, e1 := range em1 {
e2, ok := em2[extNum]
if !ok {
return false
}
m1, m2 := e1.value, e2.value
if m1 != nil && m2 != nil {
// Both are unencoded.
if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) {
return false
}
continue
}
// At least one is encoded. To do a semantically correct comparison
// we need to unmarshal them first.
var desc *ExtensionDesc
if m := extensionMaps[base]; m != nil {
desc = m[extNum]
}
if desc == nil {
log.Printf("proto: don't know how to compare extension %d of %v", extNum, base)
continue
}
var err error
if m1 == nil {
m1, err = decodeExtension(e1.enc, desc)
}
if m2 == nil && err == nil {
m2, err = decodeExtension(e2.enc, desc)
}
if err != nil {
// The encoded form is invalid.
log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err)
return false
}
if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) {
return false
}
}
return true
}

View file

@ -0,0 +1,166 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2011 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto_test
import (
"testing"
pb "./testdata"
. "code.google.com/p/goprotobuf/proto"
)
// Four identical base messages.
// The init function adds extensions to some of them.
var messageWithoutExtension = &pb.MyMessage{Count: Int32(7)}
var messageWithExtension1a = &pb.MyMessage{Count: Int32(7)}
var messageWithExtension1b = &pb.MyMessage{Count: Int32(7)}
var messageWithExtension2 = &pb.MyMessage{Count: Int32(7)}
// Two messages with non-message extensions.
var messageWithInt32Extension1 = &pb.MyMessage{Count: Int32(8)}
var messageWithInt32Extension2 = &pb.MyMessage{Count: Int32(8)}
func init() {
ext1 := &pb.Ext{Data: String("Kirk")}
ext2 := &pb.Ext{Data: String("Picard")}
// messageWithExtension1a has ext1, but never marshals it.
if err := SetExtension(messageWithExtension1a, pb.E_Ext_More, ext1); err != nil {
panic("SetExtension on 1a failed: " + err.Error())
}
// messageWithExtension1b is the unmarshaled form of messageWithExtension1a.
if err := SetExtension(messageWithExtension1b, pb.E_Ext_More, ext1); err != nil {
panic("SetExtension on 1b failed: " + err.Error())
}
buf, err := Marshal(messageWithExtension1b)
if err != nil {
panic("Marshal of 1b failed: " + err.Error())
}
messageWithExtension1b.Reset()
if err := Unmarshal(buf, messageWithExtension1b); err != nil {
panic("Unmarshal of 1b failed: " + err.Error())
}
// messageWithExtension2 has ext2.
if err := SetExtension(messageWithExtension2, pb.E_Ext_More, ext2); err != nil {
panic("SetExtension on 2 failed: " + err.Error())
}
if err := SetExtension(messageWithInt32Extension1, pb.E_Ext_Number, Int32(23)); err != nil {
panic("SetExtension on Int32-1 failed: " + err.Error())
}
if err := SetExtension(messageWithInt32Extension1, pb.E_Ext_Number, Int32(24)); err != nil {
panic("SetExtension on Int32-2 failed: " + err.Error())
}
}
var EqualTests = []struct {
desc string
a, b Message
exp bool
}{
{"different types", &pb.GoEnum{}, &pb.GoTestField{}, false},
{"equal empty", &pb.GoEnum{}, &pb.GoEnum{}, true},
{"nil vs nil", nil, nil, true},
{"typed nil vs typed nil", (*pb.GoEnum)(nil), (*pb.GoEnum)(nil), true},
{"typed nil vs empty", (*pb.GoEnum)(nil), &pb.GoEnum{}, false},
{"different typed nil", (*pb.GoEnum)(nil), (*pb.GoTestField)(nil), false},
{"one set field, one unset field", &pb.GoTestField{Label: String("foo")}, &pb.GoTestField{}, false},
{"one set field zero, one unset field", &pb.GoTest{Param: Int32(0)}, &pb.GoTest{}, false},
{"different set fields", &pb.GoTestField{Label: String("foo")}, &pb.GoTestField{Label: String("bar")}, false},
{"equal set", &pb.GoTestField{Label: String("foo")}, &pb.GoTestField{Label: String("foo")}, true},
{"repeated, one set", &pb.GoTest{F_Int32Repeated: []int32{2, 3}}, &pb.GoTest{}, false},
{"repeated, different length", &pb.GoTest{F_Int32Repeated: []int32{2, 3}}, &pb.GoTest{F_Int32Repeated: []int32{2}}, false},
{"repeated, different value", &pb.GoTest{F_Int32Repeated: []int32{2}}, &pb.GoTest{F_Int32Repeated: []int32{3}}, false},
{"repeated, equal", &pb.GoTest{F_Int32Repeated: []int32{2, 4}}, &pb.GoTest{F_Int32Repeated: []int32{2, 4}}, true},
{"repeated, nil equal nil", &pb.GoTest{F_Int32Repeated: nil}, &pb.GoTest{F_Int32Repeated: nil}, true},
{"repeated, nil equal empty", &pb.GoTest{F_Int32Repeated: nil}, &pb.GoTest{F_Int32Repeated: []int32{}}, true},
{"repeated, empty equal nil", &pb.GoTest{F_Int32Repeated: []int32{}}, &pb.GoTest{F_Int32Repeated: nil}, true},
{
"nested, different",
&pb.GoTest{RequiredField: &pb.GoTestField{Label: String("foo")}},
&pb.GoTest{RequiredField: &pb.GoTestField{Label: String("bar")}},
false,
},
{
"nested, equal",
&pb.GoTest{RequiredField: &pb.GoTestField{Label: String("wow")}},
&pb.GoTest{RequiredField: &pb.GoTestField{Label: String("wow")}},
true,
},
{"bytes", &pb.OtherMessage{Value: []byte("foo")}, &pb.OtherMessage{Value: []byte("foo")}, true},
{"bytes, empty", &pb.OtherMessage{Value: []byte{}}, &pb.OtherMessage{Value: []byte{}}, true},
{"bytes, empty vs nil", &pb.OtherMessage{Value: []byte{}}, &pb.OtherMessage{Value: nil}, false},
{
"repeated bytes",
&pb.MyMessage{RepBytes: [][]byte{[]byte("sham"), []byte("wow")}},
&pb.MyMessage{RepBytes: [][]byte{[]byte("sham"), []byte("wow")}},
true,
},
{"extension vs. no extension", messageWithoutExtension, messageWithExtension1a, false},
{"extension vs. same extension", messageWithExtension1a, messageWithExtension1b, true},
{"extension vs. different extension", messageWithExtension1a, messageWithExtension2, false},
{"int32 extension vs. itself", messageWithInt32Extension1, messageWithInt32Extension1, true},
{"int32 extension vs. a different int32", messageWithInt32Extension1, messageWithInt32Extension2, false},
{
"message with group",
&pb.MyMessage{
Count: Int32(1),
Somegroup: &pb.MyMessage_SomeGroup{
GroupField: Int32(5),
},
},
&pb.MyMessage{
Count: Int32(1),
Somegroup: &pb.MyMessage_SomeGroup{
GroupField: Int32(5),
},
},
true,
},
}
func TestEqual(t *testing.T) {
for _, tc := range EqualTests {
if res := Equal(tc.a, tc.b); res != tc.exp {
t.Errorf("%v: Equal(%v, %v) = %v, want %v", tc.desc, tc.a, tc.b, res, tc.exp)
}
}
}

View file

@ -0,0 +1,351 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto
/*
* Types and routines for supporting protocol buffer extensions.
*/
import (
"errors"
"reflect"
"strconv"
"sync"
)
// ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message.
var ErrMissingExtension = errors.New("proto: missing extension")
// ExtensionRange represents a range of message extensions for a protocol buffer.
// Used in code generated by the protocol compiler.
type ExtensionRange struct {
Start, End int32 // both inclusive
}
// extendableProto is an interface implemented by any protocol buffer that may be extended.
type extendableProto interface {
Message
ExtensionRangeArray() []ExtensionRange
ExtensionMap() map[int32]Extension
}
var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
// ExtensionDesc represents an extension specification.
// Used in generated code from the protocol compiler.
type ExtensionDesc struct {
ExtendedType Message // nil pointer to the type that is being extended
ExtensionType interface{} // nil pointer to the extension type
Field int32 // field number
Name string // fully-qualified name of extension, for text formatting
Tag string // protobuf tag style
}
func (ed *ExtensionDesc) repeated() bool {
t := reflect.TypeOf(ed.ExtensionType)
return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
}
// Extension represents an extension in a message.
type Extension struct {
// When an extension is stored in a message using SetExtension
// only desc and value are set. When the message is marshaled
// enc will be set to the encoded form of the message.
//
// When a message is unmarshaled and contains extensions, each
// extension will have only enc set. When such an extension is
// accessed using GetExtension (or GetExtensions) desc and value
// will be set.
desc *ExtensionDesc
value interface{}
enc []byte
}
// SetRawExtension is for testing only.
func SetRawExtension(base extendableProto, id int32, b []byte) {
base.ExtensionMap()[id] = Extension{enc: b}
}
// isExtensionField returns true iff the given field number is in an extension range.
func isExtensionField(pb extendableProto, field int32) bool {
for _, er := range pb.ExtensionRangeArray() {
if er.Start <= field && field <= er.End {
return true
}
}
return false
}
// checkExtensionTypes checks that the given extension is valid for pb.
func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
// Check the extended type.
if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b {
return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String())
}
// Check the range.
if !isExtensionField(pb, extension.Field) {
return errors.New("proto: bad extension number; not in declared ranges")
}
return nil
}
// extPropKey is sufficient to uniquely identify an extension.
type extPropKey struct {
base reflect.Type
field int32
}
var extProp = struct {
sync.RWMutex
m map[extPropKey]*Properties
}{
m: make(map[extPropKey]*Properties),
}
func extensionProperties(ed *ExtensionDesc) *Properties {
key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field}
extProp.RLock()
if prop, ok := extProp.m[key]; ok {
extProp.RUnlock()
return prop
}
extProp.RUnlock()
extProp.Lock()
defer extProp.Unlock()
// Check again.
if prop, ok := extProp.m[key]; ok {
return prop
}
prop := new(Properties)
prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil)
extProp.m[key] = prop
return prop
}
// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m.
func encodeExtensionMap(m map[int32]Extension) error {
for k, e := range m {
if e.value == nil || e.desc == nil {
// Extension is only in its encoded form.
continue
}
// We don't skip extensions that have an encoded form set,
// because the extension value may have been mutated after
// the last time this function was called.
et := reflect.TypeOf(e.desc.ExtensionType)
props := extensionProperties(e.desc)
p := NewBuffer(nil)
// If e.value has type T, the encoder expects a *struct{ X T }.
// Pass a *T with a zero field and hope it all works out.
x := reflect.New(et)
x.Elem().Set(reflect.ValueOf(e.value))
if err := props.enc(p, props, toStructPointer(x)); err != nil {
return err
}
e.enc = p.buf
m[k] = e
}
return nil
}
func sizeExtensionMap(m map[int32]Extension) (n int) {
for _, e := range m {
if e.value == nil || e.desc == nil {
// Extension is only in its encoded form.
n += len(e.enc)
continue
}
// We don't skip extensions that have an encoded form set,
// because the extension value may have been mutated after
// the last time this function was called.
et := reflect.TypeOf(e.desc.ExtensionType)
props := extensionProperties(e.desc)
// If e.value has type T, the encoder expects a *struct{ X T }.
// Pass a *T with a zero field and hope it all works out.
x := reflect.New(et)
x.Elem().Set(reflect.ValueOf(e.value))
n += props.size(props, toStructPointer(x))
}
return
}
// HasExtension returns whether the given extension is present in pb.
func HasExtension(pb extendableProto, extension *ExtensionDesc) bool {
// TODO: Check types, field numbers, etc.?
_, ok := pb.ExtensionMap()[extension.Field]
return ok
}
// ClearExtension removes the given extension from pb.
func ClearExtension(pb extendableProto, extension *ExtensionDesc) {
// TODO: Check types, field numbers, etc.?
delete(pb.ExtensionMap(), extension.Field)
}
// GetExtension parses and returns the given extension of pb.
// If the extension is not present it returns ErrMissingExtension.
func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, error) {
if err := checkExtensionTypes(pb, extension); err != nil {
return nil, err
}
e, ok := pb.ExtensionMap()[extension.Field]
if !ok {
return nil, ErrMissingExtension
}
if e.value != nil {
// Already decoded. Check the descriptor, though.
if e.desc != extension {
// This shouldn't happen. If it does, it means that
// GetExtension was called twice with two different
// descriptors with the same field number.
return nil, errors.New("proto: descriptor conflict")
}
return e.value, nil
}
v, err := decodeExtension(e.enc, extension)
if err != nil {
return nil, err
}
// Remember the decoded version and drop the encoded version.
// That way it is safe to mutate what we return.
e.value = v
e.desc = extension
e.enc = nil
return e.value, nil
}
// decodeExtension decodes an extension encoded in b.
func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
o := NewBuffer(b)
t := reflect.TypeOf(extension.ExtensionType)
rep := extension.repeated()
props := extensionProperties(extension)
// t is a pointer to a struct, pointer to basic type or a slice.
// Allocate a "field" to store the pointer/slice itself; the
// pointer/slice will be stored here. We pass
// the address of this field to props.dec.
// This passes a zero field and a *t and lets props.dec
// interpret it as a *struct{ x t }.
value := reflect.New(t).Elem()
for {
// Discard wire type and field number varint. It isn't needed.
if _, err := o.DecodeVarint(); err != nil {
return nil, err
}
if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil {
return nil, err
}
if !rep || o.index >= len(o.buf) {
break
}
}
return value.Interface(), nil
}
// GetExtensions returns a slice of the extensions present in pb that are also listed in es.
// The returned slice has the same length as es; missing extensions will appear as nil elements.
func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
epb, ok := pb.(extendableProto)
if !ok {
err = errors.New("proto: not an extendable proto")
return
}
extensions = make([]interface{}, len(es))
for i, e := range es {
extensions[i], err = GetExtension(epb, e)
if err == ErrMissingExtension {
err = nil
}
if err != nil {
return
}
}
return
}
// SetExtension sets the specified extension of pb to the specified value.
func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{}) error {
if err := checkExtensionTypes(pb, extension); err != nil {
return err
}
typ := reflect.TypeOf(extension.ExtensionType)
if typ != reflect.TypeOf(value) {
return errors.New("proto: bad extension value type")
}
pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
return nil
}
// A global registry of extensions.
// The generated code will register the generated descriptors by calling RegisterExtension.
var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc)
// RegisterExtension is called from the generated code.
func RegisterExtension(desc *ExtensionDesc) {
st := reflect.TypeOf(desc.ExtendedType).Elem()
m := extensionMaps[st]
if m == nil {
m = make(map[int32]*ExtensionDesc)
extensionMaps[st] = m
}
if _, ok := m[desc.Field]; ok {
panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field)))
}
m[desc.Field] = desc
}
// RegisteredExtensions returns a map of the registered extensions of a
// protocol buffer struct, indexed by the extension number.
// The argument pb should be a nil pointer to the struct type.
func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc {
return extensionMaps[reflect.TypeOf(pb).Elem()]
}

View file

@ -0,0 +1,60 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2014 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto_test
import (
"testing"
pb "./testdata"
"code.google.com/p/goprotobuf/proto"
)
func TestGetExtensionsWithMissingExtensions(t *testing.T) {
msg := &pb.MyMessage{}
ext1 := &pb.Ext{}
if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
t.Fatalf("Could not set ext1: %s", ext1)
}
exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
pb.E_Ext_More,
pb.E_Ext_Text,
})
if err != nil {
t.Fatalf("GetExtensions() failed: %s", err)
}
if exts[0] != ext1 {
t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
}
if exts[1] != nil {
t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
}
}

View file

@ -0,0 +1,740 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/*
Package proto converts data structures to and from the wire format of
protocol buffers. It works in concert with the Go source code generated
for .proto files by the protocol compiler.
A summary of the properties of the protocol buffer interface
for a protocol buffer variable v:
- Names are turned from camel_case to CamelCase for export.
- There are no methods on v to set fields; just treat
them as structure fields.
- There are getters that return a field's value if set,
and return the field's default value if unset.
The getters work even if the receiver is a nil message.
- The zero value for a struct is its correct initialization state.
All desired fields must be set before marshaling.
- A Reset() method will restore a protobuf struct to its zero state.
- Non-repeated fields are pointers to the values; nil means unset.
That is, optional or required field int32 f becomes F *int32.
- Repeated fields are slices.
- Helper functions are available to aid the setting of fields.
Helpers for getting values are superseded by the
GetFoo methods and their use is deprecated.
msg.Foo = proto.String("hello") // set field
- Constants are defined to hold the default values of all fields that
have them. They have the form Default_StructName_FieldName.
Because the getter methods handle defaulted values,
direct use of these constants should be rare.
- Enums are given type names and maps from names to values.
Enum values are prefixed with the enum's type name. Enum types have
a String method, and a Enum method to assist in message construction.
- Nested groups and enums have type names prefixed with the name of
the surrounding message type.
- Extensions are given descriptor names that start with E_,
followed by an underscore-delimited list of the nested messages
that contain it (if any) followed by the CamelCased name of the
extension field itself. HasExtension, ClearExtension, GetExtension
and SetExtension are functions for manipulating extensions.
- Marshal and Unmarshal are functions to encode and decode the wire format.
The simplest way to describe this is to see an example.
Given file test.proto, containing
package example;
enum FOO { X = 17; };
message Test {
required string label = 1;
optional int32 type = 2 [default=77];
repeated int64 reps = 3;
optional group OptionalGroup = 4 {
required string RequiredField = 5;
}
}
The resulting file, test.pb.go, is:
package example
import "code.google.com/p/goprotobuf/proto"
type FOO int32
const (
FOO_X FOO = 17
)
var FOO_name = map[int32]string{
17: "X",
}
var FOO_value = map[string]int32{
"X": 17,
}
func (x FOO) Enum() *FOO {
p := new(FOO)
*p = x
return p
}
func (x FOO) String() string {
return proto.EnumName(FOO_name, int32(x))
}
type Test struct {
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"`
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"`
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
XXX_unrecognized []byte `json:"-"`
}
func (this *Test) Reset() { *this = Test{} }
func (this *Test) String() string { return proto.CompactTextString(this) }
const Default_Test_Type int32 = 77
func (this *Test) GetLabel() string {
if this != nil && this.Label != nil {
return *this.Label
}
return ""
}
func (this *Test) GetType() int32 {
if this != nil && this.Type != nil {
return *this.Type
}
return Default_Test_Type
}
func (this *Test) GetOptionalgroup() *Test_OptionalGroup {
if this != nil {
return this.Optionalgroup
}
return nil
}
type Test_OptionalGroup struct {
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"`
XXX_unrecognized []byte `json:"-"`
}
func (this *Test_OptionalGroup) Reset() { *this = Test_OptionalGroup{} }
func (this *Test_OptionalGroup) String() string { return proto.CompactTextString(this) }
func (this *Test_OptionalGroup) GetRequiredField() string {
if this != nil && this.RequiredField != nil {
return *this.RequiredField
}
return ""
}
func init() {
proto.RegisterEnum("example.FOO", FOO_name, FOO_value)
}
To create and play with a Test object:
package main
import (
"log"
"code.google.com/p/goprotobuf/proto"
"./example.pb"
)
func main() {
test := &example.Test{
Label: proto.String("hello"),
Type: proto.Int32(17),
Optionalgroup: &example.Test_OptionalGroup{
RequiredField: proto.String("good bye"),
},
}
data, err := proto.Marshal(test)
if err != nil {
log.Fatal("marshaling error: ", err)
}
newTest := new(example.Test)
err = proto.Unmarshal(data, newTest)
if err != nil {
log.Fatal("unmarshaling error: ", err)
}
// Now test and newTest contain the same data.
if test.GetLabel() != newTest.GetLabel() {
log.Fatalf("data mismatch %q != %q", test.GetLabel(), newTest.GetLabel())
}
// etc.
}
*/
package proto
import (
"encoding/json"
"fmt"
"log"
"reflect"
"strconv"
"sync"
)
// Message is implemented by generated protocol buffer messages.
type Message interface {
Reset()
String() string
ProtoMessage()
}
// Stats records allocation details about the protocol buffer encoders
// and decoders. Useful for tuning the library itself.
type Stats struct {
Emalloc uint64 // mallocs in encode
Dmalloc uint64 // mallocs in decode
Encode uint64 // number of encodes
Decode uint64 // number of decodes
Chit uint64 // number of cache hits
Cmiss uint64 // number of cache misses
Size uint64 // number of sizes
}
// Set to true to enable stats collection.
const collectStats = false
var stats Stats
// GetStats returns a copy of the global Stats structure.
func GetStats() Stats { return stats }
// A Buffer is a buffer manager for marshaling and unmarshaling
// protocol buffers. It may be reused between invocations to
// reduce memory usage. It is not necessary to use a Buffer;
// the global functions Marshal and Unmarshal create a
// temporary Buffer and are fine for most applications.
type Buffer struct {
buf []byte // encode/decode byte stream
index int // write point
// pools of basic types to amortize allocation.
bools []bool
uint32s []uint32
uint64s []uint64
// extra pools, only used with pointer_reflect.go
int32s []int32
int64s []int64
float32s []float32
float64s []float64
}
// NewBuffer allocates a new Buffer and initializes its internal data to
// the contents of the argument slice.
func NewBuffer(e []byte) *Buffer {
return &Buffer{buf: e}
}
// Reset resets the Buffer, ready for marshaling a new protocol buffer.
func (p *Buffer) Reset() {
p.buf = p.buf[0:0] // for reading/writing
p.index = 0 // for reading
}
// SetBuf replaces the internal buffer with the slice,
// ready for unmarshaling the contents of the slice.
func (p *Buffer) SetBuf(s []byte) {
p.buf = s
p.index = 0
}
// Bytes returns the contents of the Buffer.
func (p *Buffer) Bytes() []byte { return p.buf }
/*
* Helper routines for simplifying the creation of optional fields of basic type.
*/
// Bool is a helper routine that allocates a new bool value
// to store v and returns a pointer to it.
func Bool(v bool) *bool {
return &v
}
// Int32 is a helper routine that allocates a new int32 value
// to store v and returns a pointer to it.
func Int32(v int32) *int32 {
return &v
}
// Int is a helper routine that allocates a new int32 value
// to store v and returns a pointer to it, but unlike Int32
// its argument value is an int.
func Int(v int) *int32 {
p := new(int32)
*p = int32(v)
return p
}
// Int64 is a helper routine that allocates a new int64 value
// to store v and returns a pointer to it.
func Int64(v int64) *int64 {
return &v
}
// Float32 is a helper routine that allocates a new float32 value
// to store v and returns a pointer to it.
func Float32(v float32) *float32 {
return &v
}
// Float64 is a helper routine that allocates a new float64 value
// to store v and returns a pointer to it.
func Float64(v float64) *float64 {
return &v
}
// Uint32 is a helper routine that allocates a new uint32 value
// to store v and returns a pointer to it.
func Uint32(v uint32) *uint32 {
p := new(uint32)
*p = v
return p
}
// Uint64 is a helper routine that allocates a new uint64 value
// to store v and returns a pointer to it.
func Uint64(v uint64) *uint64 {
return &v
}
// String is a helper routine that allocates a new string value
// to store v and returns a pointer to it.
func String(v string) *string {
return &v
}
// EnumName is a helper function to simplify printing protocol buffer enums
// by name. Given an enum map and a value, it returns a useful string.
func EnumName(m map[int32]string, v int32) string {
s, ok := m[v]
if ok {
return s
}
return strconv.Itoa(int(v))
}
// UnmarshalJSONEnum is a helper function to simplify recovering enum int values
// from their JSON-encoded representation. Given a map from the enum's symbolic
// names to its int values, and a byte buffer containing the JSON-encoded
// value, it returns an int32 that can be cast to the enum type by the caller.
//
// The function can deal with both JSON representations, numeric and symbolic.
func UnmarshalJSONEnum(m map[string]int32, data []byte, enumName string) (int32, error) {
if data[0] == '"' {
// New style: enums are strings.
var repr string
if err := json.Unmarshal(data, &repr); err != nil {
return -1, err
}
val, ok := m[repr]
if !ok {
return 0, fmt.Errorf("unrecognized enum %s value %q", enumName, repr)
}
return val, nil
}
// Old style: enums are ints.
var val int32
if err := json.Unmarshal(data, &val); err != nil {
return 0, fmt.Errorf("cannot unmarshal %#q into enum %s", data, enumName)
}
return val, nil
}
// DebugPrint dumps the encoded data in b in a debugging format with a header
// including the string s. Used in testing but made available for general debugging.
func (o *Buffer) DebugPrint(s string, b []byte) {
var u uint64
obuf := o.buf
index := o.index
o.buf = b
o.index = 0
depth := 0
fmt.Printf("\n--- %s ---\n", s)
out:
for {
for i := 0; i < depth; i++ {
fmt.Print(" ")
}
index := o.index
if index == len(o.buf) {
break
}
op, err := o.DecodeVarint()
if err != nil {
fmt.Printf("%3d: fetching op err %v\n", index, err)
break out
}
tag := op >> 3
wire := op & 7
switch wire {
default:
fmt.Printf("%3d: t=%3d unknown wire=%d\n",
index, tag, wire)
break out
case WireBytes:
var r []byte
r, err = o.DecodeRawBytes(false)
if err != nil {
break out
}
fmt.Printf("%3d: t=%3d bytes [%d]", index, tag, len(r))
if len(r) <= 6 {
for i := 0; i < len(r); i++ {
fmt.Printf(" %.2x", r[i])
}
} else {
for i := 0; i < 3; i++ {
fmt.Printf(" %.2x", r[i])
}
fmt.Printf(" ..")
for i := len(r) - 3; i < len(r); i++ {
fmt.Printf(" %.2x", r[i])
}
}
fmt.Printf("\n")
case WireFixed32:
u, err = o.DecodeFixed32()
if err != nil {
fmt.Printf("%3d: t=%3d fix32 err %v\n", index, tag, err)
break out
}
fmt.Printf("%3d: t=%3d fix32 %d\n", index, tag, u)
case WireFixed64:
u, err = o.DecodeFixed64()
if err != nil {
fmt.Printf("%3d: t=%3d fix64 err %v\n", index, tag, err)
break out
}
fmt.Printf("%3d: t=%3d fix64 %d\n", index, tag, u)
break
case WireVarint:
u, err = o.DecodeVarint()
if err != nil {
fmt.Printf("%3d: t=%3d varint err %v\n", index, tag, err)
break out
}
fmt.Printf("%3d: t=%3d varint %d\n", index, tag, u)
case WireStartGroup:
if err != nil {
fmt.Printf("%3d: t=%3d start err %v\n", index, tag, err)
break out
}
fmt.Printf("%3d: t=%3d start\n", index, tag)
depth++
case WireEndGroup:
depth--
if err != nil {
fmt.Printf("%3d: t=%3d end err %v\n", index, tag, err)
break out
}
fmt.Printf("%3d: t=%3d end\n", index, tag)
}
}
if depth != 0 {
fmt.Printf("%3d: start-end not balanced %d\n", o.index, depth)
}
fmt.Printf("\n")
o.buf = obuf
o.index = index
}
// SetDefaults sets unset protocol buffer fields to their default values.
// It only modifies fields that are both unset and have defined defaults.
// It recursively sets default values in any non-nil sub-messages.
func SetDefaults(pb Message) {
setDefaults(reflect.ValueOf(pb), true, false)
}
// v is a pointer to a struct.
func setDefaults(v reflect.Value, recur, zeros bool) {
v = v.Elem()
defaultMu.RLock()
dm, ok := defaults[v.Type()]
defaultMu.RUnlock()
if !ok {
dm = buildDefaultMessage(v.Type())
defaultMu.Lock()
defaults[v.Type()] = dm
defaultMu.Unlock()
}
for _, sf := range dm.scalars {
f := v.Field(sf.index)
if !f.IsNil() {
// field already set
continue
}
dv := sf.value
if dv == nil && !zeros {
// no explicit default, and don't want to set zeros
continue
}
fptr := f.Addr().Interface() // **T
// TODO: Consider batching the allocations we do here.
switch sf.kind {
case reflect.Bool:
b := new(bool)
if dv != nil {
*b = dv.(bool)
}
*(fptr.(**bool)) = b
case reflect.Float32:
f := new(float32)
if dv != nil {
*f = dv.(float32)
}
*(fptr.(**float32)) = f
case reflect.Float64:
f := new(float64)
if dv != nil {
*f = dv.(float64)
}
*(fptr.(**float64)) = f
case reflect.Int32:
// might be an enum
if ft := f.Type(); ft != int32PtrType {
// enum
f.Set(reflect.New(ft.Elem()))
if dv != nil {
f.Elem().SetInt(int64(dv.(int32)))
}
} else {
// int32 field
i := new(int32)
if dv != nil {
*i = dv.(int32)
}
*(fptr.(**int32)) = i
}
case reflect.Int64:
i := new(int64)
if dv != nil {
*i = dv.(int64)
}
*(fptr.(**int64)) = i
case reflect.String:
s := new(string)
if dv != nil {
*s = dv.(string)
}
*(fptr.(**string)) = s
case reflect.Uint8:
// exceptional case: []byte
var b []byte
if dv != nil {
db := dv.([]byte)
b = make([]byte, len(db))
copy(b, db)
} else {
b = []byte{}
}
*(fptr.(*[]byte)) = b
case reflect.Uint32:
u := new(uint32)
if dv != nil {
*u = dv.(uint32)
}
*(fptr.(**uint32)) = u
case reflect.Uint64:
u := new(uint64)
if dv != nil {
*u = dv.(uint64)
}
*(fptr.(**uint64)) = u
default:
log.Printf("proto: can't set default for field %v (sf.kind=%v)", f, sf.kind)
}
}
for _, ni := range dm.nested {
f := v.Field(ni)
if f.IsNil() {
continue
}
// f is *T or []*T
if f.Kind() == reflect.Ptr {
setDefaults(f, recur, zeros)
} else {
for i := 0; i < f.Len(); i++ {
e := f.Index(i)
if e.IsNil() {
continue
}
setDefaults(e, recur, zeros)
}
}
}
}
var (
// defaults maps a protocol buffer struct type to a slice of the fields,
// with its scalar fields set to their proto-declared non-zero default values.
defaultMu sync.RWMutex
defaults = make(map[reflect.Type]defaultMessage)
int32PtrType = reflect.TypeOf((*int32)(nil))
)
// defaultMessage represents information about the default values of a message.
type defaultMessage struct {
scalars []scalarField
nested []int // struct field index of nested messages
}
type scalarField struct {
index int // struct field index
kind reflect.Kind // element type (the T in *T or []T)
value interface{} // the proto-declared default value, or nil
}
func ptrToStruct(t reflect.Type) bool {
return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
}
// t is a struct type.
func buildDefaultMessage(t reflect.Type) (dm defaultMessage) {
sprop := GetProperties(t)
for _, prop := range sprop.Prop {
fi, ok := sprop.decoderTags.get(prop.Tag)
if !ok {
// XXX_unrecognized
continue
}
ft := t.Field(fi).Type
// nested messages
if ptrToStruct(ft) || (ft.Kind() == reflect.Slice && ptrToStruct(ft.Elem())) {
dm.nested = append(dm.nested, fi)
continue
}
sf := scalarField{
index: fi,
kind: ft.Elem().Kind(),
}
// scalar fields without defaults
if !prop.HasDefault {
dm.scalars = append(dm.scalars, sf)
continue
}
// a scalar field: either *T or []byte
switch ft.Elem().Kind() {
case reflect.Bool:
x, err := strconv.ParseBool(prop.Default)
if err != nil {
log.Printf("proto: bad default bool %q: %v", prop.Default, err)
continue
}
sf.value = x
case reflect.Float32:
x, err := strconv.ParseFloat(prop.Default, 32)
if err != nil {
log.Printf("proto: bad default float32 %q: %v", prop.Default, err)
continue
}
sf.value = float32(x)
case reflect.Float64:
x, err := strconv.ParseFloat(prop.Default, 64)
if err != nil {
log.Printf("proto: bad default float64 %q: %v", prop.Default, err)
continue
}
sf.value = x
case reflect.Int32:
x, err := strconv.ParseInt(prop.Default, 10, 32)
if err != nil {
log.Printf("proto: bad default int32 %q: %v", prop.Default, err)
continue
}
sf.value = int32(x)
case reflect.Int64:
x, err := strconv.ParseInt(prop.Default, 10, 64)
if err != nil {
log.Printf("proto: bad default int64 %q: %v", prop.Default, err)
continue
}
sf.value = x
case reflect.String:
sf.value = prop.Default
case reflect.Uint8:
// []byte (not *uint8)
sf.value = []byte(prop.Default)
case reflect.Uint32:
x, err := strconv.ParseUint(prop.Default, 10, 32)
if err != nil {
log.Printf("proto: bad default uint32 %q: %v", prop.Default, err)
continue
}
sf.value = uint32(x)
case reflect.Uint64:
x, err := strconv.ParseUint(prop.Default, 10, 64)
if err != nil {
log.Printf("proto: bad default uint64 %q: %v", prop.Default, err)
continue
}
sf.value = x
default:
log.Printf("proto: unhandled def kind %v", ft.Elem().Kind())
continue
}
dm.scalars = append(dm.scalars, sf)
}
return dm
}

View file

@ -0,0 +1,229 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto
/*
* Support for message sets.
*/
import (
"errors"
"reflect"
"sort"
)
// ErrNoMessageTypeId occurs when a protocol buffer does not have a message type ID.
// A message type ID is required for storing a protocol buffer in a message set.
var ErrNoMessageTypeId = errors.New("proto does not have a message type ID")
// The first two types (_MessageSet_Item and MessageSet)
// model what the protocol compiler produces for the following protocol message:
// message MessageSet {
// repeated group Item = 1 {
// required int32 type_id = 2;
// required string message = 3;
// };
// }
// That is the MessageSet wire format. We can't use a proto to generate these
// because that would introduce a circular dependency between it and this package.
//
// When a proto1 proto has a field that looks like:
// optional message<MessageSet> info = 3;
// the protocol compiler produces a field in the generated struct that looks like:
// Info *_proto_.MessageSet `protobuf:"bytes,3,opt,name=info"`
// The package is automatically inserted so there is no need for that proto file to
// import this package.
type _MessageSet_Item struct {
TypeId *int32 `protobuf:"varint,2,req,name=type_id"`
Message []byte `protobuf:"bytes,3,req,name=message"`
}
type MessageSet struct {
Item []*_MessageSet_Item `protobuf:"group,1,rep"`
XXX_unrecognized []byte
// TODO: caching?
}
// Make sure MessageSet is a Message.
var _ Message = (*MessageSet)(nil)
// messageTypeIder is an interface satisfied by a protocol buffer type
// that may be stored in a MessageSet.
type messageTypeIder interface {
MessageTypeId() int32
}
func (ms *MessageSet) find(pb Message) *_MessageSet_Item {
mti, ok := pb.(messageTypeIder)
if !ok {
return nil
}
id := mti.MessageTypeId()
for _, item := range ms.Item {
if *item.TypeId == id {
return item
}
}
return nil
}
func (ms *MessageSet) Has(pb Message) bool {
if ms.find(pb) != nil {
return true
}
return false
}
func (ms *MessageSet) Unmarshal(pb Message) error {
if item := ms.find(pb); item != nil {
return Unmarshal(item.Message, pb)
}
if _, ok := pb.(messageTypeIder); !ok {
return ErrNoMessageTypeId
}
return nil // TODO: return error instead?
}
func (ms *MessageSet) Marshal(pb Message) error {
msg, err := Marshal(pb)
if err != nil {
return err
}
if item := ms.find(pb); item != nil {
// reuse existing item
item.Message = msg
return nil
}
mti, ok := pb.(messageTypeIder)
if !ok {
return ErrNoMessageTypeId
}
mtid := mti.MessageTypeId()
ms.Item = append(ms.Item, &_MessageSet_Item{
TypeId: &mtid,
Message: msg,
})
return nil
}
func (ms *MessageSet) Reset() { *ms = MessageSet{} }
func (ms *MessageSet) String() string { return CompactTextString(ms) }
func (*MessageSet) ProtoMessage() {}
// Support for the message_set_wire_format message option.
func skipVarint(buf []byte) []byte {
i := 0
for ; buf[i]&0x80 != 0; i++ {
}
return buf[i+1:]
}
// MarshalMessageSet encodes the extension map represented by m in the message set wire format.
// It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option.
func MarshalMessageSet(m map[int32]Extension) ([]byte, error) {
if err := encodeExtensionMap(m); err != nil {
return nil, err
}
// Sort extension IDs to provide a deterministic encoding.
// See also enc_map in encode.go.
ids := make([]int, 0, len(m))
for id := range m {
ids = append(ids, int(id))
}
sort.Ints(ids)
ms := &MessageSet{Item: make([]*_MessageSet_Item, 0, len(m))}
for _, id := range ids {
e := m[int32(id)]
// Remove the wire type and field number varint, as well as the length varint.
msg := skipVarint(skipVarint(e.enc))
ms.Item = append(ms.Item, &_MessageSet_Item{
TypeId: Int32(int32(id)),
Message: msg,
})
}
return Marshal(ms)
}
// UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
// It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error {
ms := new(MessageSet)
if err := Unmarshal(buf, ms); err != nil {
return err
}
for _, item := range ms.Item {
id := *item.TypeId
msg := item.Message
// Restore wire type and field number varint, plus length varint.
// Be careful to preserve duplicate items.
b := EncodeVarint(uint64(id)<<3 | WireBytes)
if ext, ok := m[id]; ok {
// Existing data; rip off the tag and length varint
// so we join the new data correctly.
// We can assume that ext.enc is set because we are unmarshaling.
o := ext.enc[len(b):] // skip wire type and field number
_, n := DecodeVarint(o) // calculate length of length varint
o = o[n:] // skip length varint
msg = append(o, msg...) // join old data and new data
}
b = append(b, EncodeVarint(uint64(len(msg)))...)
b = append(b, msg...)
m[id] = Extension{enc: b}
}
return nil
}
// A global registry of types that can be used in a MessageSet.
var messageSetMap = make(map[int32]messageSetDesc)
type messageSetDesc struct {
t reflect.Type // pointer to struct
name string
}
// RegisterMessageSetType is called from the generated code.
func RegisterMessageSetType(i messageTypeIder, name string) {
messageSetMap[i.MessageTypeId()] = messageSetDesc{
t: reflect.TypeOf(i),
name: name,
}
}

View file

@ -0,0 +1,66 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2014 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto
import (
"bytes"
"testing"
)
func TestUnmarshalMessageSetWithDuplicate(t *testing.T) {
// Check that a repeated message set entry will be concatenated.
in := &MessageSet{
Item: []*_MessageSet_Item{
{TypeId: Int32(12345), Message: []byte("hoo")},
{TypeId: Int32(12345), Message: []byte("hah")},
},
}
b, err := Marshal(in)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
t.Logf("Marshaled bytes: %q", b)
m := make(map[int32]Extension)
if err := UnmarshalMessageSet(b, m); err != nil {
t.Fatalf("UnmarshalMessageSet: %v", err)
}
ext, ok := m[12345]
if !ok {
t.Fatalf("Didn't retrieve extension 12345; map is %v", m)
}
// Skip wire type/field number and length varints.
got := skipVarint(skipVarint(ext.enc))
if want := []byte("hoohah"); !bytes.Equal(got, want) {
t.Errorf("Combined extension is %q, want %q", got, want)
}
}

View file

@ -0,0 +1,384 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2012 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// +build appengine,!appenginevm
// This file contains an implementation of proto field accesses using package reflect.
// It is slower than the code in pointer_unsafe.go but it avoids package unsafe and can
// be used on App Engine.
package proto
import (
"math"
"reflect"
)
// A structPointer is a pointer to a struct.
type structPointer struct {
v reflect.Value
}
// toStructPointer returns a structPointer equivalent to the given reflect value.
// The reflect value must itself be a pointer to a struct.
func toStructPointer(v reflect.Value) structPointer {
return structPointer{v}
}
// IsNil reports whether p is nil.
func structPointer_IsNil(p structPointer) bool {
return p.v.IsNil()
}
// Interface returns the struct pointer as an interface value.
func structPointer_Interface(p structPointer, _ reflect.Type) interface{} {
return p.v.Interface()
}
// A field identifies a field in a struct, accessible from a structPointer.
// In this implementation, a field is identified by the sequence of field indices
// passed to reflect's FieldByIndex.
type field []int
// toField returns a field equivalent to the given reflect field.
func toField(f *reflect.StructField) field {
return f.Index
}
// invalidField is an invalid field identifier.
var invalidField = field(nil)
// IsValid reports whether the field identifier is valid.
func (f field) IsValid() bool { return f != nil }
// field returns the given field in the struct as a reflect value.
func structPointer_field(p structPointer, f field) reflect.Value {
// Special case: an extension map entry with a value of type T
// passes a *T to the struct-handling code with a zero field,
// expecting that it will be treated as equivalent to *struct{ X T },
// which has the same memory layout. We have to handle that case
// specially, because reflect will panic if we call FieldByIndex on a
// non-struct.
if f == nil {
return p.v.Elem()
}
return p.v.Elem().FieldByIndex(f)
}
// ifield returns the given field in the struct as an interface value.
func structPointer_ifield(p structPointer, f field) interface{} {
return structPointer_field(p, f).Addr().Interface()
}
// Bytes returns the address of a []byte field in the struct.
func structPointer_Bytes(p structPointer, f field) *[]byte {
return structPointer_ifield(p, f).(*[]byte)
}
// BytesSlice returns the address of a [][]byte field in the struct.
func structPointer_BytesSlice(p structPointer, f field) *[][]byte {
return structPointer_ifield(p, f).(*[][]byte)
}
// Bool returns the address of a *bool field in the struct.
func structPointer_Bool(p structPointer, f field) **bool {
return structPointer_ifield(p, f).(**bool)
}
// BoolSlice returns the address of a []bool field in the struct.
func structPointer_BoolSlice(p structPointer, f field) *[]bool {
return structPointer_ifield(p, f).(*[]bool)
}
// String returns the address of a *string field in the struct.
func structPointer_String(p structPointer, f field) **string {
return structPointer_ifield(p, f).(**string)
}
// StringSlice returns the address of a []string field in the struct.
func structPointer_StringSlice(p structPointer, f field) *[]string {
return structPointer_ifield(p, f).(*[]string)
}
// ExtMap returns the address of an extension map field in the struct.
func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
return structPointer_ifield(p, f).(*map[int32]Extension)
}
// SetStructPointer writes a *struct field in the struct.
func structPointer_SetStructPointer(p structPointer, f field, q structPointer) {
structPointer_field(p, f).Set(q.v)
}
// GetStructPointer reads a *struct field in the struct.
func structPointer_GetStructPointer(p structPointer, f field) structPointer {
return structPointer{structPointer_field(p, f)}
}
// StructPointerSlice the address of a []*struct field in the struct.
func structPointer_StructPointerSlice(p structPointer, f field) structPointerSlice {
return structPointerSlice{structPointer_field(p, f)}
}
// A structPointerSlice represents the address of a slice of pointers to structs
// (themselves messages or groups). That is, v.Type() is *[]*struct{...}.
type structPointerSlice struct {
v reflect.Value
}
func (p structPointerSlice) Len() int { return p.v.Len() }
func (p structPointerSlice) Index(i int) structPointer { return structPointer{p.v.Index(i)} }
func (p structPointerSlice) Append(q structPointer) {
p.v.Set(reflect.Append(p.v, q.v))
}
var (
int32Type = reflect.TypeOf(int32(0))
uint32Type = reflect.TypeOf(uint32(0))
float32Type = reflect.TypeOf(float32(0))
int64Type = reflect.TypeOf(int64(0))
uint64Type = reflect.TypeOf(uint64(0))
float64Type = reflect.TypeOf(float64(0))
)
// A word32 represents a field of type *int32, *uint32, *float32, or *enum.
// That is, v.Type() is *int32, *uint32, *float32, or *enum and v is assignable.
type word32 struct {
v reflect.Value
}
// IsNil reports whether p is nil.
func word32_IsNil(p word32) bool {
return p.v.IsNil()
}
// Set sets p to point at a newly allocated word with bits set to x.
func word32_Set(p word32, o *Buffer, x uint32) {
t := p.v.Type().Elem()
switch t {
case int32Type:
if len(o.int32s) == 0 {
o.int32s = make([]int32, uint32PoolSize)
}
o.int32s[0] = int32(x)
p.v.Set(reflect.ValueOf(&o.int32s[0]))
o.int32s = o.int32s[1:]
return
case uint32Type:
if len(o.uint32s) == 0 {
o.uint32s = make([]uint32, uint32PoolSize)
}
o.uint32s[0] = x
p.v.Set(reflect.ValueOf(&o.uint32s[0]))
o.uint32s = o.uint32s[1:]
return
case float32Type:
if len(o.float32s) == 0 {
o.float32s = make([]float32, uint32PoolSize)
}
o.float32s[0] = math.Float32frombits(x)
p.v.Set(reflect.ValueOf(&o.float32s[0]))
o.float32s = o.float32s[1:]
return
}
// must be enum
p.v.Set(reflect.New(t))
p.v.Elem().SetInt(int64(int32(x)))
}
// Get gets the bits pointed at by p, as a uint32.
func word32_Get(p word32) uint32 {
elem := p.v.Elem()
switch elem.Kind() {
case reflect.Int32:
return uint32(elem.Int())
case reflect.Uint32:
return uint32(elem.Uint())
case reflect.Float32:
return math.Float32bits(float32(elem.Float()))
}
panic("unreachable")
}
// Word32 returns a reference to a *int32, *uint32, *float32, or *enum field in the struct.
func structPointer_Word32(p structPointer, f field) word32 {
return word32{structPointer_field(p, f)}
}
// A word32Slice is a slice of 32-bit values.
// That is, v.Type() is []int32, []uint32, []float32, or []enum.
type word32Slice struct {
v reflect.Value
}
func (p word32Slice) Append(x uint32) {
n, m := p.v.Len(), p.v.Cap()
if n < m {
p.v.SetLen(n + 1)
} else {
t := p.v.Type().Elem()
p.v.Set(reflect.Append(p.v, reflect.Zero(t)))
}
elem := p.v.Index(n)
switch elem.Kind() {
case reflect.Int32:
elem.SetInt(int64(int32(x)))
case reflect.Uint32:
elem.SetUint(uint64(x))
case reflect.Float32:
elem.SetFloat(float64(math.Float32frombits(x)))
}
}
func (p word32Slice) Len() int {
return p.v.Len()
}
func (p word32Slice) Index(i int) uint32 {
elem := p.v.Index(i)
switch elem.Kind() {
case reflect.Int32:
return uint32(elem.Int())
case reflect.Uint32:
return uint32(elem.Uint())
case reflect.Float32:
return math.Float32bits(float32(elem.Float()))
}
panic("unreachable")
}
// Word32Slice returns a reference to a []int32, []uint32, []float32, or []enum field in the struct.
func structPointer_Word32Slice(p structPointer, f field) word32Slice {
return word32Slice{structPointer_field(p, f)}
}
// word64 is like word32 but for 64-bit values.
type word64 struct {
v reflect.Value
}
func word64_Set(p word64, o *Buffer, x uint64) {
t := p.v.Type().Elem()
switch t {
case int64Type:
if len(o.int64s) == 0 {
o.int64s = make([]int64, uint64PoolSize)
}
o.int64s[0] = int64(x)
p.v.Set(reflect.ValueOf(&o.int64s[0]))
o.int64s = o.int64s[1:]
return
case uint64Type:
if len(o.uint64s) == 0 {
o.uint64s = make([]uint64, uint64PoolSize)
}
o.uint64s[0] = x
p.v.Set(reflect.ValueOf(&o.uint64s[0]))
o.uint64s = o.uint64s[1:]
return
case float64Type:
if len(o.float64s) == 0 {
o.float64s = make([]float64, uint64PoolSize)
}
o.float64s[0] = math.Float64frombits(x)
p.v.Set(reflect.ValueOf(&o.float64s[0]))
o.float64s = o.float64s[1:]
return
}
panic("unreachable")
}
func word64_IsNil(p word64) bool {
return p.v.IsNil()
}
func word64_Get(p word64) uint64 {
elem := p.v.Elem()
switch elem.Kind() {
case reflect.Int64:
return uint64(elem.Int())
case reflect.Uint64:
return elem.Uint()
case reflect.Float64:
return math.Float64bits(elem.Float())
}
panic("unreachable")
}
func structPointer_Word64(p structPointer, f field) word64 {
return word64{structPointer_field(p, f)}
}
type word64Slice struct {
v reflect.Value
}
func (p word64Slice) Append(x uint64) {
n, m := p.v.Len(), p.v.Cap()
if n < m {
p.v.SetLen(n + 1)
} else {
t := p.v.Type().Elem()
p.v.Set(reflect.Append(p.v, reflect.Zero(t)))
}
elem := p.v.Index(n)
switch elem.Kind() {
case reflect.Int64:
elem.SetInt(int64(int64(x)))
case reflect.Uint64:
elem.SetUint(uint64(x))
case reflect.Float64:
elem.SetFloat(float64(math.Float64frombits(x)))
}
}
func (p word64Slice) Len() int {
return p.v.Len()
}
func (p word64Slice) Index(i int) uint64 {
elem := p.v.Index(i)
switch elem.Kind() {
case reflect.Int64:
return uint64(elem.Int())
case reflect.Uint64:
return uint64(elem.Uint())
case reflect.Float64:
return math.Float64bits(float64(elem.Float()))
}
panic("unreachable")
}
func structPointer_Word64Slice(p structPointer, f field) word64Slice {
return word64Slice{structPointer_field(p, f)}
}

View file

@ -0,0 +1,218 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2012 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// +build !appengine appenginevm
// This file contains the implementation of the proto field accesses using package unsafe.
package proto
import (
"reflect"
"unsafe"
)
// NOTE: These type_Foo functions would more idiomatically be methods,
// but Go does not allow methods on pointer types, and we must preserve
// some pointer type for the garbage collector. We use these
// funcs with clunky names as our poor approximation to methods.
//
// An alternative would be
// type structPointer struct { p unsafe.Pointer }
// but that does not registerize as well.
// A structPointer is a pointer to a struct.
type structPointer unsafe.Pointer
// toStructPointer returns a structPointer equivalent to the given reflect value.
func toStructPointer(v reflect.Value) structPointer {
return structPointer(unsafe.Pointer(v.Pointer()))
}
// IsNil reports whether p is nil.
func structPointer_IsNil(p structPointer) bool {
return p == nil
}
// Interface returns the struct pointer, assumed to have element type t,
// as an interface value.
func structPointer_Interface(p structPointer, t reflect.Type) interface{} {
return reflect.NewAt(t, unsafe.Pointer(p)).Interface()
}
// A field identifies a field in a struct, accessible from a structPointer.
// In this implementation, a field is identified by its byte offset from the start of the struct.
type field uintptr
// toField returns a field equivalent to the given reflect field.
func toField(f *reflect.StructField) field {
return field(f.Offset)
}
// invalidField is an invalid field identifier.
const invalidField = ^field(0)
// IsValid reports whether the field identifier is valid.
func (f field) IsValid() bool {
return f != ^field(0)
}
// Bytes returns the address of a []byte field in the struct.
func structPointer_Bytes(p structPointer, f field) *[]byte {
return (*[]byte)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// BytesSlice returns the address of a [][]byte field in the struct.
func structPointer_BytesSlice(p structPointer, f field) *[][]byte {
return (*[][]byte)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// Bool returns the address of a *bool field in the struct.
func structPointer_Bool(p structPointer, f field) **bool {
return (**bool)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// BoolSlice returns the address of a []bool field in the struct.
func structPointer_BoolSlice(p structPointer, f field) *[]bool {
return (*[]bool)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// String returns the address of a *string field in the struct.
func structPointer_String(p structPointer, f field) **string {
return (**string)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// StringSlice returns the address of a []string field in the struct.
func structPointer_StringSlice(p structPointer, f field) *[]string {
return (*[]string)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// ExtMap returns the address of an extension map field in the struct.
func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// SetStructPointer writes a *struct field in the struct.
func structPointer_SetStructPointer(p structPointer, f field, q structPointer) {
*(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f))) = q
}
// GetStructPointer reads a *struct field in the struct.
func structPointer_GetStructPointer(p structPointer, f field) structPointer {
return *(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// StructPointerSlice the address of a []*struct field in the struct.
func structPointer_StructPointerSlice(p structPointer, f field) *structPointerSlice {
return (*structPointerSlice)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// A structPointerSlice represents a slice of pointers to structs (themselves submessages or groups).
type structPointerSlice []structPointer
func (v *structPointerSlice) Len() int { return len(*v) }
func (v *structPointerSlice) Index(i int) structPointer { return (*v)[i] }
func (v *structPointerSlice) Append(p structPointer) { *v = append(*v, p) }
// A word32 is the address of a "pointer to 32-bit value" field.
type word32 **uint32
// IsNil reports whether *v is nil.
func word32_IsNil(p word32) bool {
return *p == nil
}
// Set sets *v to point at a newly allocated word set to x.
func word32_Set(p word32, o *Buffer, x uint32) {
if len(o.uint32s) == 0 {
o.uint32s = make([]uint32, uint32PoolSize)
}
o.uint32s[0] = x
*p = &o.uint32s[0]
o.uint32s = o.uint32s[1:]
}
// Get gets the value pointed at by *v.
func word32_Get(p word32) uint32 {
return **p
}
// Word32 returns the address of a *int32, *uint32, *float32, or *enum field in the struct.
func structPointer_Word32(p structPointer, f field) word32 {
return word32((**uint32)(unsafe.Pointer(uintptr(p) + uintptr(f))))
}
// A word32Slice is a slice of 32-bit values.
type word32Slice []uint32
func (v *word32Slice) Append(x uint32) { *v = append(*v, x) }
func (v *word32Slice) Len() int { return len(*v) }
func (v *word32Slice) Index(i int) uint32 { return (*v)[i] }
// Word32Slice returns the address of a []int32, []uint32, []float32, or []enum field in the struct.
func structPointer_Word32Slice(p structPointer, f field) *word32Slice {
return (*word32Slice)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// word64 is like word32 but for 64-bit values.
type word64 **uint64
func word64_Set(p word64, o *Buffer, x uint64) {
if len(o.uint64s) == 0 {
o.uint64s = make([]uint64, uint64PoolSize)
}
o.uint64s[0] = x
*p = &o.uint64s[0]
o.uint64s = o.uint64s[1:]
}
func word64_IsNil(p word64) bool {
return *p == nil
}
func word64_Get(p word64) uint64 {
return **p
}
func structPointer_Word64(p structPointer, f field) word64 {
return word64((**uint64)(unsafe.Pointer(uintptr(p) + uintptr(f))))
}
// word64Slice is like word32Slice but for 64-bit values.
type word64Slice []uint64
func (v *word64Slice) Append(x uint64) { *v = append(*v, x) }
func (v *word64Slice) Len() int { return len(*v) }
func (v *word64Slice) Index(i int) uint64 { return (*v)[i] }
func structPointer_Word64Slice(p structPointer, f field) *word64Slice {
return (*word64Slice)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}

View file

@ -0,0 +1,658 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto
/*
* Routines for encoding data into the wire format for protocol buffers.
*/
import (
"fmt"
"os"
"reflect"
"sort"
"strconv"
"strings"
"sync"
)
const debug bool = false
// Constants that identify the encoding of a value on the wire.
const (
WireVarint = 0
WireFixed64 = 1
WireBytes = 2
WireStartGroup = 3
WireEndGroup = 4
WireFixed32 = 5
)
const startSize = 10 // initial slice/string sizes
// Encoders are defined in encode.go
// An encoder outputs the full representation of a field, including its
// tag and encoder type.
type encoder func(p *Buffer, prop *Properties, base structPointer) error
// A valueEncoder encodes a single integer in a particular encoding.
type valueEncoder func(o *Buffer, x uint64) error
// Sizers are defined in encode.go
// A sizer returns the encoded size of a field, including its tag and encoder
// type.
type sizer func(prop *Properties, base structPointer) int
// A valueSizer returns the encoded size of a single integer in a particular
// encoding.
type valueSizer func(x uint64) int
// Decoders are defined in decode.go
// A decoder creates a value from its wire representation.
// Unrecognized subelements are saved in unrec.
type decoder func(p *Buffer, prop *Properties, base structPointer) error
// A valueDecoder decodes a single integer in a particular encoding.
type valueDecoder func(o *Buffer) (x uint64, err error)
// tagMap is an optimization over map[int]int for typical protocol buffer
// use-cases. Encoded protocol buffers are often in tag order with small tag
// numbers.
type tagMap struct {
fastTags []int
slowTags map[int]int
}
// tagMapFastLimit is the upper bound on the tag number that will be stored in
// the tagMap slice rather than its map.
const tagMapFastLimit = 1024
func (p *tagMap) get(t int) (int, bool) {
if t > 0 && t < tagMapFastLimit {
if t >= len(p.fastTags) {
return 0, false
}
fi := p.fastTags[t]
return fi, fi >= 0
}
fi, ok := p.slowTags[t]
return fi, ok
}
func (p *tagMap) put(t int, fi int) {
if t > 0 && t < tagMapFastLimit {
for len(p.fastTags) < t+1 {
p.fastTags = append(p.fastTags, -1)
}
p.fastTags[t] = fi
return
}
if p.slowTags == nil {
p.slowTags = make(map[int]int)
}
p.slowTags[t] = fi
}
// StructProperties represents properties for all the fields of a struct.
// decoderTags and decoderOrigNames should only be used by the decoder.
type StructProperties struct {
Prop []*Properties // properties for each field
reqCount int // required count
decoderTags tagMap // map from proto tag to struct field number
decoderOrigNames map[string]int // map from original name to struct field number
order []int // list of struct field numbers in tag order
unrecField field // field id of the XXX_unrecognized []byte field
extendable bool // is this an extendable proto
}
// Implement the sorting interface so we can sort the fields in tag order, as recommended by the spec.
// See encode.go, (*Buffer).enc_struct.
func (sp *StructProperties) Len() int { return len(sp.order) }
func (sp *StructProperties) Less(i, j int) bool {
return sp.Prop[sp.order[i]].Tag < sp.Prop[sp.order[j]].Tag
}
func (sp *StructProperties) Swap(i, j int) { sp.order[i], sp.order[j] = sp.order[j], sp.order[i] }
// Properties represents the protocol-specific behavior of a single struct field.
type Properties struct {
Name string // name of the field, for error messages
OrigName string // original name before protocol compiler (always set)
Wire string
WireType int
Tag int
Required bool
Optional bool
Repeated bool
Packed bool // relevant for repeated primitives only
Enum string // set for enum types only
Default string // default value
HasDefault bool // whether an explicit default was provided
def_uint64 uint64
enc encoder
valEnc valueEncoder // set for bool and numeric types only
field field
tagcode []byte // encoding of EncodeVarint((Tag<<3)|WireType)
tagbuf [8]byte
stype reflect.Type // set for struct types only
sprop *StructProperties // set for struct types only
isMarshaler bool
isUnmarshaler bool
size sizer
valSize valueSizer // set for bool and numeric types only
dec decoder
valDec valueDecoder // set for bool and numeric types only
// If this is a packable field, this will be the decoder for the packed version of the field.
packedDec decoder
}
// String formats the properties in the protobuf struct field tag style.
func (p *Properties) String() string {
s := p.Wire
s = ","
s += strconv.Itoa(p.Tag)
if p.Required {
s += ",req"
}
if p.Optional {
s += ",opt"
}
if p.Repeated {
s += ",rep"
}
if p.Packed {
s += ",packed"
}
if p.OrigName != p.Name {
s += ",name=" + p.OrigName
}
if len(p.Enum) > 0 {
s += ",enum=" + p.Enum
}
if p.HasDefault {
s += ",def=" + p.Default
}
return s
}
// Parse populates p by parsing a string in the protobuf struct field tag style.
func (p *Properties) Parse(s string) {
// "bytes,49,opt,name=foo,def=hello!"
fields := strings.Split(s, ",") // breaks def=, but handled below.
if len(fields) < 2 {
fmt.Fprintf(os.Stderr, "proto: tag has too few fields: %q\n", s)
return
}
p.Wire = fields[0]
switch p.Wire {
case "varint":
p.WireType = WireVarint
p.valEnc = (*Buffer).EncodeVarint
p.valDec = (*Buffer).DecodeVarint
p.valSize = sizeVarint
case "fixed32":
p.WireType = WireFixed32
p.valEnc = (*Buffer).EncodeFixed32
p.valDec = (*Buffer).DecodeFixed32
p.valSize = sizeFixed32
case "fixed64":
p.WireType = WireFixed64
p.valEnc = (*Buffer).EncodeFixed64
p.valDec = (*Buffer).DecodeFixed64
p.valSize = sizeFixed64
case "zigzag32":
p.WireType = WireVarint
p.valEnc = (*Buffer).EncodeZigzag32
p.valDec = (*Buffer).DecodeZigzag32
p.valSize = sizeZigzag32
case "zigzag64":
p.WireType = WireVarint
p.valEnc = (*Buffer).EncodeZigzag64
p.valDec = (*Buffer).DecodeZigzag64
p.valSize = sizeZigzag64
case "bytes", "group":
p.WireType = WireBytes
// no numeric converter for non-numeric types
default:
fmt.Fprintf(os.Stderr, "proto: tag has unknown wire type: %q\n", s)
return
}
var err error
p.Tag, err = strconv.Atoi(fields[1])
if err != nil {
return
}
for i := 2; i < len(fields); i++ {
f := fields[i]
switch {
case f == "req":
p.Required = true
case f == "opt":
p.Optional = true
case f == "rep":
p.Repeated = true
case f == "packed":
p.Packed = true
case strings.HasPrefix(f, "name="):
p.OrigName = f[5:]
case strings.HasPrefix(f, "enum="):
p.Enum = f[5:]
case strings.HasPrefix(f, "def="):
p.HasDefault = true
p.Default = f[4:] // rest of string
if i+1 < len(fields) {
// Commas aren't escaped, and def is always last.
p.Default += "," + strings.Join(fields[i+1:], ",")
break
}
}
}
}
func logNoSliceEnc(t1, t2 reflect.Type) {
fmt.Fprintf(os.Stderr, "proto: no slice oenc for %T = []%T\n", t1, t2)
}
var protoMessageType = reflect.TypeOf((*Message)(nil)).Elem()
// Initialize the fields for encoding and decoding.
func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) {
p.enc = nil
p.dec = nil
p.size = nil
switch t1 := typ; t1.Kind() {
default:
fmt.Fprintf(os.Stderr, "proto: no coders for %T\n", t1)
case reflect.Ptr:
switch t2 := t1.Elem(); t2.Kind() {
default:
fmt.Fprintf(os.Stderr, "proto: no encoder function for %T -> %T\n", t1, t2)
break
case reflect.Bool:
p.enc = (*Buffer).enc_bool
p.dec = (*Buffer).dec_bool
p.size = size_bool
case reflect.Int32:
p.enc = (*Buffer).enc_int32
p.dec = (*Buffer).dec_int32
p.size = size_int32
case reflect.Uint32:
p.enc = (*Buffer).enc_uint32
p.dec = (*Buffer).dec_int32 // can reuse
p.size = size_uint32
case reflect.Int64, reflect.Uint64:
p.enc = (*Buffer).enc_int64
p.dec = (*Buffer).dec_int64
p.size = size_int64
case reflect.Float32:
p.enc = (*Buffer).enc_uint32 // can just treat them as bits
p.dec = (*Buffer).dec_int32
p.size = size_uint32
case reflect.Float64:
p.enc = (*Buffer).enc_int64 // can just treat them as bits
p.dec = (*Buffer).dec_int64
p.size = size_int64
case reflect.String:
p.enc = (*Buffer).enc_string
p.dec = (*Buffer).dec_string
p.size = size_string
case reflect.Struct:
p.stype = t1.Elem()
p.isMarshaler = isMarshaler(t1)
p.isUnmarshaler = isUnmarshaler(t1)
if p.Wire == "bytes" {
p.enc = (*Buffer).enc_struct_message
p.dec = (*Buffer).dec_struct_message
p.size = size_struct_message
} else {
p.enc = (*Buffer).enc_struct_group
p.dec = (*Buffer).dec_struct_group
p.size = size_struct_group
}
}
case reflect.Slice:
switch t2 := t1.Elem(); t2.Kind() {
default:
logNoSliceEnc(t1, t2)
break
case reflect.Bool:
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_bool
p.size = size_slice_packed_bool
} else {
p.enc = (*Buffer).enc_slice_bool
p.size = size_slice_bool
}
p.dec = (*Buffer).dec_slice_bool
p.packedDec = (*Buffer).dec_slice_packed_bool
case reflect.Int32:
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_int32
p.size = size_slice_packed_int32
} else {
p.enc = (*Buffer).enc_slice_int32
p.size = size_slice_int32
}
p.dec = (*Buffer).dec_slice_int32
p.packedDec = (*Buffer).dec_slice_packed_int32
case reflect.Uint32:
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_uint32
p.size = size_slice_packed_uint32
} else {
p.enc = (*Buffer).enc_slice_uint32
p.size = size_slice_uint32
}
p.dec = (*Buffer).dec_slice_int32
p.packedDec = (*Buffer).dec_slice_packed_int32
case reflect.Int64, reflect.Uint64:
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_int64
p.size = size_slice_packed_int64
} else {
p.enc = (*Buffer).enc_slice_int64
p.size = size_slice_int64
}
p.dec = (*Buffer).dec_slice_int64
p.packedDec = (*Buffer).dec_slice_packed_int64
case reflect.Uint8:
p.enc = (*Buffer).enc_slice_byte
p.dec = (*Buffer).dec_slice_byte
p.size = size_slice_byte
case reflect.Float32, reflect.Float64:
switch t2.Bits() {
case 32:
// can just treat them as bits
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_uint32
p.size = size_slice_packed_uint32
} else {
p.enc = (*Buffer).enc_slice_uint32
p.size = size_slice_uint32
}
p.dec = (*Buffer).dec_slice_int32
p.packedDec = (*Buffer).dec_slice_packed_int32
case 64:
// can just treat them as bits
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_int64
p.size = size_slice_packed_int64
} else {
p.enc = (*Buffer).enc_slice_int64
p.size = size_slice_int64
}
p.dec = (*Buffer).dec_slice_int64
p.packedDec = (*Buffer).dec_slice_packed_int64
default:
logNoSliceEnc(t1, t2)
break
}
case reflect.String:
p.enc = (*Buffer).enc_slice_string
p.dec = (*Buffer).dec_slice_string
p.size = size_slice_string
case reflect.Ptr:
switch t3 := t2.Elem(); t3.Kind() {
default:
fmt.Fprintf(os.Stderr, "proto: no ptr oenc for %T -> %T -> %T\n", t1, t2, t3)
break
case reflect.Struct:
p.stype = t2.Elem()
p.isMarshaler = isMarshaler(t2)
p.isUnmarshaler = isUnmarshaler(t2)
if p.Wire == "bytes" {
p.enc = (*Buffer).enc_slice_struct_message
p.dec = (*Buffer).dec_slice_struct_message
p.size = size_slice_struct_message
} else {
p.enc = (*Buffer).enc_slice_struct_group
p.dec = (*Buffer).dec_slice_struct_group
p.size = size_slice_struct_group
}
}
case reflect.Slice:
switch t2.Elem().Kind() {
default:
fmt.Fprintf(os.Stderr, "proto: no slice elem oenc for %T -> %T -> %T\n", t1, t2, t2.Elem())
break
case reflect.Uint8:
p.enc = (*Buffer).enc_slice_slice_byte
p.dec = (*Buffer).dec_slice_slice_byte
p.size = size_slice_slice_byte
}
}
}
// precalculate tag code
wire := p.WireType
if p.Packed {
wire = WireBytes
}
x := uint32(p.Tag)<<3 | uint32(wire)
i := 0
for i = 0; x > 127; i++ {
p.tagbuf[i] = 0x80 | uint8(x&0x7F)
x >>= 7
}
p.tagbuf[i] = uint8(x)
p.tagcode = p.tagbuf[0 : i+1]
if p.stype != nil {
if lockGetProp {
p.sprop = GetProperties(p.stype)
} else {
p.sprop = getPropertiesLocked(p.stype)
}
}
}
var (
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
)
// isMarshaler reports whether type t implements Marshaler.
func isMarshaler(t reflect.Type) bool {
// We're checking for (likely) pointer-receiver methods
// so if t is not a pointer, something is very wrong.
// The calls above only invoke isMarshaler on pointer types.
if t.Kind() != reflect.Ptr {
panic("proto: misuse of isMarshaler")
}
return t.Implements(marshalerType)
}
// isUnmarshaler reports whether type t implements Unmarshaler.
func isUnmarshaler(t reflect.Type) bool {
// We're checking for (likely) pointer-receiver methods
// so if t is not a pointer, something is very wrong.
// The calls above only invoke isUnmarshaler on pointer types.
if t.Kind() != reflect.Ptr {
panic("proto: misuse of isUnmarshaler")
}
return t.Implements(unmarshalerType)
}
// Init populates the properties from a protocol buffer struct tag.
func (p *Properties) Init(typ reflect.Type, name, tag string, f *reflect.StructField) {
p.init(typ, name, tag, f, true)
}
func (p *Properties) init(typ reflect.Type, name, tag string, f *reflect.StructField, lockGetProp bool) {
// "bytes,49,opt,def=hello!"
p.Name = name
p.OrigName = name
if f != nil {
p.field = toField(f)
}
if tag == "" {
return
}
p.Parse(tag)
p.setEncAndDec(typ, lockGetProp)
}
var (
mutex sync.Mutex
propertiesMap = make(map[reflect.Type]*StructProperties)
)
// GetProperties returns the list of properties for the type represented by t.
func GetProperties(t reflect.Type) *StructProperties {
mutex.Lock()
sprop := getPropertiesLocked(t)
mutex.Unlock()
return sprop
}
// getPropertiesLocked requires that mutex is held.
func getPropertiesLocked(t reflect.Type) *StructProperties {
if prop, ok := propertiesMap[t]; ok {
if collectStats {
stats.Chit++
}
return prop
}
if collectStats {
stats.Cmiss++
}
prop := new(StructProperties)
// in case of recursive protos, fill this in now.
propertiesMap[t] = prop
// build properties
prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType)
prop.unrecField = invalidField
prop.Prop = make([]*Properties, t.NumField())
prop.order = make([]int, t.NumField())
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
p := new(Properties)
name := f.Name
p.init(f.Type, name, f.Tag.Get("protobuf"), &f, false)
if f.Name == "XXX_extensions" { // special case
p.enc = (*Buffer).enc_map
p.dec = nil // not needed
p.size = size_map
}
if f.Name == "XXX_unrecognized" { // special case
prop.unrecField = toField(&f)
}
prop.Prop[i] = p
prop.order[i] = i
if debug {
print(i, " ", f.Name, " ", t.String(), " ")
if p.Tag > 0 {
print(p.String())
}
print("\n")
}
if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") {
fmt.Fprintln(os.Stderr, "proto: no encoder for", f.Name, f.Type.String(), "[GetProperties]")
}
}
// Re-order prop.order.
sort.Sort(prop)
// build required counts
// build tags
reqCount := 0
prop.decoderOrigNames = make(map[string]int)
for i, p := range prop.Prop {
if strings.HasPrefix(p.Name, "XXX_") {
// Internal fields should not appear in tags/origNames maps.
// They are handled specially when encoding and decoding.
continue
}
if p.Required {
reqCount++
}
prop.decoderTags.put(p.Tag, i)
prop.decoderOrigNames[p.OrigName] = i
}
prop.reqCount = reqCount
return prop
}
// Return the Properties object for the x[0]'th field of the structure.
func propByIndex(t reflect.Type, x []int) *Properties {
if len(x) != 1 {
fmt.Fprintf(os.Stderr, "proto: field index dimension %d (not 1) for type %s\n", len(x), t)
return nil
}
prop := GetProperties(t)
return prop.Prop[x[0]]
}
// Get the address and type of a pointer to a struct from an interface.
func getbase(pb Message) (t reflect.Type, b structPointer, err error) {
if pb == nil {
err = ErrNil
return
}
// get the reflect type of the pointer to the struct.
t = reflect.TypeOf(pb)
// get the address of the struct.
value := reflect.ValueOf(pb)
b = toStructPointer(value)
return
}
// A global registry of enum types.
// The generated code will register the generated maps by calling RegisterEnum.
var enumValueMaps = make(map[string]map[string]int32)
// RegisterEnum is called from the generated code to install the enum descriptor
// maps into the global table to aid parsing text format protocol buffers.
func RegisterEnum(typeName string, unusedNameMap map[int32]string, valueMap map[string]int32) {
if _, ok := enumValueMaps[typeName]; ok {
panic("proto: duplicate enum registered: " + typeName)
}
enumValueMaps[typeName] = valueMap
}

View file

@ -0,0 +1,63 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2012 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto
import (
"testing"
)
// This is a separate file and package from size_test.go because that one uses
// generated messages and thus may not be in package proto without having a circular
// dependency, whereas this file tests unexported details of size.go.
func TestVarintSize(t *testing.T) {
// Check the edge cases carefully.
testCases := []struct {
n uint64
size int
}{
{0, 1},
{1, 1},
{127, 1},
{128, 2},
{16383, 2},
{16384, 3},
{1<<63 - 1, 9},
{1 << 63, 10},
}
for _, tc := range testCases {
size := sizeVarint(tc.n)
if size != tc.size {
t.Errorf("sizeVarint(%d) = %d, want %d", tc.n, size, tc.size)
}
}
}

View file

@ -0,0 +1,120 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2012 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto_test
import (
"log"
"testing"
pb "./testdata"
. "code.google.com/p/goprotobuf/proto"
)
var messageWithExtension1 = &pb.MyMessage{Count: Int32(7)}
// messageWithExtension2 is in equal_test.go.
var messageWithExtension3 = &pb.MyMessage{Count: Int32(8)}
func init() {
if err := SetExtension(messageWithExtension1, pb.E_Ext_More, &pb.Ext{Data: String("Abbott")}); err != nil {
log.Panicf("SetExtension: %v", err)
}
if err := SetExtension(messageWithExtension3, pb.E_Ext_More, &pb.Ext{Data: String("Costello")}); err != nil {
log.Panicf("SetExtension: %v", err)
}
// Force messageWithExtension3 to have the extension encoded.
Marshal(messageWithExtension3)
}
var SizeTests = []struct {
desc string
pb Message
}{
{"empty", &pb.OtherMessage{}},
// Basic types.
{"bool", &pb.Defaults{F_Bool: Bool(true)}},
{"int32", &pb.Defaults{F_Int32: Int32(12)}},
{"negative int32", &pb.Defaults{F_Int32: Int32(-1)}},
{"small int64", &pb.Defaults{F_Int64: Int64(1)}},
{"big int64", &pb.Defaults{F_Int64: Int64(1 << 20)}},
{"negative int64", &pb.Defaults{F_Int64: Int64(-1)}},
{"fixed32", &pb.Defaults{F_Fixed32: Uint32(71)}},
{"fixed64", &pb.Defaults{F_Fixed64: Uint64(72)}},
{"uint32", &pb.Defaults{F_Uint32: Uint32(123)}},
{"uint64", &pb.Defaults{F_Uint64: Uint64(124)}},
{"float", &pb.Defaults{F_Float: Float32(12.6)}},
{"double", &pb.Defaults{F_Double: Float64(13.9)}},
{"string", &pb.Defaults{F_String: String("niles")}},
{"bytes", &pb.Defaults{F_Bytes: []byte("wowsa")}},
{"bytes, empty", &pb.Defaults{F_Bytes: []byte{}}},
{"sint32", &pb.Defaults{F_Sint32: Int32(65)}},
{"sint64", &pb.Defaults{F_Sint64: Int64(67)}},
{"enum", &pb.Defaults{F_Enum: pb.Defaults_BLUE.Enum()}},
// Repeated.
{"empty repeated bool", &pb.MoreRepeated{Bools: []bool{}}},
{"repeated bool", &pb.MoreRepeated{Bools: []bool{false, true, true, false}}},
{"packed repeated bool", &pb.MoreRepeated{BoolsPacked: []bool{false, true, true, false, true, true, true}}},
{"repeated int32", &pb.MoreRepeated{Ints: []int32{1, 12203, 1729, -1}}},
{"repeated int32 packed", &pb.MoreRepeated{IntsPacked: []int32{1, 12203, 1729}}},
{"repeated int64 packed", &pb.MoreRepeated{Int64SPacked: []int64{
// Need enough large numbers to verify that the header is counting the number of bytes
// for the field, not the number of elements.
1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62,
1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62,
}}},
{"repeated string", &pb.MoreRepeated{Strings: []string{"r", "ken", "gri"}}},
{"repeated fixed", &pb.MoreRepeated{Fixeds: []uint32{1, 2, 3, 4}}},
// Nested.
{"nested", &pb.OldMessage{Nested: &pb.OldMessage_Nested{Name: String("whatever")}}},
{"group", &pb.GroupOld{G: &pb.GroupOld_G{X: Int32(12345)}}},
// Other things.
{"unrecognized", &pb.MoreRepeated{XXX_unrecognized: []byte{13<<3 | 0, 4}}},
{"extension (unencoded)", messageWithExtension1},
{"extension (encoded)", messageWithExtension3},
}
func TestSize(t *testing.T) {
for _, tc := range SizeTests {
size := Size(tc.pb)
b, err := Marshal(tc.pb)
if err != nil {
t.Errorf("%v: Marshal failed: %v", tc.desc, err)
continue
}
if size != len(b) {
t.Errorf("%v: Size(%v) = %d, want %d", tc.desc, tc.pb, size, len(b))
t.Logf("%v: bytes: %#v", tc.desc, b)
}
}
}

View file

@ -0,0 +1,50 @@
# Go support for Protocol Buffers - Google's data interchange format
#
# Copyright 2010 The Go Authors. All rights reserved.
# http://code.google.com/p/goprotobuf/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
include ../../Make.protobuf
all: regenerate
regenerate:
rm -f test.pb.go
make test.pb.go
# The following rules are just aids to development. Not needed for typical testing.
diff: regenerate
hg diff test.pb.go
restore:
cp test.pb.go.golden test.pb.go
preserve:
cp test.pb.go test.pb.go.golden

View file

@ -0,0 +1,86 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2012 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Verify that the compiler output for test.proto is unchanged.
package testdata
import (
"crypto/sha1"
"fmt"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"testing"
)
// sum returns in string form (for easy comparison) the SHA-1 hash of the named file.
func sum(t *testing.T, name string) string {
data, err := ioutil.ReadFile(name)
if err != nil {
t.Fatal(err)
}
t.Logf("sum(%q): length is %d", name, len(data))
hash := sha1.New()
_, err = hash.Write(data)
if err != nil {
t.Fatal(err)
}
return fmt.Sprintf("% x", hash.Sum(nil))
}
func run(t *testing.T, name string, args ...string) {
cmd := exec.Command(name, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Run()
if err != nil {
t.Fatal(err)
}
}
func TestGolden(t *testing.T) {
// Compute the original checksum.
goldenSum := sum(t, "test.pb.go")
// Run the proto compiler.
run(t, "protoc", "--go_out="+os.TempDir(), "test.proto")
newFile := filepath.Join(os.TempDir(), "test.pb.go")
defer os.Remove(newFile)
// Compute the new checksum.
newSum := sum(t, newFile)
// Verify
if newSum != goldenSum {
run(t, "diff", "-u", "test.pb.go", newFile)
t.Fatal("Code generated by protoc-gen-go has changed; update test.pb.go")
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,428 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// A feature-rich test file for the protocol compiler and libraries.
syntax = "proto2";
package testdata;
enum FOO { FOO1 = 1; };
message GoEnum {
required FOO foo = 1;
}
message GoTestField {
required string Label = 1;
required string Type = 2;
}
message GoTest {
// An enum, for completeness.
enum KIND {
VOID = 0;
// Basic types
BOOL = 1;
BYTES = 2;
FINGERPRINT = 3;
FLOAT = 4;
INT = 5;
STRING = 6;
TIME = 7;
// Groupings
TUPLE = 8;
ARRAY = 9;
MAP = 10;
// Table types
TABLE = 11;
// Functions
FUNCTION = 12; // last tag
};
// Some typical parameters
required KIND Kind = 1;
optional string Table = 2;
optional int32 Param = 3;
// Required, repeated and optional foreign fields.
required GoTestField RequiredField = 4;
repeated GoTestField RepeatedField = 5;
optional GoTestField OptionalField = 6;
// Required fields of all basic types
required bool F_Bool_required = 10;
required int32 F_Int32_required = 11;
required int64 F_Int64_required = 12;
required fixed32 F_Fixed32_required = 13;
required fixed64 F_Fixed64_required = 14;
required uint32 F_Uint32_required = 15;
required uint64 F_Uint64_required = 16;
required float F_Float_required = 17;
required double F_Double_required = 18;
required string F_String_required = 19;
required bytes F_Bytes_required = 101;
required sint32 F_Sint32_required = 102;
required sint64 F_Sint64_required = 103;
// Repeated fields of all basic types
repeated bool F_Bool_repeated = 20;
repeated int32 F_Int32_repeated = 21;
repeated int64 F_Int64_repeated = 22;
repeated fixed32 F_Fixed32_repeated = 23;
repeated fixed64 F_Fixed64_repeated = 24;
repeated uint32 F_Uint32_repeated = 25;
repeated uint64 F_Uint64_repeated = 26;
repeated float F_Float_repeated = 27;
repeated double F_Double_repeated = 28;
repeated string F_String_repeated = 29;
repeated bytes F_Bytes_repeated = 201;
repeated sint32 F_Sint32_repeated = 202;
repeated sint64 F_Sint64_repeated = 203;
// Optional fields of all basic types
optional bool F_Bool_optional = 30;
optional int32 F_Int32_optional = 31;
optional int64 F_Int64_optional = 32;
optional fixed32 F_Fixed32_optional = 33;
optional fixed64 F_Fixed64_optional = 34;
optional uint32 F_Uint32_optional = 35;
optional uint64 F_Uint64_optional = 36;
optional float F_Float_optional = 37;
optional double F_Double_optional = 38;
optional string F_String_optional = 39;
optional bytes F_Bytes_optional = 301;
optional sint32 F_Sint32_optional = 302;
optional sint64 F_Sint64_optional = 303;
// Default-valued fields of all basic types
optional bool F_Bool_defaulted = 40 [default=true];
optional int32 F_Int32_defaulted = 41 [default=32];
optional int64 F_Int64_defaulted = 42 [default=64];
optional fixed32 F_Fixed32_defaulted = 43 [default=320];
optional fixed64 F_Fixed64_defaulted = 44 [default=640];
optional uint32 F_Uint32_defaulted = 45 [default=3200];
optional uint64 F_Uint64_defaulted = 46 [default=6400];
optional float F_Float_defaulted = 47 [default=314159.];
optional double F_Double_defaulted = 48 [default=271828.];
optional string F_String_defaulted = 49 [default="hello, \"world!\"\n"];
optional bytes F_Bytes_defaulted = 401 [default="Bignose"];
optional sint32 F_Sint32_defaulted = 402 [default = -32];
optional sint64 F_Sint64_defaulted = 403 [default = -64];
// Packed repeated fields (no string or bytes).
repeated bool F_Bool_repeated_packed = 50 [packed=true];
repeated int32 F_Int32_repeated_packed = 51 [packed=true];
repeated int64 F_Int64_repeated_packed = 52 [packed=true];
repeated fixed32 F_Fixed32_repeated_packed = 53 [packed=true];
repeated fixed64 F_Fixed64_repeated_packed = 54 [packed=true];
repeated uint32 F_Uint32_repeated_packed = 55 [packed=true];
repeated uint64 F_Uint64_repeated_packed = 56 [packed=true];
repeated float F_Float_repeated_packed = 57 [packed=true];
repeated double F_Double_repeated_packed = 58 [packed=true];
repeated sint32 F_Sint32_repeated_packed = 502 [packed=true];
repeated sint64 F_Sint64_repeated_packed = 503 [packed=true];
// Required, repeated, and optional groups.
required group RequiredGroup = 70 {
required string RequiredField = 71;
};
repeated group RepeatedGroup = 80 {
required string RequiredField = 81;
};
optional group OptionalGroup = 90 {
required string RequiredField = 91;
};
}
// For testing skipping of unrecognized fields.
// Numbers are all big, larger than tag numbers in GoTestField,
// the message used in the corresponding test.
message GoSkipTest {
required int32 skip_int32 = 11;
required fixed32 skip_fixed32 = 12;
required fixed64 skip_fixed64 = 13;
required string skip_string = 14;
required group SkipGroup = 15 {
required int32 group_int32 = 16;
required string group_string = 17;
}
}
// For testing packed/non-packed decoder switching.
// A serialized instance of one should be deserializable as the other.
message NonPackedTest {
repeated int32 a = 1;
}
message PackedTest {
repeated int32 b = 1 [packed=true];
}
message MaxTag {
// Maximum possible tag number.
optional string last_field = 536870911;
}
message OldMessage {
message Nested {
optional string name = 1;
}
optional Nested nested = 1;
optional int32 num = 2;
}
// NewMessage is wire compatible with OldMessage;
// imagine it as a future version.
message NewMessage {
message Nested {
optional string name = 1;
optional string food_group = 2;
}
optional Nested nested = 1;
// This is an int32 in OldMessage.
optional int64 num = 2;
}
// Smaller tests for ASCII formatting.
message InnerMessage {
required string host = 1;
optional int32 port = 2 [default=4000];
optional bool connected = 3;
}
message OtherMessage {
optional int64 key = 1;
optional bytes value = 2;
optional float weight = 3;
optional InnerMessage inner = 4;
}
message MyMessage {
required int32 count = 1;
optional string name = 2;
optional string quote = 3;
repeated string pet = 4;
optional InnerMessage inner = 5;
repeated OtherMessage others = 6;
repeated InnerMessage rep_inner = 12;
enum Color {
RED = 0;
GREEN = 1;
BLUE = 2;
};
optional Color bikeshed = 7;
optional group SomeGroup = 8 {
optional int32 group_field = 9;
}
// This field becomes [][]byte in the generated code.
repeated bytes rep_bytes = 10;
optional double bigfloat = 11;
extensions 100 to max;
}
message Ext {
extend MyMessage {
optional Ext more = 103;
optional string text = 104;
optional int32 number = 105;
}
optional string data = 1;
}
extend MyMessage {
repeated string greeting = 106;
}
message MyMessageSet {
option message_set_wire_format = true;
extensions 100 to max;
}
message Empty {
}
extend MyMessageSet {
optional Empty x201 = 201;
optional Empty x202 = 202;
optional Empty x203 = 203;
optional Empty x204 = 204;
optional Empty x205 = 205;
optional Empty x206 = 206;
optional Empty x207 = 207;
optional Empty x208 = 208;
optional Empty x209 = 209;
optional Empty x210 = 210;
optional Empty x211 = 211;
optional Empty x212 = 212;
optional Empty x213 = 213;
optional Empty x214 = 214;
optional Empty x215 = 215;
optional Empty x216 = 216;
optional Empty x217 = 217;
optional Empty x218 = 218;
optional Empty x219 = 219;
optional Empty x220 = 220;
optional Empty x221 = 221;
optional Empty x222 = 222;
optional Empty x223 = 223;
optional Empty x224 = 224;
optional Empty x225 = 225;
optional Empty x226 = 226;
optional Empty x227 = 227;
optional Empty x228 = 228;
optional Empty x229 = 229;
optional Empty x230 = 230;
optional Empty x231 = 231;
optional Empty x232 = 232;
optional Empty x233 = 233;
optional Empty x234 = 234;
optional Empty x235 = 235;
optional Empty x236 = 236;
optional Empty x237 = 237;
optional Empty x238 = 238;
optional Empty x239 = 239;
optional Empty x240 = 240;
optional Empty x241 = 241;
optional Empty x242 = 242;
optional Empty x243 = 243;
optional Empty x244 = 244;
optional Empty x245 = 245;
optional Empty x246 = 246;
optional Empty x247 = 247;
optional Empty x248 = 248;
optional Empty x249 = 249;
optional Empty x250 = 250;
}
message MessageList {
repeated group Message = 1 {
required string name = 2;
required int32 count = 3;
}
}
message Strings {
optional string string_field = 1;
optional bytes bytes_field = 2;
}
message Defaults {
enum Color {
RED = 0;
GREEN = 1;
BLUE = 2;
}
// Default-valued fields of all basic types.
// Same as GoTest, but copied here to make testing easier.
optional bool F_Bool = 1 [default=true];
optional int32 F_Int32 = 2 [default=32];
optional int64 F_Int64 = 3 [default=64];
optional fixed32 F_Fixed32 = 4 [default=320];
optional fixed64 F_Fixed64 = 5 [default=640];
optional uint32 F_Uint32 = 6 [default=3200];
optional uint64 F_Uint64 = 7 [default=6400];
optional float F_Float = 8 [default=314159.];
optional double F_Double = 9 [default=271828.];
optional string F_String = 10 [default="hello, \"world!\"\n"];
optional bytes F_Bytes = 11 [default="Bignose"];
optional sint32 F_Sint32 = 12 [default=-32];
optional sint64 F_Sint64 = 13 [default=-64];
optional Color F_Enum = 14 [default=GREEN];
// More fields with crazy defaults.
optional float F_Pinf = 15 [default=inf];
optional float F_Ninf = 16 [default=-inf];
optional float F_Nan = 17 [default=nan];
// Sub-message.
optional SubDefaults sub = 18;
// Redundant but explicit defaults.
optional string str_zero = 19 [default=""];
}
message SubDefaults {
optional int64 n = 1 [default=7];
}
message RepeatedEnum {
enum Color {
RED = 1;
}
repeated Color color = 1;
}
message MoreRepeated {
repeated bool bools = 1;
repeated bool bools_packed = 2 [packed=true];
repeated int32 ints = 3;
repeated int32 ints_packed = 4 [packed=true];
repeated int64 int64s_packed = 7 [packed=true];
repeated string strings = 5;
repeated fixed32 fixeds = 6;
}
// GroupOld and GroupNew have the same wire format.
// GroupNew has a new field inside a group.
message GroupOld {
optional group G = 101 {
optional int32 x = 2;
}
}
message GroupNew {
optional group G = 101 {
optional int32 x = 2;
optional int32 y = 3;
}
}
message FloatingPoint {
required double f = 1;
}

View file

@ -0,0 +1,701 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto
// Functions for writing the text protocol buffer format.
import (
"bufio"
"bytes"
"fmt"
"io"
"log"
"math"
"os"
"reflect"
"sort"
"strings"
)
var (
newline = []byte("\n")
spaces = []byte(" ")
gtNewline = []byte(">\n")
endBraceNewline = []byte("}\n")
backslashN = []byte{'\\', 'n'}
backslashR = []byte{'\\', 'r'}
backslashT = []byte{'\\', 't'}
backslashDQ = []byte{'\\', '"'}
backslashBS = []byte{'\\', '\\'}
posInf = []byte("inf")
negInf = []byte("-inf")
nan = []byte("nan")
)
type writer interface {
io.Writer
WriteByte(byte) error
}
// textWriter is an io.Writer that tracks its indentation level.
type textWriter struct {
ind int
complete bool // if the current position is a complete line
compact bool // whether to write out as a one-liner
w writer
}
// textMarshaler is implemented by Messages that can marshal themsleves.
// It is identical to encoding.TextMarshaler, introduced in go 1.2,
// which will eventually replace it.
type textMarshaler interface {
MarshalText() (text []byte, err error)
}
func (w *textWriter) WriteString(s string) (n int, err error) {
if !strings.Contains(s, "\n") {
if !w.compact && w.complete {
w.writeIndent()
}
w.complete = false
return io.WriteString(w.w, s)
}
// WriteString is typically called without newlines, so this
// codepath and its copy are rare. We copy to avoid
// duplicating all of Write's logic here.
return w.Write([]byte(s))
}
func (w *textWriter) Write(p []byte) (n int, err error) {
newlines := bytes.Count(p, newline)
if newlines == 0 {
if !w.compact && w.complete {
w.writeIndent()
}
n, err = w.w.Write(p)
w.complete = false
return n, err
}
frags := bytes.SplitN(p, newline, newlines+1)
if w.compact {
for i, frag := range frags {
if i > 0 {
if err := w.w.WriteByte(' '); err != nil {
return n, err
}
n++
}
nn, err := w.w.Write(frag)
n += nn
if err != nil {
return n, err
}
}
return n, nil
}
for i, frag := range frags {
if w.complete {
w.writeIndent()
}
nn, err := w.w.Write(frag)
n += nn
if err != nil {
return n, err
}
if i+1 < len(frags) {
if err := w.w.WriteByte('\n'); err != nil {
return n, err
}
n++
}
}
w.complete = len(frags[len(frags)-1]) == 0
return n, nil
}
func (w *textWriter) WriteByte(c byte) error {
if w.compact && c == '\n' {
c = ' '
}
if !w.compact && w.complete {
w.writeIndent()
}
err := w.w.WriteByte(c)
w.complete = c == '\n'
return err
}
func (w *textWriter) indent() { w.ind++ }
func (w *textWriter) unindent() {
if w.ind == 0 {
log.Printf("proto: textWriter unindented too far")
return
}
w.ind--
}
func writeName(w *textWriter, props *Properties) error {
if _, err := w.WriteString(props.OrigName); err != nil {
return err
}
if props.Wire != "group" {
return w.WriteByte(':')
}
return nil
}
var (
messageSetType = reflect.TypeOf((*MessageSet)(nil)).Elem()
)
// raw is the interface satisfied by RawMessage.
type raw interface {
Bytes() []byte
}
func writeStruct(w *textWriter, sv reflect.Value) error {
if sv.Type() == messageSetType {
return writeMessageSet(w, sv.Addr().Interface().(*MessageSet))
}
st := sv.Type()
sprops := GetProperties(st)
for i := 0; i < sv.NumField(); i++ {
fv := sv.Field(i)
props := sprops.Prop[i]
name := st.Field(i).Name
if strings.HasPrefix(name, "XXX_") {
// There are two XXX_ fields:
// XXX_unrecognized []byte
// XXX_extensions map[int32]proto.Extension
// The first is handled here;
// the second is handled at the bottom of this function.
if name == "XXX_unrecognized" && !fv.IsNil() {
if err := writeUnknownStruct(w, fv.Interface().([]byte)); err != nil {
return err
}
}
continue
}
if fv.Kind() == reflect.Ptr && fv.IsNil() {
// Field not filled in. This could be an optional field or
// a required field that wasn't filled in. Either way, there
// isn't anything we can show for it.
continue
}
if fv.Kind() == reflect.Slice && fv.IsNil() {
// Repeated field that is empty, or a bytes field that is unused.
continue
}
if props.Repeated && fv.Kind() == reflect.Slice {
// Repeated field.
for j := 0; j < fv.Len(); j++ {
if err := writeName(w, props); err != nil {
return err
}
if !w.compact {
if err := w.WriteByte(' '); err != nil {
return err
}
}
v := fv.Index(j)
if v.Kind() == reflect.Ptr && v.IsNil() {
// A nil message in a repeated field is not valid,
// but we can handle that more gracefully than panicking.
if _, err := w.Write([]byte("<nil>\n")); err != nil {
return err
}
continue
}
if err := writeAny(w, v, props); err != nil {
return err
}
if err := w.WriteByte('\n'); err != nil {
return err
}
}
continue
}
if err := writeName(w, props); err != nil {
return err
}
if !w.compact {
if err := w.WriteByte(' '); err != nil {
return err
}
}
if b, ok := fv.Interface().(raw); ok {
if err := writeRaw(w, b.Bytes()); err != nil {
return err
}
continue
}
// Enums have a String method, so writeAny will work fine.
if err := writeAny(w, fv, props); err != nil {
return err
}
if err := w.WriteByte('\n'); err != nil {
return err
}
}
// Extensions (the XXX_extensions field).
pv := sv.Addr()
if pv.Type().Implements(extendableProtoType) {
if err := writeExtensions(w, pv); err != nil {
return err
}
}
return nil
}
// writeRaw writes an uninterpreted raw message.
func writeRaw(w *textWriter, b []byte) error {
if err := w.WriteByte('<'); err != nil {
return err
}
if !w.compact {
if err := w.WriteByte('\n'); err != nil {
return err
}
}
w.indent()
if err := writeUnknownStruct(w, b); err != nil {
return err
}
w.unindent()
if err := w.WriteByte('>'); err != nil {
return err
}
return nil
}
// writeAny writes an arbitrary field.
func writeAny(w *textWriter, v reflect.Value, props *Properties) error {
v = reflect.Indirect(v)
// Floats have special cases.
if v.Kind() == reflect.Float32 || v.Kind() == reflect.Float64 {
x := v.Float()
var b []byte
switch {
case math.IsInf(x, 1):
b = posInf
case math.IsInf(x, -1):
b = negInf
case math.IsNaN(x):
b = nan
}
if b != nil {
_, err := w.Write(b)
return err
}
// Other values are handled below.
}
// We don't attempt to serialise every possible value type; only those
// that can occur in protocol buffers.
switch v.Kind() {
case reflect.Slice:
// Should only be a []byte; repeated fields are handled in writeStruct.
if err := writeString(w, string(v.Interface().([]byte))); err != nil {
return err
}
case reflect.String:
if err := writeString(w, v.String()); err != nil {
return err
}
case reflect.Struct:
// Required/optional group/message.
var bra, ket byte = '<', '>'
if props != nil && props.Wire == "group" {
bra, ket = '{', '}'
}
if err := w.WriteByte(bra); err != nil {
return err
}
if !w.compact {
if err := w.WriteByte('\n'); err != nil {
return err
}
}
w.indent()
if tm, ok := v.Interface().(textMarshaler); ok {
text, err := tm.MarshalText()
if err != nil {
return err
}
if _, err = w.Write(text); err != nil {
return err
}
} else if err := writeStruct(w, v); err != nil {
return err
}
w.unindent()
if err := w.WriteByte(ket); err != nil {
return err
}
default:
_, err := fmt.Fprint(w, v.Interface())
return err
}
return nil
}
// equivalent to C's isprint.
func isprint(c byte) bool {
return c >= 0x20 && c < 0x7f
}
// writeString writes a string in the protocol buffer text format.
// It is similar to strconv.Quote except we don't use Go escape sequences,
// we treat the string as a byte sequence, and we use octal escapes.
// These differences are to maintain interoperability with the other
// languages' implementations of the text format.
func writeString(w *textWriter, s string) error {
// use WriteByte here to get any needed indent
if err := w.WriteByte('"'); err != nil {
return err
}
// Loop over the bytes, not the runes.
for i := 0; i < len(s); i++ {
var err error
// Divergence from C++: we don't escape apostrophes.
// There's no need to escape them, and the C++ parser
// copes with a naked apostrophe.
switch c := s[i]; c {
case '\n':
_, err = w.w.Write(backslashN)
case '\r':
_, err = w.w.Write(backslashR)
case '\t':
_, err = w.w.Write(backslashT)
case '"':
_, err = w.w.Write(backslashDQ)
case '\\':
_, err = w.w.Write(backslashBS)
default:
if isprint(c) {
err = w.w.WriteByte(c)
} else {
_, err = fmt.Fprintf(w.w, "\\%03o", c)
}
}
if err != nil {
return err
}
}
return w.WriteByte('"')
}
func writeMessageSet(w *textWriter, ms *MessageSet) error {
for _, item := range ms.Item {
id := *item.TypeId
if msd, ok := messageSetMap[id]; ok {
// Known message set type.
if _, err := fmt.Fprintf(w, "[%s]: <\n", msd.name); err != nil {
return err
}
w.indent()
pb := reflect.New(msd.t.Elem())
if err := Unmarshal(item.Message, pb.Interface().(Message)); err != nil {
if _, err := fmt.Fprintf(w, "/* bad message: %v */\n", err); err != nil {
return err
}
} else {
if err := writeStruct(w, pb.Elem()); err != nil {
return err
}
}
} else {
// Unknown type.
if _, err := fmt.Fprintf(w, "[%d]: <\n", id); err != nil {
return err
}
w.indent()
if err := writeUnknownStruct(w, item.Message); err != nil {
return err
}
}
w.unindent()
if _, err := w.Write(gtNewline); err != nil {
return err
}
}
return nil
}
func writeUnknownStruct(w *textWriter, data []byte) (err error) {
if !w.compact {
if _, err := fmt.Fprintf(w, "/* %d unknown bytes */\n", len(data)); err != nil {
return err
}
}
b := NewBuffer(data)
for b.index < len(b.buf) {
x, err := b.DecodeVarint()
if err != nil {
_, err := fmt.Fprintf(w, "/* %v */\n", err)
return err
}
wire, tag := x&7, x>>3
if wire == WireEndGroup {
w.unindent()
if _, err := w.Write(endBraceNewline); err != nil {
return err
}
continue
}
if _, err := fmt.Fprint(w, tag); err != nil {
return err
}
if wire != WireStartGroup {
if err := w.WriteByte(':'); err != nil {
return err
}
}
if !w.compact || wire == WireStartGroup {
if err := w.WriteByte(' '); err != nil {
return err
}
}
switch wire {
case WireBytes:
buf, e := b.DecodeRawBytes(false)
if e == nil {
_, err = fmt.Fprintf(w, "%q", buf)
} else {
_, err = fmt.Fprintf(w, "/* %v */", e)
}
case WireFixed32:
x, err = b.DecodeFixed32()
err = writeUnknownInt(w, x, err)
case WireFixed64:
x, err = b.DecodeFixed64()
err = writeUnknownInt(w, x, err)
case WireStartGroup:
err = w.WriteByte('{')
w.indent()
case WireVarint:
x, err = b.DecodeVarint()
err = writeUnknownInt(w, x, err)
default:
_, err = fmt.Fprintf(w, "/* unknown wire type %d */", wire)
}
if err != nil {
return err
}
if err = w.WriteByte('\n'); err != nil {
return err
}
}
return nil
}
func writeUnknownInt(w *textWriter, x uint64, err error) error {
if err == nil {
_, err = fmt.Fprint(w, x)
} else {
_, err = fmt.Fprintf(w, "/* %v */", err)
}
return err
}
type int32Slice []int32
func (s int32Slice) Len() int { return len(s) }
func (s int32Slice) Less(i, j int) bool { return s[i] < s[j] }
func (s int32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// writeExtensions writes all the extensions in pv.
// pv is assumed to be a pointer to a protocol message struct that is extendable.
func writeExtensions(w *textWriter, pv reflect.Value) error {
emap := extensionMaps[pv.Type().Elem()]
ep := pv.Interface().(extendableProto)
// Order the extensions by ID.
// This isn't strictly necessary, but it will give us
// canonical output, which will also make testing easier.
m := ep.ExtensionMap()
ids := make([]int32, 0, len(m))
for id := range m {
ids = append(ids, id)
}
sort.Sort(int32Slice(ids))
for _, extNum := range ids {
ext := m[extNum]
var desc *ExtensionDesc
if emap != nil {
desc = emap[extNum]
}
if desc == nil {
// Unknown extension.
if err := writeUnknownStruct(w, ext.enc); err != nil {
return err
}
continue
}
pb, err := GetExtension(ep, desc)
if err != nil {
if _, err := fmt.Fprintln(os.Stderr, "proto: failed getting extension: ", err); err != nil {
return err
}
continue
}
// Repeated extensions will appear as a slice.
if !desc.repeated() {
if err := writeExtension(w, desc.Name, pb); err != nil {
return err
}
} else {
v := reflect.ValueOf(pb)
for i := 0; i < v.Len(); i++ {
if err := writeExtension(w, desc.Name, v.Index(i).Interface()); err != nil {
return err
}
}
}
}
return nil
}
func writeExtension(w *textWriter, name string, pb interface{}) error {
if _, err := fmt.Fprintf(w, "[%s]:", name); err != nil {
return err
}
if !w.compact {
if err := w.WriteByte(' '); err != nil {
return err
}
}
if err := writeAny(w, reflect.ValueOf(pb), nil); err != nil {
return err
}
if err := w.WriteByte('\n'); err != nil {
return err
}
return nil
}
func (w *textWriter) writeIndent() {
if !w.complete {
return
}
remain := w.ind * 2
for remain > 0 {
n := remain
if n > len(spaces) {
n = len(spaces)
}
w.w.Write(spaces[:n])
remain -= n
}
w.complete = false
}
func marshalText(w io.Writer, pb Message, compact bool) error {
val := reflect.ValueOf(pb)
if pb == nil || val.IsNil() {
w.Write([]byte("<nil>"))
return nil
}
var bw *bufio.Writer
ww, ok := w.(writer)
if !ok {
bw = bufio.NewWriter(w)
ww = bw
}
aw := &textWriter{
w: ww,
complete: true,
compact: compact,
}
if tm, ok := pb.(textMarshaler); ok {
text, err := tm.MarshalText()
if err != nil {
return err
}
if _, err = aw.Write(text); err != nil {
return err
}
if bw != nil {
return bw.Flush()
}
return nil
}
// Dereference the received pointer so we don't have outer < and >.
v := reflect.Indirect(val)
if err := writeStruct(aw, v); err != nil {
return err
}
if bw != nil {
return bw.Flush()
}
return nil
}
// MarshalText writes a given protocol buffer in text format.
// The only errors returned are from w.
func MarshalText(w io.Writer, pb Message) error {
return marshalText(w, pb, false)
}
// MarshalTextString is the same as MarshalText, but returns the string directly.
func MarshalTextString(pb Message) string {
var buf bytes.Buffer
marshalText(&buf, pb, false)
return buf.String()
}
// CompactText writes a given protocol buffer in compact text format (one line).
func CompactText(w io.Writer, pb Message) error { return marshalText(w, pb, true) }
// CompactTextString is the same as CompactText, but returns the string directly.
func CompactTextString(pb Message) string {
var buf bytes.Buffer
marshalText(&buf, pb, true)
return buf.String()
}

View file

@ -0,0 +1,684 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto
// Functions for parsing the Text protocol buffer format.
// TODO: message sets.
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"unicode/utf8"
)
// textUnmarshaler is implemented by Messages that can unmarshal themsleves.
// It is identical to encoding.TextUnmarshaler, introduced in go 1.2,
// which will eventually replace it.
type textUnmarshaler interface {
UnmarshalText(text []byte) error
}
type ParseError struct {
Message string
Line int // 1-based line number
Offset int // 0-based byte offset from start of input
}
func (p *ParseError) Error() string {
if p.Line == 1 {
// show offset only for first line
return fmt.Sprintf("line 1.%d: %v", p.Offset, p.Message)
}
return fmt.Sprintf("line %d: %v", p.Line, p.Message)
}
type token struct {
value string
err *ParseError
line int // line number
offset int // byte number from start of input, not start of line
unquoted string // the unquoted version of value, if it was a quoted string
}
func (t *token) String() string {
if t.err == nil {
return fmt.Sprintf("%q (line=%d, offset=%d)", t.value, t.line, t.offset)
}
return fmt.Sprintf("parse error: %v", t.err)
}
type textParser struct {
s string // remaining input
done bool // whether the parsing is finished (success or error)
backed bool // whether back() was called
offset, line int
cur token
}
func newTextParser(s string) *textParser {
p := new(textParser)
p.s = s
p.line = 1
p.cur.line = 1
return p
}
func (p *textParser) errorf(format string, a ...interface{}) *ParseError {
pe := &ParseError{fmt.Sprintf(format, a...), p.cur.line, p.cur.offset}
p.cur.err = pe
p.done = true
return pe
}
// Numbers and identifiers are matched by [-+._A-Za-z0-9]
func isIdentOrNumberChar(c byte) bool {
switch {
case 'A' <= c && c <= 'Z', 'a' <= c && c <= 'z':
return true
case '0' <= c && c <= '9':
return true
}
switch c {
case '-', '+', '.', '_':
return true
}
return false
}
func isWhitespace(c byte) bool {
switch c {
case ' ', '\t', '\n', '\r':
return true
}
return false
}
func (p *textParser) skipWhitespace() {
i := 0
for i < len(p.s) && (isWhitespace(p.s[i]) || p.s[i] == '#') {
if p.s[i] == '#' {
// comment; skip to end of line or input
for i < len(p.s) && p.s[i] != '\n' {
i++
}
if i == len(p.s) {
break
}
}
if p.s[i] == '\n' {
p.line++
}
i++
}
p.offset += i
p.s = p.s[i:len(p.s)]
if len(p.s) == 0 {
p.done = true
}
}
func (p *textParser) advance() {
// Skip whitespace
p.skipWhitespace()
if p.done {
return
}
// Start of non-whitespace
p.cur.err = nil
p.cur.offset, p.cur.line = p.offset, p.line
p.cur.unquoted = ""
switch p.s[0] {
case '<', '>', '{', '}', ':', '[', ']', ';', ',':
// Single symbol
p.cur.value, p.s = p.s[0:1], p.s[1:len(p.s)]
case '"', '\'':
// Quoted string
i := 1
for i < len(p.s) && p.s[i] != p.s[0] && p.s[i] != '\n' {
if p.s[i] == '\\' && i+1 < len(p.s) {
// skip escaped char
i++
}
i++
}
if i >= len(p.s) || p.s[i] != p.s[0] {
p.errorf("unmatched quote")
return
}
unq, err := unquoteC(p.s[1:i], rune(p.s[0]))
if err != nil {
p.errorf("invalid quoted string %v", p.s[0:i+1])
return
}
p.cur.value, p.s = p.s[0:i+1], p.s[i+1:len(p.s)]
p.cur.unquoted = unq
default:
i := 0
for i < len(p.s) && isIdentOrNumberChar(p.s[i]) {
i++
}
if i == 0 {
p.errorf("unexpected byte %#x", p.s[0])
return
}
p.cur.value, p.s = p.s[0:i], p.s[i:len(p.s)]
}
p.offset += len(p.cur.value)
}
var (
errBadUTF8 = errors.New("proto: bad UTF-8")
errBadHex = errors.New("proto: bad hexadecimal")
)
func unquoteC(s string, quote rune) (string, error) {
// This is based on C++'s tokenizer.cc.
// Despite its name, this is *not* parsing C syntax.
// For instance, "\0" is an invalid quoted string.
// Avoid allocation in trivial cases.
simple := true
for _, r := range s {
if r == '\\' || r == quote {
simple = false
break
}
}
if simple {
return s, nil
}
buf := make([]byte, 0, 3*len(s)/2)
for len(s) > 0 {
r, n := utf8.DecodeRuneInString(s)
if r == utf8.RuneError && n == 1 {
return "", errBadUTF8
}
s = s[n:]
if r != '\\' {
if r < utf8.RuneSelf {
buf = append(buf, byte(r))
} else {
buf = append(buf, string(r)...)
}
continue
}
ch, tail, err := unescape(s)
if err != nil {
return "", err
}
buf = append(buf, ch...)
s = tail
}
return string(buf), nil
}
func unescape(s string) (ch string, tail string, err error) {
r, n := utf8.DecodeRuneInString(s)
if r == utf8.RuneError && n == 1 {
return "", "", errBadUTF8
}
s = s[n:]
switch r {
case 'a':
return "\a", s, nil
case 'b':
return "\b", s, nil
case 'f':
return "\f", s, nil
case 'n':
return "\n", s, nil
case 'r':
return "\r", s, nil
case 't':
return "\t", s, nil
case 'v':
return "\v", s, nil
case '?':
return "?", s, nil // trigraph workaround
case '\'', '"', '\\':
return string(r), s, nil
case '0', '1', '2', '3', '4', '5', '6', '7', 'x', 'X':
if len(s) < 2 {
return "", "", fmt.Errorf(`\%c requires 2 following digits`, r)
}
base := 8
ss := s[:2]
s = s[2:]
if r == 'x' || r == 'X' {
base = 16
} else {
ss = string(r) + ss
}
i, err := strconv.ParseUint(ss, base, 8)
if err != nil {
return "", "", err
}
return string([]byte{byte(i)}), s, nil
case 'u', 'U':
n := 4
if r == 'U' {
n = 8
}
if len(s) < n {
return "", "", fmt.Errorf(`\%c requires %d digits`, r, n)
}
bs := make([]byte, n/2)
for i := 0; i < n; i += 2 {
a, ok1 := unhex(s[i])
b, ok2 := unhex(s[i+1])
if !ok1 || !ok2 {
return "", "", errBadHex
}
bs[i/2] = a<<4 | b
}
s = s[n:]
return string(bs), s, nil
}
return "", "", fmt.Errorf(`unknown escape \%c`, r)
}
// Adapted from src/pkg/strconv/quote.go.
func unhex(b byte) (v byte, ok bool) {
switch {
case '0' <= b && b <= '9':
return b - '0', true
case 'a' <= b && b <= 'f':
return b - 'a' + 10, true
case 'A' <= b && b <= 'F':
return b - 'A' + 10, true
}
return 0, false
}
// Back off the parser by one token. Can only be done between calls to next().
// It makes the next advance() a no-op.
func (p *textParser) back() { p.backed = true }
// Advances the parser and returns the new current token.
func (p *textParser) next() *token {
if p.backed || p.done {
p.backed = false
return &p.cur
}
p.advance()
if p.done {
p.cur.value = ""
} else if len(p.cur.value) > 0 && p.cur.value[0] == '"' {
// Look for multiple quoted strings separated by whitespace,
// and concatenate them.
cat := p.cur
for {
p.skipWhitespace()
if p.done || p.s[0] != '"' {
break
}
p.advance()
if p.cur.err != nil {
return &p.cur
}
cat.value += " " + p.cur.value
cat.unquoted += p.cur.unquoted
}
p.done = false // parser may have seen EOF, but we want to return cat
p.cur = cat
}
return &p.cur
}
// Return an error indicating which required field was not set.
func (p *textParser) missingRequiredFieldError(sv reflect.Value) *ParseError {
st := sv.Type()
sprops := GetProperties(st)
for i := 0; i < st.NumField(); i++ {
if !isNil(sv.Field(i)) {
continue
}
props := sprops.Prop[i]
if props.Required {
return p.errorf("message %v missing required field %q", st, props.OrigName)
}
}
return p.errorf("message %v missing required field", st) // should not happen
}
// Returns the index in the struct for the named field, as well as the parsed tag properties.
func structFieldByName(st reflect.Type, name string) (int, *Properties, bool) {
sprops := GetProperties(st)
i, ok := sprops.decoderOrigNames[name]
if ok {
return i, sprops.Prop[i], true
}
return -1, nil, false
}
// Consume a ':' from the input stream (if the next token is a colon),
// returning an error if a colon is needed but not present.
func (p *textParser) checkForColon(props *Properties, typ reflect.Type) *ParseError {
tok := p.next()
if tok.err != nil {
return tok.err
}
if tok.value != ":" {
// Colon is optional when the field is a group or message.
needColon := true
switch props.Wire {
case "group":
needColon = false
case "bytes":
// A "bytes" field is either a message, a string, or a repeated field;
// those three become *T, *string and []T respectively, so we can check for
// this field being a pointer to a non-string.
if typ.Kind() == reflect.Ptr {
// *T or *string
if typ.Elem().Kind() == reflect.String {
break
}
} else if typ.Kind() == reflect.Slice {
// []T or []*T
if typ.Elem().Kind() != reflect.Ptr {
break
}
}
needColon = false
}
if needColon {
return p.errorf("expected ':', found %q", tok.value)
}
p.back()
}
return nil
}
func (p *textParser) readStruct(sv reflect.Value, terminator string) *ParseError {
st := sv.Type()
reqCount := GetProperties(st).reqCount
// A struct is a sequence of "name: value", terminated by one of
// '>' or '}', or the end of the input. A name may also be
// "[extension]".
for {
tok := p.next()
if tok.err != nil {
return tok.err
}
if tok.value == terminator {
break
}
if tok.value == "[" {
// Looks like an extension.
//
// TODO: Check whether we need to handle
// namespace rooted names (e.g. ".something.Foo").
tok = p.next()
if tok.err != nil {
return tok.err
}
var desc *ExtensionDesc
// This could be faster, but it's functional.
// TODO: Do something smarter than a linear scan.
for _, d := range RegisteredExtensions(reflect.New(st).Interface().(Message)) {
if d.Name == tok.value {
desc = d
break
}
}
if desc == nil {
return p.errorf("unrecognized extension %q", tok.value)
}
// Check the extension terminator.
tok = p.next()
if tok.err != nil {
return tok.err
}
if tok.value != "]" {
return p.errorf("unrecognized extension terminator %q", tok.value)
}
props := &Properties{}
props.Parse(desc.Tag)
typ := reflect.TypeOf(desc.ExtensionType)
if err := p.checkForColon(props, typ); err != nil {
return err
}
rep := desc.repeated()
// Read the extension structure, and set it in
// the value we're constructing.
var ext reflect.Value
if !rep {
ext = reflect.New(typ).Elem()
} else {
ext = reflect.New(typ.Elem()).Elem()
}
if err := p.readAny(ext, props); err != nil {
return err
}
ep := sv.Addr().Interface().(extendableProto)
if !rep {
SetExtension(ep, desc, ext.Interface())
} else {
old, err := GetExtension(ep, desc)
var sl reflect.Value
if err == nil {
sl = reflect.ValueOf(old) // existing slice
} else {
sl = reflect.MakeSlice(typ, 0, 1)
}
sl = reflect.Append(sl, ext)
SetExtension(ep, desc, sl.Interface())
}
} else {
// This is a normal, non-extension field.
fi, props, ok := structFieldByName(st, tok.value)
if !ok {
return p.errorf("unknown field name %q in %v", tok.value, st)
}
dst := sv.Field(fi)
isDstNil := isNil(dst)
// Check that it's not already set if it's not a repeated field.
if !props.Repeated && !isDstNil {
return p.errorf("non-repeated field %q was repeated", tok.value)
}
if err := p.checkForColon(props, st.Field(fi).Type); err != nil {
return err
}
// Parse into the field.
if err := p.readAny(dst, props); err != nil {
return err
}
if props.Required {
reqCount--
}
}
// For backward compatibility, permit a semicolon or comma after a field.
tok = p.next()
if tok.err != nil {
return tok.err
}
if tok.value != ";" && tok.value != "," {
p.back()
}
}
if reqCount > 0 {
return p.missingRequiredFieldError(sv)
}
return nil
}
func (p *textParser) readAny(v reflect.Value, props *Properties) *ParseError {
tok := p.next()
if tok.err != nil {
return tok.err
}
if tok.value == "" {
return p.errorf("unexpected EOF")
}
switch fv := v; fv.Kind() {
case reflect.Slice:
at := v.Type()
if at.Elem().Kind() == reflect.Uint8 {
// Special case for []byte
if tok.value[0] != '"' && tok.value[0] != '\'' {
// Deliberately written out here, as the error after
// this switch statement would write "invalid []byte: ...",
// which is not as user-friendly.
return p.errorf("invalid string: %v", tok.value)
}
bytes := []byte(tok.unquoted)
fv.Set(reflect.ValueOf(bytes))
return nil
}
// Repeated field. May already exist.
flen := fv.Len()
if flen == fv.Cap() {
nav := reflect.MakeSlice(at, flen, 2*flen+1)
reflect.Copy(nav, fv)
fv.Set(nav)
}
fv.SetLen(flen + 1)
// Read one.
p.back()
return p.readAny(fv.Index(flen), props)
case reflect.Bool:
// Either "true", "false", 1 or 0.
switch tok.value {
case "true", "1":
fv.SetBool(true)
return nil
case "false", "0":
fv.SetBool(false)
return nil
}
case reflect.Float32, reflect.Float64:
v := tok.value
// Ignore 'f' for compatibility with output generated by C++, but don't
// remove 'f' when the value is "-inf" or "inf".
if strings.HasSuffix(v, "f") && tok.value != "-inf" && tok.value != "inf" {
v = v[:len(v)-1]
}
if f, err := strconv.ParseFloat(v, fv.Type().Bits()); err == nil {
fv.SetFloat(f)
return nil
}
case reflect.Int32:
if x, err := strconv.ParseInt(tok.value, 0, 32); err == nil {
fv.SetInt(x)
return nil
}
if len(props.Enum) == 0 {
break
}
m, ok := enumValueMaps[props.Enum]
if !ok {
break
}
x, ok := m[tok.value]
if !ok {
break
}
fv.SetInt(int64(x))
return nil
case reflect.Int64:
if x, err := strconv.ParseInt(tok.value, 0, 64); err == nil {
fv.SetInt(x)
return nil
}
case reflect.Ptr:
// A basic field (indirected through pointer), or a repeated message/group
p.back()
fv.Set(reflect.New(fv.Type().Elem()))
return p.readAny(fv.Elem(), props)
case reflect.String:
if tok.value[0] == '"' || tok.value[0] == '\'' {
fv.SetString(tok.unquoted)
return nil
}
case reflect.Struct:
var terminator string
switch tok.value {
case "{":
terminator = "}"
case "<":
terminator = ">"
default:
return p.errorf("expected '{' or '<', found %q", tok.value)
}
// TODO: Handle nested messages which implement textUnmarshaler.
return p.readStruct(fv, terminator)
case reflect.Uint32:
if x, err := strconv.ParseUint(tok.value, 0, 32); err == nil {
fv.SetUint(uint64(x))
return nil
}
case reflect.Uint64:
if x, err := strconv.ParseUint(tok.value, 0, 64); err == nil {
fv.SetUint(x)
return nil
}
}
return p.errorf("invalid %v: %v", v.Type(), tok.value)
}
// UnmarshalText reads a protocol buffer in Text format. UnmarshalText resets pb
// before starting to unmarshal, so any existing data in pb is always removed.
func UnmarshalText(s string, pb Message) error {
if um, ok := pb.(textUnmarshaler); ok {
err := um.UnmarshalText([]byte(s))
return err
}
pb.Reset()
v := reflect.ValueOf(pb)
if pe := newTextParser(s).readStruct(v.Elem(), ""); pe != nil {
return pe
}
return nil
}

View file

@ -0,0 +1,462 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto_test
import (
"math"
"reflect"
"testing"
. "./testdata"
. "code.google.com/p/goprotobuf/proto"
)
type UnmarshalTextTest struct {
in string
err string // if "", no error expected
out *MyMessage
}
func buildExtStructTest(text string) UnmarshalTextTest {
msg := &MyMessage{
Count: Int32(42),
}
SetExtension(msg, E_Ext_More, &Ext{
Data: String("Hello, world!"),
})
return UnmarshalTextTest{in: text, out: msg}
}
func buildExtDataTest(text string) UnmarshalTextTest {
msg := &MyMessage{
Count: Int32(42),
}
SetExtension(msg, E_Ext_Text, String("Hello, world!"))
SetExtension(msg, E_Ext_Number, Int32(1729))
return UnmarshalTextTest{in: text, out: msg}
}
func buildExtRepStringTest(text string) UnmarshalTextTest {
msg := &MyMessage{
Count: Int32(42),
}
if err := SetExtension(msg, E_Greeting, []string{"bula", "hola"}); err != nil {
panic(err)
}
return UnmarshalTextTest{in: text, out: msg}
}
var unMarshalTextTests = []UnmarshalTextTest{
// Basic
{
in: " count:42\n name:\"Dave\" ",
out: &MyMessage{
Count: Int32(42),
Name: String("Dave"),
},
},
// Empty quoted string
{
in: `count:42 name:""`,
out: &MyMessage{
Count: Int32(42),
Name: String(""),
},
},
// Quoted string concatenation
{
in: `count:42 name: "My name is "` + "\n" + `"elsewhere"`,
out: &MyMessage{
Count: Int32(42),
Name: String("My name is elsewhere"),
},
},
// Quoted string with escaped apostrophe
{
in: `count:42 name: "HOLIDAY - New Year\'s Day"`,
out: &MyMessage{
Count: Int32(42),
Name: String("HOLIDAY - New Year's Day"),
},
},
// Quoted string with single quote
{
in: `count:42 name: 'Roger "The Ramster" Ramjet'`,
out: &MyMessage{
Count: Int32(42),
Name: String(`Roger "The Ramster" Ramjet`),
},
},
// Quoted string with all the accepted special characters from the C++ test
{
in: `count:42 name: ` + "\"\\\"A string with \\' characters \\n and \\r newlines and \\t tabs and \\001 slashes \\\\ and multiple spaces\"",
out: &MyMessage{
Count: Int32(42),
Name: String("\"A string with ' characters \n and \r newlines and \t tabs and \001 slashes \\ and multiple spaces"),
},
},
// Quoted string with quoted backslash
{
in: `count:42 name: "\\'xyz"`,
out: &MyMessage{
Count: Int32(42),
Name: String(`\'xyz`),
},
},
// Quoted string with UTF-8 bytes.
{
in: "count:42 name: '\303\277\302\201\xAB'",
out: &MyMessage{
Count: Int32(42),
Name: String("\303\277\302\201\xAB"),
},
},
// Bad quoted string
{
in: `inner: < host: "\0" >` + "\n",
err: `line 1.15: invalid quoted string "\0"`,
},
// Number too large for int64
{
in: "count: 1 others { key: 123456789012345678901 }",
err: "line 1.23: invalid int64: 123456789012345678901",
},
// Number too large for int32
{
in: "count: 1234567890123",
err: "line 1.7: invalid int32: 1234567890123",
},
// Number in hexadecimal
{
in: "count: 0x2beef",
out: &MyMessage{
Count: Int32(0x2beef),
},
},
// Number in octal
{
in: "count: 024601",
out: &MyMessage{
Count: Int32(024601),
},
},
// Floating point number with "f" suffix
{
in: "count: 4 others:< weight: 17.0f >",
out: &MyMessage{
Count: Int32(4),
Others: []*OtherMessage{
{
Weight: Float32(17),
},
},
},
},
// Floating point positive infinity
{
in: "count: 4 bigfloat: inf",
out: &MyMessage{
Count: Int32(4),
Bigfloat: Float64(math.Inf(1)),
},
},
// Floating point negative infinity
{
in: "count: 4 bigfloat: -inf",
out: &MyMessage{
Count: Int32(4),
Bigfloat: Float64(math.Inf(-1)),
},
},
// Number too large for float32
{
in: "others:< weight: 12345678901234567890123456789012345678901234567890 >",
err: "line 1.17: invalid float32: 12345678901234567890123456789012345678901234567890",
},
// Number posing as a quoted string
{
in: `inner: < host: 12 >` + "\n",
err: `line 1.15: invalid string: 12`,
},
// Quoted string posing as int32
{
in: `count: "12"`,
err: `line 1.7: invalid int32: "12"`,
},
// Quoted string posing a float32
{
in: `others:< weight: "17.4" >`,
err: `line 1.17: invalid float32: "17.4"`,
},
// Enum
{
in: `count:42 bikeshed: BLUE`,
out: &MyMessage{
Count: Int32(42),
Bikeshed: MyMessage_BLUE.Enum(),
},
},
// Repeated field
{
in: `count:42 pet: "horsey" pet:"bunny"`,
out: &MyMessage{
Count: Int32(42),
Pet: []string{"horsey", "bunny"},
},
},
// Repeated message with/without colon and <>/{}
{
in: `count:42 others:{} others{} others:<> others:{}`,
out: &MyMessage{
Count: Int32(42),
Others: []*OtherMessage{
{},
{},
{},
{},
},
},
},
// Missing colon for inner message
{
in: `count:42 inner < host: "cauchy.syd" >`,
out: &MyMessage{
Count: Int32(42),
Inner: &InnerMessage{
Host: String("cauchy.syd"),
},
},
},
// Missing colon for string field
{
in: `name "Dave"`,
err: `line 1.5: expected ':', found "\"Dave\""`,
},
// Missing colon for int32 field
{
in: `count 42`,
err: `line 1.6: expected ':', found "42"`,
},
// Missing required field
{
in: ``,
err: `line 1.0: message testdata.MyMessage missing required field "count"`,
},
// Repeated non-repeated field
{
in: `name: "Rob" name: "Russ"`,
err: `line 1.12: non-repeated field "name" was repeated`,
},
// Group
{
in: `count: 17 SomeGroup { group_field: 12 }`,
out: &MyMessage{
Count: Int32(17),
Somegroup: &MyMessage_SomeGroup{
GroupField: Int32(12),
},
},
},
// Semicolon between fields
{
in: `count:3;name:"Calvin"`,
out: &MyMessage{
Count: Int32(3),
Name: String("Calvin"),
},
},
// Comma between fields
{
in: `count:4,name:"Ezekiel"`,
out: &MyMessage{
Count: Int32(4),
Name: String("Ezekiel"),
},
},
// Extension
buildExtStructTest(`count: 42 [testdata.Ext.more]:<data:"Hello, world!" >`),
buildExtStructTest(`count: 42 [testdata.Ext.more] {data:"Hello, world!"}`),
buildExtDataTest(`count: 42 [testdata.Ext.text]:"Hello, world!" [testdata.Ext.number]:1729`),
buildExtRepStringTest(`count: 42 [testdata.greeting]:"bula" [testdata.greeting]:"hola"`),
// Big all-in-one
{
in: "count:42 # Meaning\n" +
`name:"Dave" ` +
`quote:"\"I didn't want to go.\"" ` +
`pet:"bunny" ` +
`pet:"kitty" ` +
`pet:"horsey" ` +
`inner:<` +
` host:"footrest.syd" ` +
` port:7001 ` +
` connected:true ` +
`> ` +
`others:<` +
` key:3735928559 ` +
` value:"\x01A\a\f" ` +
`> ` +
`others:<` +
" weight:58.9 # Atomic weight of Co\n" +
` inner:<` +
` host:"lesha.mtv" ` +
` port:8002 ` +
` >` +
`>`,
out: &MyMessage{
Count: Int32(42),
Name: String("Dave"),
Quote: String(`"I didn't want to go."`),
Pet: []string{"bunny", "kitty", "horsey"},
Inner: &InnerMessage{
Host: String("footrest.syd"),
Port: Int32(7001),
Connected: Bool(true),
},
Others: []*OtherMessage{
{
Key: Int64(3735928559),
Value: []byte{0x1, 'A', '\a', '\f'},
},
{
Weight: Float32(58.9),
Inner: &InnerMessage{
Host: String("lesha.mtv"),
Port: Int32(8002),
},
},
},
},
},
}
func TestUnmarshalText(t *testing.T) {
for i, test := range unMarshalTextTests {
pb := new(MyMessage)
err := UnmarshalText(test.in, pb)
if test.err == "" {
// We don't expect failure.
if err != nil {
t.Errorf("Test %d: Unexpected error: %v", i, err)
} else if !reflect.DeepEqual(pb, test.out) {
t.Errorf("Test %d: Incorrect populated \nHave: %v\nWant: %v",
i, pb, test.out)
}
} else {
// We do expect failure.
if err == nil {
t.Errorf("Test %d: Didn't get expected error: %v", i, test.err)
} else if err.Error() != test.err {
t.Errorf("Test %d: Incorrect error.\nHave: %v\nWant: %v",
i, err.Error(), test.err)
}
}
}
}
func TestUnmarshalTextCustomMessage(t *testing.T) {
msg := &textMessage{}
if err := UnmarshalText("custom", msg); err != nil {
t.Errorf("Unexpected error from custom unmarshal: %v", err)
}
if UnmarshalText("not custom", msg) == nil {
t.Errorf("Didn't get expected error from custom unmarshal")
}
}
// Regression test; this caused a panic.
func TestRepeatedEnum(t *testing.T) {
pb := new(RepeatedEnum)
if err := UnmarshalText("color: RED", pb); err != nil {
t.Fatal(err)
}
exp := &RepeatedEnum{
Color: []RepeatedEnum_Color{RepeatedEnum_RED},
}
if !Equal(pb, exp) {
t.Errorf("Incorrect populated \nHave: %v\nWant: %v", pb, exp)
}
}
var benchInput string
func init() {
benchInput = "count: 4\n"
for i := 0; i < 1000; i++ {
benchInput += "pet: \"fido\"\n"
}
// Check it is valid input.
pb := new(MyMessage)
err := UnmarshalText(benchInput, pb)
if err != nil {
panic("Bad benchmark input: " + err.Error())
}
}
func BenchmarkUnmarshalText(b *testing.B) {
pb := new(MyMessage)
for i := 0; i < b.N; i++ {
UnmarshalText(benchInput, pb)
}
b.SetBytes(int64(len(benchInput)))
}

View file

@ -0,0 +1,408 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto_test
import (
"bytes"
"errors"
"io/ioutil"
"math"
"strings"
"testing"
"code.google.com/p/goprotobuf/proto"
pb "./testdata"
)
// textMessage implements the methods that allow it to marshal and unmarshal
// itself as text.
type textMessage struct {
}
func (*textMessage) MarshalText() ([]byte, error) {
return []byte("custom"), nil
}
func (*textMessage) UnmarshalText(bytes []byte) error {
if string(bytes) != "custom" {
return errors.New("expected 'custom'")
}
return nil
}
func (*textMessage) Reset() {}
func (*textMessage) String() string { return "" }
func (*textMessage) ProtoMessage() {}
func newTestMessage() *pb.MyMessage {
msg := &pb.MyMessage{
Count: proto.Int32(42),
Name: proto.String("Dave"),
Quote: proto.String(`"I didn't want to go."`),
Pet: []string{"bunny", "kitty", "horsey"},
Inner: &pb.InnerMessage{
Host: proto.String("footrest.syd"),
Port: proto.Int32(7001),
Connected: proto.Bool(true),
},
Others: []*pb.OtherMessage{
{
Key: proto.Int64(0xdeadbeef),
Value: []byte{1, 65, 7, 12},
},
{
Weight: proto.Float32(6.022),
Inner: &pb.InnerMessage{
Host: proto.String("lesha.mtv"),
Port: proto.Int32(8002),
},
},
},
Bikeshed: pb.MyMessage_BLUE.Enum(),
Somegroup: &pb.MyMessage_SomeGroup{
GroupField: proto.Int32(8),
},
// One normally wouldn't do this.
// This is an undeclared tag 13, as a varint (wire type 0) with value 4.
XXX_unrecognized: []byte{13<<3 | 0, 4},
}
ext := &pb.Ext{
Data: proto.String("Big gobs for big rats"),
}
if err := proto.SetExtension(msg, pb.E_Ext_More, ext); err != nil {
panic(err)
}
greetings := []string{"adg", "easy", "cow"}
if err := proto.SetExtension(msg, pb.E_Greeting, greetings); err != nil {
panic(err)
}
// Add an unknown extension. We marshal a pb.Ext, and fake the ID.
b, err := proto.Marshal(&pb.Ext{Data: proto.String("3G skiing")})
if err != nil {
panic(err)
}
b = append(proto.EncodeVarint(201<<3|proto.WireBytes), b...)
proto.SetRawExtension(msg, 201, b)
// Extensions can be plain fields, too, so let's test that.
b = append(proto.EncodeVarint(202<<3|proto.WireVarint), 19)
proto.SetRawExtension(msg, 202, b)
return msg
}
const text = `count: 42
name: "Dave"
quote: "\"I didn't want to go.\""
pet: "bunny"
pet: "kitty"
pet: "horsey"
inner: <
host: "footrest.syd"
port: 7001
connected: true
>
others: <
key: 3735928559
value: "\001A\007\014"
>
others: <
weight: 6.022
inner: <
host: "lesha.mtv"
port: 8002
>
>
bikeshed: BLUE
SomeGroup {
group_field: 8
}
/* 2 unknown bytes */
13: 4
[testdata.Ext.more]: <
data: "Big gobs for big rats"
>
[testdata.greeting]: "adg"
[testdata.greeting]: "easy"
[testdata.greeting]: "cow"
/* 13 unknown bytes */
201: "\t3G skiing"
/* 3 unknown bytes */
202: 19
`
func TestMarshalText(t *testing.T) {
buf := new(bytes.Buffer)
if err := proto.MarshalText(buf, newTestMessage()); err != nil {
t.Fatalf("proto.MarshalText: %v", err)
}
s := buf.String()
if s != text {
t.Errorf("Got:\n===\n%v===\nExpected:\n===\n%v===\n", s, text)
}
}
func TestMarshalTextCustomMessage(t *testing.T) {
buf := new(bytes.Buffer)
if err := proto.MarshalText(buf, &textMessage{}); err != nil {
t.Fatalf("proto.MarshalText: %v", err)
}
s := buf.String()
if s != "custom" {
t.Errorf("Got %q, expected %q", s, "custom")
}
}
func TestMarshalTextNil(t *testing.T) {
want := "<nil>"
tests := []proto.Message{nil, (*pb.MyMessage)(nil)}
for i, test := range tests {
buf := new(bytes.Buffer)
if err := proto.MarshalText(buf, test); err != nil {
t.Fatal(err)
}
if got := buf.String(); got != want {
t.Errorf("%d: got %q want %q", i, got, want)
}
}
}
func TestMarshalTextUnknownEnum(t *testing.T) {
// The Color enum only specifies values 0-2.
m := &pb.MyMessage{Bikeshed: pb.MyMessage_Color(3).Enum()}
got := m.String()
const want = `bikeshed:3 `
if got != want {
t.Errorf("\n got %q\nwant %q", got, want)
}
}
func BenchmarkMarshalTextBuffered(b *testing.B) {
buf := new(bytes.Buffer)
m := newTestMessage()
for i := 0; i < b.N; i++ {
buf.Reset()
proto.MarshalText(buf, m)
}
}
func BenchmarkMarshalTextUnbuffered(b *testing.B) {
w := ioutil.Discard
m := newTestMessage()
for i := 0; i < b.N; i++ {
proto.MarshalText(w, m)
}
}
func compact(src string) string {
// s/[ \n]+/ /g; s/ $//;
dst := make([]byte, len(src))
space, comment := false, false
j := 0
for i := 0; i < len(src); i++ {
if strings.HasPrefix(src[i:], "/*") {
comment = true
i++
continue
}
if comment && strings.HasPrefix(src[i:], "*/") {
comment = false
i++
continue
}
if comment {
continue
}
c := src[i]
if c == ' ' || c == '\n' {
space = true
continue
}
if j > 0 && (dst[j-1] == ':' || dst[j-1] == '<' || dst[j-1] == '{') {
space = false
}
if c == '{' {
space = false
}
if space {
dst[j] = ' '
j++
space = false
}
dst[j] = c
j++
}
if space {
dst[j] = ' '
j++
}
return string(dst[0:j])
}
var compactText = compact(text)
func TestCompactText(t *testing.T) {
s := proto.CompactTextString(newTestMessage())
if s != compactText {
t.Errorf("Got:\n===\n%v===\nExpected:\n===\n%v\n===\n", s, compactText)
}
}
func TestStringEscaping(t *testing.T) {
testCases := []struct {
in *pb.Strings
out string
}{
{
// Test data from C++ test (TextFormatTest.StringEscape).
// Single divergence: we don't escape apostrophes.
&pb.Strings{StringField: proto.String("\"A string with ' characters \n and \r newlines and \t tabs and \001 slashes \\ and multiple spaces")},
"string_field: \"\\\"A string with ' characters \\n and \\r newlines and \\t tabs and \\001 slashes \\\\ and multiple spaces\"\n",
},
{
// Test data from the same C++ test.
&pb.Strings{StringField: proto.String("\350\260\267\346\255\214")},
"string_field: \"\\350\\260\\267\\346\\255\\214\"\n",
},
{
// Some UTF-8.
&pb.Strings{StringField: proto.String("\x00\x01\xff\x81")},
`string_field: "\000\001\377\201"` + "\n",
},
}
for i, tc := range testCases {
var buf bytes.Buffer
if err := proto.MarshalText(&buf, tc.in); err != nil {
t.Errorf("proto.MarsalText: %v", err)
continue
}
s := buf.String()
if s != tc.out {
t.Errorf("#%d: Got:\n%s\nExpected:\n%s\n", i, s, tc.out)
continue
}
// Check round-trip.
pb := new(pb.Strings)
if err := proto.UnmarshalText(s, pb); err != nil {
t.Errorf("#%d: UnmarshalText: %v", i, err)
continue
}
if !proto.Equal(pb, tc.in) {
t.Errorf("#%d: Round-trip failed:\nstart: %v\n end: %v", i, tc.in, pb)
}
}
}
// A limitedWriter accepts some output before it fails.
// This is a proxy for something like a nearly-full or imminently-failing disk,
// or a network connection that is about to die.
type limitedWriter struct {
b bytes.Buffer
limit int
}
var outOfSpace = errors.New("proto: insufficient space")
func (w *limitedWriter) Write(p []byte) (n int, err error) {
var avail = w.limit - w.b.Len()
if avail <= 0 {
return 0, outOfSpace
}
if len(p) <= avail {
return w.b.Write(p)
}
n, _ = w.b.Write(p[:avail])
return n, outOfSpace
}
func TestMarshalTextFailing(t *testing.T) {
// Try lots of different sizes to exercise more error code-paths.
for lim := 0; lim < len(text); lim++ {
buf := new(limitedWriter)
buf.limit = lim
err := proto.MarshalText(buf, newTestMessage())
// We expect a certain error, but also some partial results in the buffer.
if err != outOfSpace {
t.Errorf("Got:\n===\n%v===\nExpected:\n===\n%v===\n", err, outOfSpace)
}
s := buf.b.String()
x := text[:buf.limit]
if s != x {
t.Errorf("Got:\n===\n%v===\nExpected:\n===\n%v===\n", s, x)
}
}
}
func TestFloats(t *testing.T) {
tests := []struct {
f float64
want string
}{
{0, "0"},
{4.7, "4.7"},
{math.Inf(1), "inf"},
{math.Inf(-1), "-inf"},
{math.NaN(), "nan"},
}
for _, test := range tests {
msg := &pb.FloatingPoint{F: &test.f}
got := strings.TrimSpace(msg.String())
want := `f:` + test.want
if got != want {
t.Errorf("f=%f: got %q, want %q", test.f, got, want)
}
}
}
func TestRepeatedNilText(t *testing.T) {
m := &pb.MessageList{
Message: []*pb.MessageList_Message{
nil,
&pb.MessageList_Message{
Name: proto.String("Horse"),
},
nil,
},
}
want := `Message <nil>
Message {
name: "Horse"
}
Message <nil>
`
if s := proto.MarshalTextString(m); s != want {
t.Errorf(" got: %s\nwant: %s", s, want)
}
}

191
Godeps/_workspace/src/github.com/golang/glog/LICENSE generated vendored Normal file
View file

@ -0,0 +1,191 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction, and
distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by the copyright
owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all other entities
that control, are controlled by, or are under common control with that entity.
For the purposes of this definition, "control" means (i) the power, direct or
indirect, to cause the direction or management of such entity, whether by
contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity exercising
permissions granted by this License.
"Source" form shall mean the preferred form for making modifications, including
but not limited to software source code, documentation source, and configuration
files.
"Object" form shall mean any form resulting from mechanical transformation or
translation of a Source form, including but not limited to compiled object code,
generated documentation, and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or Object form, made
available under the License, as indicated by a copyright notice that is included
in or attached to the work (an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object form, that
is based on (or derived from) the Work and for which the editorial revisions,
annotations, elaborations, or other modifications represent, as a whole, an
original work of authorship. For the purposes of this License, Derivative Works
shall not include works that remain separable from, or merely link (or bind by
name) to the interfaces of, the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including the original version
of the Work and any modifications or additions to that Work or Derivative Works
thereof, that is intentionally submitted to Licensor for inclusion in the Work
by the copyright owner or by an individual or Legal Entity authorized to submit
on behalf of the copyright owner. For the purposes of this definition,
"submitted" means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems, and
issue tracking systems that are managed by, or on behalf of, the Licensor for
the purpose of discussing and improving the Work, but excluding communication
that is conspicuously marked or otherwise designated in writing by the copyright
owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf
of whom a Contribution has been received by Licensor and subsequently
incorporated within the Work.
2. Grant of Copyright License.
Subject to the terms and conditions of this License, each Contributor hereby
grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
irrevocable copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the Work and such
Derivative Works in Source or Object form.
3. Grant of Patent License.
Subject to the terms and conditions of this License, each Contributor hereby
grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
irrevocable (except as stated in this section) patent license to make, have
made, use, offer to sell, sell, import, and otherwise transfer the Work, where
such license applies only to those patent claims licensable by such Contributor
that are necessarily infringed by their Contribution(s) alone or by combination
of their Contribution(s) with the Work to which such Contribution(s) was
submitted. If You institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work or a
Contribution incorporated within the Work constitutes direct or contributory
patent infringement, then any patent licenses granted to You under this License
for that Work shall terminate as of the date such litigation is filed.
4. Redistribution.
You may reproduce and distribute copies of the Work or Derivative Works thereof
in any medium, with or without modifications, and in Source or Object form,
provided that You meet the following conditions:
You must give any other recipients of the Work or Derivative Works a copy of
this License; and
You must cause any modified files to carry prominent notices stating that You
changed the files; and
You must retain, in the Source form of any Derivative Works that You distribute,
all copyright, patent, trademark, and attribution notices from the Source form
of the Work, excluding those notices that do not pertain to any part of the
Derivative Works; and
If the Work includes a "NOTICE" text file as part of its distribution, then any
Derivative Works that You distribute must include a readable copy of the
attribution notices contained within such NOTICE file, excluding those notices
that do not pertain to any part of the Derivative Works, in at least one of the
following places: within a NOTICE text file distributed as part of the
Derivative Works; within the Source form or documentation, if provided along
with the Derivative Works; or, within a display generated by the Derivative
Works, if and wherever such third-party notices normally appear. The contents of
the NOTICE file are for informational purposes only and do not modify the
License. You may add Your own attribution notices within Derivative Works that
You distribute, alongside or as an addendum to the NOTICE text from the Work,
provided that such additional attribution notices cannot be construed as
modifying the License.
You may add Your own copyright statement to Your modifications and may provide
additional or different license terms and conditions for use, reproduction, or
distribution of Your modifications, or for any such Derivative Works as a whole,
provided Your use, reproduction, and distribution of the Work otherwise complies
with the conditions stated in this License.
5. Submission of Contributions.
Unless You explicitly state otherwise, any Contribution intentionally submitted
for inclusion in the Work by You to the Licensor shall be under the terms and
conditions of this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify the terms of
any separate license agreement you may have executed with Licensor regarding
such Contributions.
6. Trademarks.
This License does not grant permission to use the trade names, trademarks,
service marks, or product names of the Licensor, except as required for
reasonable and customary use in describing the origin of the Work and
reproducing the content of the NOTICE file.
7. Disclaimer of Warranty.
Unless required by applicable law or agreed to in writing, Licensor provides the
Work (and each Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied,
including, without limitation, any warranties or conditions of TITLE,
NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are
solely responsible for determining the appropriateness of using or
redistributing the Work and assume any risks associated with Your exercise of
permissions under this License.
8. Limitation of Liability.
In no event and under no legal theory, whether in tort (including negligence),
contract, or otherwise, unless required by applicable law (such as deliberate
and grossly negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special, incidental,
or consequential damages of any character arising as a result of this License or
out of the use or inability to use the Work (including but not limited to
damages for loss of goodwill, work stoppage, computer failure or malfunction, or
any and all other commercial damages or losses), even if such Contributor has
been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability.
While redistributing the Work or Derivative Works thereof, You may choose to
offer, and charge a fee for, acceptance of support, warranty, indemnity, or
other liability obligations and/or rights consistent with this License. However,
in accepting such obligations, You may act only on Your own behalf and on Your
sole responsibility, not on behalf of any other Contributor, and only if You
agree to indemnify, defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason of your
accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work
To apply the Apache License to your work, attach the following boilerplate
notice, with the fields enclosed by brackets "[]" replaced with your own
identifying information. (Don't include the brackets!) The text should be
enclosed in the appropriate comment syntax for the file format. We also
recommend that a file or class name and description of purpose be included on
the same "printed page" as the copyright notice for easier identification within
third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

44
Godeps/_workspace/src/github.com/golang/glog/README generated vendored Normal file
View file

@ -0,0 +1,44 @@
glog
====
Leveled execution logs for Go.
This is an efficient pure Go implementation of leveled logs in the
manner of the open source C++ package
http://code.google.com/p/google-glog
By binding methods to booleans it is possible to use the log package
without paying the expense of evaluating the arguments to the log.
Through the -vmodule flag, the package also provides fine-grained
control over logging at the file level.
The comment from glog.go introduces the ideas:
Package glog implements logging analogous to the Google-internal
C++ INFO/ERROR/V setup. It provides functions Info, Warning,
Error, Fatal, plus formatting variants such as Infof. It
also provides V-style logging controlled by the -v and
-vmodule=file=2 flags.
Basic examples:
glog.Info("Prepare to repel boarders")
glog.Fatalf("Initialization failed: %s", err)
See the documentation for the V function for an explanation
of these examples:
if glog.V(2) {
glog.Info("Starting transaction...")
}
glog.V(2).Infoln("Processed", nItems, "elements")
The repository contains an open source version of the log package
used inside Google. The master copy of the source lives inside
Google, not here. The code in this repo is for export only and is not itself
under development. Feature requests will be ignored.
Send bug reports to golang-nuts@googlegroups.com.

1177
Godeps/_workspace/src/github.com/golang/glog/glog.go generated vendored Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,124 @@
// Go support for leveled logs, analogous to https://code.google.com/p/google-glog/
//
// Copyright 2013 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// File I/O for logs.
package glog
import (
"errors"
"flag"
"fmt"
"os"
"os/user"
"path/filepath"
"strings"
"sync"
"time"
)
// MaxSize is the maximum size of a log file in bytes.
var MaxSize uint64 = 1024 * 1024 * 1800
// logDirs lists the candidate directories for new log files.
var logDirs []string
// If non-empty, overrides the choice of directory in which to write logs.
// See createLogDirs for the full list of possible destinations.
var logDir = flag.String("log_dir", "", "If non-empty, write log files in this directory")
func createLogDirs() {
if *logDir != "" {
logDirs = append(logDirs, *logDir)
}
logDirs = append(logDirs, os.TempDir())
}
var (
pid = os.Getpid()
program = filepath.Base(os.Args[0])
host = "unknownhost"
userName = "unknownuser"
)
func init() {
h, err := os.Hostname()
if err == nil {
host = shortHostname(h)
}
current, err := user.Current()
if err == nil {
userName = current.Username
}
// Sanitize userName since it may contain filepath separators on Windows.
userName = strings.Replace(userName, `\`, "_", -1)
}
// shortHostname returns its argument, truncating at the first period.
// For instance, given "www.google.com" it returns "www".
func shortHostname(hostname string) string {
if i := strings.Index(hostname, "."); i >= 0 {
return hostname[:i]
}
return hostname
}
// logName returns a new log file name containing tag, with start time t, and
// the name for the symlink for tag.
func logName(tag string, t time.Time) (name, link string) {
name = fmt.Sprintf("%s.%s.%s.log.%s.%04d%02d%02d-%02d%02d%02d.%d",
program,
host,
userName,
tag,
t.Year(),
t.Month(),
t.Day(),
t.Hour(),
t.Minute(),
t.Second(),
pid)
return name, program + "." + tag
}
var onceLogDirs sync.Once
// create creates a new log file and returns the file and its filename, which
// contains tag ("INFO", "FATAL", etc.) and t. If the file is created
// successfully, create also attempts to update the symlink for that tag, ignoring
// errors.
func create(tag string, t time.Time) (f *os.File, filename string, err error) {
onceLogDirs.Do(createLogDirs)
if len(logDirs) == 0 {
return nil, "", errors.New("log: no log dirs")
}
name, link := logName(tag, t)
var lastErr error
for _, dir := range logDirs {
fname := filepath.Join(dir, name)
f, err := os.Create(fname)
if err == nil {
symlink := filepath.Join(dir, link)
os.Remove(symlink) // ignore err
os.Symlink(name, symlink) // ignore err
return f, fname, nil
}
lastErr = err
}
return nil, "", fmt.Errorf("log: cannot create log: %v", lastErr)
}

View file

@ -0,0 +1,415 @@
// Go support for leveled logs, analogous to https://code.google.com/p/google-glog/
//
// Copyright 2013 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package glog
import (
"bytes"
"fmt"
stdLog "log"
"path/filepath"
"runtime"
"strconv"
"strings"
"testing"
"time"
)
// Test that shortHostname works as advertised.
func TestShortHostname(t *testing.T) {
for hostname, expect := range map[string]string{
"": "",
"host": "host",
"host.google.com": "host",
} {
if got := shortHostname(hostname); expect != got {
t.Errorf("shortHostname(%q): expected %q, got %q", hostname, expect, got)
}
}
}
// flushBuffer wraps a bytes.Buffer to satisfy flushSyncWriter.
type flushBuffer struct {
bytes.Buffer
}
func (f *flushBuffer) Flush() error {
return nil
}
func (f *flushBuffer) Sync() error {
return nil
}
// swap sets the log writers and returns the old array.
func (l *loggingT) swap(writers [numSeverity]flushSyncWriter) (old [numSeverity]flushSyncWriter) {
l.mu.Lock()
defer l.mu.Unlock()
old = l.file
for i, w := range writers {
logging.file[i] = w
}
return
}
// newBuffers sets the log writers to all new byte buffers and returns the old array.
func (l *loggingT) newBuffers() [numSeverity]flushSyncWriter {
return l.swap([numSeverity]flushSyncWriter{new(flushBuffer), new(flushBuffer), new(flushBuffer), new(flushBuffer)})
}
// contents returns the specified log value as a string.
func contents(s severity) string {
return logging.file[s].(*flushBuffer).String()
}
// contains reports whether the string is contained in the log.
func contains(s severity, str string, t *testing.T) bool {
return strings.Contains(contents(s), str)
}
// setFlags configures the logging flags how the test expects them.
func setFlags() {
logging.toStderr = false
}
// Test that Info works as advertised.
func TestInfo(t *testing.T) {
setFlags()
defer logging.swap(logging.newBuffers())
Info("test")
if !contains(infoLog, "I", t) {
t.Errorf("Info has wrong character: %q", contents(infoLog))
}
if !contains(infoLog, "test", t) {
t.Error("Info failed")
}
}
func TestInfoDepth(t *testing.T) {
setFlags()
defer logging.swap(logging.newBuffers())
f := func() { InfoDepth(1, "depth-test1") }
// The next three lines must stay together
_, _, wantLine, _ := runtime.Caller(0)
InfoDepth(0, "depth-test0")
f()
msgs := strings.Split(strings.TrimSuffix(contents(infoLog), "\n"), "\n")
if len(msgs) != 2 {
t.Fatalf("Got %d lines, expected 2", len(msgs))
}
for i, m := range msgs {
if !strings.HasPrefix(m, "I") {
t.Errorf("InfoDepth[%d] has wrong character: %q", i, m)
}
w := fmt.Sprintf("depth-test%d", i)
if !strings.Contains(m, w) {
t.Errorf("InfoDepth[%d] missing %q: %q", i, w, m)
}
// pull out the line number (between : and ])
msg := m[strings.LastIndex(m, ":")+1:]
x := strings.Index(msg, "]")
if x < 0 {
t.Errorf("InfoDepth[%d]: missing ']': %q", i, m)
continue
}
line, err := strconv.Atoi(msg[:x])
if err != nil {
t.Errorf("InfoDepth[%d]: bad line number: %q", i, m)
continue
}
wantLine++
if wantLine != line {
t.Errorf("InfoDepth[%d]: got line %d, want %d", i, line, wantLine)
}
}
}
func init() {
CopyStandardLogTo("INFO")
}
// Test that CopyStandardLogTo panics on bad input.
func TestCopyStandardLogToPanic(t *testing.T) {
defer func() {
if s, ok := recover().(string); !ok || !strings.Contains(s, "LOG") {
t.Errorf(`CopyStandardLogTo("LOG") should have panicked: %v`, s)
}
}()
CopyStandardLogTo("LOG")
}
// Test that using the standard log package logs to INFO.
func TestStandardLog(t *testing.T) {
setFlags()
defer logging.swap(logging.newBuffers())
stdLog.Print("test")
if !contains(infoLog, "I", t) {
t.Errorf("Info has wrong character: %q", contents(infoLog))
}
if !contains(infoLog, "test", t) {
t.Error("Info failed")
}
}
// Test that the header has the correct format.
func TestHeader(t *testing.T) {
setFlags()
defer logging.swap(logging.newBuffers())
defer func(previous func() time.Time) { timeNow = previous }(timeNow)
timeNow = func() time.Time {
return time.Date(2006, 1, 2, 15, 4, 5, .067890e9, time.Local)
}
pid = 1234
Info("test")
var line int
format := "I0102 15:04:05.067890 1234 glog_test.go:%d] test\n"
n, err := fmt.Sscanf(contents(infoLog), format, &line)
if n != 1 || err != nil {
t.Errorf("log format error: %d elements, error %s:\n%s", n, err, contents(infoLog))
}
// Scanf treats multiple spaces as equivalent to a single space,
// so check for correct space-padding also.
want := fmt.Sprintf(format, line)
if contents(infoLog) != want {
t.Errorf("log format error: got:\n\t%q\nwant:\t%q", contents(infoLog), want)
}
}
// Test that an Error log goes to Warning and Info.
// Even in the Info log, the source character will be E, so the data should
// all be identical.
func TestError(t *testing.T) {
setFlags()
defer logging.swap(logging.newBuffers())
Error("test")
if !contains(errorLog, "E", t) {
t.Errorf("Error has wrong character: %q", contents(errorLog))
}
if !contains(errorLog, "test", t) {
t.Error("Error failed")
}
str := contents(errorLog)
if !contains(warningLog, str, t) {
t.Error("Warning failed")
}
if !contains(infoLog, str, t) {
t.Error("Info failed")
}
}
// Test that a Warning log goes to Info.
// Even in the Info log, the source character will be W, so the data should
// all be identical.
func TestWarning(t *testing.T) {
setFlags()
defer logging.swap(logging.newBuffers())
Warning("test")
if !contains(warningLog, "W", t) {
t.Errorf("Warning has wrong character: %q", contents(warningLog))
}
if !contains(warningLog, "test", t) {
t.Error("Warning failed")
}
str := contents(warningLog)
if !contains(infoLog, str, t) {
t.Error("Info failed")
}
}
// Test that a V log goes to Info.
func TestV(t *testing.T) {
setFlags()
defer logging.swap(logging.newBuffers())
logging.verbosity.Set("2")
defer logging.verbosity.Set("0")
V(2).Info("test")
if !contains(infoLog, "I", t) {
t.Errorf("Info has wrong character: %q", contents(infoLog))
}
if !contains(infoLog, "test", t) {
t.Error("Info failed")
}
}
// Test that a vmodule enables a log in this file.
func TestVmoduleOn(t *testing.T) {
setFlags()
defer logging.swap(logging.newBuffers())
logging.vmodule.Set("glog_test=2")
defer logging.vmodule.Set("")
if !V(1) {
t.Error("V not enabled for 1")
}
if !V(2) {
t.Error("V not enabled for 2")
}
if V(3) {
t.Error("V enabled for 3")
}
V(2).Info("test")
if !contains(infoLog, "I", t) {
t.Errorf("Info has wrong character: %q", contents(infoLog))
}
if !contains(infoLog, "test", t) {
t.Error("Info failed")
}
}
// Test that a vmodule of another file does not enable a log in this file.
func TestVmoduleOff(t *testing.T) {
setFlags()
defer logging.swap(logging.newBuffers())
logging.vmodule.Set("notthisfile=2")
defer logging.vmodule.Set("")
for i := 1; i <= 3; i++ {
if V(Level(i)) {
t.Errorf("V enabled for %d", i)
}
}
V(2).Info("test")
if contents(infoLog) != "" {
t.Error("V logged incorrectly")
}
}
// vGlobs are patterns that match/don't match this file at V=2.
var vGlobs = map[string]bool{
// Easy to test the numeric match here.
"glog_test=1": false, // If -vmodule sets V to 1, V(2) will fail.
"glog_test=2": true,
"glog_test=3": true, // If -vmodule sets V to 1, V(3) will succeed.
// These all use 2 and check the patterns. All are true.
"*=2": true,
"?l*=2": true,
"????_*=2": true,
"??[mno]?_*t=2": true,
// These all use 2 and check the patterns. All are false.
"*x=2": false,
"m*=2": false,
"??_*=2": false,
"?[abc]?_*t=2": false,
}
// Test that vmodule globbing works as advertised.
func testVmoduleGlob(pat string, match bool, t *testing.T) {
setFlags()
defer logging.swap(logging.newBuffers())
defer logging.vmodule.Set("")
logging.vmodule.Set(pat)
if V(2) != Verbose(match) {
t.Errorf("incorrect match for %q: got %t expected %t", pat, V(2), match)
}
}
// Test that a vmodule globbing works as advertised.
func TestVmoduleGlob(t *testing.T) {
for glob, match := range vGlobs {
testVmoduleGlob(glob, match, t)
}
}
func TestRollover(t *testing.T) {
setFlags()
var err error
defer func(previous func(error)) { logExitFunc = previous }(logExitFunc)
logExitFunc = func(e error) {
err = e
}
defer func(previous uint64) { MaxSize = previous }(MaxSize)
MaxSize = 512
Info("x") // Be sure we have a file.
info, ok := logging.file[infoLog].(*syncBuffer)
if !ok {
t.Fatal("info wasn't created")
}
if err != nil {
t.Fatalf("info has initial error: %v", err)
}
fname0 := info.file.Name()
Info(strings.Repeat("x", int(MaxSize))) // force a rollover
if err != nil {
t.Fatalf("info has error after big write: %v", err)
}
// Make sure the next log file gets a file name with a different
// time stamp.
//
// TODO: determine whether we need to support subsecond log
// rotation. C++ does not appear to handle this case (nor does it
// handle Daylight Savings Time properly).
time.Sleep(1 * time.Second)
Info("x") // create a new file
if err != nil {
t.Fatalf("error after rotation: %v", err)
}
fname1 := info.file.Name()
if fname0 == fname1 {
t.Errorf("info.f.Name did not change: %v", fname0)
}
if info.nbytes >= MaxSize {
t.Errorf("file size was not reset: %d", info.nbytes)
}
}
func TestLogBacktraceAt(t *testing.T) {
setFlags()
defer logging.swap(logging.newBuffers())
// The peculiar style of this code simplifies line counting and maintenance of the
// tracing block below.
var infoLine string
setTraceLocation := func(file string, line int, ok bool, delta int) {
if !ok {
t.Fatal("could not get file:line")
}
_, file = filepath.Split(file)
infoLine = fmt.Sprintf("%s:%d", file, line+delta)
err := logging.traceLocation.Set(infoLine)
if err != nil {
t.Fatal("error setting log_backtrace_at: ", err)
}
}
{
// Start of tracing block. These lines know about each other's relative position.
_, file, line, ok := runtime.Caller(0)
setTraceLocation(file, line, ok, +2) // Two lines between Caller and Info calls.
Info("we want a stack trace here")
}
numAppearances := strings.Count(contents(infoLog), infoLine)
if numAppearances < 2 {
// Need 2 appearances, one in the log header and one in the trace:
// log_test.go:281: I0511 16:36:06.952398 02238 log_test.go:280] we want a stack trace here
// ...
// github.com/glog/glog_test.go:280 (0x41ba91)
// ...
// We could be more precise but that would require knowing the details
// of the traceback format, which may not be dependable.
t.Fatal("got no trace back; log is ", contents(infoLog))
}
}
func BenchmarkHeader(b *testing.B) {
for i := 0; i < b.N; i++ {
buf, _, _ := logging.header(infoLog, 0)
logging.putBuffer(buf)
}
}

View file

@ -0,0 +1,355 @@
// Copyright 2013 Matt T. Proud
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"bytes"
"math/rand"
"reflect"
"testing"
"testing/quick"
. "code.google.com/p/goprotobuf/proto"
. "code.google.com/p/goprotobuf/proto/testdata"
)
func TestWriteDelimited(t *testing.T) {
for _, test := range []struct {
msg Message
buf []byte
n int
err error
}{
{
msg: &Empty{},
n: 1,
buf: []byte{0},
},
{
msg: &GoEnum{Foo: FOO_FOO1.Enum()},
n: 3,
buf: []byte{2, 8, 1},
},
{
msg: &Strings{
StringField: String(`This is my gigantic, unhappy string. It exceeds
the encoding size of a single byte varint. We are using it to fuzz test the
correctness of the header decoding mechanisms, which may prove problematic.
I expect it may. Let's hope you enjoy testing as much as we do.`),
},
n: 271,
buf: []byte{141, 2, 10, 138, 2, 84, 104, 105, 115, 32, 105, 115, 32, 109,
121, 32, 103, 105, 103, 97, 110, 116, 105, 99, 44, 32, 117, 110, 104,
97, 112, 112, 121, 32, 115, 116, 114, 105, 110, 103, 46, 32, 32, 73,
116, 32, 101, 120, 99, 101, 101, 100, 115, 10, 116, 104, 101, 32, 101,
110, 99, 111, 100, 105, 110, 103, 32, 115, 105, 122, 101, 32, 111, 102,
32, 97, 32, 115, 105, 110, 103, 108, 101, 32, 98, 121, 116, 101, 32,
118, 97, 114, 105, 110, 116, 46, 32, 32, 87, 101, 32, 97, 114, 101, 32,
117, 115, 105, 110, 103, 32, 105, 116, 32, 116, 111, 32, 102, 117, 122,
122, 32, 116, 101, 115, 116, 32, 116, 104, 101, 10, 99, 111, 114, 114,
101, 99, 116, 110, 101, 115, 115, 32, 111, 102, 32, 116, 104, 101, 32,
104, 101, 97, 100, 101, 114, 32, 100, 101, 99, 111, 100, 105, 110, 103,
32, 109, 101, 99, 104, 97, 110, 105, 115, 109, 115, 44, 32, 119, 104,
105, 99, 104, 32, 109, 97, 121, 32, 112, 114, 111, 118, 101, 32, 112,
114, 111, 98, 108, 101, 109, 97, 116, 105, 99, 46, 10, 73, 32, 101, 120,
112, 101, 99, 116, 32, 105, 116, 32, 109, 97, 121, 46, 32, 32, 76, 101,
116, 39, 115, 32, 104, 111, 112, 101, 32, 121, 111, 117, 32, 101, 110,
106, 111, 121, 32, 116, 101, 115, 116, 105, 110, 103, 32, 97, 115, 32,
109, 117, 99, 104, 32, 97, 115, 32, 119, 101, 32, 100, 111, 46},
},
} {
var buf bytes.Buffer
if n, err := WriteDelimited(&buf, test.msg); n != test.n || err != test.err {
t.Fatalf("WriteDelimited(buf, %#v) = %v, %v; want %v, %v", test.msg, n, err, test.n, test.err)
}
if out := buf.Bytes(); !bytes.Equal(out, test.buf) {
t.Fatalf("WriteDelimited(buf, %#v); buf = %v; want %v", test.msg, out, test.buf)
}
}
}
func TestReadDelimited(t *testing.T) {
for _, test := range []struct {
buf []byte
msg Message
n int
err error
}{
{
buf: []byte{0},
msg: &Empty{},
n: 1,
},
{
n: 3,
buf: []byte{2, 8, 1},
msg: &GoEnum{Foo: FOO_FOO1.Enum()},
},
{
buf: []byte{141, 2, 10, 138, 2, 84, 104, 105, 115, 32, 105, 115, 32, 109,
121, 32, 103, 105, 103, 97, 110, 116, 105, 99, 44, 32, 117, 110, 104,
97, 112, 112, 121, 32, 115, 116, 114, 105, 110, 103, 46, 32, 32, 73,
116, 32, 101, 120, 99, 101, 101, 100, 115, 10, 116, 104, 101, 32, 101,
110, 99, 111, 100, 105, 110, 103, 32, 115, 105, 122, 101, 32, 111, 102,
32, 97, 32, 115, 105, 110, 103, 108, 101, 32, 98, 121, 116, 101, 32,
118, 97, 114, 105, 110, 116, 46, 32, 32, 87, 101, 32, 97, 114, 101, 32,
117, 115, 105, 110, 103, 32, 105, 116, 32, 116, 111, 32, 102, 117, 122,
122, 32, 116, 101, 115, 116, 32, 116, 104, 101, 10, 99, 111, 114, 114,
101, 99, 116, 110, 101, 115, 115, 32, 111, 102, 32, 116, 104, 101, 32,
104, 101, 97, 100, 101, 114, 32, 100, 101, 99, 111, 100, 105, 110, 103,
32, 109, 101, 99, 104, 97, 110, 105, 115, 109, 115, 44, 32, 119, 104,
105, 99, 104, 32, 109, 97, 121, 32, 112, 114, 111, 118, 101, 32, 112,
114, 111, 98, 108, 101, 109, 97, 116, 105, 99, 46, 10, 73, 32, 101, 120,
112, 101, 99, 116, 32, 105, 116, 32, 109, 97, 121, 46, 32, 32, 76, 101,
116, 39, 115, 32, 104, 111, 112, 101, 32, 121, 111, 117, 32, 101, 110,
106, 111, 121, 32, 116, 101, 115, 116, 105, 110, 103, 32, 97, 115, 32,
109, 117, 99, 104, 32, 97, 115, 32, 119, 101, 32, 100, 111, 46},
msg: &Strings{
StringField: String(`This is my gigantic, unhappy string. It exceeds
the encoding size of a single byte varint. We are using it to fuzz test the
correctness of the header decoding mechanisms, which may prove problematic.
I expect it may. Let's hope you enjoy testing as much as we do.`),
},
n: 271,
},
} {
msg := Clone(test.msg)
msg.Reset()
if n, err := ReadDelimited(bytes.NewBuffer(test.buf), msg); n != test.n || err != test.err {
t.Fatalf("ReadDelimited(%v, msg) = %v, %v; want %v, %v", test.buf, n, err, test.n, test.err)
}
if !Equal(msg, test.msg) {
t.Fatalf("ReadDelimited(%v, msg); msg = %v; want %v", test.buf, msg, test.msg)
}
}
}
func TestEndToEndValid(t *testing.T) {
for _, test := range [][]Message{
[]Message{&Empty{}},
[]Message{&GoEnum{Foo: FOO_FOO1.Enum()}, &Empty{}, &GoEnum{Foo: FOO_FOO1.Enum()}},
[]Message{&GoEnum{Foo: FOO_FOO1.Enum()}},
[]Message{&Strings{
StringField: String(`This is my gigantic, unhappy string. It exceeds
the encoding size of a single byte varint. We are using it to fuzz test the
correctness of the header decoding mechanisms, which may prove problematic.
I expect it may. Let's hope you enjoy testing as much as we do.`),
}},
} {
var buf bytes.Buffer
var written int
for i, msg := range test {
n, err := WriteDelimited(&buf, msg)
if err != nil {
// Assumption: TestReadDelimited and TestWriteDelimited are sufficient
// and inputs for this test are explicitly exercised there.
t.Fatalf("WriteDelimited(buf, %v[%d]) = ?, %v; wanted ?, nil", test, i, err)
}
written += n
}
var read int
for i, msg := range test {
out := Clone(msg)
out.Reset()
n, _ := ReadDelimited(&buf, out)
// Decide to do EOF checking?
read += n
if !Equal(out, msg) {
t.Fatalf("out = %v; want %v[%d] = %#v", out, test, i, msg)
}
}
if read != written {
t.Fatalf("%v read = %d; want %d", test, read, written)
}
}
}
// visitMessage empties the private state fields of the quick.Value()-generated
// Protocol Buffer messages, for they cause an inordinate amount of problems.
// This is because we are using an automated fuzz generator on a type with
// private fields.
func visitMessage(m Message) {
t := reflect.TypeOf(m)
if t.Kind() != reflect.Ptr {
return
}
derefed := t.Elem()
if derefed.Kind() != reflect.Struct {
return
}
v := reflect.ValueOf(m)
elem := v.Elem()
for i := 0; i < elem.NumField(); i++ {
field := elem.FieldByIndex([]int{i})
fieldType := field.Type()
if fieldType.Implements(reflect.TypeOf((*Message)(nil)).Elem()) {
visitMessage(field.Interface().(Message))
}
if field.Kind() == reflect.Slice {
for i := 0; i < field.Len(); i++ {
elem := field.Index(i)
elemType := elem.Type()
if elemType.Implements(reflect.TypeOf((*Message)(nil)).Elem()) {
visitMessage(elem.Interface().(Message))
}
}
}
}
if field := elem.FieldByName("XXX_unrecognized"); field.IsValid() {
field.Set(reflect.ValueOf([]byte{}))
}
if field := elem.FieldByName("XXX_extensions"); field.IsValid() {
field.Set(reflect.ValueOf(nil))
}
}
// rndMessage generates a random valid Protocol Buffer message.
func rndMessage(r *rand.Rand) Message {
var t reflect.Type
switch v := rand.Intn(23); v {
// TODO(br): Uncomment the elements below once fix is incorporated, except
// for the elements marked as patently incompatible.
// case 0:
// t = reflect.TypeOf(&GoEnum{})
// break
// case 1:
// t = reflect.TypeOf(&GoTestField{})
// break
case 2:
t = reflect.TypeOf(&GoTest{})
break
// case 3:
// t = reflect.TypeOf(&GoSkipTest{})
// break
// case 4:
// t = reflect.TypeOf(&NonPackedTest{})
// break
// case 5:
// t = reflect.TypeOf(&PackedTest{})
// break
case 6:
t = reflect.TypeOf(&MaxTag{})
break
case 7:
t = reflect.TypeOf(&OldMessage{})
break
case 8:
t = reflect.TypeOf(&NewMessage{})
break
case 9:
t = reflect.TypeOf(&InnerMessage{})
break
case 10:
t = reflect.TypeOf(&OtherMessage{})
break
case 11:
// PATENTLY INVALID FOR FUZZ GENERATION
// t = reflect.TypeOf(&MyMessage{})
break
// case 12:
// t = reflect.TypeOf(&Ext{})
// break
case 13:
// PATENTLY INVALID FOR FUZZ GENERATION
// t = reflect.TypeOf(&MyMessageSet{})
break
// case 14:
// t = reflect.TypeOf(&Empty{})
// break
// case 15:
// t = reflect.TypeOf(&MessageList{})
// break
// case 16:
// t = reflect.TypeOf(&Strings{})
// break
// case 17:
// t = reflect.TypeOf(&Defaults{})
// break
// case 17:
// t = reflect.TypeOf(&SubDefaults{})
// break
// case 18:
// t = reflect.TypeOf(&RepeatedEnum{})
// break
case 19:
t = reflect.TypeOf(&MoreRepeated{})
break
// case 20:
// t = reflect.TypeOf(&GroupOld{})
// break
// case 21:
// t = reflect.TypeOf(&GroupNew{})
// break
case 22:
t = reflect.TypeOf(&FloatingPoint{})
break
default:
// TODO(br): Replace with an unreachable once fixed.
t = reflect.TypeOf(&GoTest{})
break
}
if t == nil {
t = reflect.TypeOf(&GoTest{})
}
v, ok := quick.Value(t, r)
if !ok {
panic("attempt to generate illegal item; consult item 11")
}
visitMessage(v.Interface().(Message))
return v.Interface().(Message)
}
// rndMessages generates several random Protocol Buffer messages.
func rndMessages(r *rand.Rand) []Message {
n := r.Intn(128)
out := make([]Message, 0, n)
for i := 0; i < n; i++ {
out = append(out, rndMessage(r))
}
return out
}
func TestFuzz(t *testing.T) {
rnd := rand.New(rand.NewSource(42))
check := func() bool {
messages := rndMessages(rnd)
var buf bytes.Buffer
var written int
for i, msg := range messages {
n, err := WriteDelimited(&buf, msg)
if err != nil {
t.Fatalf("WriteDelimited(buf, %v[%d]) = ?, %v; wanted ?, nil", messages, i, err)
}
written += n
}
var read int
for i, msg := range messages {
out := Clone(msg)
out.Reset()
n, _ := ReadDelimited(&buf, out)
read += n
if !Equal(out, msg) {
t.Fatalf("out = %v; want %v[%d] = %#v", out, messages, i, msg)
}
}
if read != written {
t.Fatalf("%v read = %d; want %d", messages, read, written)
}
return true
}
if err := quick.Check(check, nil); err != nil {
t.Fatal(err)
}
}

View file

@ -0,0 +1,75 @@
// Copyright 2013 Matt T. Proud
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"encoding/binary"
"errors"
"io"
"code.google.com/p/goprotobuf/proto"
)
var errInvalidVarint = errors.New("invalid varint32 encountered")
// ReadDelimited decodes a message from the provided length-delimited stream,
// where the length is encoded as 32-bit varint prefix to the message body.
// It returns the total number of bytes read and any applicable error. This is
// roughly equivalent to the companion Java API's
// MessageLite#parseDelimitedFrom. As per the reader contract, this function
// calls r.Read repeatedly as required until exactly one message including its
// prefix is read and decoded (or an error has occurred). The function never
// reads more bytes from the stream than required. The function never returns
// an error if a message has been read and decoded correctly, even if the end
// of the stream has been reached in doing so. In that case, any subsequent
// calls return (0, io.EOF).
func ReadDelimited(r io.Reader, m proto.Message) (n int, err error) {
// Per AbstractParser#parsePartialDelimitedFrom with
// CodedInputStream#readRawVarint32.
headerBuf := make([]byte, binary.MaxVarintLen32)
var bytesRead, varIntBytes int
var messageLength uint64
for varIntBytes == 0 { // i.e. no varint has been decoded yet.
if bytesRead >= len(headerBuf) {
return bytesRead, errInvalidVarint
}
// We have to read byte by byte here to avoid reading more bytes
// than required. Each read byte is appended to what we have
// read before.
newBytesRead, err := r.Read(headerBuf[bytesRead : bytesRead+1])
if newBytesRead == 0 {
if err != nil {
return bytesRead, err
}
// A Reader should not return (0, nil), but if it does,
// it should be treated as no-op (according to the
// Reader contract). So let's go on...
continue
}
bytesRead += newBytesRead
// Now present everything read so far to the varint decoder and
// see if a varint can be decoded already.
messageLength, varIntBytes = proto.DecodeVarint(headerBuf[:bytesRead])
}
messageBuf := make([]byte, messageLength)
newBytesRead, err := io.ReadFull(r, messageBuf)
bytesRead += newBytesRead
if err != nil {
return bytesRead, err
}
return bytesRead, proto.Unmarshal(messageBuf, m)
}

View file

@ -0,0 +1,16 @@
// Copyright 2013 Matt T. Proud
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package ext enables record length-delimited Protocol Buffer streaming.
package ext

View file

@ -0,0 +1,46 @@
// Copyright 2013 Matt T. Proud
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"encoding/binary"
"io"
"code.google.com/p/goprotobuf/proto"
)
// WriteDelimited encodes and dumps a message to the provided writer prefixed
// with a 32-bit varint indicating the length of the encoded message, producing
// a length-delimited record stream, which can be used to chain together
// encoded messages of the same type together in a file. It returns the total
// number of bytes written and any applicable error. This is roughly
// equivalent to the companion Java API's MessageLite#writeDelimitedTo.
func WriteDelimited(w io.Writer, m proto.Message) (n int, err error) {
buffer, err := proto.Marshal(m)
if err != nil {
return 0, err
}
buf := make([]byte, binary.MaxVarintLen32)
encodedLength := binary.PutUvarint(buf, uint64(len(buffer)))
sync, err := w.Write(buf[:encodedLength])
if err != nil {
return sync, err
}
n, err = w.Write(buffer)
return n + sync, err
}

View file

@ -0,0 +1,103 @@
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package ext
import (
. "code.google.com/p/goprotobuf/proto"
. "code.google.com/p/goprotobuf/proto/testdata"
)
// FROM https://code.google.com/p/goprotobuf/source/browse/proto/all_test.go.
func initGoTestField() *GoTestField {
f := new(GoTestField)
f.Label = String("label")
f.Type = String("type")
return f
}
// These are all structurally equivalent but the tag numbers differ.
// (It's remarkable that required, optional, and repeated all have
// 8 letters.)
func initGoTest_RequiredGroup() *GoTest_RequiredGroup {
return &GoTest_RequiredGroup{
RequiredField: String("required"),
}
}
func initGoTest_OptionalGroup() *GoTest_OptionalGroup {
return &GoTest_OptionalGroup{
RequiredField: String("optional"),
}
}
func initGoTest_RepeatedGroup() *GoTest_RepeatedGroup {
return &GoTest_RepeatedGroup{
RequiredField: String("repeated"),
}
}
func initGoTest(setdefaults bool) *GoTest {
pb := new(GoTest)
if setdefaults {
pb.F_BoolDefaulted = Bool(Default_GoTest_F_BoolDefaulted)
pb.F_Int32Defaulted = Int32(Default_GoTest_F_Int32Defaulted)
pb.F_Int64Defaulted = Int64(Default_GoTest_F_Int64Defaulted)
pb.F_Fixed32Defaulted = Uint32(Default_GoTest_F_Fixed32Defaulted)
pb.F_Fixed64Defaulted = Uint64(Default_GoTest_F_Fixed64Defaulted)
pb.F_Uint32Defaulted = Uint32(Default_GoTest_F_Uint32Defaulted)
pb.F_Uint64Defaulted = Uint64(Default_GoTest_F_Uint64Defaulted)
pb.F_FloatDefaulted = Float32(Default_GoTest_F_FloatDefaulted)
pb.F_DoubleDefaulted = Float64(Default_GoTest_F_DoubleDefaulted)
pb.F_StringDefaulted = String(Default_GoTest_F_StringDefaulted)
pb.F_BytesDefaulted = Default_GoTest_F_BytesDefaulted
pb.F_Sint32Defaulted = Int32(Default_GoTest_F_Sint32Defaulted)
pb.F_Sint64Defaulted = Int64(Default_GoTest_F_Sint64Defaulted)
}
pb.Kind = GoTest_TIME.Enum()
pb.RequiredField = initGoTestField()
pb.F_BoolRequired = Bool(true)
pb.F_Int32Required = Int32(3)
pb.F_Int64Required = Int64(6)
pb.F_Fixed32Required = Uint32(32)
pb.F_Fixed64Required = Uint64(64)
pb.F_Uint32Required = Uint32(3232)
pb.F_Uint64Required = Uint64(6464)
pb.F_FloatRequired = Float32(3232)
pb.F_DoubleRequired = Float64(6464)
pb.F_StringRequired = String("string")
pb.F_BytesRequired = []byte("bytes")
pb.F_Sint32Required = Int32(-32)
pb.F_Sint64Required = Int64(-64)
pb.Requiredgroup = initGoTest_RequiredGroup()
return pb
}

View file

@ -0,0 +1,4 @@
*.6
tags
test.out
a.out

19
Godeps/_workspace/src/github.com/miekg/dns/.travis.yml generated vendored Normal file
View file

@ -0,0 +1,19 @@
language: go
go:
- 1.2
- 1.3
env:
# "gvm update" resets GOOS and GOARCH environment variable, workaround it by setting
# BUILD_GOOS and BUILD_GOARCH and overriding GOARCH and GOOS in the build script
global:
- BUILD_GOARCH=amd64
matrix:
- BUILD_GOOS=linux
- BUILD_GOOS=darwin
- BUILD_GOOS=windows
script:
- gvm cross $BUILD_GOOS $BUILD_GOARCH
- GOARCH=$BUILD_GOARCH GOOS=$BUILD_GOOS go build
# only test on linux
- if [ $BUILD_GOOS == "linux" ]; then GOARCH=$BUILD_GOARCH GOOS=$BUILD_GOOS go test -bench=.; fi

1
Godeps/_workspace/src/github.com/miekg/dns/AUTHORS generated vendored Normal file
View file

@ -0,0 +1 @@
Miek Gieben <miek@miek.nl>

View file

@ -0,0 +1,9 @@
Alex A. Skinner
Andrew Tunnell-Jones
Ask Bjørn Hansen
Dave Cheney
Dusty Wilson
Marek Majkowski
Peter van Dijk
Omri Bahumi
Alex Sergeyev

9
Godeps/_workspace/src/github.com/miekg/dns/COPYRIGHT generated vendored Normal file
View file

@ -0,0 +1,9 @@
Copyright 2009 The Go Authors. All rights reserved. Use of this source code
is governed by a BSD-style license that can be found in the LICENSE file.
Extensions of the original work are copyright (c) 2011 Miek Gieben
Copyright 2011 Miek Gieben. All rights reserved. Use of this source code is
governed by a BSD-style license that can be found in the LICENSE file.
Copyright 2014 CloudFlare. All rights reserved. Use of this source code is
governed by a BSD-style license that can be found in the LICENSE file.

32
Godeps/_workspace/src/github.com/miekg/dns/LICENSE generated vendored Normal file
View file

@ -0,0 +1,32 @@
Extensions of the original work are copyright (c) 2011 Miek Gieben
As this is fork of the official Go code the same license applies:
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

140
Godeps/_workspace/src/github.com/miekg/dns/README.md generated vendored Normal file
View file

@ -0,0 +1,140 @@
[![Build Status](https://travis-ci.org/miekg/dns.svg?branch=master)](https://travis-ci.org/miekg/dns)
# Alternative (more granular) approach to a DNS library
> Less is more.
Complete and usable DNS library. All widely used Resource Records are
supported, including the DNSSEC types. It follows a lean and mean philosophy.
If there is stuff you should know as a DNS programmer there isn't a convenience
function for it. Server side and client side programming is supported, i.e. you
can build servers and resolvers with it.
If you like this, you may also be interested in:
* https://github.com/miekg/unbound -- Go wrapper for the Unbound resolver.
# Goals
* KISS;
* Fast;
* Small API, if its easy to code in Go, don't make a function for it.
# Users
A not-so-up-to-date-list-that-may-be-actually-current:
* https://github.com/abh/geodns
* http://www.statdns.com/
* http://www.dnsinspect.com/
* https://github.com/chuangbo/jianbing-dictionary-dns
* http://www.dns-lg.com/
* https://github.com/fcambus/rrda
* https://github.com/kenshinx/godns
* https://github.com/skynetservices/skydns
* https://github.com/DevelopersPL/godnsagent
* https://github.com/duedil-ltd/discodns
Send pull request if you want to be listed here.
# Features
* UDP/TCP queries, IPv4 and IPv6;
* RFC 1035 zone file parsing ($INCLUDE, $ORIGIN, $TTL and $GENERATE (for all record types) are supported;
* Fast:
* Reply speed around ~ 80K qps (faster hardware results in more qps);
* Parsing RRs ~ 100K RR/s, that's 5M records in about 50 seconds;
* Server side programming (mimicking the net/http package);
* Client side programming;
* DNSSEC: signing, validating and key generation for DSA, RSA and ECDSA;
* EDNS0, NSID;
* AXFR/IXFR;
* TSIG;
* DNS name compression;
* Depends only on the standard library.
Have fun!
Miek Gieben - 2010-2012 - <miek@miek.nl>
# Building
Building is done with the `go` tool. If you have setup your GOPATH
correctly, the following should work:
go get github.com/miekg/dns
go build github.com/miekg/dns
## Examples
A short "how to use the API" is at the beginning of dns.go (this also will show
when you call `godoc github.com/miekg/dns`).
Example programs can be found in the `github.com/miekg/exdns` repository.
## Supported RFCs
*all of them*
* 103{4,5} - DNS standard
* 1348 - NSAP record
* 1982 - Serial Arithmetic
* 1876 - LOC record
* 1995 - IXFR
* 1996 - DNS notify
* 2136 - DNS Update (dynamic updates)
* 2181 - RRset definition - there is no RRset type though, just []RR
* 2537 - RSAMD5 DNS keys
* 2065 - DNSSEC (updated in later RFCs)
* 2671 - EDNS record
* 2782 - SRV record
* 2845 - TSIG record
* 2915 - NAPTR record
* 2929 - DNS IANA Considerations
* 3110 - RSASHA1 DNS keys
* 3225 - DO bit (DNSSEC OK)
* 340{1,2,3} - NAPTR record
* 3445 - Limiting the scope of (DNS)KEY
* 3597 - Unkown RRs
* 403{3,4,5} - DNSSEC + validation functions
* 4255 - SSHFP record
* 4343 - Case insensitivity
* 4408 - SPF record
* 4509 - SHA256 Hash in DS
* 4592 - Wildcards in the DNS
* 4635 - HMAC SHA TSIG
* 4701 - DHCID
* 4892 - id.server
* 5001 - NSID
* 5155 - NSEC3 record
* 5205 - HIP record
* 5702 - SHA2 in the DNS
* 5936 - AXFR
* 5966 - TCP implementation recommendations
* 6605 - ECDSA
* 6742 - ILNP DNS
* 6891 - EDNS0 update
* 6895 - DNS IANA considerations
* 6975 - Algorithm Understanding in DNSSEC
* 7043 - EUI48/EUI64 records
* 7314 - DNS (EDNS) EXPIRE Option
* xxxx - URI record (draft)
* xxxx - EDNS0 DNS Update Lease (draft)
## Loosely based upon
* `ldns`
* `NSD`
* `Net::DNS`
* `GRONG`
## TODO
* privatekey.Precompute() when signing?
* Last remaining RRs: APL, ATMA, A6 and NXT;
* Missing in parsing: ISDN, UNSPEC, ATMA;
* CAA parsing is broken;
* NSEC(3) cover/match/closest enclose;
* Replies with TC bit are not parsed to the end;
* SIG(0);
* Create IsMsg to validate a message before fully parsing it.

319
Godeps/_workspace/src/github.com/miekg/dns/client.go generated vendored Normal file
View file

@ -0,0 +1,319 @@
package dns
// A client implementation.
import (
"bytes"
"io"
"net"
"time"
)
const dnsTimeout time.Duration = 2 * 1e9
const tcpIdleTimeout time.Duration = 8 * time.Second
// A Conn represents a connection to a DNS server.
type Conn struct {
net.Conn // a net.Conn holding the connection
UDPSize uint16 // minimum receive buffer for UDP messages
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
rtt time.Duration
t time.Time
tsigRequestMAC string
}
// A Client defines parameters for a DNS client.
type Client struct {
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
UDPSize uint16 // minimum receive buffer for UDP messages
DialTimeout time.Duration // net.DialTimeout (ns), defaults to 2 * 1e9
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections (ns), defaults to 2 * 1e9
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections (ns), defaults to 2 * 1e9
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
group singleflight
}
// Exchange performs a synchronous UDP query. It sends the message m to the address
// contained in a and waits for an reply. Exchange does not retry a failed query, nor
// will it fall back to TCP in case of truncation.
// If you need to send a DNS message on an already existing connection, you can use the
// following:
//
// co := &dns.Conn{Conn: c} // c is your net.Conn
// co.WriteMsg(m)
// in, err := co.ReadMsg()
// co.Close()
//
func Exchange(m *Msg, a string) (r *Msg, err error) {
var co *Conn
co, err = DialTimeout("udp", a, dnsTimeout)
if err != nil {
return nil, err
}
defer co.Close()
co.SetReadDeadline(time.Now().Add(dnsTimeout))
co.SetWriteDeadline(time.Now().Add(dnsTimeout))
if err = co.WriteMsg(m); err != nil {
return nil, err
}
r, err = co.ReadMsg()
return r, err
}
// ExchangeConn performs a synchronous query. It sends the message m via the connection
// c and waits for a reply. The connection c is not closed by ExchangeConn.
// This function is going away, but can easily be mimicked:
//
// co := &dns.Conn{Conn: c} // c is your net.Conn
// co.WriteMsg(m)
// in, _ := co.ReadMsg()
// co.Close()
//
func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) {
println("dns: this function is deprecated")
co := new(Conn)
co.Conn = c
if err = co.WriteMsg(m); err != nil {
return nil, err
}
r, err = co.ReadMsg()
return r, err
}
// Exchange performs an synchronous query. It sends the message m to the address
// contained in a and waits for an reply. Basic use pattern with a *dns.Client:
//
// c := new(dns.Client)
// in, rtt, err := c.Exchange(message, "127.0.0.1:53")
//
// Exchange does not retry a failed query, nor will it fall back to TCP in
// case of truncation.
func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
if !c.SingleInflight {
return c.exchange(m, a)
}
// This adds a bunch of garbage, TODO(miek).
t := "nop"
if t1, ok := TypeToString[m.Question[0].Qtype]; ok {
t = t1
}
cl := "nop"
if cl1, ok := ClassToString[m.Question[0].Qclass]; ok {
cl = cl1
}
r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) {
return c.exchange(m, a)
})
if err != nil {
return r, rtt, err
}
if shared {
return r.Copy(), rtt, nil
}
return r, rtt, nil
}
func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
timeout := dnsTimeout
var co *Conn
if c.DialTimeout != 0 {
timeout = c.DialTimeout
}
if c.Net == "" {
co, err = DialTimeout("udp", a, timeout)
} else {
co, err = DialTimeout(c.Net, a, timeout)
}
if err != nil {
return nil, 0, err
}
timeout = dnsTimeout
if c.ReadTimeout != 0 {
timeout = c.ReadTimeout
}
co.SetReadDeadline(time.Now().Add(timeout))
timeout = dnsTimeout
if c.WriteTimeout != 0 {
timeout = c.WriteTimeout
}
co.SetWriteDeadline(time.Now().Add(timeout))
defer co.Close()
opt := m.IsEdns0()
// If EDNS0 is used use that for size.
if opt != nil && opt.UDPSize() >= MinMsgSize {
co.UDPSize = opt.UDPSize()
}
// Otherwise use the client's configured UDP size.
if opt == nil && c.UDPSize >= MinMsgSize {
co.UDPSize = c.UDPSize
}
co.TsigSecret = c.TsigSecret
if err = co.WriteMsg(m); err != nil {
return nil, 0, err
}
r, err = co.ReadMsg()
return r, co.rtt, err
}
// ReadMsg reads a message from the connection co.
// If the received message contains a TSIG record the transaction
// signature is verified.
func (co *Conn) ReadMsg() (*Msg, error) {
var p []byte
m := new(Msg)
if _, ok := co.Conn.(*net.TCPConn); ok {
p = make([]byte, MaxMsgSize)
} else {
if co.UDPSize >= 512 {
p = make([]byte, co.UDPSize)
} else {
p = make([]byte, MinMsgSize)
}
}
n, err := co.Read(p)
if err != nil && n == 0 {
return nil, err
}
p = p[:n]
if err := m.Unpack(p); err != nil {
return nil, err
}
co.rtt = time.Since(co.t)
if t := m.IsTsig(); t != nil {
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
return m, ErrSecret
}
// Need to work on the original message p, as that was used to calculate the tsig.
err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
}
return m, err
}
// Read implements the net.Conn read method.
func (co *Conn) Read(p []byte) (n int, err error) {
if co.Conn == nil {
return 0, ErrConnEmpty
}
if len(p) < 2 {
return 0, io.ErrShortBuffer
}
if t, ok := co.Conn.(*net.TCPConn); ok {
n, err = t.Read(p[0:2])
if err != nil || n != 2 {
return n, err
}
l, _ := unpackUint16(p[0:2], 0)
if l == 0 {
return 0, ErrShortRead
}
if int(l) > len(p) {
return int(l), io.ErrShortBuffer
}
n, err = t.Read(p[:l])
if err != nil {
return n, err
}
i := n
for i < int(l) {
j, err := t.Read(p[i:int(l)])
if err != nil {
return i, err
}
i += j
}
n = i
return n, err
}
// UDP connection
n, err = co.Conn.Read(p)
if err != nil {
return n, err
}
return n, err
}
// WriteMsg sends a message throught the connection co.
// If the message m contains a TSIG record the transaction
// signature is calculated.
func (co *Conn) WriteMsg(m *Msg) (err error) {
var out []byte
if t := m.IsTsig(); t != nil {
mac := ""
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
return ErrSecret
}
out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
// Set for the next read, allthough only used in zone transfers
co.tsigRequestMAC = mac
} else {
out, err = m.Pack()
}
if err != nil {
return err
}
co.t = time.Now()
if _, err = co.Write(out); err != nil {
return err
}
return nil
}
// Write implements the net.Conn Write method.
func (co *Conn) Write(p []byte) (n int, err error) {
if t, ok := co.Conn.(*net.TCPConn); ok {
lp := len(p)
if lp < 2 {
return 0, io.ErrShortBuffer
}
if lp > MaxMsgSize {
return 0, &Error{err: "message too large"}
}
l := make([]byte, 2, lp+2)
l[0], l[1] = packUint16(uint16(lp))
p = append(l, p...)
n, err := io.Copy(t, bytes.NewReader(p))
return int(n), err
}
n, err = co.Conn.(*net.UDPConn).Write(p)
return n, err
}
// Dial connects to the address on the named network.
func Dial(network, address string) (conn *Conn, err error) {
conn = new(Conn)
conn.Conn, err = net.Dial(network, address)
if err != nil {
return nil, err
}
return conn, nil
}
// Dialtimeout acts like Dial but takes a timeout.
func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) {
conn = new(Conn)
conn.Conn, err = net.DialTimeout(network, address, timeout)
if err != nil {
return nil, err
}
return conn, nil
}
// Close implements the net.Conn Close method.
func (co *Conn) Close() error { return co.Conn.Close() }
// LocalAddr implements the net.Conn LocalAddr method.
func (co *Conn) LocalAddr() net.Addr { return co.Conn.LocalAddr() }
// RemoteAddr implements the net.Conn RemoteAddr method.
func (co *Conn) RemoteAddr() net.Addr { return co.Conn.RemoteAddr() }
// SetDeadline implements the net.Conn SetDeadline method.
func (co *Conn) SetDeadline(t time.Time) error { return co.Conn.SetDeadline(t) }
// SetReadDeadline implements the net.Conn SetReadDeadline method.
func (co *Conn) SetReadDeadline(t time.Time) error { return co.Conn.SetReadDeadline(t) }
// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
func (co *Conn) SetWriteDeadline(t time.Time) error { return co.Conn.SetWriteDeadline(t) }

View file

@ -0,0 +1,195 @@
package dns
import (
"testing"
"time"
)
func TestClientSync(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalUDPServer("127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to run test server: %s", err)
}
defer s.Shutdown()
m := new(Msg)
m.SetQuestion("miek.nl.", TypeSOA)
c := new(Client)
r, _, e := c.Exchange(m, addrstr)
if e != nil {
t.Logf("failed to exchange: %s", e.Error())
t.Fail()
}
if r != nil && r.Rcode != RcodeSuccess {
t.Log("failed to get an valid answer")
t.Fail()
t.Logf("%v\n", r)
}
// And now with plain Exchange().
r, e = Exchange(m, addrstr)
if e != nil {
t.Logf("failed to exchange: %s", e.Error())
t.Fail()
}
if r != nil && r.Rcode != RcodeSuccess {
t.Log("failed to get an valid answer")
t.Fail()
t.Logf("%v\n", r)
}
}
func TestClientEDNS0(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalUDPServer("127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to run test server: %s", err)
}
defer s.Shutdown()
m := new(Msg)
m.SetQuestion("miek.nl.", TypeDNSKEY)
m.SetEdns0(2048, true)
c := new(Client)
r, _, e := c.Exchange(m, addrstr)
if e != nil {
t.Logf("failed to exchange: %s", e.Error())
t.Fail()
}
if r != nil && r.Rcode != RcodeSuccess {
t.Log("failed to get an valid answer")
t.Fail()
t.Logf("%v\n", r)
}
}
func TestSingleSingleInflight(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalUDPServer("127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to run test server: %s", err)
}
defer s.Shutdown()
m := new(Msg)
m.SetQuestion("miek.nl.", TypeDNSKEY)
c := new(Client)
c.SingleInflight = true
nr := 10
ch := make(chan time.Duration)
for i := 0; i < nr; i++ {
go func() {
_, rtt, _ := c.Exchange(m, addrstr)
ch <- rtt
}()
}
i := 0
var first time.Duration
// With inflight *all* rtt are identical, and by doing actual lookups
// the changes that this is a coincidence is small.
Loop:
for {
select {
case rtt := <-ch:
if i == 0 {
first = rtt
} else {
if first != rtt {
t.Log("all rtts should be equal")
t.Fail()
}
}
i++
if i == 10 {
break Loop
}
}
}
}
/*
func TestClientTsigAXFR(t *testing.T) {
m := new(Msg)
m.SetAxfr("example.nl.")
m.SetTsig("axfr.", HmacMD5, 300, time.Now().Unix())
tr := new(Transfer)
tr.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
if a, err := tr.In(m, "176.58.119.54:53"); err != nil {
t.Log("failed to setup axfr: " + err.Error())
t.Fatal()
} else {
for ex := range a {
if ex.Error != nil {
t.Logf("error %s\n", ex.Error.Error())
t.Fail()
break
}
for _, rr := range ex.RR {
t.Logf("%s\n", rr.String())
}
}
}
}
func TestClientAXFRMultipleEnvelopes(t *testing.T) {
m := new(Msg)
m.SetAxfr("nlnetlabs.nl.")
tr := new(Transfer)
if a, err := tr.In(m, "213.154.224.1:53"); err != nil {
t.Log("Failed to setup axfr" + err.Error())
t.Fail()
return
} else {
for ex := range a {
if ex.Error != nil {
t.Logf("Error %s\n", ex.Error.Error())
t.Fail()
break
}
}
}
}
*/
// ExapleUpdateLeaseTSIG shows how to update a lease signed with TSIG.
func ExampleUpdateLeaseTSIG(t *testing.T) {
m := new(Msg)
m.SetUpdate("t.local.ip6.io.")
rr, _ := NewRR("t.local.ip6.io. 30 A 127.0.0.1")
rrs := make([]RR, 1)
rrs[0] = rr
m.Insert(rrs)
lease_rr := new(OPT)
lease_rr.Hdr.Name = "."
lease_rr.Hdr.Rrtype = TypeOPT
e := new(EDNS0_UL)
e.Code = EDNS0UL
e.Lease = 120
lease_rr.Option = append(lease_rr.Option, e)
m.Extra = append(m.Extra, lease_rr)
c := new(Client)
m.SetTsig("polvi.", HmacMD5, 300, time.Now().Unix())
c.TsigSecret = map[string]string{"polvi.": "pRZgBrBvI4NAHZYhxmhs/Q=="}
_, _, err := c.Exchange(m, "127.0.0.1:53")
if err != nil {
t.Log(err.Error())
t.Fail()
}
}

View file

@ -0,0 +1,94 @@
package dns
import (
"bufio"
"os"
"strconv"
"strings"
)
// Wraps the contents of the /etc/resolv.conf.
type ClientConfig struct {
Servers []string // servers to use
Search []string // suffixes to append to local name
Port string // what port to use
Ndots int // number of dots in name to trigger absolute lookup
Timeout int // seconds before giving up on packet
Attempts int // lost packets before giving up on server, not used in the package dns
}
// ClientConfigFromFile parses a resolv.conf(5) like file and returns
// a *ClientConfig.
func ClientConfigFromFile(resolvconf string) (*ClientConfig, error) {
file, err := os.Open(resolvconf)
if err != nil {
return nil, err
}
defer file.Close()
c := new(ClientConfig)
b := bufio.NewReader(file)
c.Servers = make([]string, 0)
c.Search = make([]string, 0)
c.Port = "53"
c.Ndots = 1
c.Timeout = 5
c.Attempts = 2
for line, ok := b.ReadString('\n'); ok == nil; line, ok = b.ReadString('\n') {
f := strings.Fields(line)
if len(f) < 1 {
continue
}
switch f[0] {
case "nameserver": // add one name server
if len(f) > 1 {
// One more check: make sure server name is
// just an IP address. Otherwise we need DNS
// to look it up.
name := f[1]
c.Servers = append(c.Servers, name)
}
case "domain": // set search path to just this domain
if len(f) > 1 {
c.Search = make([]string, 1)
c.Search[0] = f[1]
} else {
c.Search = make([]string, 0)
}
case "search": // set search path to given servers
c.Search = make([]string, len(f)-1)
for i := 0; i < len(c.Search); i++ {
c.Search[i] = f[i+1]
}
case "options": // magic options
for i := 1; i < len(f); i++ {
s := f[i]
switch {
case len(s) >= 6 && s[:6] == "ndots:":
n, _ := strconv.Atoi(s[6:])
if n < 1 {
n = 1
}
c.Ndots = n
case len(s) >= 8 && s[:8] == "timeout:":
n, _ := strconv.Atoi(s[8:])
if n < 1 {
n = 1
}
c.Timeout = n
case len(s) >= 8 && s[:9] == "attempts:":
n, _ := strconv.Atoi(s[9:])
if n < 1 {
n = 1
}
c.Attempts = n
case s == "rotate":
/* not imp */
}
}
}
}
return c, nil
}

242
Godeps/_workspace/src/github.com/miekg/dns/defaults.go generated vendored Normal file
View file

@ -0,0 +1,242 @@
package dns
import (
"errors"
"net"
"strconv"
)
const hexDigit = "0123456789abcdef"
// Everything is assumed in ClassINET.
// SetReply creates a reply message from a request message.
func (dns *Msg) SetReply(request *Msg) *Msg {
dns.Id = request.Id
dns.RecursionDesired = request.RecursionDesired // Copy rd bit
dns.Response = true
dns.Opcode = OpcodeQuery
dns.Rcode = RcodeSuccess
if len(request.Question) > 0 {
dns.Question = make([]Question, 1)
dns.Question[0] = request.Question[0]
}
return dns
}
// SetQuestion creates a question message.
func (dns *Msg) SetQuestion(z string, t uint16) *Msg {
dns.Id = Id()
dns.RecursionDesired = true
dns.Question = make([]Question, 1)
dns.Question[0] = Question{z, t, ClassINET}
return dns
}
// SetNotify creates a notify message.
func (dns *Msg) SetNotify(z string) *Msg {
dns.Opcode = OpcodeNotify
dns.Authoritative = true
dns.Id = Id()
dns.Question = make([]Question, 1)
dns.Question[0] = Question{z, TypeSOA, ClassINET}
return dns
}
// SetRcode creates an error message suitable for the request.
func (dns *Msg) SetRcode(request *Msg, rcode int) *Msg {
dns.SetReply(request)
dns.Rcode = rcode
return dns
}
// SetRcodeFormatError creates a message with FormError set.
func (dns *Msg) SetRcodeFormatError(request *Msg) *Msg {
dns.Rcode = RcodeFormatError
dns.Opcode = OpcodeQuery
dns.Response = true
dns.Authoritative = false
dns.Id = request.Id
return dns
}
// SetUpdate makes the message a dynamic update message. It
// sets the ZONE section to: z, TypeSOA, ClassINET.
func (dns *Msg) SetUpdate(z string) *Msg {
dns.Id = Id()
dns.Response = false
dns.Opcode = OpcodeUpdate
dns.Compress = false // BIND9 cannot handle compression
dns.Question = make([]Question, 1)
dns.Question[0] = Question{z, TypeSOA, ClassINET}
return dns
}
// SetIxfr creates message for requesting an IXFR.
func (dns *Msg) SetIxfr(z string, serial uint32) *Msg {
dns.Id = Id()
dns.Question = make([]Question, 1)
dns.Ns = make([]RR, 1)
s := new(SOA)
s.Hdr = RR_Header{z, TypeSOA, ClassINET, defaultTtl, 0}
s.Serial = serial
dns.Question[0] = Question{z, TypeIXFR, ClassINET}
dns.Ns[0] = s
return dns
}
// SetAxfr creates message for requesting an AXFR.
func (dns *Msg) SetAxfr(z string) *Msg {
dns.Id = Id()
dns.Question = make([]Question, 1)
dns.Question[0] = Question{z, TypeAXFR, ClassINET}
return dns
}
// SetTsig appends a TSIG RR to the message.
// This is only a skeleton TSIG RR that is added as the last RR in the
// additional section. The Tsig is calculated when the message is being send.
func (dns *Msg) SetTsig(z, algo string, fudge, timesigned int64) *Msg {
t := new(TSIG)
t.Hdr = RR_Header{z, TypeTSIG, ClassANY, 0, 0}
t.Algorithm = algo
t.Fudge = 300
t.TimeSigned = uint64(timesigned)
t.OrigId = dns.Id
dns.Extra = append(dns.Extra, t)
return dns
}
// SetEdns0 appends a EDNS0 OPT RR to the message.
// TSIG should always the last RR in a message.
func (dns *Msg) SetEdns0(udpsize uint16, do bool) *Msg {
e := new(OPT)
e.Hdr.Name = "."
e.Hdr.Rrtype = TypeOPT
e.SetUDPSize(udpsize)
if do {
e.SetDo()
}
dns.Extra = append(dns.Extra, e)
return dns
}
// IsTsig checks if the message has a TSIG record as the last record
// in the additional section. It returns the TSIG record found or nil.
func (dns *Msg) IsTsig() *TSIG {
if len(dns.Extra) > 0 {
if dns.Extra[len(dns.Extra)-1].Header().Rrtype == TypeTSIG {
return dns.Extra[len(dns.Extra)-1].(*TSIG)
}
}
return nil
}
// IsEdns0 checks if the message has a EDNS0 (OPT) record, any EDNS0
// record in the additional section will do. It returns the OPT record
// found or nil.
func (dns *Msg) IsEdns0() *OPT {
for _, r := range dns.Extra {
if r.Header().Rrtype == TypeOPT {
return r.(*OPT)
}
}
return nil
}
// IsDomainName checks if s is a valid domainname, it returns
// the number of labels and true, when a domain name is valid.
// Note that non fully qualified domain name is considered valid, in this case the
// last label is counted in the number of labels.
// When false is returned the number of labels is not defined.
func IsDomainName(s string) (labels int, ok bool) {
_, labels, err := packDomainName(s, nil, 0, nil, false)
return labels, err == nil
}
// IsSubDomain checks if child is indeed a child of the parent. Both child and
// parent are *not* downcased before doing the comparison.
func IsSubDomain(parent, child string) bool {
// Entire child is contained in parent
return CompareDomainName(parent, child) == CountLabel(parent)
}
// IsMsg sanity checks buf and returns an error if it isn't a valid DNS packet.
// The checking is performed on the binary payload.
func IsMsg(buf []byte) error {
// Header
if len(buf) < 12 {
return errors.New("dns: bad message header")
}
// Header: Opcode
return nil
}
// IsFqdn checks if a domain name is fully qualified.
func IsFqdn(s string) bool {
l := len(s)
if l == 0 {
return false
}
return s[l-1] == '.'
}
// Fqdns return the fully qualified domain name from s.
// If s is already fully qualified, it behaves as the identity function.
func Fqdn(s string) string {
if IsFqdn(s) {
return s
}
return s + "."
}
// Copied from the official Go code.
// ReverseAddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP
// address suitable for reverse DNS (PTR) record lookups or an error if it fails
// to parse the IP address.
func ReverseAddr(addr string) (arpa string, err error) {
ip := net.ParseIP(addr)
if ip == nil {
return "", &Error{err: "unrecognized address: " + addr}
}
if ip.To4() != nil {
return strconv.Itoa(int(ip[15])) + "." + strconv.Itoa(int(ip[14])) + "." + strconv.Itoa(int(ip[13])) + "." +
strconv.Itoa(int(ip[12])) + ".in-addr.arpa.", nil
}
// Must be IPv6
buf := make([]byte, 0, len(ip)*4+len("ip6.arpa."))
// Add it, in reverse, to the buffer
for i := len(ip) - 1; i >= 0; i-- {
v := ip[i]
buf = append(buf, hexDigit[v&0xF])
buf = append(buf, '.')
buf = append(buf, hexDigit[v>>4])
buf = append(buf, '.')
}
// Append "ip6.arpa." and return (buf already has the final .)
buf = append(buf, "ip6.arpa."...)
return string(buf), nil
}
// String returns the string representation for the type t.
func (t Type) String() string {
if t1, ok := TypeToString[uint16(t)]; ok {
return t1
}
return "TYPE" + strconv.Itoa(int(t))
}
// String returns the string representation for the class c.
func (c Class) String() string {
if c1, ok := ClassToString[uint16(c)]; ok {
return c1
}
return "CLASS" + strconv.Itoa(int(c))
}
// String returns the string representation for the name n.
func (n Name) String() string {
return sprintName(string(n))
}

193
Godeps/_workspace/src/github.com/miekg/dns/dns.go generated vendored Normal file
View file

@ -0,0 +1,193 @@
// Package dns implements a full featured interface to the Domain Name System.
// Server- and client-side programming is supported.
// The package allows complete control over what is send out to the DNS. The package
// API follows the less-is-more principle, by presenting a small, clean interface.
//
// The package dns supports (asynchronous) querying/replying, incoming/outgoing zone transfers,
// TSIG, EDNS0, dynamic updates, notifies and DNSSEC validation/signing.
// Note that domain names MUST be fully qualified, before sending them, unqualified
// names in a message will result in a packing failure.
//
// Resource records are native types. They are not stored in wire format.
// Basic usage pattern for creating a new resource record:
//
// r := new(dns.MX)
// r.Hdr = dns.RR_Header{Name: "miek.nl.", Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: 3600}
// r.Preference = 10
// r.Mx = "mx.miek.nl."
//
// Or directly from a string:
//
// mx, err := dns.NewRR("miek.nl. 3600 IN MX 10 mx.miek.nl.")
//
// Or when the default TTL (3600) and class (IN) suit you:
//
// mx, err := dns.NewRR("miek.nl. MX 10 mx.miek.nl.")
//
// Or even:
//
// mx, err := dns.NewRR("$ORIGIN nl.\nmiek 1H IN MX 10 mx.miek")
//
// In the DNS messages are exchanged, these messages contain resource
// records (sets). Use pattern for creating a message:
//
// m := new(dns.Msg)
// m.SetQuestion("miek.nl.", dns.TypeMX)
//
// Or when not certain if the domain name is fully qualified:
//
// m.SetQuestion(dns.Fqdn("miek.nl"), dns.TypeMX)
//
// The message m is now a message with the question section set to ask
// the MX records for the miek.nl. zone.
//
// The following is slightly more verbose, but more flexible:
//
// m1 := new(dns.Msg)
// m1.Id = dns.Id()
// m1.RecursionDesired = true
// m1.Question = make([]dns.Question, 1)
// m1.Question[0] = dns.Question{"miek.nl.", dns.TypeMX, dns.ClassINET}
//
// After creating a message it can be send.
// Basic use pattern for synchronous querying the DNS at a
// server configured on 127.0.0.1 and port 53:
//
// c := new(dns.Client)
// in, rtt, err := c.Exchange(m1, "127.0.0.1:53")
//
// Suppressing
// multiple outstanding queries (with the same question, type and class) is as easy as setting:
//
// c.SingleInflight = true
//
// If these "advanced" features are not needed, a simple UDP query can be send,
// with:
//
// in, err := dns.Exchange(m1, "127.0.0.1:53")
//
// When this functions returns you will get dns message. A dns message consists
// out of four sections.
// The question section: in.Question, the answer section: in.Answer,
// the authority section: in.Ns and the additional section: in.Extra.
//
// Each of these sections (except the Question section) contain a []RR. Basic
// use pattern for accessing the rdata of a TXT RR as the first RR in
// the Answer section:
//
// if t, ok := in.Answer[0].(*dns.TXT); ok {
// // do something with t.Txt
// }
//
// Domain Name and TXT Character String Representations
//
// Both domain names and TXT character strings are converted to presentation
// form both when unpacked and when converted to strings.
//
// For TXT character strings, tabs, carriage returns and line feeds will be
// converted to \t, \r and \n respectively. Back slashes and quotations marks
// will be escaped. Bytes below 32 and above 127 will be converted to \DDD
// form.
//
// For domain names, in addition to the above rules brackets, periods,
// spaces, semicolons and the at symbol are escaped.
package dns
import (
"strconv"
)
const (
year68 = 1 << 31 // For RFC1982 (Serial Arithmetic) calculations in 32 bits.
DefaultMsgSize = 4096 // Standard default for larger than 512 bytes.
MinMsgSize = 512 // Minimal size of a DNS packet.
MaxMsgSize = 65536 // Largest possible DNS packet.
defaultTtl = 3600 // Default TTL.
)
// Error represents a DNS error
type Error struct{ err string }
func (e *Error) Error() string {
if e == nil {
return "dns: <nil>"
}
return "dns: " + e.err
}
// An RR represents a resource record.
type RR interface {
// Header returns the header of an resource record. The header contains
// everything up to the rdata.
Header() *RR_Header
// String returns the text representation of the resource record.
String() string
// copy returns a copy of the RR
copy() RR
// len returns the length (in octects) of the uncompressed RR in wire format.
len() int
}
// DNS resource records.
// There are many types of RRs,
// but they all share the same header.
type RR_Header struct {
Name string `dns:"cdomain-name"`
Rrtype uint16
Class uint16
Ttl uint32
Rdlength uint16 // length of data after header
}
func (h *RR_Header) Header() *RR_Header { return h }
// Just to imlement the RR interface
func (h *RR_Header) copy() RR { return nil }
func (h *RR_Header) copyHeader() *RR_Header {
r := new(RR_Header)
r.Name = h.Name
r.Rrtype = h.Rrtype
r.Class = h.Class
r.Ttl = h.Ttl
r.Rdlength = h.Rdlength
return r
}
func (h *RR_Header) String() string {
var s string
if h.Rrtype == TypeOPT {
s = ";"
// and maybe other things
}
s += sprintName(h.Name) + "\t"
s += strconv.FormatInt(int64(h.Ttl), 10) + "\t"
s += Class(h.Class).String() + "\t"
s += Type(h.Rrtype).String() + "\t"
return s
}
func (h *RR_Header) len() int {
l := len(h.Name) + 1
l += 10 // rrtype(2) + class(2) + ttl(4) + rdlength(2)
return l
}
// ToRFC3597 converts a known RR to the unknown RR representation
// from RFC 3597.
func (rr *RFC3597) ToRFC3597(r RR) error {
buf := make([]byte, r.len()*2)
off, err := PackStruct(r, buf, 0)
if err != nil {
return err
}
buf = buf[:off]
rawSetRdlength(buf, 0, off)
_, err = UnpackStruct(rr, buf, 0)
if err != nil {
return err
}
return nil
}

501
Godeps/_workspace/src/github.com/miekg/dns/dns_test.go generated vendored Normal file
View file

@ -0,0 +1,501 @@
package dns
import (
"encoding/hex"
"net"
"testing"
)
func TestPackUnpack(t *testing.T) {
out := new(Msg)
out.Answer = make([]RR, 1)
key := new(DNSKEY)
key = &DNSKEY{Flags: 257, Protocol: 3, Algorithm: RSASHA1}
key.Hdr = RR_Header{Name: "miek.nl.", Rrtype: TypeDNSKEY, Class: ClassINET, Ttl: 3600}
key.PublicKey = "AwEAAaHIwpx3w4VHKi6i1LHnTaWeHCL154Jug0Rtc9ji5qwPXpBo6A5sRv7cSsPQKPIwxLpyCrbJ4mr2L0EPOdvP6z6YfljK2ZmTbogU9aSU2fiq/4wjxbdkLyoDVgtO+JsxNN4bjr4WcWhsmk1Hg93FV9ZpkWb0Tbad8DFqNDzr//kZ"
out.Answer[0] = key
msg, err := out.Pack()
if err != nil {
t.Log("failed to pack msg with DNSKEY")
t.Fail()
}
in := new(Msg)
if in.Unpack(msg) != nil {
t.Log("failed to unpack msg with DNSKEY")
t.Fail()
}
sig := new(RRSIG)
sig = &RRSIG{TypeCovered: TypeDNSKEY, Algorithm: RSASHA1, Labels: 2,
OrigTtl: 3600, Expiration: 4000, Inception: 4000, KeyTag: 34641, SignerName: "miek.nl.",
Signature: "AwEAAaHIwpx3w4VHKi6i1LHnTaWeHCL154Jug0Rtc9ji5qwPXpBo6A5sRv7cSsPQKPIwxLpyCrbJ4mr2L0EPOdvP6z6YfljK2ZmTbogU9aSU2fiq/4wjxbdkLyoDVgtO+JsxNN4bjr4WcWhsmk1Hg93FV9ZpkWb0Tbad8DFqNDzr//kZ"}
sig.Hdr = RR_Header{Name: "miek.nl.", Rrtype: TypeRRSIG, Class: ClassINET, Ttl: 3600}
out.Answer[0] = sig
msg, err = out.Pack()
if err != nil {
t.Log("failed to pack msg with RRSIG")
t.Fail()
}
if in.Unpack(msg) != nil {
t.Log("failed to unpack msg with RRSIG")
t.Fail()
}
}
func TestPackUnpack2(t *testing.T) {
m := new(Msg)
m.Extra = make([]RR, 1)
m.Answer = make([]RR, 1)
dom := "miek.nl."
rr := new(A)
rr.Hdr = RR_Header{Name: dom, Rrtype: TypeA, Class: ClassINET, Ttl: 0}
rr.A = net.IPv4(127, 0, 0, 1)
x := new(TXT)
x.Hdr = RR_Header{Name: dom, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}
x.Txt = []string{"heelalaollo"}
m.Extra[0] = x
m.Answer[0] = rr
_, err := m.Pack()
if err != nil {
t.Log("Packing failed: " + err.Error())
t.Fail()
return
}
}
func TestPackUnpack3(t *testing.T) {
m := new(Msg)
m.Extra = make([]RR, 2)
m.Answer = make([]RR, 1)
dom := "miek.nl."
rr := new(A)
rr.Hdr = RR_Header{Name: dom, Rrtype: TypeA, Class: ClassINET, Ttl: 0}
rr.A = net.IPv4(127, 0, 0, 1)
x1 := new(TXT)
x1.Hdr = RR_Header{Name: dom, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}
x1.Txt = []string{}
x2 := new(TXT)
x2.Hdr = RR_Header{Name: dom, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}
x2.Txt = []string{"heelalaollo"}
m.Extra[0] = x1
m.Extra[1] = x2
m.Answer[0] = rr
b, err := m.Pack()
if err != nil {
t.Log("packing failed: " + err.Error())
t.Fail()
return
}
var unpackMsg Msg
err = unpackMsg.Unpack(b)
if err != nil {
t.Log("unpacking failed")
t.Fail()
return
}
}
func TestBailiwick(t *testing.T) {
yes := map[string]string{
"miek.nl": "ns.miek.nl",
".": "miek.nl",
}
for parent, child := range yes {
if !IsSubDomain(parent, child) {
t.Logf("%s should be child of %s\n", child, parent)
t.Logf("comparelabels %d", CompareDomainName(parent, child))
t.Logf("lenlabels %d %d", CountLabel(parent), CountLabel(child))
t.Fail()
}
}
no := map[string]string{
"www.miek.nl": "ns.miek.nl",
"m\\.iek.nl": "ns.miek.nl",
"w\\.iek.nl": "w.iek.nl",
"p\\\\.iek.nl": "ns.p.iek.nl", // p\\.iek.nl , literal \ in domain name
"miek.nl": ".",
}
for parent, child := range no {
if IsSubDomain(parent, child) {
t.Logf("%s should not be child of %s\n", child, parent)
t.Logf("comparelabels %d", CompareDomainName(parent, child))
t.Logf("lenlabels %d %d", CountLabel(parent), CountLabel(child))
t.Fail()
}
}
}
func TestPack(t *testing.T) {
rr := []string{"US. 86400 IN NSEC 0-.us. NS SOA RRSIG NSEC DNSKEY TYPE65534"}
m := new(Msg)
var err error
m.Answer = make([]RR, 1)
for _, r := range rr {
m.Answer[0], err = NewRR(r)
if err != nil {
t.Logf("failed to create RR: %s\n", err.Error())
t.Fail()
continue
}
if _, err := m.Pack(); err != nil {
t.Logf("packing failed: %s\n", err.Error())
t.Fail()
}
}
x := new(Msg)
ns, _ := NewRR("pool.ntp.org. 390 IN NS a.ntpns.org")
ns.(*NS).Ns = "a.ntpns.org"
x.Ns = append(m.Ns, ns)
x.Ns = append(m.Ns, ns)
x.Ns = append(m.Ns, ns)
// This crashes due to the fact the a.ntpns.org isn't a FQDN
// How to recover() from a remove panic()?
if _, err := x.Pack(); err == nil {
t.Log("packing should fail")
t.Fail()
}
x.Answer = make([]RR, 1)
x.Answer[0], err = NewRR(rr[0])
if _, err := x.Pack(); err == nil {
t.Log("packing should fail")
t.Fail()
}
x.Question = make([]Question, 1)
x.Question[0] = Question{";sd#edddds鍛↙赏‘℅∥↙xzztsestxssweewwsssstx@s@Z嵌e@cn.pool.ntp.org.", TypeA, ClassINET}
if _, err := x.Pack(); err == nil {
t.Log("packing should fail")
t.Fail()
}
}
func TestPackNAPTR(t *testing.T) {
for _, n := range []string{
`apple.com. IN NAPTR 100 50 "se" "SIP+D2U" "" _sip._udp.apple.com.`,
`apple.com. IN NAPTR 90 50 "se" "SIP+D2T" "" _sip._tcp.apple.com.`,
`apple.com. IN NAPTR 50 50 "se" "SIPS+D2T" "" _sips._tcp.apple.com.`,
} {
rr, _ := NewRR(n)
msg := make([]byte, rr.len())
if off, err := PackRR(rr, msg, 0, nil, false); err != nil {
t.Logf("packing failed: %s", err.Error())
t.Logf("length %d, need more than %d\n", rr.len(), off)
t.Fail()
} else {
t.Logf("buf size needed: %d\n", off)
}
}
}
func TestCompressLength(t *testing.T) {
m := new(Msg)
m.SetQuestion("miek.nl", TypeMX)
ul := m.Len()
m.Compress = true
if ul != m.Len() {
t.Fatalf("should be equal")
}
}
// Does the predicted length match final packed length?
func TestMsgCompressLength(t *testing.T) {
makeMsg := func(question string, ans, ns, e []RR) *Msg {
msg := new(Msg)
msg.SetQuestion(Fqdn(question), TypeANY)
msg.Answer = append(msg.Answer, ans...)
msg.Ns = append(msg.Ns, ns...)
msg.Extra = append(msg.Extra, e...)
msg.Compress = true
return msg
}
name1 := "12345678901234567890123456789012345.12345678.123."
rrA, _ := NewRR(name1 + " 3600 IN A 192.0.2.1")
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
tests := []*Msg{
makeMsg(name1, []RR{rrA}, nil, nil),
makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)}
for _, msg := range tests {
predicted := msg.Len()
buf, err := msg.Pack()
if err != nil {
t.Error(err)
t.Fail()
}
if predicted < len(buf) {
t.Errorf("predicted compressed length is wrong: predicted %s (len=%d) %d, actual %d\n",
msg.Question[0].Name, len(msg.Answer), predicted, len(buf))
t.Fail()
}
}
}
func TestMsgLength(t *testing.T) {
makeMsg := func(question string, ans, ns, e []RR) *Msg {
msg := new(Msg)
msg.SetQuestion(Fqdn(question), TypeANY)
msg.Answer = append(msg.Answer, ans...)
msg.Ns = append(msg.Ns, ns...)
msg.Extra = append(msg.Extra, e...)
return msg
}
name1 := "12345678901234567890123456789012345.12345678.123."
rrA, _ := NewRR(name1 + " 3600 IN A 192.0.2.1")
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
tests := []*Msg{
makeMsg(name1, []RR{rrA}, nil, nil),
makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)}
for _, msg := range tests {
predicted := msg.Len()
buf, err := msg.Pack()
if err != nil {
t.Error(err)
t.Fail()
}
if predicted < len(buf) {
t.Errorf("predicted length is wrong: predicted %s (len=%d), actual %d\n",
msg.Question[0].Name, predicted, len(buf))
t.Fail()
}
}
}
func TestMsgLength2(t *testing.T) {
// Serialized replies
var testMessages = []string{
// google.com. IN A?
"064e81800001000b0004000506676f6f676c6503636f6d0000010001c00c00010001000000050004adc22986c00c00010001000000050004adc22987c00c00010001000000050004adc22988c00c00010001000000050004adc22989c00c00010001000000050004adc2298ec00c00010001000000050004adc22980c00c00010001000000050004adc22981c00c00010001000000050004adc22982c00c00010001000000050004adc22983c00c00010001000000050004adc22984c00c00010001000000050004adc22985c00c00020001000000050006036e7331c00cc00c00020001000000050006036e7332c00cc00c00020001000000050006036e7333c00cc00c00020001000000050006036e7334c00cc0d800010001000000050004d8ef200ac0ea00010001000000050004d8ef220ac0fc00010001000000050004d8ef240ac10e00010001000000050004d8ef260a0000290500000000050000",
// amazon.com. IN A? (reply has no EDNS0 record)
// TODO(miek): this one is off-by-one, need to find out why
//"6de1818000010004000a000806616d617a6f6e03636f6d0000010001c00c000100010000000500044815c2d4c00c000100010000000500044815d7e8c00c00010001000000050004b02062a6c00c00010001000000050004cdfbf236c00c000200010000000500140570646e733408756c747261646e73036f726700c00c000200010000000500150570646e733508756c747261646e7304696e666f00c00c000200010000000500160570646e733608756c747261646e7302636f02756b00c00c00020001000000050014036e7331037033310664796e656374036e657400c00c00020001000000050006036e7332c0cfc00c00020001000000050006036e7333c0cfc00c00020001000000050006036e7334c0cfc00c000200010000000500110570646e733108756c747261646e73c0dac00c000200010000000500080570646e7332c127c00c000200010000000500080570646e7333c06ec0cb00010001000000050004d04e461fc0eb00010001000000050004cc0dfa1fc0fd00010001000000050004d04e471fc10f00010001000000050004cc0dfb1fc12100010001000000050004cc4a6c01c121001c000100000005001020010502f3ff00000000000000000001c13e00010001000000050004cc4a6d01c13e001c0001000000050010261000a1101400000000000000000001",
// yahoo.com. IN A?
"fc2d81800001000300070008057961686f6f03636f6d0000010001c00c00010001000000050004628afd6dc00c00010001000000050004628bb718c00c00010001000000050004cebe242dc00c00020001000000050006036e7336c00cc00c00020001000000050006036e7338c00cc00c00020001000000050006036e7331c00cc00c00020001000000050006036e7332c00cc00c00020001000000050006036e7333c00cc00c00020001000000050006036e7334c00cc00c00020001000000050006036e7335c00cc07b0001000100000005000444b48310c08d00010001000000050004448eff10c09f00010001000000050004cb54dd35c0b100010001000000050004628a0b9dc0c30001000100000005000477a0f77cc05700010001000000050004ca2bdfaac06900010001000000050004caa568160000290500000000050000",
// microsoft.com. IN A?
"f4368180000100020005000b096d6963726f736f667403636f6d0000010001c00c0001000100000005000440040b25c00c0001000100000005000441373ac9c00c0002000100000005000e036e7331046d736674036e657400c00c00020001000000050006036e7332c04fc00c00020001000000050006036e7333c04fc00c00020001000000050006036e7334c04fc00c00020001000000050006036e7335c04fc04b000100010000000500044137253ec04b001c00010000000500102a010111200500000000000000010001c0650001000100000005000440043badc065001c00010000000500102a010111200600060000000000010001c07700010001000000050004d5c7b435c077001c00010000000500102a010111202000000000000000010001c08900010001000000050004cf2e4bfec089001c00010000000500102404f800200300000000000000010001c09b000100010000000500044137e28cc09b001c00010000000500102a010111200f000100000000000100010000290500000000050000",
// google.com. IN MX?
"724b8180000100050004000b06676f6f676c6503636f6d00000f0001c00c000f000100000005000c000a056173706d78016cc00cc00c000f0001000000050009001404616c7431c02ac00c000f0001000000050009001e04616c7432c02ac00c000f0001000000050009002804616c7433c02ac00c000f0001000000050009003204616c7434c02ac00c00020001000000050006036e7332c00cc00c00020001000000050006036e7333c00cc00c00020001000000050006036e7334c00cc00c00020001000000050006036e7331c00cc02a00010001000000050004adc2421bc02a001c00010000000500102a00145040080c01000000000000001bc04200010001000000050004adc2461bc05700010001000000050004adc2451bc06c000100010000000500044a7d8f1bc081000100010000000500044a7d191bc0ca00010001000000050004d8ef200ac09400010001000000050004d8ef220ac0a600010001000000050004d8ef240ac0b800010001000000050004d8ef260a0000290500000000050000",
// reddit.com. IN A?
"12b98180000100080000000c0672656464697403636f6d0000020001c00c0002000100000005000f046175733204616b616d036e657400c00c000200010000000500070475736534c02dc00c000200010000000500070475737733c02dc00c000200010000000500070475737735c02dc00c00020001000000050008056173696131c02dc00c00020001000000050008056173696139c02dc00c00020001000000050008056e73312d31c02dc00c0002000100000005000a076e73312d313935c02dc02800010001000000050004c30a242ec04300010001000000050004451f1d39c05600010001000000050004451f3bc7c0690001000100000005000460073240c07c000100010000000500046007fb81c090000100010000000500047c283484c090001c00010000000500102a0226f0006700000000000000000064c0a400010001000000050004c16c5b01c0a4001c000100000005001026001401000200000000000000000001c0b800010001000000050004c16c5bc3c0b8001c0001000000050010260014010002000000000000000000c30000290500000000050000",
}
for i, hexData := range testMessages {
// we won't fail the decoding of the hex
input, _ := hex.DecodeString(hexData)
m := new(Msg)
m.Unpack(input)
//println(m.String())
m.Compress = true
lenComp := m.Len()
b, _ := m.Pack()
pacComp := len(b)
m.Compress = false
lenUnComp := m.Len()
b, _ = m.Pack()
pacUnComp := len(b)
if pacComp+1 != lenComp {
t.Errorf("msg.Len(compressed)=%d actual=%d for test %d", lenComp, pacComp, i)
}
if pacUnComp+1 != lenUnComp {
t.Errorf("msg.Len(uncompressed)=%d actual=%d for test %d", lenUnComp, pacUnComp, i)
}
}
}
func TestMsgLengthCompressionMalformed(t *testing.T) {
// SOA with empty hostmaster, which is illegal
soa := &SOA{Hdr: RR_Header{Name: ".", Rrtype: TypeSOA, Class: ClassINET, Ttl: 12345},
Ns: ".",
Mbox: "",
Serial: 0,
Refresh: 28800,
Retry: 7200,
Expire: 604800,
Minttl: 60}
m := new(Msg)
m.Compress = true
m.Ns = []RR{soa}
m.Len() // Should not crash.
}
func BenchmarkMsgLength(b *testing.B) {
b.StopTimer()
makeMsg := func(question string, ans, ns, e []RR) *Msg {
msg := new(Msg)
msg.SetQuestion(Fqdn(question), TypeANY)
msg.Answer = append(msg.Answer, ans...)
msg.Ns = append(msg.Ns, ns...)
msg.Extra = append(msg.Extra, e...)
msg.Compress = true
return msg
}
name1 := "12345678901234567890123456789012345.12345678.123."
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
b.StartTimer()
for i := 0; i < b.N; i++ {
msg.Len()
}
}
func BenchmarkMsgLengthPack(b *testing.B) {
makeMsg := func(question string, ans, ns, e []RR) *Msg {
msg := new(Msg)
msg.SetQuestion(Fqdn(question), TypeANY)
msg.Answer = append(msg.Answer, ans...)
msg.Ns = append(msg.Ns, ns...)
msg.Extra = append(msg.Extra, e...)
msg.Compress = true
return msg
}
name1 := "12345678901234567890123456789012345.12345678.123."
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = msg.Pack()
}
}
func BenchmarkMsgPackBuffer(b *testing.B) {
makeMsg := func(question string, ans, ns, e []RR) *Msg {
msg := new(Msg)
msg.SetQuestion(Fqdn(question), TypeANY)
msg.Answer = append(msg.Answer, ans...)
msg.Ns = append(msg.Ns, ns...)
msg.Extra = append(msg.Extra, e...)
msg.Compress = true
return msg
}
name1 := "12345678901234567890123456789012345.12345678.123."
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
buf := make([]byte, 512)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = msg.PackBuffer(buf)
}
}
func BenchmarkMsgUnpack(b *testing.B) {
makeMsg := func(question string, ans, ns, e []RR) *Msg {
msg := new(Msg)
msg.SetQuestion(Fqdn(question), TypeANY)
msg.Answer = append(msg.Answer, ans...)
msg.Ns = append(msg.Ns, ns...)
msg.Extra = append(msg.Extra, e...)
msg.Compress = true
return msg
}
name1 := "12345678901234567890123456789012345.12345678.123."
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
msg_buf, _ := msg.Pack()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = msg.Unpack(msg_buf)
}
}
func BenchmarkPackDomainName(b *testing.B) {
name1 := "12345678901234567890123456789012345.12345678.123."
buf := make([]byte, len(name1)+1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = PackDomainName(name1, buf, 0, nil, false)
}
}
func BenchmarkUnpackDomainName(b *testing.B) {
name1 := "12345678901234567890123456789012345.12345678.123."
buf := make([]byte, len(name1)+1)
_, _ = PackDomainName(name1, buf, 0, nil, false)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, _ = UnpackDomainName(buf, 0)
}
}
func TestToRFC3597(t *testing.T) {
a, _ := NewRR("miek.nl. IN A 10.0.1.1")
x := new(RFC3597)
x.ToRFC3597(a)
if x.String() != `miek.nl. 3600 IN A \# 4 0a000101` {
t.Fail()
}
}
func TestNoRdataPack(t *testing.T) {
data := make([]byte, 1024)
for typ, fn := range typeToRR {
if typ == TypeCAA {
continue // TODO(miek): known omission
}
r := fn()
*r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 3600}
_, e := PackRR(r, data, 0, nil, false)
if e != nil {
t.Logf("failed to pack RR with zero rdata: %s: %s\n", TypeToString[typ], e.Error())
t.Fail()
}
}
}
// TODO(miek): fix dns buffer too small errors this throws
func TestNoRdataUnpack(t *testing.T) {
data := make([]byte, 1024)
for typ, fn := range typeToRR {
if typ == TypeSOA || typ == TypeTSIG || typ == TypeWKS {
// SOA, TSIG will not be seen (like this) in dyn. updates?
// WKS is an bug, but...deprecated record.
continue
}
r := fn()
*r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 3600}
off, e := PackRR(r, data, 0, nil, false)
if e != nil {
// Should always works, TestNoDataPack should have catched this
continue
}
rr, _, e := UnpackRR(data[:off], 0)
if e != nil {
t.Logf("failed to unpack RR with zero rdata: %s: %s\n", TypeToString[typ], e.Error())
t.Fail()
}
t.Logf("%s\n", rr)
}
}
func TestRdataOverflow(t *testing.T) {
rr := new(RFC3597)
rr.Hdr.Name = "."
rr.Hdr.Class = ClassINET
rr.Hdr.Rrtype = 65280
rr.Rdata = hex.EncodeToString(make([]byte, 0xFFFF))
buf := make([]byte, 0xFFFF*2)
if _, err := PackRR(rr, buf, 0, nil, false); err != nil {
t.Fatalf("maximum size rrdata pack failed: %v", err)
}
rr.Rdata += "00"
if _, err := PackRR(rr, buf, 0, nil, false); err != ErrRdata {
t.Fatalf("oversize rrdata pack didn't return ErrRdata - instead: %v", err)
}
}
func TestCopy(t *testing.T) {
rr, _ := NewRR("miek.nl. 2311 IN A 127.0.0.1") // Weird TTL to avoid catching TTL
rr1 := Copy(rr)
if rr.String() != rr1.String() {
t.Fatalf("Copy() failed %s != %s", rr.String(), rr1.String())
}
}

739
Godeps/_workspace/src/github.com/miekg/dns/dnssec.go generated vendored Normal file
View file

@ -0,0 +1,739 @@
// DNSSEC
//
// DNSSEC (DNS Security Extension) adds a layer of security to the DNS. It
// uses public key cryptography to sign resource records. The
// public keys are stored in DNSKEY records and the signatures in RRSIG records.
//
// Requesting DNSSEC information for a zone is done by adding the DO (DNSSEC OK) bit
// to an request.
//
// m := new(dns.Msg)
// m.SetEdns0(4096, true)
//
// Signature generation, signature verification and key generation are all supported.
package dns
import (
"bytes"
"crypto"
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/md5"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/hex"
"hash"
"io"
"math/big"
"sort"
"strings"
"time"
)
// DNSSEC encryption algorithm codes.
const (
RSAMD5 = 1
DH = 2
DSA = 3
ECC = 4
RSASHA1 = 5
DSANSEC3SHA1 = 6
RSASHA1NSEC3SHA1 = 7
RSASHA256 = 8
RSASHA512 = 10
ECCGOST = 12
ECDSAP256SHA256 = 13
ECDSAP384SHA384 = 14
INDIRECT = 252
PRIVATEDNS = 253 // Private (experimental keys)
PRIVATEOID = 254
)
// DNSSEC hashing algorithm codes.
const (
_ = iota
SHA1 // RFC 4034
SHA256 // RFC 4509
GOST94 // RFC 5933
SHA384 // Experimental
SHA512 // Experimental
)
// DNSKEY flag values.
const (
SEP = 1
REVOKE = 1 << 7
ZONE = 1 << 8
)
// The RRSIG needs to be converted to wireformat with some of
// the rdata (the signature) missing. Use this struct to easy
// the conversion (and re-use the pack/unpack functions).
type rrsigWireFmt struct {
TypeCovered uint16
Algorithm uint8
Labels uint8
OrigTtl uint32
Expiration uint32
Inception uint32
KeyTag uint16
SignerName string `dns:"domain-name"`
/* No Signature */
}
// Used for converting DNSKEY's rdata to wirefmt.
type dnskeyWireFmt struct {
Flags uint16
Protocol uint8
Algorithm uint8
PublicKey string `dns:"base64"`
/* Nothing is left out */
}
// KeyTag calculates the keytag (or key-id) of the DNSKEY.
func (k *DNSKEY) KeyTag() uint16 {
if k == nil {
return 0
}
var keytag int
switch k.Algorithm {
case RSAMD5:
// Look at the bottom two bytes of the modules, which the last
// item in the pubkey. We could do this faster by looking directly
// at the base64 values. But I'm lazy.
modulus, _ := packBase64([]byte(k.PublicKey))
if len(modulus) > 1 {
x, _ := unpackUint16(modulus, len(modulus)-2)
keytag = int(x)
}
default:
keywire := new(dnskeyWireFmt)
keywire.Flags = k.Flags
keywire.Protocol = k.Protocol
keywire.Algorithm = k.Algorithm
keywire.PublicKey = k.PublicKey
wire := make([]byte, DefaultMsgSize)
n, err := PackStruct(keywire, wire, 0)
if err != nil {
return 0
}
wire = wire[:n]
for i, v := range wire {
if i&1 != 0 {
keytag += int(v) // must be larger than uint32
} else {
keytag += int(v) << 8
}
}
keytag += (keytag >> 16) & 0xFFFF
keytag &= 0xFFFF
}
return uint16(keytag)
}
// ToDS converts a DNSKEY record to a DS record.
func (k *DNSKEY) ToDS(h int) *DS {
if k == nil {
return nil
}
ds := new(DS)
ds.Hdr.Name = k.Hdr.Name
ds.Hdr.Class = k.Hdr.Class
ds.Hdr.Rrtype = TypeDS
ds.Hdr.Ttl = k.Hdr.Ttl
ds.Algorithm = k.Algorithm
ds.DigestType = uint8(h)
ds.KeyTag = k.KeyTag()
keywire := new(dnskeyWireFmt)
keywire.Flags = k.Flags
keywire.Protocol = k.Protocol
keywire.Algorithm = k.Algorithm
keywire.PublicKey = k.PublicKey
wire := make([]byte, DefaultMsgSize)
n, err := PackStruct(keywire, wire, 0)
if err != nil {
return nil
}
wire = wire[:n]
owner := make([]byte, 255)
off, err1 := PackDomainName(strings.ToLower(k.Hdr.Name), owner, 0, nil, false)
if err1 != nil {
return nil
}
owner = owner[:off]
// RFC4034:
// digest = digest_algorithm( DNSKEY owner name | DNSKEY RDATA);
// "|" denotes concatenation
// DNSKEY RDATA = Flags | Protocol | Algorithm | Public Key.
// digest buffer
digest := append(owner, wire...) // another copy
switch h {
case SHA1:
s := sha1.New()
io.WriteString(s, string(digest))
ds.Digest = hex.EncodeToString(s.Sum(nil))
case SHA256:
s := sha256.New()
io.WriteString(s, string(digest))
ds.Digest = hex.EncodeToString(s.Sum(nil))
case SHA384:
s := sha512.New384()
io.WriteString(s, string(digest))
ds.Digest = hex.EncodeToString(s.Sum(nil))
case GOST94:
/* I have no clue */
default:
return nil
}
return ds
}
// Sign signs an RRSet. The signature needs to be filled in with
// the values: Inception, Expiration, KeyTag, SignerName and Algorithm.
// The rest is copied from the RRset. Sign returns true when the signing went OK,
// otherwise false.
// There is no check if RRSet is a proper (RFC 2181) RRSet.
// If OrigTTL is non zero, it is used as-is, otherwise the TTL of the RRset
// is used as the OrigTTL.
func (rr *RRSIG) Sign(k PrivateKey, rrset []RR) error {
if k == nil {
return ErrPrivKey
}
// s.Inception and s.Expiration may be 0 (rollover etc.), the rest must be set
if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 {
return ErrKey
}
rr.Hdr.Rrtype = TypeRRSIG
rr.Hdr.Name = rrset[0].Header().Name
rr.Hdr.Class = rrset[0].Header().Class
if rr.OrigTtl == 0 { // If set don't override
rr.OrigTtl = rrset[0].Header().Ttl
}
rr.TypeCovered = rrset[0].Header().Rrtype
rr.Labels = uint8(CountLabel(rrset[0].Header().Name))
if strings.HasPrefix(rrset[0].Header().Name, "*") {
rr.Labels-- // wildcard, remove from label count
}
sigwire := new(rrsigWireFmt)
sigwire.TypeCovered = rr.TypeCovered
sigwire.Algorithm = rr.Algorithm
sigwire.Labels = rr.Labels
sigwire.OrigTtl = rr.OrigTtl
sigwire.Expiration = rr.Expiration
sigwire.Inception = rr.Inception
sigwire.KeyTag = rr.KeyTag
// For signing, lowercase this name
sigwire.SignerName = strings.ToLower(rr.SignerName)
// Create the desired binary blob
signdata := make([]byte, DefaultMsgSize)
n, err := PackStruct(sigwire, signdata, 0)
if err != nil {
return err
}
signdata = signdata[:n]
wire, err := rawSignatureData(rrset, rr)
if err != nil {
return err
}
signdata = append(signdata, wire...)
var sighash []byte
var h hash.Hash
var ch crypto.Hash // Only need for RSA
switch rr.Algorithm {
case DSA, DSANSEC3SHA1:
// Implicit in the ParameterSizes
case RSASHA1, RSASHA1NSEC3SHA1:
h = sha1.New()
ch = crypto.SHA1
case RSASHA256, ECDSAP256SHA256:
h = sha256.New()
ch = crypto.SHA256
case ECDSAP384SHA384:
h = sha512.New384()
case RSASHA512:
h = sha512.New()
ch = crypto.SHA512
case RSAMD5:
fallthrough // Deprecated in RFC 6725
default:
return ErrAlg
}
io.WriteString(h, string(signdata))
sighash = h.Sum(nil)
switch p := k.(type) {
case *dsa.PrivateKey:
r1, s1, err := dsa.Sign(rand.Reader, p, sighash)
if err != nil {
return err
}
signature := []byte{0x4D} // T value, here the ASCII M for Miek (not used in DNSSEC)
signature = append(signature, r1.Bytes()...)
signature = append(signature, s1.Bytes()...)
rr.Signature = unpackBase64(signature)
case *rsa.PrivateKey:
// We can use nil as rand.Reader here (says AGL)
signature, err := rsa.SignPKCS1v15(nil, p, ch, sighash)
if err != nil {
return err
}
rr.Signature = unpackBase64(signature)
case *ecdsa.PrivateKey:
r1, s1, err := ecdsa.Sign(rand.Reader, p, sighash)
if err != nil {
return err
}
signature := r1.Bytes()
signature = append(signature, s1.Bytes()...)
rr.Signature = unpackBase64(signature)
default:
// Not given the correct key
return ErrKeyAlg
}
return nil
}
// Verify validates an RRSet with the signature and key. This is only the
// cryptographic test, the signature validity period must be checked separately.
// This function copies the rdata of some RRs (to lowercase domain names) for the validation to work.
func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error {
// First the easy checks
if len(rrset) == 0 {
return ErrRRset
}
if rr.KeyTag != k.KeyTag() {
return ErrKey
}
if rr.Hdr.Class != k.Hdr.Class {
return ErrKey
}
if rr.Algorithm != k.Algorithm {
return ErrKey
}
if strings.ToLower(rr.SignerName) != strings.ToLower(k.Hdr.Name) {
return ErrKey
}
if k.Protocol != 3 {
return ErrKey
}
for _, r := range rrset {
if r.Header().Class != rr.Hdr.Class {
return ErrRRset
}
if r.Header().Rrtype != rr.TypeCovered {
return ErrRRset
}
}
// RFC 4035 5.3.2. Reconstructing the Signed Data
// Copy the sig, except the rrsig data
sigwire := new(rrsigWireFmt)
sigwire.TypeCovered = rr.TypeCovered
sigwire.Algorithm = rr.Algorithm
sigwire.Labels = rr.Labels
sigwire.OrigTtl = rr.OrigTtl
sigwire.Expiration = rr.Expiration
sigwire.Inception = rr.Inception
sigwire.KeyTag = rr.KeyTag
sigwire.SignerName = strings.ToLower(rr.SignerName)
// Create the desired binary blob
signeddata := make([]byte, DefaultMsgSize)
n, err := PackStruct(sigwire, signeddata, 0)
if err != nil {
return err
}
signeddata = signeddata[:n]
wire, err := rawSignatureData(rrset, rr)
if err != nil {
return err
}
signeddata = append(signeddata, wire...)
sigbuf := rr.sigBuf() // Get the binary signature data
if rr.Algorithm == PRIVATEDNS { // PRIVATEOID
// TODO(mg)
// remove the domain name and assume its our
}
switch rr.Algorithm {
case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512, RSAMD5:
// TODO(mg): this can be done quicker, ie. cache the pubkey data somewhere??
pubkey := k.publicKeyRSA() // Get the key
if pubkey == nil {
return ErrKey
}
// Setup the hash as defined for this alg.
var h hash.Hash
var ch crypto.Hash
switch rr.Algorithm {
case RSAMD5:
h = md5.New()
ch = crypto.MD5
case RSASHA1, RSASHA1NSEC3SHA1:
h = sha1.New()
ch = crypto.SHA1
case RSASHA256:
h = sha256.New()
ch = crypto.SHA256
case RSASHA512:
h = sha512.New()
ch = crypto.SHA512
}
io.WriteString(h, string(signeddata))
sighash := h.Sum(nil)
return rsa.VerifyPKCS1v15(pubkey, ch, sighash, sigbuf)
case ECDSAP256SHA256, ECDSAP384SHA384:
pubkey := k.publicKeyCurve()
if pubkey == nil {
return ErrKey
}
var h hash.Hash
switch rr.Algorithm {
case ECDSAP256SHA256:
h = sha256.New()
case ECDSAP384SHA384:
h = sha512.New384()
}
io.WriteString(h, string(signeddata))
sighash := h.Sum(nil)
// Split sigbuf into the r and s coordinates
r := big.NewInt(0)
r.SetBytes(sigbuf[:len(sigbuf)/2])
s := big.NewInt(0)
s.SetBytes(sigbuf[len(sigbuf)/2:])
if ecdsa.Verify(pubkey, sighash, r, s) {
return nil
}
return ErrSig
}
// Unknown alg
return ErrAlg
}
// ValidityPeriod uses RFC1982 serial arithmetic to calculate
// if a signature period is valid. If t is the zero time, the
// current time is taken other t is.
func (rr *RRSIG) ValidityPeriod(t time.Time) bool {
var utc int64
if t.IsZero() {
utc = time.Now().UTC().Unix()
} else {
utc = t.UTC().Unix()
}
modi := (int64(rr.Inception) - utc) / year68
mode := (int64(rr.Expiration) - utc) / year68
ti := int64(rr.Inception) + (modi * year68)
te := int64(rr.Expiration) + (mode * year68)
return ti <= utc && utc <= te
}
// Return the signatures base64 encodedig sigdata as a byte slice.
func (s *RRSIG) sigBuf() []byte {
sigbuf, err := packBase64([]byte(s.Signature))
if err != nil {
return nil
}
return sigbuf
}
// setPublicKeyInPrivate sets the public key in the private key.
func (k *DNSKEY) setPublicKeyInPrivate(p PrivateKey) bool {
switch t := p.(type) {
case *dsa.PrivateKey:
x := k.publicKeyDSA()
if x == nil {
return false
}
t.PublicKey = *x
case *rsa.PrivateKey:
x := k.publicKeyRSA()
if x == nil {
return false
}
t.PublicKey = *x
case *ecdsa.PrivateKey:
x := k.publicKeyCurve()
if x == nil {
return false
}
t.PublicKey = *x
}
return true
}
// publicKeyRSA returns the RSA public key from a DNSKEY record.
func (k *DNSKEY) publicKeyRSA() *rsa.PublicKey {
keybuf, err := packBase64([]byte(k.PublicKey))
if err != nil {
return nil
}
// RFC 2537/3110, section 2. RSA Public KEY Resource Records
// Length is in the 0th byte, unless its zero, then it
// it in bytes 1 and 2 and its a 16 bit number
explen := uint16(keybuf[0])
keyoff := 1
if explen == 0 {
explen = uint16(keybuf[1])<<8 | uint16(keybuf[2])
keyoff = 3
}
pubkey := new(rsa.PublicKey)
pubkey.N = big.NewInt(0)
shift := uint64((explen - 1) * 8)
expo := uint64(0)
for i := int(explen - 1); i > 0; i-- {
expo += uint64(keybuf[keyoff+i]) << shift
shift -= 8
}
// Remainder
expo += uint64(keybuf[keyoff])
if expo > 2<<31 {
// Larger expo than supported.
// println("dns: F5 primes (or larger) are not supported")
return nil
}
pubkey.E = int(expo)
pubkey.N.SetBytes(keybuf[keyoff+int(explen):])
return pubkey
}
// publicKeyCurve returns the Curve public key from the DNSKEY record.
func (k *DNSKEY) publicKeyCurve() *ecdsa.PublicKey {
keybuf, err := packBase64([]byte(k.PublicKey))
if err != nil {
return nil
}
pubkey := new(ecdsa.PublicKey)
switch k.Algorithm {
case ECDSAP256SHA256:
pubkey.Curve = elliptic.P256()
if len(keybuf) != 64 {
// wrongly encoded key
return nil
}
case ECDSAP384SHA384:
pubkey.Curve = elliptic.P384()
if len(keybuf) != 96 {
// Wrongly encoded key
return nil
}
}
pubkey.X = big.NewInt(0)
pubkey.X.SetBytes(keybuf[:len(keybuf)/2])
pubkey.Y = big.NewInt(0)
pubkey.Y.SetBytes(keybuf[len(keybuf)/2:])
return pubkey
}
func (k *DNSKEY) publicKeyDSA() *dsa.PublicKey {
keybuf, err := packBase64([]byte(k.PublicKey))
if err != nil {
return nil
}
if len(keybuf) < 22 { // TODO: check
return nil
}
t := int(keybuf[0])
size := 64 + t*8
pubkey := new(dsa.PublicKey)
pubkey.Parameters.Q = big.NewInt(0)
pubkey.Parameters.Q.SetBytes(keybuf[1:21]) // +/- 1 ?
pubkey.Parameters.P = big.NewInt(0)
pubkey.Parameters.P.SetBytes(keybuf[22 : 22+size])
pubkey.Parameters.G = big.NewInt(0)
pubkey.Parameters.G.SetBytes(keybuf[22+size+1 : 22+size*2])
pubkey.Y = big.NewInt(0)
pubkey.Y.SetBytes(keybuf[22+size*2+1 : 22+size*3])
return pubkey
}
// Set the public key (the value E and N)
func (k *DNSKEY) setPublicKeyRSA(_E int, _N *big.Int) bool {
if _E == 0 || _N == nil {
return false
}
buf := exponentToBuf(_E)
buf = append(buf, _N.Bytes()...)
k.PublicKey = unpackBase64(buf)
return true
}
// Set the public key for Elliptic Curves
func (k *DNSKEY) setPublicKeyCurve(_X, _Y *big.Int) bool {
if _X == nil || _Y == nil {
return false
}
buf := curveToBuf(_X, _Y)
// Check the length of the buffer, either 64 or 92 bytes
k.PublicKey = unpackBase64(buf)
return true
}
// Set the public key for DSA
func (k *DNSKEY) setPublicKeyDSA(_Q, _P, _G, _Y *big.Int) bool {
if _Q == nil || _P == nil || _G == nil || _Y == nil {
return false
}
buf := dsaToBuf(_Q, _P, _G, _Y)
k.PublicKey = unpackBase64(buf)
return true
}
// Set the public key (the values E and N) for RSA
// RFC 3110: Section 2. RSA Public KEY Resource Records
func exponentToBuf(_E int) []byte {
var buf []byte
i := big.NewInt(int64(_E))
if len(i.Bytes()) < 256 {
buf = make([]byte, 1)
buf[0] = uint8(len(i.Bytes()))
} else {
buf = make([]byte, 3)
buf[0] = 0
buf[1] = uint8(len(i.Bytes()) >> 8)
buf[2] = uint8(len(i.Bytes()))
}
buf = append(buf, i.Bytes()...)
return buf
}
// Set the public key for X and Y for Curve. The two
// values are just concatenated.
func curveToBuf(_X, _Y *big.Int) []byte {
buf := _X.Bytes()
buf = append(buf, _Y.Bytes()...)
return buf
}
// Set the public key for X and Y for Curve. The two
// values are just concatenated.
func dsaToBuf(_Q, _P, _G, _Y *big.Int) []byte {
t := byte((len(_G.Bytes()) - 64) / 8)
buf := []byte{t}
buf = append(buf, _Q.Bytes()...)
buf = append(buf, _P.Bytes()...)
buf = append(buf, _G.Bytes()...)
buf = append(buf, _Y.Bytes()...)
return buf
}
type wireSlice [][]byte
func (p wireSlice) Len() int { return len(p) }
func (p wireSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
func (p wireSlice) Less(i, j int) bool {
_, ioff, _ := UnpackDomainName(p[i], 0)
_, joff, _ := UnpackDomainName(p[j], 0)
return bytes.Compare(p[i][ioff+10:], p[j][joff+10:]) < 0
}
// Return the raw signature data.
func rawSignatureData(rrset []RR, s *RRSIG) (buf []byte, err error) {
wires := make(wireSlice, len(rrset))
for i, r := range rrset {
r1 := r.copy()
r1.Header().Ttl = s.OrigTtl
labels := SplitDomainName(r1.Header().Name)
// 6.2. Canonical RR Form. (4) - wildcards
if len(labels) > int(s.Labels) {
// Wildcard
r1.Header().Name = "*." + strings.Join(labels[len(labels)-int(s.Labels):], ".") + "."
}
// RFC 4034: 6.2. Canonical RR Form. (2) - domain name to lowercase
r1.Header().Name = strings.ToLower(r1.Header().Name)
// 6.2. Canonical RR Form. (3) - domain rdata to lowercase.
// NS, MD, MF, CNAME, SOA, MB, MG, MR, PTR,
// HINFO, MINFO, MX, RP, AFSDB, RT, SIG, PX, NXT, NAPTR, KX,
// SRV, DNAME, A6
switch x := r1.(type) {
case *NS:
x.Ns = strings.ToLower(x.Ns)
case *CNAME:
x.Target = strings.ToLower(x.Target)
case *SOA:
x.Ns = strings.ToLower(x.Ns)
x.Mbox = strings.ToLower(x.Mbox)
case *MB:
x.Mb = strings.ToLower(x.Mb)
case *MG:
x.Mg = strings.ToLower(x.Mg)
case *MR:
x.Mr = strings.ToLower(x.Mr)
case *PTR:
x.Ptr = strings.ToLower(x.Ptr)
case *MINFO:
x.Rmail = strings.ToLower(x.Rmail)
x.Email = strings.ToLower(x.Email)
case *MX:
x.Mx = strings.ToLower(x.Mx)
case *NAPTR:
x.Replacement = strings.ToLower(x.Replacement)
case *KX:
x.Exchanger = strings.ToLower(x.Exchanger)
case *SRV:
x.Target = strings.ToLower(x.Target)
case *DNAME:
x.Target = strings.ToLower(x.Target)
}
// 6.2. Canonical RR Form. (5) - origTTL
wire := make([]byte, r1.len()+1) // +1 to be safe(r)
off, err1 := PackRR(r1, wire, 0, nil, false)
if err1 != nil {
return nil, err1
}
wire = wire[:off]
wires[i] = wire
}
sort.Sort(wires)
for _, wire := range wires {
buf = append(buf, wire...)
}
return buf, nil
}
// Map for algorithm names.
var AlgorithmToString = map[uint8]string{
RSAMD5: "RSAMD5",
DH: "DH",
DSA: "DSA",
RSASHA1: "RSASHA1",
DSANSEC3SHA1: "DSA-NSEC3-SHA1",
RSASHA1NSEC3SHA1: "RSASHA1-NSEC3-SHA1",
RSASHA256: "RSASHA256",
RSASHA512: "RSASHA512",
ECCGOST: "ECC-GOST",
ECDSAP256SHA256: "ECDSAP256SHA256",
ECDSAP384SHA384: "ECDSAP384SHA384",
INDIRECT: "INDIRECT",
PRIVATEDNS: "PRIVATEDNS",
PRIVATEOID: "PRIVATEOID",
}
// Map of algorithm strings.
var StringToAlgorithm = reverseInt8(AlgorithmToString)
// Map for hash names.
var HashToString = map[uint8]string{
SHA1: "SHA1",
SHA256: "SHA256",
GOST94: "GOST94",
SHA384: "SHA384",
SHA512: "SHA512",
}
// Map of hash strings.
var StringToHash = reverseInt8(HashToString)

View file

@ -0,0 +1,646 @@
package dns
import (
"crypto/rsa"
"reflect"
"strings"
"testing"
"time"
)
func getKey() *DNSKEY {
key := new(DNSKEY)
key.Hdr.Name = "miek.nl."
key.Hdr.Class = ClassINET
key.Hdr.Ttl = 14400
key.Flags = 256
key.Protocol = 3
key.Algorithm = RSASHA256
key.PublicKey = "AwEAAcNEU67LJI5GEgF9QLNqLO1SMq1EdoQ6E9f85ha0k0ewQGCblyW2836GiVsm6k8Kr5ECIoMJ6fZWf3CQSQ9ycWfTyOHfmI3eQ/1Covhb2y4bAmL/07PhrL7ozWBW3wBfM335Ft9xjtXHPy7ztCbV9qZ4TVDTW/Iyg0PiwgoXVesz"
return key
}
func getSoa() *SOA {
soa := new(SOA)
soa.Hdr = RR_Header{"miek.nl.", TypeSOA, ClassINET, 14400, 0}
soa.Ns = "open.nlnetlabs.nl."
soa.Mbox = "miekg.atoom.net."
soa.Serial = 1293945905
soa.Refresh = 14400
soa.Retry = 3600
soa.Expire = 604800
soa.Minttl = 86400
return soa
}
func TestGenerateEC(t *testing.T) {
key := new(DNSKEY)
key.Hdr.Rrtype = TypeDNSKEY
key.Hdr.Name = "miek.nl."
key.Hdr.Class = ClassINET
key.Hdr.Ttl = 14400
key.Flags = 256
key.Protocol = 3
key.Algorithm = ECDSAP256SHA256
privkey, _ := key.Generate(256)
t.Logf("%s\n", key.String())
t.Logf("%s\n", key.PrivateKeyString(privkey))
}
func TestGenerateDSA(t *testing.T) {
key := new(DNSKEY)
key.Hdr.Rrtype = TypeDNSKEY
key.Hdr.Name = "miek.nl."
key.Hdr.Class = ClassINET
key.Hdr.Ttl = 14400
key.Flags = 256
key.Protocol = 3
key.Algorithm = DSA
privkey, _ := key.Generate(1024)
t.Logf("%s\n", key.String())
t.Logf("%s\n", key.PrivateKeyString(privkey))
}
func TestGenerateRSA(t *testing.T) {
key := new(DNSKEY)
key.Hdr.Rrtype = TypeDNSKEY
key.Hdr.Name = "miek.nl."
key.Hdr.Class = ClassINET
key.Hdr.Ttl = 14400
key.Flags = 256
key.Protocol = 3
key.Algorithm = RSASHA256
privkey, _ := key.Generate(1024)
t.Logf("%s\n", key.String())
t.Logf("%s\n", key.PrivateKeyString(privkey))
}
func TestSecure(t *testing.T) {
soa := getSoa()
sig := new(RRSIG)
sig.Hdr = RR_Header{"miek.nl.", TypeRRSIG, ClassINET, 14400, 0}
sig.TypeCovered = TypeSOA
sig.Algorithm = RSASHA256
sig.Labels = 2
sig.Expiration = 1296534305 // date -u '+%s' -d"2011-02-01 04:25:05"
sig.Inception = 1293942305 // date -u '+%s' -d"2011-01-02 04:25:05"
sig.OrigTtl = 14400
sig.KeyTag = 12051
sig.SignerName = "miek.nl."
sig.Signature = "oMCbslaAVIp/8kVtLSms3tDABpcPRUgHLrOR48OOplkYo+8TeEGWwkSwaz/MRo2fB4FxW0qj/hTlIjUGuACSd+b1wKdH5GvzRJc2pFmxtCbm55ygAh4EUL0F6U5cKtGJGSXxxg6UFCQ0doJCmiGFa78LolaUOXImJrk6AFrGa0M="
key := new(DNSKEY)
key.Hdr.Name = "miek.nl."
key.Hdr.Class = ClassINET
key.Hdr.Ttl = 14400
key.Flags = 256
key.Protocol = 3
key.Algorithm = RSASHA256
key.PublicKey = "AwEAAcNEU67LJI5GEgF9QLNqLO1SMq1EdoQ6E9f85ha0k0ewQGCblyW2836GiVsm6k8Kr5ECIoMJ6fZWf3CQSQ9ycWfTyOHfmI3eQ/1Covhb2y4bAmL/07PhrL7ozWBW3wBfM335Ft9xjtXHPy7ztCbV9qZ4TVDTW/Iyg0PiwgoXVesz"
// It should validate. Period is checked seperately, so this will keep on working
if sig.Verify(key, []RR{soa}) != nil {
t.Log("failure to validate")
t.Fail()
}
}
func TestSignature(t *testing.T) {
sig := new(RRSIG)
sig.Hdr.Name = "miek.nl."
sig.Hdr.Class = ClassINET
sig.Hdr.Ttl = 3600
sig.TypeCovered = TypeDNSKEY
sig.Algorithm = RSASHA1
sig.Labels = 2
sig.OrigTtl = 4000
sig.Expiration = 1000 //Thu Jan 1 02:06:40 CET 1970
sig.Inception = 800 //Thu Jan 1 01:13:20 CET 1970
sig.KeyTag = 34641
sig.SignerName = "miek.nl."
sig.Signature = "AwEAAaHIwpx3w4VHKi6i1LHnTaWeHCL154Jug0Rtc9ji5qwPXpBo6A5sRv7cSsPQKPIwxLpyCrbJ4mr2L0EPOdvP6z6YfljK2ZmTbogU9aSU2fiq/4wjxbdkLyoDVgtO+JsxNN4bjr4WcWhsmk1Hg93FV9ZpkWb0Tbad8DFqNDzr//kZ"
// Should not be valid
if sig.ValidityPeriod(time.Now()) {
t.Log("should not be valid")
t.Fail()
}
sig.Inception = 315565800 //Tue Jan 1 10:10:00 CET 1980
sig.Expiration = 4102477800 //Fri Jan 1 10:10:00 CET 2100
if !sig.ValidityPeriod(time.Now()) {
t.Log("should be valid")
t.Fail()
}
}
func TestSignVerify(t *testing.T) {
// The record we want to sign
soa := new(SOA)
soa.Hdr = RR_Header{"miek.nl.", TypeSOA, ClassINET, 14400, 0}
soa.Ns = "open.nlnetlabs.nl."
soa.Mbox = "miekg.atoom.net."
soa.Serial = 1293945905
soa.Refresh = 14400
soa.Retry = 3600
soa.Expire = 604800
soa.Minttl = 86400
soa1 := new(SOA)
soa1.Hdr = RR_Header{"*.miek.nl.", TypeSOA, ClassINET, 14400, 0}
soa1.Ns = "open.nlnetlabs.nl."
soa1.Mbox = "miekg.atoom.net."
soa1.Serial = 1293945905
soa1.Refresh = 14400
soa1.Retry = 3600
soa1.Expire = 604800
soa1.Minttl = 86400
srv := new(SRV)
srv.Hdr = RR_Header{"srv.miek.nl.", TypeSRV, ClassINET, 14400, 0}
srv.Port = 1000
srv.Weight = 800
srv.Target = "web1.miek.nl."
// With this key
key := new(DNSKEY)
key.Hdr.Rrtype = TypeDNSKEY
key.Hdr.Name = "miek.nl."
key.Hdr.Class = ClassINET
key.Hdr.Ttl = 14400
key.Flags = 256
key.Protocol = 3
key.Algorithm = RSASHA256
privkey, _ := key.Generate(512)
// Fill in the values of the Sig, before signing
sig := new(RRSIG)
sig.Hdr = RR_Header{"miek.nl.", TypeRRSIG, ClassINET, 14400, 0}
sig.TypeCovered = soa.Hdr.Rrtype
sig.Labels = uint8(CountLabel(soa.Hdr.Name)) // works for all 3
sig.OrigTtl = soa.Hdr.Ttl
sig.Expiration = 1296534305 // date -u '+%s' -d"2011-02-01 04:25:05"
sig.Inception = 1293942305 // date -u '+%s' -d"2011-01-02 04:25:05"
sig.KeyTag = key.KeyTag() // Get the keyfrom the Key
sig.SignerName = key.Hdr.Name
sig.Algorithm = RSASHA256
for _, r := range []RR{soa, soa1, srv} {
if sig.Sign(privkey, []RR{r}) != nil {
t.Log("failure to sign the record")
t.Fail()
continue
}
if sig.Verify(key, []RR{r}) != nil {
t.Log("failure to validate")
t.Fail()
continue
}
t.Logf("validated: %s\n", r.Header().Name)
}
}
func Test65534(t *testing.T) {
t6 := new(RFC3597)
t6.Hdr = RR_Header{"miek.nl.", 65534, ClassINET, 14400, 0}
t6.Rdata = "505D870001"
key := new(DNSKEY)
key.Hdr.Name = "miek.nl."
key.Hdr.Rrtype = TypeDNSKEY
key.Hdr.Class = ClassINET
key.Hdr.Ttl = 14400
key.Flags = 256
key.Protocol = 3
key.Algorithm = RSASHA256
privkey, _ := key.Generate(1024)
sig := new(RRSIG)
sig.Hdr = RR_Header{"miek.nl.", TypeRRSIG, ClassINET, 14400, 0}
sig.TypeCovered = t6.Hdr.Rrtype
sig.Labels = uint8(CountLabel(t6.Hdr.Name))
sig.OrigTtl = t6.Hdr.Ttl
sig.Expiration = 1296534305 // date -u '+%s' -d"2011-02-01 04:25:05"
sig.Inception = 1293942305 // date -u '+%s' -d"2011-01-02 04:25:05"
sig.KeyTag = key.KeyTag()
sig.SignerName = key.Hdr.Name
sig.Algorithm = RSASHA256
if err := sig.Sign(privkey, []RR{t6}); err != nil {
t.Log(err)
t.Log("failure to sign the TYPE65534 record")
t.Fail()
}
if err := sig.Verify(key, []RR{t6}); err != nil {
t.Log(err)
t.Log("failure to validate")
t.Fail()
} else {
t.Logf("validated: %s\n", t6.Header().Name)
}
}
func TestDnskey(t *testing.T) {
// f, _ := os.Open("t/Kmiek.nl.+010+05240.key")
pubkey, _ := ReadRR(strings.NewReader(`
miek.nl. IN DNSKEY 256 3 10 AwEAAZuMCu2FdugHkTrXYgl5qixvcDw1aDDlvL46/xJKbHBAHY16fNUb2b65cwko2Js/aJxUYJbZk5dwCDZxYfrfbZVtDPQuc3o8QaChVxC7/JYz2AHc9qHvqQ1j4VrH71RWINlQo6VYjzN/BGpMhOZoZOEwzp1HfsOE3lNYcoWU1smL ;{id = 5240 (zsk), size = 1024b}
`), "Kmiek.nl.+010+05240.key")
privkey, _ := pubkey.(*DNSKEY).ReadPrivateKey(strings.NewReader(`
Private-key-format: v1.2
Algorithm: 10 (RSASHA512)
Modulus: m4wK7YV26AeROtdiCXmqLG9wPDVoMOW8vjr/EkpscEAdjXp81RvZvrlzCSjYmz9onFRgltmTl3AINnFh+t9tlW0M9C5zejxBoKFXELv8ljPYAdz2oe+pDWPhWsfvVFYg2VCjpViPM38EakyE5mhk4TDOnUd+w4TeU1hyhZTWyYs=
PublicExponent: AQAB
PrivateExponent: UfCoIQ/Z38l8vB6SSqOI/feGjHEl/fxIPX4euKf0D/32k30fHbSaNFrFOuIFmWMB3LimWVEs6u3dpbB9CQeCVg7hwU5puG7OtuiZJgDAhNeOnxvo5btp4XzPZrJSxR4WNQnwIiYWbl0aFlL1VGgHC/3By89ENZyWaZcMLW4KGWE=
Prime1: yxwC6ogAu8aVcDx2wg1V0b5M5P6jP8qkRFVMxWNTw60Vkn+ECvw6YAZZBHZPaMyRYZLzPgUlyYRd0cjupy4+fQ==
Prime2: xA1bF8M0RTIQ6+A11AoVG6GIR/aPGg5sogRkIZ7ID/sF6g9HMVU/CM2TqVEBJLRPp73cv6ZeC3bcqOCqZhz+pw==
Exponent1: xzkblyZ96bGYxTVZm2/vHMOXswod4KWIyMoOepK6B/ZPcZoIT6omLCgtypWtwHLfqyCz3MK51Nc0G2EGzg8rFQ==
Exponent2: Pu5+mCEb7T5F+kFNZhQadHUklt0JUHbi3hsEvVoHpEGSw3BGDQrtIflDde0/rbWHgDPM4WQY+hscd8UuTXrvLw==
Coefficient: UuRoNqe7YHnKmQzE6iDWKTMIWTuoqqrFAmXPmKQnC+Y+BQzOVEHUo9bXdDnoI9hzXP1gf8zENMYwYLeWpuYlFQ==
`), "Kmiek.nl.+010+05240.private")
if pubkey.(*DNSKEY).PublicKey != "AwEAAZuMCu2FdugHkTrXYgl5qixvcDw1aDDlvL46/xJKbHBAHY16fNUb2b65cwko2Js/aJxUYJbZk5dwCDZxYfrfbZVtDPQuc3o8QaChVxC7/JYz2AHc9qHvqQ1j4VrH71RWINlQo6VYjzN/BGpMhOZoZOEwzp1HfsOE3lNYcoWU1smL" {
t.Log("pubkey is not what we've read")
t.Fail()
}
// Coefficient looks fishy...
t.Logf("%s", pubkey.(*DNSKEY).PrivateKeyString(privkey))
}
func TestTag(t *testing.T) {
key := new(DNSKEY)
key.Hdr.Name = "miek.nl."
key.Hdr.Rrtype = TypeDNSKEY
key.Hdr.Class = ClassINET
key.Hdr.Ttl = 3600
key.Flags = 256
key.Protocol = 3
key.Algorithm = RSASHA256
key.PublicKey = "AwEAAcNEU67LJI5GEgF9QLNqLO1SMq1EdoQ6E9f85ha0k0ewQGCblyW2836GiVsm6k8Kr5ECIoMJ6fZWf3CQSQ9ycWfTyOHfmI3eQ/1Covhb2y4bAmL/07PhrL7ozWBW3wBfM335Ft9xjtXHPy7ztCbV9qZ4TVDTW/Iyg0PiwgoXVesz"
tag := key.KeyTag()
if tag != 12051 {
t.Logf("wrong key tag: %d for key %v\n", tag, key)
t.Fail()
}
}
func TestKeyRSA(t *testing.T) {
key := new(DNSKEY)
key.Hdr.Name = "miek.nl."
key.Hdr.Rrtype = TypeDNSKEY
key.Hdr.Class = ClassINET
key.Hdr.Ttl = 3600
key.Flags = 256
key.Protocol = 3
key.Algorithm = RSASHA256
priv, _ := key.Generate(2048)
soa := new(SOA)
soa.Hdr = RR_Header{"miek.nl.", TypeSOA, ClassINET, 14400, 0}
soa.Ns = "open.nlnetlabs.nl."
soa.Mbox = "miekg.atoom.net."
soa.Serial = 1293945905
soa.Refresh = 14400
soa.Retry = 3600
soa.Expire = 604800
soa.Minttl = 86400
sig := new(RRSIG)
sig.Hdr = RR_Header{"miek.nl.", TypeRRSIG, ClassINET, 14400, 0}
sig.TypeCovered = TypeSOA
sig.Algorithm = RSASHA256
sig.Labels = 2
sig.Expiration = 1296534305 // date -u '+%s' -d"2011-02-01 04:25:05"
sig.Inception = 1293942305 // date -u '+%s' -d"2011-01-02 04:25:05"
sig.OrigTtl = soa.Hdr.Ttl
sig.KeyTag = key.KeyTag()
sig.SignerName = key.Hdr.Name
if err := sig.Sign(priv, []RR{soa}); err != nil {
t.Logf("failed to sign")
t.Fail()
return
}
if err := sig.Verify(key, []RR{soa}); err != nil {
t.Logf("failed to verify")
t.Fail()
}
}
func TestKeyToDS(t *testing.T) {
key := new(DNSKEY)
key.Hdr.Name = "miek.nl."
key.Hdr.Rrtype = TypeDNSKEY
key.Hdr.Class = ClassINET
key.Hdr.Ttl = 3600
key.Flags = 256
key.Protocol = 3
key.Algorithm = RSASHA256
key.PublicKey = "AwEAAcNEU67LJI5GEgF9QLNqLO1SMq1EdoQ6E9f85ha0k0ewQGCblyW2836GiVsm6k8Kr5ECIoMJ6fZWf3CQSQ9ycWfTyOHfmI3eQ/1Covhb2y4bAmL/07PhrL7ozWBW3wBfM335Ft9xjtXHPy7ztCbV9qZ4TVDTW/Iyg0PiwgoXVesz"
ds := key.ToDS(SHA1)
if strings.ToUpper(ds.Digest) != "B5121BDB5B8D86D0CC5FFAFBAAABE26C3E20BAC1" {
t.Logf("wrong DS digest for SHA1\n%v\n", ds)
t.Fail()
}
}
func TestSignRSA(t *testing.T) {
pub := "miek.nl. IN DNSKEY 256 3 5 AwEAAb+8lGNCxJgLS8rYVer6EnHVuIkQDghdjdtewDzU3G5R7PbMbKVRvH2Ma7pQyYceoaqWZQirSj72euPWfPxQnMy9ucCylA+FuH9cSjIcPf4PqJfdupHk9X6EBYjxrCLY4p1/yBwgyBIRJtZtAqM3ceAH2WovEJD6rTtOuHo5AluJ"
priv := `Private-key-format: v1.3
Algorithm: 5 (RSASHA1)
Modulus: v7yUY0LEmAtLythV6voScdW4iRAOCF2N217APNTcblHs9sxspVG8fYxrulDJhx6hqpZlCKtKPvZ649Z8/FCczL25wLKUD4W4f1xKMhw9/g+ol926keT1foQFiPGsItjinX/IHCDIEhEm1m0Cozdx4AfZai8QkPqtO064ejkCW4k=
PublicExponent: AQAB
PrivateExponent: YPwEmwjk5HuiROKU4xzHQ6l1hG8Iiha4cKRG3P5W2b66/EN/GUh07ZSf0UiYB67o257jUDVEgwCuPJz776zfApcCB4oGV+YDyEu7Hp/rL8KcSN0la0k2r9scKwxTp4BTJT23zyBFXsV/1wRDK1A5NxsHPDMYi2SoK63Enm/1ptk=
Prime1: /wjOG+fD0ybNoSRn7nQ79udGeR1b0YhUA5mNjDx/x2fxtIXzygYk0Rhx9QFfDy6LOBvz92gbNQlzCLz3DJt5hw==
Prime2: wHZsJ8OGhkp5p3mrJFZXMDc2mbYusDVTA+t+iRPdS797Tj0pjvU2HN4vTnTj8KBQp6hmnY7dLp9Y1qserySGbw==
Exponent1: N0A7FsSRIg+IAN8YPQqlawoTtG1t1OkJ+nWrurPootScApX6iMvn8fyvw3p2k51rv84efnzpWAYiC8SUaQDNxQ==
Exponent2: SvuYRaGyvo0zemE3oS+WRm2scxR8eiA8WJGeOc+obwOKCcBgeZblXzfdHGcEC1KaOcetOwNW/vwMA46lpLzJNw==
Coefficient: 8+7ZN/JgByqv0NfULiFKTjtyegUcijRuyij7yNxYbCBneDvZGxJwKNi4YYXWx743pcAj4Oi4Oh86gcmxLs+hGw==
Created: 20110302104537
Publish: 20110302104537
Activate: 20110302104537`
xk, _ := NewRR(pub)
k := xk.(*DNSKEY)
p, err := k.NewPrivateKey(priv)
if err != nil {
t.Logf("%v\n", err)
t.Fail()
}
switch priv := p.(type) {
case *rsa.PrivateKey:
if 65537 != priv.PublicKey.E {
t.Log("exponenent should be 65537")
t.Fail()
}
default:
t.Logf("we should have read an RSA key: %v", priv)
t.Fail()
}
if k.KeyTag() != 37350 {
t.Logf("%d %v\n", k.KeyTag(), k)
t.Log("keytag should be 37350")
t.Fail()
}
soa := new(SOA)
soa.Hdr = RR_Header{"miek.nl.", TypeSOA, ClassINET, 14400, 0}
soa.Ns = "open.nlnetlabs.nl."
soa.Mbox = "miekg.atoom.net."
soa.Serial = 1293945905
soa.Refresh = 14400
soa.Retry = 3600
soa.Expire = 604800
soa.Minttl = 86400
sig := new(RRSIG)
sig.Hdr = RR_Header{"miek.nl.", TypeRRSIG, ClassINET, 14400, 0}
sig.Expiration = 1296534305 // date -u '+%s' -d"2011-02-01 04:25:05"
sig.Inception = 1293942305 // date -u '+%s' -d"2011-01-02 04:25:05"
sig.KeyTag = k.KeyTag()
sig.SignerName = k.Hdr.Name
sig.Algorithm = k.Algorithm
sig.Sign(p, []RR{soa})
if sig.Signature != "D5zsobpQcmMmYsUMLxCVEtgAdCvTu8V/IEeP4EyLBjqPJmjt96bwM9kqihsccofA5LIJ7DN91qkCORjWSTwNhzCv7bMyr2o5vBZElrlpnRzlvsFIoAZCD9xg6ZY7ZyzUJmU6IcTwG4v3xEYajcpbJJiyaw/RqR90MuRdKPiBzSo=" {
t.Log("signature is not correct")
t.Logf("%v\n", sig)
t.Fail()
}
}
func TestSignVerifyECDSA(t *testing.T) {
pub := `example.net. 3600 IN DNSKEY 257 3 14 (
xKYaNhWdGOfJ+nPrL8/arkwf2EY3MDJ+SErKivBVSum1
w/egsXvSADtNJhyem5RCOpgQ6K8X1DRSEkrbYQ+OB+v8
/uX45NBwY8rp65F6Glur8I/mlVNgF6W/qTI37m40 )`
priv := `Private-key-format: v1.2
Algorithm: 14 (ECDSAP384SHA384)
PrivateKey: WURgWHCcYIYUPWgeLmiPY2DJJk02vgrmTfitxgqcL4vwW7BOrbawVmVe0d9V94SR`
eckey, err := NewRR(pub)
if err != nil {
t.Fatal(err.Error())
}
privkey, err := eckey.(*DNSKEY).NewPrivateKey(priv)
if err != nil {
t.Fatal(err.Error())
}
// TODO: Create seperate test for this
ds := eckey.(*DNSKEY).ToDS(SHA384)
if ds.KeyTag != 10771 {
t.Fatal("wrong keytag on DS")
}
if ds.Digest != "72d7b62976ce06438e9c0bf319013cf801f09ecc84b8d7e9495f27e305c6a9b0563a9b5f4d288405c3008a946df983d6" {
t.Fatal("wrong DS Digest")
}
a, _ := NewRR("www.example.net. 3600 IN A 192.0.2.1")
sig := new(RRSIG)
sig.Hdr = RR_Header{"example.net.", TypeRRSIG, ClassINET, 14400, 0}
sig.Expiration, _ = StringToTime("20100909102025")
sig.Inception, _ = StringToTime("20100812102025")
sig.KeyTag = eckey.(*DNSKEY).KeyTag()
sig.SignerName = eckey.(*DNSKEY).Hdr.Name
sig.Algorithm = eckey.(*DNSKEY).Algorithm
sig.Sign(privkey, []RR{a})
t.Logf("%s", sig.String())
if e := sig.Verify(eckey.(*DNSKEY), []RR{a}); e != nil {
t.Logf("failure to validate: %s", e.Error())
t.Fail()
}
}
func TestSignVerifyECDSA2(t *testing.T) {
srv1, err := NewRR("srv.miek.nl. IN SRV 1000 800 0 web1.miek.nl.")
if err != nil {
t.Fatalf(err.Error())
}
srv := srv1.(*SRV)
// With this key
key := new(DNSKEY)
key.Hdr.Rrtype = TypeDNSKEY
key.Hdr.Name = "miek.nl."
key.Hdr.Class = ClassINET
key.Hdr.Ttl = 14400
key.Flags = 256
key.Protocol = 3
key.Algorithm = ECDSAP256SHA256
privkey, err := key.Generate(256)
if err != nil {
t.Fatal("failure to generate key")
}
// Fill in the values of the Sig, before signing
sig := new(RRSIG)
sig.Hdr = RR_Header{"miek.nl.", TypeRRSIG, ClassINET, 14400, 0}
sig.TypeCovered = srv.Hdr.Rrtype
sig.Labels = uint8(CountLabel(srv.Hdr.Name)) // works for all 3
sig.OrigTtl = srv.Hdr.Ttl
sig.Expiration = 1296534305 // date -u '+%s' -d"2011-02-01 04:25:05"
sig.Inception = 1293942305 // date -u '+%s' -d"2011-01-02 04:25:05"
sig.KeyTag = key.KeyTag() // Get the keyfrom the Key
sig.SignerName = key.Hdr.Name
sig.Algorithm = ECDSAP256SHA256
if sig.Sign(privkey, []RR{srv}) != nil {
t.Fatal("failure to sign the record")
}
err = sig.Verify(key, []RR{srv})
if err != nil {
t.Fatal("Failure to validate:", err)
}
}
// Here the test vectors from the relevant RFCs are checked.
// rfc6605 6.1
func TestRFC6605P256(t *testing.T) {
exDNSKEY := `example.net. 3600 IN DNSKEY 257 3 13 (
GojIhhXUN/u4v54ZQqGSnyhWJwaubCvTmeexv7bR6edb
krSqQpF64cYbcB7wNcP+e+MAnLr+Wi9xMWyQLc8NAA== )`
exPriv := `Private-key-format: v1.2
Algorithm: 13 (ECDSAP256SHA256)
PrivateKey: GU6SnQ/Ou+xC5RumuIUIuJZteXT2z0O/ok1s38Et6mQ=`
rrDNSKEY, err := NewRR(exDNSKEY)
if err != nil {
t.Fatal(err.Error())
}
priv, err := rrDNSKEY.(*DNSKEY).NewPrivateKey(exPriv)
if err != nil {
t.Fatal(err.Error())
}
exDS := `example.net. 3600 IN DS 55648 13 2 (
b4c8c1fe2e7477127b27115656ad6256f424625bf5c1
e2770ce6d6e37df61d17 )`
rrDS, err := NewRR(exDS)
if err != nil {
t.Fatal(err.Error())
}
ourDS := rrDNSKEY.(*DNSKEY).ToDS(SHA256)
if !reflect.DeepEqual(ourDS, rrDS.(*DS)) {
t.Errorf("DS record differs:\n%v\n%v\n", ourDS, rrDS.(*DS))
}
exA := `www.example.net. 3600 IN A 192.0.2.1`
exRRSIG := `www.example.net. 3600 IN RRSIG A 13 3 3600 (
20100909100439 20100812100439 55648 example.net.
qx6wLYqmh+l9oCKTN6qIc+bw6ya+KJ8oMz0YP107epXA
yGmt+3SNruPFKG7tZoLBLlUzGGus7ZwmwWep666VCw== )`
rrA, err := NewRR(exA)
if err != nil {
t.Fatal(err.Error())
}
rrRRSIG, err := NewRR(exRRSIG)
if err != nil {
t.Fatal(err.Error())
}
if err = rrRRSIG.(*RRSIG).Verify(rrDNSKEY.(*DNSKEY), []RR{rrA}); err != nil {
t.Errorf("Failure to validate the spec RRSIG: %v", err)
}
ourRRSIG := &RRSIG{
Hdr: RR_Header{
Ttl: rrA.Header().Ttl,
},
KeyTag: rrDNSKEY.(*DNSKEY).KeyTag(),
SignerName: rrDNSKEY.(*DNSKEY).Hdr.Name,
Algorithm: rrDNSKEY.(*DNSKEY).Algorithm,
}
ourRRSIG.Expiration, _ = StringToTime("20100909100439")
ourRRSIG.Inception, _ = StringToTime("20100812100439")
err = ourRRSIG.Sign(priv, []RR{rrA})
if err != nil {
t.Fatal(err.Error())
}
if err = ourRRSIG.Verify(rrDNSKEY.(*DNSKEY), []RR{rrA}); err != nil {
t.Errorf("Failure to validate our RRSIG: %v", err)
}
// Signatures are randomized
rrRRSIG.(*RRSIG).Signature = ""
ourRRSIG.Signature = ""
if !reflect.DeepEqual(ourRRSIG, rrRRSIG.(*RRSIG)) {
t.Fatalf("RRSIG record differs:\n%v\n%v\n", ourRRSIG, rrRRSIG.(*RRSIG))
}
}
// rfc6605 6.2
func TestRFC6605P384(t *testing.T) {
exDNSKEY := `example.net. 3600 IN DNSKEY 257 3 14 (
xKYaNhWdGOfJ+nPrL8/arkwf2EY3MDJ+SErKivBVSum1
w/egsXvSADtNJhyem5RCOpgQ6K8X1DRSEkrbYQ+OB+v8
/uX45NBwY8rp65F6Glur8I/mlVNgF6W/qTI37m40 )`
exPriv := `Private-key-format: v1.2
Algorithm: 14 (ECDSAP384SHA384)
PrivateKey: WURgWHCcYIYUPWgeLmiPY2DJJk02vgrmTfitxgqcL4vwW7BOrbawVmVe0d9V94SR`
rrDNSKEY, err := NewRR(exDNSKEY)
if err != nil {
t.Fatal(err.Error())
}
priv, err := rrDNSKEY.(*DNSKEY).NewPrivateKey(exPriv)
if err != nil {
t.Fatal(err.Error())
}
exDS := `example.net. 3600 IN DS 10771 14 4 (
72d7b62976ce06438e9c0bf319013cf801f09ecc84b8
d7e9495f27e305c6a9b0563a9b5f4d288405c3008a94
6df983d6 )`
rrDS, err := NewRR(exDS)
if err != nil {
t.Fatal(err.Error())
}
ourDS := rrDNSKEY.(*DNSKEY).ToDS(SHA384)
if !reflect.DeepEqual(ourDS, rrDS.(*DS)) {
t.Fatalf("DS record differs:\n%v\n%v\n", ourDS, rrDS.(*DS))
}
exA := `www.example.net. 3600 IN A 192.0.2.1`
exRRSIG := `www.example.net. 3600 IN RRSIG A 14 3 3600 (
20100909102025 20100812102025 10771 example.net.
/L5hDKIvGDyI1fcARX3z65qrmPsVz73QD1Mr5CEqOiLP
95hxQouuroGCeZOvzFaxsT8Glr74hbavRKayJNuydCuz
WTSSPdz7wnqXL5bdcJzusdnI0RSMROxxwGipWcJm )`
rrA, err := NewRR(exA)
if err != nil {
t.Fatal(err.Error())
}
rrRRSIG, err := NewRR(exRRSIG)
if err != nil {
t.Fatal(err.Error())
}
if err = rrRRSIG.(*RRSIG).Verify(rrDNSKEY.(*DNSKEY), []RR{rrA}); err != nil {
t.Errorf("Failure to validate the spec RRSIG: %v", err)
}
ourRRSIG := &RRSIG{
Hdr: RR_Header{
Ttl: rrA.Header().Ttl,
},
KeyTag: rrDNSKEY.(*DNSKEY).KeyTag(),
SignerName: rrDNSKEY.(*DNSKEY).Hdr.Name,
Algorithm: rrDNSKEY.(*DNSKEY).Algorithm,
}
ourRRSIG.Expiration, _ = StringToTime("20100909102025")
ourRRSIG.Inception, _ = StringToTime("20100812102025")
err = ourRRSIG.Sign(priv, []RR{rrA})
if err != nil {
t.Fatal(err.Error())
}
if err = ourRRSIG.Verify(rrDNSKEY.(*DNSKEY), []RR{rrA}); err != nil {
t.Errorf("Failure to validate our RRSIG: %v", err)
}
// Signatures are randomized
rrRRSIG.(*RRSIG).Signature = ""
ourRRSIG.Signature = ""
if !reflect.DeepEqual(ourRRSIG, rrRRSIG.(*RRSIG)) {
t.Fatalf("RRSIG record differs:\n%v\n%v\n", ourRRSIG, rrRRSIG.(*RRSIG))
}
}

View file

@ -0,0 +1,3 @@
package dns
// Find better solution

501
Godeps/_workspace/src/github.com/miekg/dns/edns.go generated vendored Normal file
View file

@ -0,0 +1,501 @@
// EDNS0
//
// EDNS0 is an extension mechanism for the DNS defined in RFC 2671 and updated
// by RFC 6891. It defines an new RR type, the OPT RR, which is then completely
// abused.
// Basic use pattern for creating an (empty) OPT RR:
//
// o := new(dns.OPT)
// o.Hdr.Name = "." // MUST be the root zone, per definition.
// o.Hdr.Rrtype = dns.TypeOPT
//
// The rdata of an OPT RR consists out of a slice of EDNS0 (RFC 6891)
// interfaces. Currently only a few have been standardized: EDNS0_NSID
// (RFC 5001) and EDNS0_SUBNET (draft-vandergaast-edns-client-subnet-02). Note
// that these options may be combined in an OPT RR.
// Basic use pattern for a server to check if (and which) options are set:
//
// // o is a dns.OPT
// for _, s := range o.Option {
// switch e := s.(type) {
// case *dns.EDNS0_NSID:
// // do stuff with e.Nsid
// case *dns.EDNS0_SUBNET:
// // access e.Family, e.Address, etc.
// }
// }
package dns
import (
"encoding/hex"
"errors"
"net"
"strconv"
)
// EDNS0 Option codes.
const (
EDNS0LLQ = 0x1 // long lived queries: http://tools.ietf.org/html/draft-sekar-dns-llq-01
EDNS0UL = 0x2 // update lease draft: http://files.dns-sd.org/draft-sekar-dns-ul.txt
EDNS0NSID = 0x3 // nsid (RFC5001)
EDNS0DAU = 0x5 // DNSSEC Algorithm Understood
EDNS0DHU = 0x6 // DS Hash Understood
EDNS0N3U = 0x7 // NSEC3 Hash Understood
EDNS0SUBNET = 0x8 // client-subnet (RFC6891)
EDNS0EXPIRE = 0x9 // EDNS0 expire
EDNS0SUBNETDRAFT = 0x50fa // Don't use! Use EDNS0SUBNET
_DO = 1 << 15 // dnssec ok
)
type OPT struct {
Hdr RR_Header
Option []EDNS0 `dns:"opt"`
}
func (rr *OPT) Header() *RR_Header {
return &rr.Hdr
}
func (rr *OPT) String() string {
s := "\n;; OPT PSEUDOSECTION:\n; EDNS: version " + strconv.Itoa(int(rr.Version())) + "; "
if rr.Do() {
s += "flags: do; "
} else {
s += "flags: ; "
}
s += "udp: " + strconv.Itoa(int(rr.UDPSize()))
for _, o := range rr.Option {
switch o.(type) {
case *EDNS0_NSID:
s += "\n; NSID: " + o.String()
h, e := o.pack()
var r string
if e == nil {
for _, c := range h {
r += "(" + string(c) + ")"
}
s += " " + r
}
case *EDNS0_SUBNET:
s += "\n; SUBNET: " + o.String()
if o.(*EDNS0_SUBNET).DraftOption {
s += " (draft)"
}
case *EDNS0_UL:
s += "\n; UPDATE LEASE: " + o.String()
case *EDNS0_LLQ:
s += "\n; LONG LIVED QUERIES: " + o.String()
case *EDNS0_DAU:
s += "\n; DNSSEC ALGORITHM UNDERSTOOD: " + o.String()
case *EDNS0_DHU:
s += "\n; DS HASH UNDERSTOOD: " + o.String()
case *EDNS0_N3U:
s += "\n; NSEC3 HASH UNDERSTOOD: " + o.String()
}
}
return s
}
func (rr *OPT) len() int {
l := rr.Hdr.len()
for i := 0; i < len(rr.Option); i++ {
lo, _ := rr.Option[i].pack()
l += 2 + len(lo)
}
return l
}
func (rr *OPT) copy() RR {
return &OPT{*rr.Hdr.copyHeader(), rr.Option}
}
// return the old value -> delete SetVersion?
// Version returns the EDNS version used. Only zero is defined.
func (rr *OPT) Version() uint8 {
return uint8((rr.Hdr.Ttl & 0x00FF0000) >> 16)
}
// SetVersion sets the version of EDNS. This is usually zero.
func (rr *OPT) SetVersion(v uint8) {
rr.Hdr.Ttl = rr.Hdr.Ttl&0xFF00FFFF | (uint32(v) << 16)
}
// ExtendedRcode returns the EDNS extended RCODE field (the upper 8 bits of the TTL).
func (rr *OPT) ExtendedRcode() uint8 {
return uint8((rr.Hdr.Ttl & 0xFF000000) >> 24)
}
// SetExtendedRcode sets the EDNS extended RCODE field.
func (rr *OPT) SetExtendedRcode(v uint8) {
rr.Hdr.Ttl = rr.Hdr.Ttl&0x00FFFFFF | (uint32(v) << 24)
}
// UDPSize returns the UDP buffer size.
func (rr *OPT) UDPSize() uint16 {
return rr.Hdr.Class
}
// SetUDPSize sets the UDP buffer size.
func (rr *OPT) SetUDPSize(size uint16) {
rr.Hdr.Class = size
}
// Do returns the value of the DO (DNSSEC OK) bit.
func (rr *OPT) Do() bool {
return rr.Hdr.Ttl&_DO == _DO
}
// SetDo sets the DO (DNSSEC OK) bit.
func (rr *OPT) SetDo() {
rr.Hdr.Ttl |= _DO
}
// EDNS0 defines an EDNS0 Option. An OPT RR can have multiple options appended to
// it.
type EDNS0 interface {
// Option returns the option code for the option.
Option() uint16
// pack returns the bytes of the option data.
pack() ([]byte, error)
// unpack sets the data as found in the buffer. Is also sets
// the length of the slice as the length of the option data.
unpack([]byte) error
// String returns the string representation of the option.
String() string
}
// The nsid EDNS0 option is used to retrieve a nameserver
// identifier. When sending a request Nsid must be set to the empty string
// The identifier is an opaque string encoded as hex.
// Basic use pattern for creating an nsid option:
//
// o := new(dns.OPT)
// o.Hdr.Name = "."
// o.Hdr.Rrtype = dns.TypeOPT
// e := new(dns.EDNS0_NSID)
// e.Code = dns.EDNS0NSID
// e.Nsid = "AA"
// o.Option = append(o.Option, e)
type EDNS0_NSID struct {
Code uint16 // Always EDNS0NSID
Nsid string // This string needs to be hex encoded
}
func (e *EDNS0_NSID) pack() ([]byte, error) {
h, err := hex.DecodeString(e.Nsid)
if err != nil {
return nil, err
}
return h, nil
}
func (e *EDNS0_NSID) Option() uint16 { return EDNS0NSID }
func (e *EDNS0_NSID) unpack(b []byte) error { e.Nsid = hex.EncodeToString(b); return nil }
func (e *EDNS0_NSID) String() string { return string(e.Nsid) }
// The subnet EDNS0 option is used to give the remote nameserver
// an idea of where the client lives. It can then give back a different
// answer depending on the location or network topology.
// Basic use pattern for creating an subnet option:
//
// o := new(dns.OPT)
// o.Hdr.Name = "."
// o.Hdr.Rrtype = dns.TypeOPT
// e := new(dns.EDNS0_SUBNET)
// e.Code = dns.EDNS0SUBNET
// e.Family = 1 // 1 for IPv4 source address, 2 for IPv6
// e.NetMask = 32 // 32 for IPV4, 128 for IPv6
// e.SourceScope = 0
// e.Address = net.ParseIP("127.0.0.1").To4() // for IPv4
// // e.Address = net.ParseIP("2001:7b8:32a::2") // for IPV6
// o.Option = append(o.Option, e)
type EDNS0_SUBNET struct {
Code uint16 // Always EDNS0SUBNET
Family uint16 // 1 for IP, 2 for IP6
SourceNetmask uint8
SourceScope uint8
Address net.IP
DraftOption bool // Set to true if using the old (0x50fa) option code
}
func (e *EDNS0_SUBNET) Option() uint16 {
if e.DraftOption {
return EDNS0SUBNETDRAFT
}
return EDNS0SUBNET
}
func (e *EDNS0_SUBNET) pack() ([]byte, error) {
b := make([]byte, 4)
b[0], b[1] = packUint16(e.Family)
b[2] = e.SourceNetmask
b[3] = e.SourceScope
switch e.Family {
case 1:
if e.SourceNetmask > net.IPv4len*8 {
return nil, errors.New("dns: bad netmask")
}
ip := make([]byte, net.IPv4len)
a := e.Address.To4().Mask(net.CIDRMask(int(e.SourceNetmask), net.IPv4len*8))
for i := 0; i < net.IPv4len; i++ {
if i+1 > len(e.Address) {
break
}
ip[i] = a[i]
}
needLength := e.SourceNetmask / 8
if e.SourceNetmask%8 > 0 {
needLength++
}
ip = ip[:needLength]
b = append(b, ip...)
case 2:
if e.SourceNetmask > net.IPv6len*8 {
return nil, errors.New("dns: bad netmask")
}
ip := make([]byte, net.IPv6len)
a := e.Address.Mask(net.CIDRMask(int(e.SourceNetmask), net.IPv6len*8))
for i := 0; i < net.IPv6len; i++ {
if i+1 > len(e.Address) {
break
}
ip[i] = a[i]
}
needLength := e.SourceNetmask / 8
if e.SourceNetmask%8 > 0 {
needLength++
}
ip = ip[:needLength]
b = append(b, ip...)
default:
return nil, errors.New("dns: bad address family")
}
return b, nil
}
func (e *EDNS0_SUBNET) unpack(b []byte) error {
lb := len(b)
if lb < 4 {
return ErrBuf
}
e.Family, _ = unpackUint16(b, 0)
e.SourceNetmask = b[2]
e.SourceScope = b[3]
switch e.Family {
case 1:
addr := make([]byte, 4)
for i := 0; i < int(e.SourceNetmask/8); i++ {
if 4+i > len(b) {
return ErrBuf
}
addr[i] = b[4+i]
}
e.Address = net.IPv4(addr[0], addr[1], addr[2], addr[3])
case 2:
addr := make([]byte, 16)
for i := 0; i < int(e.SourceNetmask/8); i++ {
if 4+i > len(b) {
return ErrBuf
}
addr[i] = b[4+i]
}
e.Address = net.IP{addr[0], addr[1], addr[2], addr[3], addr[4],
addr[5], addr[6], addr[7], addr[8], addr[9], addr[10],
addr[11], addr[12], addr[13], addr[14], addr[15]}
}
return nil
}
func (e *EDNS0_SUBNET) String() (s string) {
if e.Address == nil {
s = "<nil>"
} else if e.Address.To4() != nil {
s = e.Address.String()
} else {
s = "[" + e.Address.String() + "]"
}
s += "/" + strconv.Itoa(int(e.SourceNetmask)) + "/" + strconv.Itoa(int(e.SourceScope))
return
}
// The UL (Update Lease) EDNS0 (draft RFC) option is used to tell the server to set
// an expiration on an update RR. This is helpful for clients that cannot clean
// up after themselves. This is a draft RFC and more information can be found at
// http://files.dns-sd.org/draft-sekar-dns-ul.txt
//
// o := new(dns.OPT)
// o.Hdr.Name = "."
// o.Hdr.Rrtype = dns.TypeOPT
// e := new(dns.EDNS0_UL)
// e.Code = dns.EDNS0UL
// e.Lease = 120 // in seconds
// o.Option = append(o.Option, e)
type EDNS0_UL struct {
Code uint16 // Always EDNS0UL
Lease uint32
}
func (e *EDNS0_UL) Option() uint16 { return EDNS0UL }
func (e *EDNS0_UL) String() string { return strconv.FormatUint(uint64(e.Lease), 10) }
// Copied: http://golang.org/src/pkg/net/dnsmsg.go
func (e *EDNS0_UL) pack() ([]byte, error) {
b := make([]byte, 4)
b[0] = byte(e.Lease >> 24)
b[1] = byte(e.Lease >> 16)
b[2] = byte(e.Lease >> 8)
b[3] = byte(e.Lease)
return b, nil
}
func (e *EDNS0_UL) unpack(b []byte) error {
if len(b) < 4 {
return ErrBuf
}
e.Lease = uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
return nil
}
// Long Lived Queries: http://tools.ietf.org/html/draft-sekar-dns-llq-01
// Implemented for completeness, as the EDNS0 type code is assigned.
type EDNS0_LLQ struct {
Code uint16 // Always EDNS0LLQ
Version uint16
Opcode uint16
Error uint16
Id uint64
LeaseLife uint32
}
func (e *EDNS0_LLQ) Option() uint16 { return EDNS0LLQ }
func (e *EDNS0_LLQ) pack() ([]byte, error) {
b := make([]byte, 18)
b[0], b[1] = packUint16(e.Version)
b[2], b[3] = packUint16(e.Opcode)
b[4], b[5] = packUint16(e.Error)
b[6] = byte(e.Id >> 56)
b[7] = byte(e.Id >> 48)
b[8] = byte(e.Id >> 40)
b[9] = byte(e.Id >> 32)
b[10] = byte(e.Id >> 24)
b[11] = byte(e.Id >> 16)
b[12] = byte(e.Id >> 8)
b[13] = byte(e.Id)
b[14] = byte(e.LeaseLife >> 24)
b[15] = byte(e.LeaseLife >> 16)
b[16] = byte(e.LeaseLife >> 8)
b[17] = byte(e.LeaseLife)
return b, nil
}
func (e *EDNS0_LLQ) unpack(b []byte) error {
if len(b) < 18 {
return ErrBuf
}
e.Version, _ = unpackUint16(b, 0)
e.Opcode, _ = unpackUint16(b, 2)
e.Error, _ = unpackUint16(b, 4)
e.Id = uint64(b[6])<<56 | uint64(b[6+1])<<48 | uint64(b[6+2])<<40 |
uint64(b[6+3])<<32 | uint64(b[6+4])<<24 | uint64(b[6+5])<<16 | uint64(b[6+6])<<8 | uint64(b[6+7])
e.LeaseLife = uint32(b[14])<<24 | uint32(b[14+1])<<16 | uint32(b[14+2])<<8 | uint32(b[14+3])
return nil
}
func (e *EDNS0_LLQ) String() string {
s := strconv.FormatUint(uint64(e.Version), 10) + " " + strconv.FormatUint(uint64(e.Opcode), 10) +
" " + strconv.FormatUint(uint64(e.Error), 10) + " " + strconv.FormatUint(uint64(e.Id), 10) +
" " + strconv.FormatUint(uint64(e.LeaseLife), 10)
return s
}
type EDNS0_DAU struct {
Code uint16 // Always EDNS0DAU
AlgCode []uint8
}
func (e *EDNS0_DAU) Option() uint16 { return EDNS0DAU }
func (e *EDNS0_DAU) pack() ([]byte, error) { return e.AlgCode, nil }
func (e *EDNS0_DAU) unpack(b []byte) error { e.AlgCode = b; return nil }
func (e *EDNS0_DAU) String() string {
s := ""
for i := 0; i < len(e.AlgCode); i++ {
if a, ok := AlgorithmToString[e.AlgCode[i]]; ok {
s += " " + a
} else {
s += " " + strconv.Itoa(int(e.AlgCode[i]))
}
}
return s
}
type EDNS0_DHU struct {
Code uint16 // Always EDNS0DHU
AlgCode []uint8
}
func (e *EDNS0_DHU) Option() uint16 { return EDNS0DHU }
func (e *EDNS0_DHU) pack() ([]byte, error) { return e.AlgCode, nil }
func (e *EDNS0_DHU) unpack(b []byte) error { e.AlgCode = b; return nil }
func (e *EDNS0_DHU) String() string {
s := ""
for i := 0; i < len(e.AlgCode); i++ {
if a, ok := HashToString[e.AlgCode[i]]; ok {
s += " " + a
} else {
s += " " + strconv.Itoa(int(e.AlgCode[i]))
}
}
return s
}
type EDNS0_N3U struct {
Code uint16 // Always EDNS0N3U
AlgCode []uint8
}
func (e *EDNS0_N3U) Option() uint16 { return EDNS0N3U }
func (e *EDNS0_N3U) pack() ([]byte, error) { return e.AlgCode, nil }
func (e *EDNS0_N3U) unpack(b []byte) error { e.AlgCode = b; return nil }
func (e *EDNS0_N3U) String() string {
// Re-use the hash map
s := ""
for i := 0; i < len(e.AlgCode); i++ {
if a, ok := HashToString[e.AlgCode[i]]; ok {
s += " " + a
} else {
s += " " + strconv.Itoa(int(e.AlgCode[i]))
}
}
return s
}
type EDNS0_EXPIRE struct {
Code uint16 // Always EDNS0EXPIRE
Expire uint32
}
func (e *EDNS0_EXPIRE) Option() uint16 { return EDNS0EXPIRE }
func (e *EDNS0_EXPIRE) String() string { return strconv.FormatUint(uint64(e.Expire), 10) }
func (e *EDNS0_EXPIRE) pack() ([]byte, error) {
b := make([]byte, 4)
b[0] = byte(e.Expire >> 24)
b[1] = byte(e.Expire >> 16)
b[2] = byte(e.Expire >> 8)
b[3] = byte(e.Expire)
return b, nil
}
func (e *EDNS0_EXPIRE) unpack(b []byte) error {
if len(b) < 4 {
return ErrBuf
}
e.Expire = uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
return nil
}

View file

@ -0,0 +1,48 @@
package dns
import "testing"
func TestOPTTtl(t *testing.T) {
e := &OPT{}
e.Hdr.Name = "."
e.Hdr.Rrtype = TypeOPT
if e.Do() {
t.Fail()
}
e.SetDo()
if !e.Do() {
t.Fail()
}
oldTtl := e.Hdr.Ttl
if e.Version() != 0 {
t.Fail()
}
e.SetVersion(42)
if e.Version() != 42 {
t.Fail()
}
e.SetVersion(0)
if e.Hdr.Ttl != oldTtl {
t.Fail()
}
if e.ExtendedRcode() != 0 {
t.Fail()
}
e.SetExtendedRcode(42)
if e.ExtendedRcode() != 42 {
t.Fail()
}
e.SetExtendedRcode(0)
if e.Hdr.Ttl != oldTtl {
t.Fail()
}
}

View file

@ -0,0 +1,147 @@
package dns_test
import (
"errors"
"fmt"
"github.com/miekg/dns"
"log"
"net"
)
// Retrieve the MX records for miek.nl.
func ExampleMX() {
config, _ := dns.ClientConfigFromFile("/etc/resolv.conf")
c := new(dns.Client)
m := new(dns.Msg)
m.SetQuestion("miek.nl.", dns.TypeMX)
m.RecursionDesired = true
r, _, err := c.Exchange(m, config.Servers[0]+":"+config.Port)
if err != nil {
return
}
if r.Rcode != dns.RcodeSuccess {
return
}
for _, a := range r.Answer {
if mx, ok := a.(*dns.MX); ok {
fmt.Printf("%s\n", mx.String())
}
}
}
// Retrieve the DNSKEY records of a zone and convert them
// to DS records for SHA1, SHA256 and SHA384.
func ExampleDS(zone string) {
config, _ := dns.ClientConfigFromFile("/etc/resolv.conf")
c := new(dns.Client)
m := new(dns.Msg)
if zone == "" {
zone = "miek.nl"
}
m.SetQuestion(dns.Fqdn(zone), dns.TypeDNSKEY)
m.SetEdns0(4096, true)
r, _, err := c.Exchange(m, config.Servers[0]+":"+config.Port)
if err != nil {
return
}
if r.Rcode != dns.RcodeSuccess {
return
}
for _, k := range r.Answer {
if key, ok := k.(*dns.DNSKEY); ok {
for _, alg := range []int{dns.SHA1, dns.SHA256, dns.SHA384} {
fmt.Printf("%s; %d\n", key.ToDS(alg).String(), key.Flags)
}
}
}
}
const TypeAPAIR = 0x0F99
type APAIR struct {
addr [2]net.IP
}
func NewAPAIR() dns.PrivateRdata { return new(APAIR) }
func (rd *APAIR) String() string { return rd.addr[0].String() + " " + rd.addr[1].String() }
func (rd *APAIR) Parse(txt []string) error {
if len(txt) != 2 {
return errors.New("two addresses required for APAIR")
}
for i, s := range txt {
ip := net.ParseIP(s)
if ip == nil {
return errors.New("invalid IP in APAIR text representation")
}
rd.addr[i] = ip
}
return nil
}
func (rd *APAIR) Pack(buf []byte) (int, error) {
b := append([]byte(rd.addr[0]), []byte(rd.addr[1])...)
n := copy(buf, b)
if n != len(b) {
return n, dns.ErrBuf
}
return n, nil
}
func (rd *APAIR) Unpack(buf []byte) (int, error) {
ln := net.IPv4len * 2
if len(buf) != ln {
return 0, errors.New("invalid length of APAIR rdata")
}
cp := make([]byte, ln)
copy(cp, buf) // clone bytes to use them in IPs
rd.addr[0] = net.IP(cp[:3])
rd.addr[1] = net.IP(cp[4:])
return len(buf), nil
}
func (rd *APAIR) Copy(dest dns.PrivateRdata) error {
cp := make([]byte, rd.Len())
_, err := rd.Pack(cp)
if err != nil {
return err
}
d := dest.(*APAIR)
d.addr[0] = net.IP(cp[:3])
d.addr[1] = net.IP(cp[4:])
return nil
}
func (rd *APAIR) Len() int {
return net.IPv4len * 2
}
func ExamplePrivateHandle() {
dns.PrivateHandle("APAIR", TypeAPAIR, NewAPAIR)
defer dns.PrivateHandleRemove(TypeAPAIR)
rr, err := dns.NewRR("miek.nl. APAIR (1.2.3.4 1.2.3.5)")
if err != nil {
log.Fatal("could not parse APAIR record: ", err)
}
fmt.Println(rr)
// Output: miek.nl. 3600 IN APAIR 1.2.3.4 1.2.3.5
m := new(dns.Msg)
m.Id = 12345
m.SetQuestion("miek.nl.", TypeAPAIR)
m.Answer = append(m.Answer, rr)
fmt.Println(m)
// ;; opcode: QUERY, status: NOERROR, id: 12345
// ;; flags: rd; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0
//
// ;; QUESTION SECTION:
// ;miek.nl. IN APAIR
//
// ;; ANSWER SECTION:
// miek.nl. 3600 IN APAIR 1.2.3.4 1.2.3.5
}

View file

@ -0,0 +1,18 @@
package idn_test
import (
"fmt"
"github.com/miekg/dns/idn"
)
func ExampleToPunycode() {
name := "インターネット.テスト"
fmt.Printf("%s -> %s", name, idn.ToPunycode(name))
// Output: インターネット.テスト -> xn--eckucmux0ukc.xn--zckzah
}
func ExampleFromPunycode() {
name := "xn--mgbaja8a1hpac.xn--mgbachtv"
fmt.Printf("%s -> %s", name, idn.FromPunycode(name))
// Output: xn--mgbaja8a1hpac.xn--mgbachtv -> الانترنت.اختبار
}

View file

@ -0,0 +1,268 @@
// Package idn implements encoding from and to punycode as speficied by RFC 3492.
package idn
import (
"bytes"
"github.com/miekg/dns"
"strings"
"unicode"
)
// Implementation idea from RFC itself and from from IDNA::Punycode created by
// Tatsuhiko Miyagawa <miyagawa@bulknews.net> and released under Perl Artistic
// License in 2002.
const (
_MIN rune = 1
_MAX rune = 26
_SKEW rune = 38
_BASE rune = 36
_BIAS rune = 72
_N rune = 128
_DAMP rune = 700
_DELIMITER = '-'
_PREFIX = "xn--"
)
// ToPunycode converts unicode domain names to DNS-appropriate punycode names.
// This function would return incorrect result for strings for non-canonical
// unicode strings.
func ToPunycode(s string) string {
tokens := dns.SplitDomainName(s)
switch {
case s == "":
return ""
case tokens == nil: // s == .
return "."
case s[len(s)-1] == '.':
tokens = append(tokens, "")
}
for i := range tokens {
tokens[i] = string(encode([]byte(tokens[i])))
}
return strings.Join(tokens, ".")
}
// FromPunycode returns unicode domain name from provided punycode string.
func FromPunycode(s string) string {
tokens := dns.SplitDomainName(s)
switch {
case s == "":
return ""
case tokens == nil: // s == .
return "."
case s[len(s)-1] == '.':
tokens = append(tokens, "")
}
for i := range tokens {
tokens[i] = string(decode([]byte(tokens[i])))
}
return strings.Join(tokens, ".")
}
// digitval converts single byte into meaningful value that's used to calculate decoded unicode character.
const errdigit = 0xffff
func digitval(code rune) rune {
switch {
case code >= 'A' && code <= 'Z':
return code - 'A'
case code >= 'a' && code <= 'z':
return code - 'a'
case code >= '0' && code <= '9':
return code - '0' + 26
}
return errdigit
}
// lettercode finds BASE36 byte (a-z0-9) based on calculated number.
func lettercode(digit rune) rune {
switch {
case digit >= 0 && digit <= 25:
return digit + 'a'
case digit >= 26 && digit <= 36:
return digit - 26 + '0'
}
panic("dns: not reached")
}
// adapt calculates next bias to be used for next iteration delta.
func adapt(delta rune, numpoints int, firsttime bool) rune {
if firsttime {
delta /= _DAMP
} else {
delta /= 2
}
var k rune
for delta = delta + delta/rune(numpoints); delta > (_BASE-_MIN)*_MAX/2; k += _BASE {
delta /= _BASE - _MIN
}
return k + ((_BASE-_MIN+1)*delta)/(delta+_SKEW)
}
// next finds minimal rune (one with lowest codepoint value) that should be equal or above boundary.
func next(b []rune, boundary rune) rune {
if len(b) == 0 {
panic("dns: invalid set of runes to determine next one")
}
m := b[0]
for _, x := range b[1:] {
if x >= boundary && (m < boundary || x < m) {
m = x
}
}
return m
}
// preprune converts unicode rune to lower case. At this time it's not
// supporting all things described in RFCs
func preprune(r rune) rune {
if unicode.IsUpper(r) {
r = unicode.ToLower(r)
}
return r
}
// tfunc is a function that helps calculate each character weight
func tfunc(k, bias rune) rune {
switch {
case k <= bias:
return _MIN
case k >= bias+_MAX:
return _MAX
}
return k - bias
}
// encode transforms Unicode input bytes (that represent DNS label) into punycode bytestream
func encode(input []byte) []byte {
n, bias := _N, _BIAS
b := bytes.Runes(input)
for i := range b {
b[i] = preprune(b[i])
}
basic := make([]byte, 0, len(b))
for _, ltr := range b {
if ltr <= 0x7f {
basic = append(basic, byte(ltr))
}
}
basiclen := len(basic)
fulllen := len(b)
if basiclen == fulllen {
return basic
}
var out bytes.Buffer
out.WriteString(_PREFIX)
if basiclen > 0 {
out.Write(basic)
out.WriteByte(_DELIMITER)
}
var (
ltr, nextltr rune
delta, q rune // delta calculation (see rfc)
t, k, cp rune // weight and codepoint calculation
)
s := &bytes.Buffer{}
for h := basiclen; h < fulllen; n, delta = n+1, delta+1 {
nextltr = next(b, n)
s.Truncate(0)
s.WriteRune(nextltr)
delta, n = delta+(nextltr-n)*rune(h+1), nextltr
for _, ltr = range b {
if ltr < n {
delta++
}
if ltr == n {
q = delta
for k = _BASE; ; k += _BASE {
t = tfunc(k, bias)
if q < t {
break
}
cp = t + ((q - t) % (_BASE - t))
out.WriteRune(lettercode(cp))
q = (q - t) / (_BASE - t)
}
out.WriteRune(lettercode(q))
bias = adapt(delta, h+1, h == basiclen)
h, delta = h+1, 0
}
}
}
return out.Bytes()
}
// decode transforms punycode input bytes (that represent DNS label) into Unicode bytestream
func decode(b []byte) []byte {
src := b // b would move and we need to keep it
n, bias := _N, _BIAS
if !bytes.HasPrefix(b, []byte(_PREFIX)) {
return b
}
out := make([]rune, 0, len(b))
b = b[len(_PREFIX):]
for pos, x := range b {
if x == _DELIMITER {
out = append(out, bytes.Runes(b[:pos])...)
b = b[pos+1:] // trim source string
break
}
}
if len(b) == 0 {
return src
}
var (
i, oldi, w rune
ch byte
t, digit rune
ln int
)
for i = 0; len(b) > 0; i++ {
oldi, w = i, 1
for k := _BASE; len(b) > 0; k += _BASE {
ch, b = b[0], b[1:]
digit = digitval(rune(ch))
if digit == errdigit {
return src
}
i += digit * w
t = tfunc(k, bias)
if digit < t {
break
}
w *= _BASE - t
}
ln = len(out) + 1
bias = adapt(i-oldi, ln, oldi == 0)
n += i / rune(ln)
i = i % rune(ln)
// insert
out = append(out, 0)
copy(out[i+1:], out[i:])
out[i] = n
}
var ret bytes.Buffer
for _, r := range out {
ret.WriteRune(r)
}
return ret.Bytes()
}

View file

@ -0,0 +1,94 @@
package idn
import (
"strings"
"testing"
)
var testcases = [][2]string{
{"", ""},
{"a", "a"},
{"A-B", "a-b"},
{"AbC", "abc"},
{"я", "xn--41a"},
{"zя", "xn--z-0ub"},
{"ЯZ", "xn--z-zub"},
{"إختبار", "xn--kgbechtv"},
{"آزمایشی", "xn--hgbk6aj7f53bba"},
{"测试", "xn--0zwm56d"},
{"測試", "xn--g6w251d"},
{"Испытание", "xn--80akhbyknj4f"},
{"परीक्षा", "xn--11b5bs3a9aj6g"},
{"δοκιμή", "xn--jxalpdlp"},
{"테스트", "xn--9t4b11yi5a"},
{"טעסט", "xn--deba0ad"},
{"テスト", "xn--zckzah"},
{"பரிட்சை", "xn--hlcj6aya9esc7a"},
}
func TestEncodeDecodePunycode(t *testing.T) {
for _, tst := range testcases {
enc := encode([]byte(tst[0]))
if string(enc) != tst[1] {
t.Errorf("%s encodeded as %s but should be %s", tst[0], enc, tst[1])
}
dec := decode([]byte(tst[1]))
if string(dec) != strings.ToLower(tst[0]) {
t.Errorf("%s decoded as %s but should be %s", tst[1], dec, strings.ToLower(tst[0]))
}
}
}
func TestToFromPunycode(t *testing.T) {
for _, tst := range testcases {
// assert unicode.com == punycode.com
full := ToPunycode(tst[0] + ".com")
if full != tst[1]+".com" {
t.Errorf("invalid result from string conversion to punycode, %s and should be %s.com", full, tst[1])
}
// assert punycode.punycode == unicode.unicode
decoded := FromPunycode(tst[1] + "." + tst[1])
if decoded != strings.ToLower(tst[0]+"."+tst[0]) {
t.Errorf("invalid result from string conversion to punycode, %s and should be %s.%s", decoded, tst[0], tst[0])
}
}
}
func TestEncodeDecodeFinalPeriod(t *testing.T) {
for _, tst := range testcases {
// assert unicode.com. == punycode.com.
full := ToPunycode(tst[0] + ".")
if full != tst[1]+"." {
t.Errorf("invalid result from string conversion to punycode when period added at the end, %#v and should be %#v", full, tst[1]+".")
}
// assert punycode.com. == unicode.com.
decoded := FromPunycode(tst[1] + ".")
if decoded != strings.ToLower(tst[0]+".") {
t.Errorf("invalid result from string conversion to punycode when period added, %#v and should be %#v", decoded, tst[0]+".")
}
full = ToPunycode(tst[0])
if full != tst[1] {
t.Errorf("invalid result from string conversion to punycode when no period added at the end, %#v and should be %#v", full, tst[1]+".")
}
// assert punycode.com. == unicode.com.
decoded = FromPunycode(tst[1])
if decoded != strings.ToLower(tst[0]) {
t.Errorf("invalid result from string conversion to punycode when no period added, %#v and should be %#v", decoded, tst[0]+".")
}
}
}
var invalid = []string{
"xn--*",
"xn--",
"xn---",
}
func TestInvalidPunycode(t *testing.T) {
for _, d := range invalid {
s := FromPunycode(d)
if s != d {
t.Errorf("Changed invalid name %s to %#v", d, s)
}
}
}

149
Godeps/_workspace/src/github.com/miekg/dns/keygen.go generated vendored Normal file
View file

@ -0,0 +1,149 @@
package dns
import (
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"math/big"
"strconv"
)
const _FORMAT = "Private-key-format: v1.3\n"
// Empty interface that is used as a wrapper around all possible
// private key implementations from the crypto package.
type PrivateKey interface{}
// Generate generates a DNSKEY of the given bit size.
// The public part is put inside the DNSKEY record.
// The Algorithm in the key must be set as this will define
// what kind of DNSKEY will be generated.
// The ECDSA algorithms imply a fixed keysize, in that case
// bits should be set to the size of the algorithm.
func (r *DNSKEY) Generate(bits int) (PrivateKey, error) {
switch r.Algorithm {
case DSA, DSANSEC3SHA1:
if bits != 1024 {
return nil, ErrKeySize
}
case RSAMD5, RSASHA1, RSASHA256, RSASHA1NSEC3SHA1:
if bits < 512 || bits > 4096 {
return nil, ErrKeySize
}
case RSASHA512:
if bits < 1024 || bits > 4096 {
return nil, ErrKeySize
}
case ECDSAP256SHA256:
if bits != 256 {
return nil, ErrKeySize
}
case ECDSAP384SHA384:
if bits != 384 {
return nil, ErrKeySize
}
}
switch r.Algorithm {
case DSA, DSANSEC3SHA1:
params := new(dsa.Parameters)
if err := dsa.GenerateParameters(params, rand.Reader, dsa.L1024N160); err != nil {
return nil, err
}
priv := new(dsa.PrivateKey)
priv.PublicKey.Parameters = *params
err := dsa.GenerateKey(priv, rand.Reader)
if err != nil {
return nil, err
}
r.setPublicKeyDSA(params.Q, params.P, params.G, priv.PublicKey.Y)
return priv, nil
case RSAMD5, RSASHA1, RSASHA256, RSASHA512, RSASHA1NSEC3SHA1:
priv, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, err
}
r.setPublicKeyRSA(priv.PublicKey.E, priv.PublicKey.N)
return priv, nil
case ECDSAP256SHA256, ECDSAP384SHA384:
var c elliptic.Curve
switch r.Algorithm {
case ECDSAP256SHA256:
c = elliptic.P256()
case ECDSAP384SHA384:
c = elliptic.P384()
}
priv, err := ecdsa.GenerateKey(c, rand.Reader)
if err != nil {
return nil, err
}
r.setPublicKeyCurve(priv.PublicKey.X, priv.PublicKey.Y)
return priv, nil
default:
return nil, ErrAlg
}
return nil, nil // Dummy return
}
// PrivateKeyString converts a PrivateKey to a string. This
// string has the same format as the private-key-file of BIND9 (Private-key-format: v1.3).
// It needs some info from the key (hashing, keytag), so its a method of the DNSKEY.
func (r *DNSKEY) PrivateKeyString(p PrivateKey) (s string) {
switch t := p.(type) {
case *rsa.PrivateKey:
algorithm := strconv.Itoa(int(r.Algorithm)) + " (" + AlgorithmToString[r.Algorithm] + ")"
modulus := unpackBase64(t.PublicKey.N.Bytes())
e := big.NewInt(int64(t.PublicKey.E))
publicExponent := unpackBase64(e.Bytes())
privateExponent := unpackBase64(t.D.Bytes())
prime1 := unpackBase64(t.Primes[0].Bytes())
prime2 := unpackBase64(t.Primes[1].Bytes())
// Calculate Exponent1/2 and Coefficient as per: http://en.wikipedia.org/wiki/RSA#Using_the_Chinese_remainder_algorithm
// and from: http://code.google.com/p/go/issues/detail?id=987
one := big.NewInt(1)
minusone := big.NewInt(-1)
p_1 := big.NewInt(0).Sub(t.Primes[0], one)
q_1 := big.NewInt(0).Sub(t.Primes[1], one)
exp1 := big.NewInt(0).Mod(t.D, p_1)
exp2 := big.NewInt(0).Mod(t.D, q_1)
coeff := big.NewInt(0).Exp(t.Primes[1], minusone, t.Primes[0])
exponent1 := unpackBase64(exp1.Bytes())
exponent2 := unpackBase64(exp2.Bytes())
coefficient := unpackBase64(coeff.Bytes())
s = _FORMAT +
"Algorithm: " + algorithm + "\n" +
"Modules: " + modulus + "\n" +
"PublicExponent: " + publicExponent + "\n" +
"PrivateExponent: " + privateExponent + "\n" +
"Prime1: " + prime1 + "\n" +
"Prime2: " + prime2 + "\n" +
"Exponent1: " + exponent1 + "\n" +
"Exponent2: " + exponent2 + "\n" +
"Coefficient: " + coefficient + "\n"
case *ecdsa.PrivateKey:
algorithm := strconv.Itoa(int(r.Algorithm)) + " (" + AlgorithmToString[r.Algorithm] + ")"
private := unpackBase64(t.D.Bytes())
s = _FORMAT +
"Algorithm: " + algorithm + "\n" +
"PrivateKey: " + private + "\n"
case *dsa.PrivateKey:
algorithm := strconv.Itoa(int(r.Algorithm)) + " (" + AlgorithmToString[r.Algorithm] + ")"
prime := unpackBase64(t.PublicKey.Parameters.P.Bytes())
subprime := unpackBase64(t.PublicKey.Parameters.Q.Bytes())
base := unpackBase64(t.PublicKey.Parameters.G.Bytes())
priv := unpackBase64(t.X.Bytes())
pub := unpackBase64(t.PublicKey.Y.Bytes())
s = _FORMAT +
"Algorithm: " + algorithm + "\n" +
"Prime(p): " + prime + "\n" +
"Subprime(q): " + subprime + "\n" +
"Base(g): " + base + "\n" +
"Private_value(x): " + priv + "\n" +
"Public_value(y): " + pub + "\n"
}
return
}

244
Godeps/_workspace/src/github.com/miekg/dns/kscan.go generated vendored Normal file
View file

@ -0,0 +1,244 @@
package dns
import (
"crypto/dsa"
"crypto/ecdsa"
"crypto/rsa"
"io"
"math/big"
"strings"
)
func (k *DNSKEY) NewPrivateKey(s string) (PrivateKey, error) {
if s[len(s)-1] != '\n' { // We need a closing newline
return k.ReadPrivateKey(strings.NewReader(s+"\n"), "")
}
return k.ReadPrivateKey(strings.NewReader(s), "")
}
// ReadPrivateKey reads a private key from the io.Reader q. The string file is
// only used in error reporting.
// The public key must be
// known, because some cryptographic algorithms embed the public inside the privatekey.
func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (PrivateKey, error) {
m, e := parseKey(q, file)
if m == nil {
return nil, e
}
if _, ok := m["private-key-format"]; !ok {
return nil, ErrPrivKey
}
if m["private-key-format"] != "v1.2" && m["private-key-format"] != "v1.3" {
return nil, ErrPrivKey
}
// TODO(mg): check if the pubkey matches the private key
switch m["algorithm"] {
case "3 (DSA)":
p, e := readPrivateKeyDSA(m)
if e != nil {
return nil, e
}
if !k.setPublicKeyInPrivate(p) {
return nil, ErrPrivKey
}
return p, e
case "1 (RSAMD5)":
fallthrough
case "5 (RSASHA1)":
fallthrough
case "7 (RSASHA1NSEC3SHA1)":
fallthrough
case "8 (RSASHA256)":
fallthrough
case "10 (RSASHA512)":
p, e := readPrivateKeyRSA(m)
if e != nil {
return nil, e
}
if !k.setPublicKeyInPrivate(p) {
return nil, ErrPrivKey
}
return p, e
case "12 (ECC-GOST)":
p, e := readPrivateKeyGOST(m)
if e != nil {
return nil, e
}
// setPublicKeyInPrivate(p)
return p, e
case "13 (ECDSAP256SHA256)":
fallthrough
case "14 (ECDSAP384SHA384)":
p, e := readPrivateKeyECDSA(m)
if e != nil {
return nil, e
}
if !k.setPublicKeyInPrivate(p) {
return nil, ErrPrivKey
}
return p, e
}
return nil, ErrPrivKey
}
// Read a private key (file) string and create a public key. Return the private key.
func readPrivateKeyRSA(m map[string]string) (PrivateKey, error) {
p := new(rsa.PrivateKey)
p.Primes = []*big.Int{nil, nil}
for k, v := range m {
switch k {
case "modulus", "publicexponent", "privateexponent", "prime1", "prime2":
v1, err := packBase64([]byte(v))
if err != nil {
return nil, err
}
switch k {
case "modulus":
p.PublicKey.N = big.NewInt(0)
p.PublicKey.N.SetBytes(v1)
case "publicexponent":
i := big.NewInt(0)
i.SetBytes(v1)
p.PublicKey.E = int(i.Int64()) // int64 should be large enough
case "privateexponent":
p.D = big.NewInt(0)
p.D.SetBytes(v1)
case "prime1":
p.Primes[0] = big.NewInt(0)
p.Primes[0].SetBytes(v1)
case "prime2":
p.Primes[1] = big.NewInt(0)
p.Primes[1].SetBytes(v1)
}
case "exponent1", "exponent2", "coefficient":
// not used in Go (yet)
case "created", "publish", "activate":
// not used in Go (yet)
}
}
return p, nil
}
func readPrivateKeyDSA(m map[string]string) (PrivateKey, error) {
p := new(dsa.PrivateKey)
p.X = big.NewInt(0)
for k, v := range m {
switch k {
case "private_value(x)":
v1, err := packBase64([]byte(v))
if err != nil {
return nil, err
}
p.X.SetBytes(v1)
case "created", "publish", "activate":
/* not used in Go (yet) */
}
}
return p, nil
}
func readPrivateKeyECDSA(m map[string]string) (PrivateKey, error) {
p := new(ecdsa.PrivateKey)
p.D = big.NewInt(0)
// TODO: validate that the required flags are present
for k, v := range m {
switch k {
case "privatekey":
v1, err := packBase64([]byte(v))
if err != nil {
return nil, err
}
p.D.SetBytes(v1)
case "created", "publish", "activate":
/* not used in Go (yet) */
}
}
return p, nil
}
func readPrivateKeyGOST(m map[string]string) (PrivateKey, error) {
// TODO(miek)
return nil, nil
}
// parseKey reads a private key from r. It returns a map[string]string,
// with the key-value pairs, or an error when the file is not correct.
func parseKey(r io.Reader, file string) (map[string]string, error) {
s := scanInit(r)
m := make(map[string]string)
c := make(chan lex)
k := ""
// Start the lexer
go klexer(s, c)
for l := range c {
// It should alternate
switch l.value {
case _KEY:
k = l.token
case _VALUE:
if k == "" {
return nil, &ParseError{file, "no private key seen", l}
}
//println("Setting", strings.ToLower(k), "to", l.token, "b")
m[strings.ToLower(k)] = l.token
k = ""
}
}
return m, nil
}
// klexer scans the sourcefile and returns tokens on the channel c.
func klexer(s *scan, c chan lex) {
var l lex
str := "" // Hold the current read text
commt := false
key := true
x, err := s.tokenText()
defer close(c)
for err == nil {
l.column = s.position.Column
l.line = s.position.Line
switch x {
case ':':
if commt {
break
}
l.token = str
if key {
l.value = _KEY
c <- l
// Next token is a space, eat it
s.tokenText()
key = false
str = ""
} else {
l.value = _VALUE
}
case ';':
commt = true
case '\n':
if commt {
// Reset a comment
commt = false
}
l.value = _VALUE
l.token = str
c <- l
str = ""
commt = false
key = true
default:
if commt {
break
}
str += string(x)
}
x, err = s.tokenText()
}
if len(str) > 0 {
// Send remainder
l.token = str
l.value = _VALUE
c <- l
}
}

162
Godeps/_workspace/src/github.com/miekg/dns/labels.go generated vendored Normal file
View file

@ -0,0 +1,162 @@
package dns
// Holds a bunch of helper functions for dealing with labels.
// SplitDomainName splits a name string into it's labels.
// www.miek.nl. returns []string{"www", "miek", "nl"}
// The root label (.) returns nil. Note that using
// strings.Split(s) will work in most cases, but does not handle
// escaped dots (\.) for instance.
func SplitDomainName(s string) (labels []string) {
if len(s) == 0 {
return nil
}
fqdnEnd := 0 // offset of the final '.' or the length of the name
idx := Split(s)
begin := 0
if s[len(s)-1] == '.' {
fqdnEnd = len(s) - 1
} else {
fqdnEnd = len(s)
}
switch len(idx) {
case 0:
return nil
case 1:
// no-op
default:
end := 0
for i := 1; i < len(idx); i++ {
end = idx[i]
labels = append(labels, s[begin:end-1])
begin = end
}
}
labels = append(labels, s[begin:fqdnEnd])
return labels
}
// CompareDomainName compares the names s1 and s2 and
// returns how many labels they have in common starting from the *right*.
// The comparison stops at the first inequality. The names are not downcased
// before the comparison.
//
// www.miek.nl. and miek.nl. have two labels in common: miek and nl
// www.miek.nl. and www.bla.nl. have one label in common: nl
func CompareDomainName(s1, s2 string) (n int) {
s1 = Fqdn(s1)
s2 = Fqdn(s2)
l1 := Split(s1)
l2 := Split(s2)
// the first check: root label
if l1 == nil || l2 == nil {
return
}
j1 := len(l1) - 1 // end
i1 := len(l1) - 2 // start
j2 := len(l2) - 1
i2 := len(l2) - 2
// the second check can be done here: last/only label
// before we fall through into the for-loop below
if s1[l1[j1]:] == s2[l2[j2]:] {
n++
} else {
return
}
for {
if i1 < 0 || i2 < 0 {
break
}
if s1[l1[i1]:l1[j1]] == s2[l2[i2]:l2[j2]] {
n++
} else {
break
}
j1--
i1--
j2--
i2--
}
return
}
// CountLabel counts the the number of labels in the string s.
func CountLabel(s string) (labels int) {
if s == "." {
return
}
off := 0
end := false
for {
off, end = NextLabel(s, off)
labels++
if end {
return
}
}
panic("dns: not reached")
}
// Split splits a name s into its label indexes.
// www.miek.nl. returns []int{0, 4, 9}, www.miek.nl also returns []int{0, 4, 9}.
// The root name (.) returns nil. Also see dns.SplitDomainName.
func Split(s string) []int {
if s == "." {
return nil
}
idx := make([]int, 1, 3)
off := 0
end := false
for {
off, end = NextLabel(s, off)
if end {
return idx
}
idx = append(idx, off)
}
panic("dns: not reached")
}
// NextLabel returns the index of the start of the next label in the
// string s starting at offset.
// The bool end is true when the end of the string has been reached.
func NextLabel(s string, offset int) (i int, end bool) {
quote := false
for i = offset; i < len(s)-1; i++ {
switch s[i] {
case '\\':
quote = !quote
default:
quote = false
case '.':
if quote {
quote = !quote
continue
}
return i + 1, false
}
}
return i + 1, true
}
// PrevLabel returns the index of the label when starting from the right and
// jumping n labels to the left.
// The bool start is true when the start of the string has been overshot.
func PrevLabel(s string, n int) (i int, start bool) {
if n == 0 {
return len(s), false
}
lab := Split(s)
if lab == nil {
return 0, true
}
if n > len(lab) {
return 0, true
}
return lab[len(lab)-n], false
}

View file

@ -0,0 +1,214 @@
package dns
import (
"testing"
)
func TestCompareDomainName(t *testing.T) {
s1 := "www.miek.nl."
s2 := "miek.nl."
s3 := "www.bla.nl."
s4 := "nl.www.bla."
s5 := "nl"
s6 := "miek.nl"
if CompareDomainName(s1, s2) != 2 {
t.Logf("%s with %s should be %d", s1, s2, 2)
t.Fail()
}
if CompareDomainName(s1, s3) != 1 {
t.Logf("%s with %s should be %d", s1, s3, 1)
t.Fail()
}
if CompareDomainName(s3, s4) != 0 {
t.Logf("%s with %s should be %d", s3, s4, 0)
t.Fail()
}
// Non qualified tests
if CompareDomainName(s1, s5) != 1 {
t.Logf("%s with %s should be %d", s1, s5, 1)
t.Fail()
}
if CompareDomainName(s1, s6) != 2 {
t.Logf("%s with %s should be %d", s1, s5, 2)
t.Fail()
}
if CompareDomainName(s1, ".") != 0 {
t.Logf("%s with %s should be %d", s1, s5, 0)
t.Fail()
}
if CompareDomainName(".", ".") != 0 {
t.Logf("%s with %s should be %d", ".", ".", 0)
t.Fail()
}
}
func TestSplit(t *testing.T) {
splitter := map[string]int{
"www.miek.nl.": 3,
"www.miek.nl": 3,
"www..miek.nl": 4,
`www\.miek.nl.`: 2,
`www\\.miek.nl.`: 3,
".": 0,
"nl.": 1,
"nl": 1,
"com.": 1,
".com.": 2,
}
for s, i := range splitter {
if x := len(Split(s)); x != i {
t.Logf("labels should be %d, got %d: %s %v\n", i, x, s, Split(s))
t.Fail()
} else {
t.Logf("%s %v\n", s, Split(s))
}
}
}
func TestSplit2(t *testing.T) {
splitter := map[string][]int{
"www.miek.nl.": []int{0, 4, 9},
"www.miek.nl": []int{0, 4, 9},
"nl": []int{0},
}
for s, i := range splitter {
x := Split(s)
switch len(i) {
case 1:
if x[0] != i[0] {
t.Logf("labels should be %v, got %v: %s\n", i, x, s)
t.Fail()
}
default:
if x[0] != i[0] || x[1] != i[1] || x[2] != i[2] {
t.Logf("labels should be %v, got %v: %s\n", i, x, s)
t.Fail()
}
}
}
}
func TestPrevLabel(t *testing.T) {
type prev struct {
string
int
}
prever := map[prev]int{
prev{"www.miek.nl.", 0}: 12,
prev{"www.miek.nl.", 1}: 9,
prev{"www.miek.nl.", 2}: 4,
prev{"www.miek.nl", 0}: 11,
prev{"www.miek.nl", 1}: 9,
prev{"www.miek.nl", 2}: 4,
prev{"www.miek.nl.", 5}: 0,
prev{"www.miek.nl", 5}: 0,
prev{"www.miek.nl.", 3}: 0,
prev{"www.miek.nl", 3}: 0,
}
for s, i := range prever {
x, ok := PrevLabel(s.string, s.int)
if i != x {
t.Logf("label should be %d, got %d, %t: preving %d, %s\n", i, x, ok, s.int, s.string)
t.Fail()
}
}
}
func TestCountLabel(t *testing.T) {
splitter := map[string]int{
"www.miek.nl.": 3,
"www.miek.nl": 3,
"nl": 1,
".": 0,
}
for s, i := range splitter {
x := CountLabel(s)
if x != i {
t.Logf("CountLabel should have %d, got %d\n", i, x)
t.Fail()
}
}
}
func TestSplitDomainName(t *testing.T) {
labels := map[string][]string{
"miek.nl": []string{"miek", "nl"},
".": nil,
"www.miek.nl.": []string{"www", "miek", "nl"},
"www.miek.nl": []string{"www", "miek", "nl"},
"www..miek.nl": []string{"www", "", "miek", "nl"},
`www\.miek.nl`: []string{`www\.miek`, "nl"},
`www\\.miek.nl`: []string{`www\\`, "miek", "nl"},
}
domainLoop:
for domain, splits := range labels {
parts := SplitDomainName(domain)
if len(parts) != len(splits) {
t.Logf("SplitDomainName returned %v for %s, expected %v", parts, domain, splits)
t.Fail()
continue domainLoop
}
for i := range parts {
if parts[i] != splits[i] {
t.Logf("SplitDomainName returned %v for %s, expected %v", parts, domain, splits)
t.Fail()
continue domainLoop
}
}
}
}
func TestIsDomainName(t *testing.T) {
type ret struct {
ok bool
lab int
}
names := map[string]*ret{
"..": &ret{false, 1},
"@.": &ret{true, 1},
"www.example.com": &ret{true, 3},
"www.e%ample.com": &ret{true, 3},
"www.example.com.": &ret{true, 3},
"mi\\k.nl.": &ret{true, 2},
"mi\\k.nl": &ret{true, 2},
}
for d, ok := range names {
l, k := IsDomainName(d)
if ok.ok != k || ok.lab != l {
t.Logf(" got %v %d for %s ", k, l, d)
t.Logf("have %v %d for %s ", ok.ok, ok.lab, d)
t.Fail()
}
}
}
func BenchmarkSplitLabels(b *testing.B) {
for i := 0; i < b.N; i++ {
Split("www.example.com")
}
}
func BenchmarkLenLabels(b *testing.B) {
for i := 0; i < b.N; i++ {
CountLabel("www.example.com")
}
}
func BenchmarkCompareLabels(b *testing.B) {
for i := 0; i < b.N; i++ {
CompareDomainName("www.example.com", "aa.example.com")
}
}
func BenchmarkIsSubDomain(b *testing.B) {
for i := 0; i < b.N; i++ {
IsSubDomain("www.example.com", "aa.example.com")
IsSubDomain("example.com", "aa.example.com")
IsSubDomain("miek.nl", "aa.example.com")
}
}

1881
Godeps/_workspace/src/github.com/miekg/dns/msg.go generated vendored Normal file

File diff suppressed because it is too large Load diff

110
Godeps/_workspace/src/github.com/miekg/dns/nsecx.go generated vendored Normal file
View file

@ -0,0 +1,110 @@
package dns
import (
"crypto/sha1"
"hash"
"io"
"strings"
)
type saltWireFmt struct {
Salt string `dns:"size-hex"`
}
// HashName hashes a string (label) according to RFC 5155. It returns the hashed string in
// uppercase.
func HashName(label string, ha uint8, iter uint16, salt string) string {
saltwire := new(saltWireFmt)
saltwire.Salt = salt
wire := make([]byte, DefaultMsgSize)
n, err := PackStruct(saltwire, wire, 0)
if err != nil {
return ""
}
wire = wire[:n]
name := make([]byte, 255)
off, err := PackDomainName(strings.ToLower(label), name, 0, nil, false)
if err != nil {
return ""
}
name = name[:off]
var s hash.Hash
switch ha {
case SHA1:
s = sha1.New()
default:
return ""
}
// k = 0
name = append(name, wire...)
io.WriteString(s, string(name))
nsec3 := s.Sum(nil)
// k > 0
for k := uint16(0); k < iter; k++ {
s.Reset()
nsec3 = append(nsec3, wire...)
io.WriteString(s, string(nsec3))
nsec3 = s.Sum(nil)
}
return unpackBase32(nsec3)
}
type Denialer interface {
// Cover will check if the (unhashed) name is being covered by this NSEC or NSEC3.
Cover(name string) bool
// Match will check if the ownername matches the (unhashed) name for this NSEC3 or NSEC3.
Match(name string) bool
}
// Cover implements the Denialer interface.
func (rr *NSEC) Cover(name string) bool {
return true
}
// Match implements the Denialer interface.
func (rr *NSEC) Match(name string) bool {
return true
}
// Cover implements the Denialer interface.
func (rr *NSEC3) Cover(name string) bool {
// FIXME(miek): check if the zones match
// FIXME(miek): check if we're not dealing with parent nsec3
hname := HashName(name, rr.Hash, rr.Iterations, rr.Salt)
labels := Split(rr.Hdr.Name)
if len(labels) < 2 {
return false
}
hash := strings.ToUpper(rr.Hdr.Name[labels[0] : labels[1]-1]) // -1 to remove the dot
if hash == rr.NextDomain {
return false // empty interval
}
if hash > rr.NextDomain { // last name, points to apex
// hname > hash
// hname > rr.NextDomain
// TODO(miek)
}
if hname <= hash {
return false
}
if hname >= rr.NextDomain {
return false
}
return true
}
// Match implements the Denialer interface.
func (rr *NSEC3) Match(name string) bool {
// FIXME(miek): Check if we are in the same zone
hname := HashName(name, rr.Hash, rr.Iterations, rr.Salt)
labels := Split(rr.Hdr.Name)
if len(labels) < 2 {
return false
}
hash := strings.ToUpper(rr.Hdr.Name[labels[0] : labels[1]-1]) // -1 to remove the .
if hash == hname {
return true
}
return false
}

View file

@ -0,0 +1,33 @@
package dns
import (
"testing"
)
func TestPackNsec3(t *testing.T) {
nsec3 := HashName("dnsex.nl.", SHA1, 0, "DEAD")
if nsec3 != "ROCCJAE8BJJU7HN6T7NG3TNM8ACRS87J" {
t.Logf("%v\n", nsec3)
t.Fail()
}
nsec3 = HashName("a.b.c.example.org.", SHA1, 2, "DEAD")
if nsec3 != "6LQ07OAHBTOOEU2R9ANI2AT70K5O0RCG" {
t.Logf("%v\n", nsec3)
t.Fail()
}
}
func TestNsec3(t *testing.T) {
// examples taken from .nl
nsec3, _ := NewRR("39p91242oslggest5e6a7cci4iaeqvnk.nl. IN NSEC3 1 1 5 F10E9F7EA83FC8F3 39P99DCGG0MDLARTCRMCF6OFLLUL7PR6 NS DS RRSIG")
if !nsec3.(*NSEC3).Cover("snasajsksasasa.nl.") { // 39p94jrinub66hnpem8qdpstrec86pg3
t.Logf("39p94jrinub66hnpem8qdpstrec86pg3. should be covered by 39p91242oslggest5e6a7cci4iaeqvnk.nl. - 39P99DCGG0MDLARTCRMCF6OFLLUL7PR6")
t.Fail()
}
nsec3, _ = NewRR("sk4e8fj94u78smusb40o1n0oltbblu2r.nl. IN NSEC3 1 1 5 F10E9F7EA83FC8F3 SK4F38CQ0ATIEI8MH3RGD0P5I4II6QAN NS SOA TXT RRSIG DNSKEY NSEC3PARAM")
if !nsec3.(*NSEC3).Match("nl.") { // sk4e8fj94u78smusb40o1n0oltbblu2r.nl.
t.Logf("sk4e8fj94u78smusb40o1n0oltbblu2r.nl. should match sk4e8fj94u78smusb40o1n0oltbblu2r.nl.")
t.Fail()
}
}

1287
Godeps/_workspace/src/github.com/miekg/dns/parse_test.go generated vendored Normal file

File diff suppressed because it is too large Load diff

122
Godeps/_workspace/src/github.com/miekg/dns/privaterr.go generated vendored Normal file
View file

@ -0,0 +1,122 @@
/*
PRIVATE RR
RFC 6895 sets aside a range of type codes for private use. This range
is 65,280 - 65,534 (0xFF00 - 0xFFFE). When experimenting with new Resource Records these
can be used, before requesting an official type code from IANA.
*/
package dns
import (
"fmt"
"strings"
)
// PrivateRdata is an interface used for implementing "Private Use" RR types, see
// RFC 6895. This allows one to experiment with new RR types, without requesting an
// official type code. Also see dns.PrivateHandle and dns.PrivateHandleRemove.
type PrivateRdata interface {
// String returns the text presentaton of the Rdata of the Private RR.
String() string
// Parse parses the Rdata of the private RR.
Parse([]string) error
// Pack is used when packing a private RR into a buffer.
Pack([]byte) (int, error)
// Unpack is used when unpacking a private RR from a buffer.
// TODO(miek): diff. signature than Pack, see edns0.go for instance.
Unpack([]byte) (int, error)
// Copy copies the Rdata.
Copy(PrivateRdata) error
// Len returns the length in octets of the Rdata.
Len() int
}
// PrivateRR represents an RR that uses a PrivateRdata user-defined type.
// It mocks normal RRs and implements dns.RR interface.
type PrivateRR struct {
Hdr RR_Header
Data PrivateRdata
}
func mkPrivateRR(rrtype uint16) *PrivateRR {
// Panics if RR is not an instance of PrivateRR.
rrfunc, ok := typeToRR[rrtype]
if !ok {
panic(fmt.Sprintf("dns: invalid operation with Private RR type %d", rrtype))
}
anyrr := rrfunc()
switch rr := anyrr.(type) {
case *PrivateRR:
return rr
}
panic(fmt.Sprintf("dns: RR is not a PrivateRR, typeToRR[%d] generator returned %T", rrtype, anyrr))
}
func (r *PrivateRR) Header() *RR_Header { return &r.Hdr }
func (r *PrivateRR) String() string { return r.Hdr.String() + r.Data.String() }
// Private len and copy parts to satisfy RR interface.
func (r *PrivateRR) len() int { return r.Hdr.len() + r.Data.Len() }
func (r *PrivateRR) copy() RR {
// make new RR like this:
rr := mkPrivateRR(r.Hdr.Rrtype)
newh := r.Hdr.copyHeader()
rr.Hdr = *newh
err := r.Data.Copy(rr.Data)
if err != nil {
panic("dns: got value that could not be used to copy Private rdata")
}
return rr
}
// PrivateHandle registers a private resource record type. It requires
// string and numeric representation of private RR type and generator function as argument.
func PrivateHandle(rtypestr string, rtype uint16, generator func() PrivateRdata) {
rtypestr = strings.ToUpper(rtypestr)
typeToRR[rtype] = func() RR { return &PrivateRR{RR_Header{}, generator()} }
TypeToString[rtype] = rtypestr
StringToType[rtypestr] = rtype
setPrivateRR := func(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := mkPrivateRR(h.Rrtype)
rr.Hdr = h
var l lex
text := make([]string, 0, 2) // could be 0..N elements, median is probably 1
FETCH:
for {
// TODO(miek): we could also be returning _QUOTE, this might or might not
// be an issue (basically parsing TXT becomes hard)
switch l = <-c; l.value {
case _NEWLINE, _EOF:
break FETCH
case _STRING:
text = append(text, l.token)
}
}
err := rr.Data.Parse(text)
if err != nil {
return nil, &ParseError{f, err.Error(), l}, ""
}
return rr, nil, ""
}
typeToparserFunc[rtype] = parserFunc{setPrivateRR, true}
}
// PrivateHandleRemove removes defenitions required to support private RR type.
func PrivateHandleRemove(rtype uint16) {
rtypestr, ok := TypeToString[rtype]
if ok {
delete(typeToRR, rtype)
delete(TypeToString, rtype)
delete(typeToparserFunc, rtype)
delete(StringToType, rtypestr)
}
return
}

View file

@ -0,0 +1,169 @@
package dns_test
import (
"github.com/miekg/dns"
"strings"
"testing"
)
const TypeISBN uint16 = 0x0F01
// A crazy new RR type :)
type ISBN struct {
x string // rdata with 10 or 13 numbers, dashes or spaces allowed
}
func NewISBN() dns.PrivateRdata { return &ISBN{""} }
func (rd *ISBN) Len() int { return len([]byte(rd.x)) }
func (rd *ISBN) String() string { return rd.x }
func (rd *ISBN) Parse(txt []string) error {
rd.x = strings.TrimSpace(strings.Join(txt, " "))
return nil
}
func (rd *ISBN) Pack(buf []byte) (int, error) {
b := []byte(rd.x)
n := copy(buf, b)
if n != len(b) {
return n, dns.ErrBuf
}
return n, nil
}
func (rd *ISBN) Unpack(buf []byte) (int, error) {
rd.x = string(buf)
return len(buf), nil
}
func (rd *ISBN) Copy(dest dns.PrivateRdata) error {
isbn, ok := dest.(*ISBN)
if !ok {
return dns.ErrRdata
}
isbn.x = rd.x
return nil
}
var testrecord = strings.Join([]string{"example.org.", "3600", "IN", "ISBN", "12-3 456789-0-123"}, "\t")
func TestPrivateText(t *testing.T) {
dns.PrivateHandle("ISBN", TypeISBN, NewISBN)
defer dns.PrivateHandleRemove(TypeISBN)
rr, err := dns.NewRR(testrecord)
if err != nil {
t.Fatal(err)
}
if rr.String() != testrecord {
t.Errorf("record string representation did not match original %#v != %#v", rr.String(), testrecord)
} else {
t.Log(rr.String())
}
}
func TestPrivateByteSlice(t *testing.T) {
dns.PrivateHandle("ISBN", TypeISBN, NewISBN)
defer dns.PrivateHandleRemove(TypeISBN)
rr, err := dns.NewRR(testrecord)
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 100)
off, err := dns.PackRR(rr, buf, 0, nil, false)
if err != nil {
t.Errorf("got error packing ISBN: %s", err)
}
custrr := rr.(*dns.PrivateRR)
if ln := custrr.Data.Len() + len(custrr.Header().Name) + 11; ln != off {
t.Errorf("offset is not matching to length of Private RR: %d!=%d", off, ln)
}
rr1, off1, err := dns.UnpackRR(buf[:off], 0)
if err != nil {
t.Errorf("got error unpacking ISBN: %s", err)
}
if off1 != off {
t.Errorf("Offset after unpacking differs: %d != %d", off1, off)
}
if rr1.String() != testrecord {
t.Errorf("Record string representation did not match original %#v != %#v", rr1.String(), testrecord)
} else {
t.Log(rr1.String())
}
}
const TypeVERSION uint16 = 0x0F02
type VERSION struct {
x string
}
func NewVersion() dns.PrivateRdata { return &VERSION{""} }
func (rd *VERSION) String() string { return rd.x }
func (rd *VERSION) Parse(txt []string) error {
rd.x = strings.TrimSpace(strings.Join(txt, " "))
return nil
}
func (rd *VERSION) Pack(buf []byte) (int, error) {
b := []byte(rd.x)
n := copy(buf, b)
if n != len(b) {
return n, dns.ErrBuf
}
return n, nil
}
func (rd *VERSION) Unpack(buf []byte) (int, error) {
rd.x = string(buf)
return len(buf), nil
}
func (rd *VERSION) Copy(dest dns.PrivateRdata) error {
isbn, ok := dest.(*VERSION)
if !ok {
return dns.ErrRdata
}
isbn.x = rd.x
return nil
}
func (rd *VERSION) Len() int {
return len([]byte(rd.x))
}
var smallzone = `$ORIGIN example.org.
@ SOA sns.dns.icann.org. noc.dns.icann.org. (
2014091518 7200 3600 1209600 3600
)
A 1.2.3.4
ok ISBN 1231-92110-12
go VERSION (
1.3.1 ; comment
)
www ISBN 1231-92110-16
* CNAME @
`
func TestPrivateZoneParser(t *testing.T) {
dns.PrivateHandle("ISBN", TypeISBN, NewISBN)
dns.PrivateHandle("VERSION", TypeVERSION, NewVersion)
defer dns.PrivateHandleRemove(TypeISBN)
defer dns.PrivateHandleRemove(TypeVERSION)
r := strings.NewReader(smallzone)
for x := range dns.ParseZone(r, ".", "") {
if err := x.Error; err != nil {
t.Fatal(err)
}
t.Log(x.RR)
}
}

95
Godeps/_workspace/src/github.com/miekg/dns/rawmsg.go generated vendored Normal file
View file

@ -0,0 +1,95 @@
package dns
// These raw* functions do not use reflection, they directly set the values
// in the buffer. There are faster than their reflection counterparts.
// RawSetId sets the message id in buf.
func rawSetId(msg []byte, i uint16) bool {
if len(msg) < 2 {
return false
}
msg[0], msg[1] = packUint16(i)
return true
}
// rawSetQuestionLen sets the length of the question section.
func rawSetQuestionLen(msg []byte, i uint16) bool {
if len(msg) < 6 {
return false
}
msg[4], msg[5] = packUint16(i)
return true
}
// rawSetAnswerLen sets the lenght of the answer section.
func rawSetAnswerLen(msg []byte, i uint16) bool {
if len(msg) < 8 {
return false
}
msg[6], msg[7] = packUint16(i)
return true
}
// rawSetsNsLen sets the lenght of the authority section.
func rawSetNsLen(msg []byte, i uint16) bool {
if len(msg) < 10 {
return false
}
msg[8], msg[9] = packUint16(i)
return true
}
// rawSetExtraLen sets the lenght of the additional section.
func rawSetExtraLen(msg []byte, i uint16) bool {
if len(msg) < 12 {
return false
}
msg[10], msg[11] = packUint16(i)
return true
}
// rawSetRdlength sets the rdlength in the header of
// the RR. The offset 'off' must be positioned at the
// start of the header of the RR, 'end' must be the
// end of the RR.
func rawSetRdlength(msg []byte, off, end int) bool {
l := len(msg)
Loop:
for {
if off+1 > l {
return false
}
c := int(msg[off])
off++
switch c & 0xC0 {
case 0x00:
if c == 0x00 {
// End of the domainname
break Loop
}
if off+c > l {
return false
}
off += c
case 0xC0:
// pointer, next byte included, ends domainname
off++
break Loop
}
}
// The domainname has been seen, we at the start of the fixed part in the header.
// Type is 2 bytes, class is 2 bytes, ttl 4 and then 2 bytes for the length.
off += 2 + 2 + 4
if off+2 > l {
return false
}
//off+1 is the end of the header, 'end' is the end of the rr
//so 'end' - 'off+2' is the length of the rdata
rdatalen := end - (off + 2)
if rdatalen > 0xFFFF {
return false
}
msg[off], msg[off+1] = packUint16(uint16(rdatalen))
return true
}

43
Godeps/_workspace/src/github.com/miekg/dns/scanner.go generated vendored Normal file
View file

@ -0,0 +1,43 @@
package dns
// Implement a simple scanner, return a byte stream from an io reader.
import (
"bufio"
"io"
"text/scanner"
)
type scan struct {
src *bufio.Reader
position scanner.Position
eof bool // Have we just seen a eof
}
func scanInit(r io.Reader) *scan {
s := new(scan)
s.src = bufio.NewReader(r)
s.position.Line = 1
return s
}
// tokenText returns the next byte from the input
func (s *scan) tokenText() (byte, error) {
c, err := s.src.ReadByte()
if err != nil {
return c, err
}
// delay the newline handling until the next token is delivered,
// fixes off-by-one errors when reporting a parse error.
if s.eof == true {
s.position.Line++
s.position.Column = 0
s.eof = false
}
if c == '\n' {
s.eof = true
return c, nil
}
s.position.Column++
return c, nil
}

612
Godeps/_workspace/src/github.com/miekg/dns/server.go generated vendored Normal file
View file

@ -0,0 +1,612 @@
// DNS server implementation.
package dns
import (
"bytes"
"io"
"net"
"sync"
"time"
)
type Handler interface {
ServeDNS(w ResponseWriter, r *Msg)
}
// A ResponseWriter interface is used by an DNS handler to
// construct an DNS response.
type ResponseWriter interface {
// LocalAddr returns the net.Addr of the server
LocalAddr() net.Addr
// RemoteAddr returns the net.Addr of the client that sent the current request.
RemoteAddr() net.Addr
// WriteMsg writes a reply back to the client.
WriteMsg(*Msg) error
// Write writes a raw buffer back to the client.
Write([]byte) (int, error)
// Close closes the connection.
Close() error
// TsigStatus returns the status of the Tsig.
TsigStatus() error
// TsigTimersOnly sets the tsig timers only boolean.
TsigTimersOnly(bool)
// Hijack lets the caller take over the connection.
// After a call to Hijack(), the DNS package will not do anything with the connection.
Hijack()
}
type response struct {
hijacked bool // connection has been hijacked by handler
tsigStatus error
tsigTimersOnly bool
tsigRequestMAC string
tsigSecret map[string]string // the tsig secrets
udp *net.UDPConn // i/o connection if UDP was used
tcp *net.TCPConn // i/o connection if TCP was used
udpSession *sessionUDP // oob data to get egress interface right
remoteAddr net.Addr // address of the client
}
// ServeMux is an DNS request multiplexer. It matches the
// zone name of each incoming request against a list of
// registered patterns add calls the handler for the pattern
// that most closely matches the zone name. ServeMux is DNSSEC aware, meaning
// that queries for the DS record are redirected to the parent zone (if that
// is also registered), otherwise the child gets the query.
// ServeMux is also safe for concurrent access from multiple goroutines.
type ServeMux struct {
z map[string]Handler
m *sync.RWMutex
}
// NewServeMux allocates and returns a new ServeMux.
func NewServeMux() *ServeMux { return &ServeMux{z: make(map[string]Handler), m: new(sync.RWMutex)} }
// DefaultServeMux is the default ServeMux used by Serve.
var DefaultServeMux = NewServeMux()
// The HandlerFunc type is an adapter to allow the use of
// ordinary functions as DNS handlers. If f is a function
// with the appropriate signature, HandlerFunc(f) is a
// Handler object that calls f.
type HandlerFunc func(ResponseWriter, *Msg)
// ServerDNS calls f(w, r)
func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
f(w, r)
}
// FailedHandler returns a HandlerFunc
// returns SERVFAIL for every request it gets.
func HandleFailed(w ResponseWriter, r *Msg) {
m := new(Msg)
m.SetRcode(r, RcodeServerFailure)
// does not matter if this write fails
w.WriteMsg(m)
}
func failedHandler() Handler { return HandlerFunc(HandleFailed) }
// ListenAndServe Starts a server on addresss and network speficied. Invoke handler
// for incoming queries.
func ListenAndServe(addr string, network string, handler Handler) error {
server := &Server{Addr: addr, Net: network, Handler: handler}
return server.ListenAndServe()
}
// ActivateAndServe activates a server with a listener from systemd,
// l and p should not both be non-nil.
// If both l and p are not nil only p will be used.
// Invoke handler for incoming queries.
func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error {
server := &Server{Listener: l, PacketConn: p, Handler: handler}
return server.ActivateAndServe()
}
func (mux *ServeMux) match(q string, t uint16) Handler {
mux.m.RLock()
defer mux.m.RUnlock()
var handler Handler
b := make([]byte, len(q)) // worst case, one label of length q
off := 0
end := false
for {
l := len(q[off:])
for i := 0; i < l; i++ {
b[i] = q[off+i]
if b[i] >= 'A' && b[i] <= 'Z' {
b[i] |= ('a' - 'A')
}
}
if h, ok := mux.z[string(b[:l])]; ok { // 'causes garbage, might want to change the map key
if t != TypeDS {
return h
} else {
// Continue for DS to see if we have a parent too, if so delegeate to the parent
handler = h
}
}
off, end = NextLabel(q, off)
if end {
break
}
}
// Wildcard match, if we have found nothing try the root zone as a last resort.
if h, ok := mux.z["."]; ok {
return h
}
return handler
}
// Handle adds a handler to the ServeMux for pattern.
func (mux *ServeMux) Handle(pattern string, handler Handler) {
if pattern == "" {
panic("dns: invalid pattern " + pattern)
}
mux.m.Lock()
mux.z[Fqdn(pattern)] = handler
mux.m.Unlock()
}
// Handle adds a handler to the ServeMux for pattern.
func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
mux.Handle(pattern, HandlerFunc(handler))
}
// HandleRemove deregistrars the handler specific for pattern from the ServeMux.
func (mux *ServeMux) HandleRemove(pattern string) {
if pattern == "" {
panic("dns: invalid pattern " + pattern)
}
// don't need a mutex here, because deleting is OK, even if the
// entry is note there.
delete(mux.z, Fqdn(pattern))
}
// ServeDNS dispatches the request to the handler whose
// pattern most closely matches the request message. If DefaultServeMux
// is used the correct thing for DS queries is done: a possible parent
// is sought.
// If no handler is found a standard SERVFAIL message is returned
// If the request message does not have exactly one question in the
// question section a SERVFAIL is returned.
func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) {
var h Handler
if len(request.Question) != 1 {
h = failedHandler()
} else {
if h = mux.match(request.Question[0].Name, request.Question[0].Qtype); h == nil {
h = failedHandler()
}
}
h.ServeDNS(w, request)
}
// Handle registers the handler with the given pattern
// in the DefaultServeMux. The documentation for
// ServeMux explains how patterns are matched.
func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
// HandleRemove deregisters the handle with the given pattern
// in the DefaultServeMux.
func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) }
// HandleFunc registers the handler function with the given pattern
// in the DefaultServeMux.
func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
DefaultServeMux.HandleFunc(pattern, handler)
}
// A Server defines parameters for running an DNS server.
type Server struct {
// Address to listen on, ":dns" if empty.
Addr string
// if "tcp" it will invoke a TCP listener, otherwise an UDP one.
Net string
// TCP Listener to use, this is to aid in systemd's socket activation.
Listener net.Listener
// UDP "Listener" to use, this is to aid in systemd's socket activation.
PacketConn net.PacketConn
// Handler to invoke, dns.DefaultServeMux if nil.
Handler Handler
// Default buffer size to use to read incoming UDP messages. If not set
// it defaults to MinMsgSize (512 B).
UDPSize int
// The net.Conn.SetReadTimeout value for new connections, defaults to 2 * time.Second.
ReadTimeout time.Duration
// The net.Conn.SetWriteTimeout value for new connections, defaults to 2 * time.Second.
WriteTimeout time.Duration
// TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
IdleTimeout func() time.Duration
// Secret(s) for Tsig map[<zonename>]<base64 secret>.
TsigSecret map[string]string
// For graceful shutdown.
stopUDP chan bool
stopTCP chan bool
wgUDP sync.WaitGroup
wgTCP sync.WaitGroup
// make start/shutdown not racy
lock sync.Mutex
started bool
}
// ListenAndServe starts a nameserver on the configured address in *Server.
func (srv *Server) ListenAndServe() error {
srv.lock.Lock()
if srv.started {
return &Error{err: "server already started"}
}
srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool)
srv.started = true
srv.lock.Unlock()
addr := srv.Addr
if addr == "" {
addr = ":domain"
}
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
switch srv.Net {
case "tcp", "tcp4", "tcp6":
a, e := net.ResolveTCPAddr(srv.Net, addr)
if e != nil {
return e
}
l, e := net.ListenTCP(srv.Net, a)
if e != nil {
return e
}
return srv.serveTCP(l)
case "udp", "udp4", "udp6":
a, e := net.ResolveUDPAddr(srv.Net, addr)
if e != nil {
return e
}
l, e := net.ListenUDP(srv.Net, a)
if e != nil {
return e
}
if e := setUDPSocketOptions(l); e != nil {
return e
}
return srv.serveUDP(l)
}
return &Error{err: "bad network"}
}
// ActivateAndServe starts a nameserver with the PacketConn or Listener
// configured in *Server. Its main use is to start a server from systemd.
func (srv *Server) ActivateAndServe() error {
srv.lock.Lock()
if srv.started {
return &Error{err: "server already started"}
}
srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool)
srv.started = true
srv.lock.Unlock()
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
if srv.PacketConn != nil {
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
if t, ok := srv.PacketConn.(*net.UDPConn); ok {
if e := setUDPSocketOptions(t); e != nil {
return e
}
return srv.serveUDP(t)
}
}
if srv.Listener != nil {
if t, ok := srv.Listener.(*net.TCPListener); ok {
return srv.serveTCP(t)
}
}
return &Error{err: "bad listeners"}
}
// Shutdown gracefully shuts down a server. After a call to Shutdown, ListenAndServe and
// ActivateAndServe will return. All in progress queries are completed before the server
// is taken down. If the Shutdown is taking longer than the reading timeout and error
// is returned.
func (srv *Server) Shutdown() error {
srv.lock.Lock()
if !srv.started {
return &Error{err: "server not started"}
}
srv.started = false
srv.lock.Unlock()
net, addr := srv.Net, srv.Addr
switch {
case srv.Listener != nil:
a := srv.Listener.Addr()
net, addr = a.Network(), a.String()
case srv.PacketConn != nil:
a := srv.PacketConn.LocalAddr()
net, addr = a.Network(), a.String()
}
fin := make(chan bool)
switch net {
case "tcp", "tcp4", "tcp6":
go func() {
srv.stopTCP <- true
srv.wgTCP.Wait()
fin <- true
}()
case "udp", "udp4", "udp6":
go func() {
srv.stopUDP <- true
srv.wgUDP.Wait()
fin <- true
}()
}
c := &Client{Net: net}
go c.Exchange(new(Msg), addr) // extra query to help ReadXXX loop to pass
select {
case <-time.After(srv.getReadTimeout()):
return &Error{err: "server shutdown is pending"}
case <-fin:
return nil
}
}
// getReadTimeout is a helper func to use system timeout if server did not intend to change it.
func (srv *Server) getReadTimeout() time.Duration {
rtimeout := dnsTimeout
if srv.ReadTimeout != 0 {
rtimeout = srv.ReadTimeout
}
return rtimeout
}
// serveTCP starts a TCP listener for the server.
// Each request is handled in a seperate goroutine.
func (srv *Server) serveTCP(l *net.TCPListener) error {
defer l.Close()
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
rtimeout := srv.getReadTimeout()
// deadline is not used here
for {
rw, e := l.AcceptTCP()
if e != nil {
continue
}
m, e := srv.readTCP(rw, rtimeout)
select {
case <-srv.stopTCP:
return nil
default:
}
if e != nil {
continue
}
srv.wgTCP.Add(1)
go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
}
panic("dns: not reached")
}
// serveUDP starts a UDP listener for the server.
// Each request is handled in a seperate goroutine.
func (srv *Server) serveUDP(l *net.UDPConn) error {
defer l.Close()
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
rtimeout := srv.getReadTimeout()
// deadline is not used here
for {
m, s, e := srv.readUDP(l, rtimeout)
select {
case <-srv.stopUDP:
return nil
default:
}
if e != nil {
continue
}
srv.wgUDP.Add(1)
go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
}
panic("dns: not reached")
}
// Serve a new connection.
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *sessionUDP, t *net.TCPConn) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
q := 0
defer func() {
if u != nil {
srv.wgUDP.Done()
}
if t != nil {
srv.wgTCP.Done()
}
}()
Redo:
req := new(Msg)
err := req.Unpack(m)
if err != nil { // Send a FormatError back
x := new(Msg)
x.SetRcodeFormatError(req)
w.WriteMsg(x)
goto Exit
}
w.tsigStatus = nil
if w.tsigSecret != nil {
if t := req.IsTsig(); t != nil {
secret := t.Hdr.Name
if _, ok := w.tsigSecret[secret]; !ok {
w.tsigStatus = ErrKeyAlg
}
w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false)
w.tsigTimersOnly = false
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
}
}
h.ServeDNS(w, req) // Writes back to the client
Exit:
if w.hijacked {
return // client calls Close()
}
if u != nil { // UDP, "close" and return
w.Close()
return
}
idleTimeout := tcpIdleTimeout
if srv.IdleTimeout != nil {
idleTimeout = srv.IdleTimeout()
}
m, e := srv.readTCP(w.tcp, idleTimeout)
if e == nil {
q++
// TODO(miek): make this number configurable?
if q > 128 { // close socket after this many queries
w.Close()
return
}
goto Redo
}
w.Close()
return
}
func (srv *Server) readTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error) {
conn.SetReadDeadline(time.Now().Add(timeout))
l := make([]byte, 2)
n, err := conn.Read(l)
if err != nil || n != 2 {
if err != nil {
return nil, err
}
return nil, ErrShortRead
}
length, _ := unpackUint16(l, 0)
if length == 0 {
return nil, ErrShortRead
}
m := make([]byte, int(length))
n, err = conn.Read(m[:int(length)])
if err != nil || n == 0 {
if err != nil {
return nil, err
}
return nil, ErrShortRead
}
i := n
for i < int(length) {
j, err := conn.Read(m[i:int(length)])
if err != nil {
return nil, err
}
i += j
}
n = i
m = m[:n]
return m, nil
}
func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *sessionUDP, error) {
conn.SetReadDeadline(time.Now().Add(timeout))
m := make([]byte, srv.UDPSize)
n, s, e := readFromSessionUDP(conn, m)
if e != nil || n == 0 {
if e != nil {
return nil, nil, e
}
return nil, nil, ErrShortRead
}
m = m[:n]
return m, s, nil
}
// WriteMsg implements the ResponseWriter.WriteMsg method.
func (w *response) WriteMsg(m *Msg) (err error) {
var data []byte
if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check)
if t := m.IsTsig(); t != nil {
data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly)
if err != nil {
return err
}
_, err = w.Write(data)
return err
}
}
data, err = m.Pack()
if err != nil {
return err
}
_, err = w.Write(data)
return err
}
// Write implements the ResponseWriter.Write method.
func (w *response) Write(m []byte) (int, error) {
switch {
case w.udp != nil:
n, err := writeToSessionUDP(w.udp, m, w.udpSession)
return n, err
case w.tcp != nil:
lm := len(m)
if lm < 2 {
return 0, io.ErrShortBuffer
}
if lm > MaxMsgSize {
return 0, &Error{err: "message too large"}
}
l := make([]byte, 2, 2+lm)
l[0], l[1] = packUint16(uint16(lm))
m = append(l, m...)
n, err := io.Copy(w.tcp, bytes.NewReader(m))
return int(n), err
}
panic("not reached")
}
// LocalAddr implements the ResponseWriter.LocalAddr method.
func (w *response) LocalAddr() net.Addr {
if w.tcp != nil {
return w.tcp.LocalAddr()
}
return w.udp.LocalAddr()
}
// RemoteAddr implements the ResponseWriter.RemoteAddr method.
func (w *response) RemoteAddr() net.Addr { return w.remoteAddr }
// TsigStatus implements the ResponseWriter.TsigStatus method.
func (w *response) TsigStatus() error { return w.tsigStatus }
// TsigTimersOnly implements the ResponseWriter.TsigTimersOnly method.
func (w *response) TsigTimersOnly(b bool) { w.tsigTimersOnly = b }
// Hijack implements the ResponseWriter.Hijack method.
func (w *response) Hijack() { w.hijacked = true }
// Close implements the ResponseWriter.Close method
func (w *response) Close() error {
// Can't close the udp conn, as that is actually the listener.
if w.tcp != nil {
e := w.tcp.Close()
w.tcp = nil
return e
}
return nil
}

View file

@ -0,0 +1,331 @@
package dns
import (
"fmt"
"net"
"runtime"
"sync"
"testing"
)
func HelloServer(w ResponseWriter, req *Msg) {
m := new(Msg)
m.SetReply(req)
m.Extra = make([]RR, 1)
m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}}
w.WriteMsg(m)
}
func AnotherHelloServer(w ResponseWriter, req *Msg) {
m := new(Msg)
m.SetReply(req)
m.Extra = make([]RR, 1)
m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello example"}}
w.WriteMsg(m)
}
func RunLocalUDPServer(laddr string) (*Server, string, error) {
pc, err := net.ListenPacket("udp", laddr)
if err != nil {
return nil, "", err
}
server := &Server{PacketConn: pc}
go func() {
server.ActivateAndServe()
pc.Close()
}()
return server, pc.LocalAddr().String(), nil
}
func RunLocalTCPServer(laddr string) (*Server, string, error) {
l, err := net.Listen("tcp", laddr)
if err != nil {
return nil, "", err
}
server := &Server{Listener: l}
go func() {
server.ActivateAndServe()
l.Close()
}()
return server, l.Addr().String(), nil
}
func TestServing(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
HandleFunc("example.com.", AnotherHelloServer)
defer HandleRemove("miek.nl.")
defer HandleRemove("example.com.")
s, addrstr, err := RunLocalUDPServer("127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to run test server: %s", err)
}
defer s.Shutdown()
c := new(Client)
m := new(Msg)
m.SetQuestion("miek.nl.", TypeTXT)
r, _, err := c.Exchange(m, addrstr)
if err != nil {
t.Log("failed to exchange miek.nl", err)
t.Fatal()
}
txt := r.Extra[0].(*TXT).Txt[0]
if txt != "Hello world" {
t.Log("Unexpected result for miek.nl", txt, "!= Hello world")
t.Fail()
}
m.SetQuestion("example.com.", TypeTXT)
r, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Log("failed to exchange example.com", err)
t.Fatal()
}
txt = r.Extra[0].(*TXT).Txt[0]
if txt != "Hello example" {
t.Log("Unexpected result for example.com", txt, "!= Hello example")
t.Fail()
}
// Test Mixes cased as noticed by Ask.
m.SetQuestion("eXaMplE.cOm.", TypeTXT)
r, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Log("failed to exchange eXaMplE.cOm", err)
t.Fail()
}
txt = r.Extra[0].(*TXT).Txt[0]
if txt != "Hello example" {
t.Log("Unexpected result for example.com", txt, "!= Hello example")
t.Fail()
}
}
func BenchmarkServe(b *testing.B) {
b.StopTimer()
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")
a := runtime.GOMAXPROCS(4)
s, addrstr, err := RunLocalUDPServer("127.0.0.1:0")
if err != nil {
b.Fatalf("Unable to run test server: %s", err)
}
defer s.Shutdown()
c := new(Client)
m := new(Msg)
m.SetQuestion("miek.nl", TypeSOA)
b.StartTimer()
for i := 0; i < b.N; i++ {
c.Exchange(m, addrstr)
}
runtime.GOMAXPROCS(a)
}
func benchmarkServe6(b *testing.B) {
b.StopTimer()
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")
a := runtime.GOMAXPROCS(4)
s, addrstr, err := RunLocalUDPServer("[::1]:0")
if err != nil {
b.Fatalf("Unable to run test server: %s", err)
}
defer s.Shutdown()
c := new(Client)
m := new(Msg)
m.SetQuestion("miek.nl", TypeSOA)
b.StartTimer()
for i := 0; i < b.N; i++ {
c.Exchange(m, addrstr)
}
runtime.GOMAXPROCS(a)
}
func HelloServerCompress(w ResponseWriter, req *Msg) {
m := new(Msg)
m.SetReply(req)
m.Extra = make([]RR, 1)
m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}}
m.Compress = true
w.WriteMsg(m)
}
func BenchmarkServeCompress(b *testing.B) {
b.StopTimer()
HandleFunc("miek.nl.", HelloServerCompress)
defer HandleRemove("miek.nl.")
a := runtime.GOMAXPROCS(4)
s, addrstr, err := RunLocalUDPServer("127.0.0.1:0")
if err != nil {
b.Fatalf("Unable to run test server: %s", err)
}
defer s.Shutdown()
c := new(Client)
m := new(Msg)
m.SetQuestion("miek.nl", TypeSOA)
b.StartTimer()
for i := 0; i < b.N; i++ {
c.Exchange(m, addrstr)
}
runtime.GOMAXPROCS(a)
}
func TestDotAsCatchAllWildcard(t *testing.T) {
mux := NewServeMux()
mux.Handle(".", HandlerFunc(HelloServer))
mux.Handle("example.com.", HandlerFunc(AnotherHelloServer))
handler := mux.match("www.miek.nl.", TypeTXT)
if handler == nil {
t.Error("wildcard match failed")
}
handler = mux.match("www.example.com.", TypeTXT)
if handler == nil {
t.Error("example.com match failed")
}
handler = mux.match("a.www.example.com.", TypeTXT)
if handler == nil {
t.Error("a.www.example.com match failed")
}
handler = mux.match("boe.", TypeTXT)
if handler == nil {
t.Error("boe. match failed")
}
}
func TestCaseFolding(t *testing.T) {
mux := NewServeMux()
mux.Handle("_udp.example.com.", HandlerFunc(HelloServer))
handler := mux.match("_dns._udp.example.com.", TypeSRV)
if handler == nil {
t.Error("case sensitive characters folded")
}
handler = mux.match("_DNS._UDP.EXAMPLE.COM.", TypeSRV)
if handler == nil {
t.Error("case insensitive characters not folded")
}
}
func TestRootServer(t *testing.T) {
mux := NewServeMux()
mux.Handle(".", HandlerFunc(HelloServer))
handler := mux.match(".", TypeNS)
if handler == nil {
t.Error("root match failed")
}
}
type maxRec struct {
max int
sync.RWMutex
}
var M = new(maxRec)
func HelloServerLargeResponse(resp ResponseWriter, req *Msg) {
m := new(Msg)
m.SetReply(req)
m.Authoritative = true
m1 := 0
M.RLock()
m1 = M.max
M.RUnlock()
for i := 0; i < m1; i++ {
aRec := &A{
Hdr: RR_Header{
Name: req.Question[0].Name,
Rrtype: TypeA,
Class: ClassINET,
Ttl: 0,
},
A: net.ParseIP(fmt.Sprintf("127.0.0.%d", i+1)).To4(),
}
m.Answer = append(m.Answer, aRec)
}
resp.WriteMsg(m)
}
func TestServingLargeResponses(t *testing.T) {
HandleFunc("example.", HelloServerLargeResponse)
defer HandleRemove("example.")
s, addrstr, err := RunLocalUDPServer("127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to run test server: %s", err)
}
defer s.Shutdown()
// Create request
m := new(Msg)
m.SetQuestion("web.service.example.", TypeANY)
c := new(Client)
c.Net = "udp"
M.Lock()
M.max = 2
M.Unlock()
_, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Logf("failed to exchange: %s", err.Error())
t.Fail()
}
// This must fail
M.Lock()
M.max = 20
M.Unlock()
_, _, err = c.Exchange(m, addrstr)
if err == nil {
t.Logf("failed to fail exchange, this should generate packet error")
t.Fail()
}
// But this must work again
c.UDPSize = 7000
_, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Logf("failed to exchange: %s", err.Error())
t.Fail()
}
}
func TestShutdownTCP(t *testing.T) {
s, _, err := RunLocalTCPServer("127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to run test server: %s", err)
}
// it normally is too early to shutting down because server
// activates in goroutine.
runtime.Gosched()
err = s.Shutdown()
if err != nil {
t.Errorf("Could not shutdown test TCP server, %s", err)
}
}
func TestShutdownUDP(t *testing.T) {
s, _, err := RunLocalUDPServer("127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to run test server: %s", err)
}
// it normally is too early to shutting down because server
// activates in goroutine.
runtime.Gosched()
err = s.Shutdown()
if err != nil {
t.Errorf("Could not shutdown test UDP server, %s", err)
}
}

View file

@ -0,0 +1,57 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Adapted for dns package usage by Miek Gieben.
package dns
import "sync"
import "time"
// call is an in-flight or completed singleflight.Do call
type call struct {
wg sync.WaitGroup
val *Msg
rtt time.Duration
err error
dups int
}
// singleflight represents a class of work and forms a namespace in
// which units of work can be executed with duplicate suppression.
type singleflight struct {
sync.Mutex // protects m
m map[string]*call // lazily initialized
}
// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
// The return value shared indicates whether v was given to multiple callers.
func (g *singleflight) Do(key string, fn func() (*Msg, time.Duration, error)) (v *Msg, rtt time.Duration, err error, shared bool) {
g.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
c.dups++
g.Unlock()
c.wg.Wait()
return c.val, c.rtt, c.err, true
}
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.Unlock()
c.val, c.rtt, c.err = fn()
c.wg.Done()
g.Lock()
delete(g.m, key)
g.Unlock()
return c.val, c.rtt, c.err, c.dups > 0
}

84
Godeps/_workspace/src/github.com/miekg/dns/tlsa.go generated vendored Normal file
View file

@ -0,0 +1,84 @@
package dns
import (
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"encoding/hex"
"errors"
"io"
"net"
"strconv"
)
// CertificateToDANE converts a certificate to a hex string as used in the TLSA record.
func CertificateToDANE(selector, matchingType uint8, cert *x509.Certificate) (string, error) {
switch matchingType {
case 0:
switch selector {
case 0:
return hex.EncodeToString(cert.Raw), nil
case 1:
return hex.EncodeToString(cert.RawSubjectPublicKeyInfo), nil
}
case 1:
h := sha256.New()
switch selector {
case 0:
return hex.EncodeToString(cert.Raw), nil
case 1:
io.WriteString(h, string(cert.RawSubjectPublicKeyInfo))
return hex.EncodeToString(h.Sum(nil)), nil
}
case 2:
h := sha512.New()
switch selector {
case 0:
return hex.EncodeToString(cert.Raw), nil
case 1:
io.WriteString(h, string(cert.RawSubjectPublicKeyInfo))
return hex.EncodeToString(h.Sum(nil)), nil
}
}
return "", errors.New("dns: bad TLSA MatchingType or TLSA Selector")
}
// Sign creates a TLSA record from an SSL certificate.
func (r *TLSA) Sign(usage, selector, matchingType int, cert *x509.Certificate) (err error) {
r.Hdr.Rrtype = TypeTLSA
r.Usage = uint8(usage)
r.Selector = uint8(selector)
r.MatchingType = uint8(matchingType)
r.Certificate, err = CertificateToDANE(r.Selector, r.MatchingType, cert)
if err != nil {
return err
}
return nil
}
// Verify verifies a TLSA record against an SSL certificate. If it is OK
// a nil error is returned.
func (r *TLSA) Verify(cert *x509.Certificate) error {
c, err := CertificateToDANE(r.Selector, r.MatchingType, cert)
if err != nil {
return err // Not also ErrSig?
}
if r.Certificate == c {
return nil
}
return ErrSig // ErrSig, really?
}
// TLSAName returns the ownername of a TLSA resource record as per the
// rules specified in RFC 6698, Section 3.
func TLSAName(name, service, network string) (string, error) {
if !IsFqdn(name) {
return "", ErrFqdn
}
p, e := net.LookupPort(network, service)
if e != nil {
return "", e
}
return "_" + strconv.Itoa(p) + "_" + network + "." + name, nil
}

371
Godeps/_workspace/src/github.com/miekg/dns/tsig.go generated vendored Normal file
View file

@ -0,0 +1,371 @@
// TRANSACTION SIGNATURE
//
// An TSIG or transaction signature adds a HMAC TSIG record to each message sent.
// The supported algorithms include: HmacMD5, HmacSHA1 and HmacSHA256.
//
// Basic use pattern when querying with a TSIG name "axfr." (note that these key names
// must be fully qualified - as they are domain names) and the base64 secret
// "so6ZGir4GPAqINNh9U5c3A==":
//
// c := new(dns.Client)
// c.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
// m := new(dns.Msg)
// m.SetQuestion("miek.nl.", dns.TypeMX)
// m.SetTsig("axfr.", dns.HmacMD5, 300, time.Now().Unix())
// ...
// // When sending the TSIG RR is calculated and filled in before sending
//
// When requesting an zone transfer (almost all TSIG usage is when requesting zone transfers), with
// TSIG, this is the basic use pattern. In this example we request an AXFR for
// miek.nl. with TSIG key named "axfr." and secret "so6ZGir4GPAqINNh9U5c3A=="
// and using the server 176.58.119.54:
//
// t := new(dns.Transfer)
// m := new(dns.Msg)
// t.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
// m.SetAxfr("miek.nl.")
// m.SetTsig("axfr.", dns.HmacMD5, 300, time.Now().Unix())
// c, err := t.In(m, "176.58.119.54:53")
// for r := range c { /* r.RR */ }
//
// You can now read the records from the transfer as they come in. Each envelope is checked with TSIG.
// If something is not correct an error is returned.
//
// Basic use pattern validating and replying to a message that has TSIG set.
//
// server := &dns.Server{Addr: ":53", Net: "udp"}
// server.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
// go server.ListenAndServe()
// dns.HandleFunc(".", handleRequest)
//
// func handleRequest(w dns.ResponseWriter, r *dns.Msg) {
// m := new(Msg)
// m.SetReply(r)
// if r.IsTsig() {
// if w.TsigStatus() == nil {
// // *Msg r has an TSIG record and it was validated
// m.SetTsig("axfr.", dns.HmacMD5, 300, time.Now().Unix())
// } else {
// // *Msg r has an TSIG records and it was not valided
// }
// }
// w.WriteMsg(m)
// }
package dns
import (
"crypto/hmac"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"hash"
"io"
"strconv"
"strings"
"time"
)
// HMAC hashing codes. These are transmitted as domain names.
const (
HmacMD5 = "hmac-md5.sig-alg.reg.int."
HmacSHA1 = "hmac-sha1."
HmacSHA256 = "hmac-sha256."
)
type TSIG struct {
Hdr RR_Header
Algorithm string `dns:"domain-name"`
TimeSigned uint64 `dns:"uint48"`
Fudge uint16
MACSize uint16
MAC string `dns:"size-hex"`
OrigId uint16
Error uint16
OtherLen uint16
OtherData string `dns:"size-hex"`
}
func (rr *TSIG) Header() *RR_Header {
return &rr.Hdr
}
// TSIG has no official presentation format, but this will suffice.
func (rr *TSIG) String() string {
s := "\n;; TSIG PSEUDOSECTION:\n"
s += rr.Hdr.String() +
" " + rr.Algorithm +
" " + tsigTimeToString(rr.TimeSigned) +
" " + strconv.Itoa(int(rr.Fudge)) +
" " + strconv.Itoa(int(rr.MACSize)) +
" " + strings.ToUpper(rr.MAC) +
" " + strconv.Itoa(int(rr.OrigId)) +
" " + strconv.Itoa(int(rr.Error)) + // BIND prints NOERROR
" " + strconv.Itoa(int(rr.OtherLen)) +
" " + rr.OtherData
return s
}
func (rr *TSIG) len() int {
return rr.Hdr.len() + len(rr.Algorithm) + 1 + 6 +
4 + len(rr.MAC)/2 + 1 + 6 + len(rr.OtherData)/2 + 1
}
func (rr *TSIG) copy() RR {
return &TSIG{*rr.Hdr.copyHeader(), rr.Algorithm, rr.TimeSigned, rr.Fudge, rr.MACSize, rr.MAC, rr.OrigId, rr.Error, rr.OtherLen, rr.OtherData}
}
// The following values must be put in wireformat, so that the MAC can be calculated.
// RFC 2845, section 3.4.2. TSIG Variables.
type tsigWireFmt struct {
// From RR_Header
Name string `dns:"domain-name"`
Class uint16
Ttl uint32
// Rdata of the TSIG
Algorithm string `dns:"domain-name"`
TimeSigned uint64 `dns:"uint48"`
Fudge uint16
// MACSize, MAC and OrigId excluded
Error uint16
OtherLen uint16
OtherData string `dns:"size-hex"`
}
// If we have the MAC use this type to convert it to wiredata.
// Section 3.4.3. Request MAC
type macWireFmt struct {
MACSize uint16
MAC string `dns:"size-hex"`
}
// 3.3. Time values used in TSIG calculations
type timerWireFmt struct {
TimeSigned uint64 `dns:"uint48"`
Fudge uint16
}
// TsigGenerate fills out the TSIG record attached to the message.
// The message should contain
// a "stub" TSIG RR with the algorithm, key name (owner name of the RR),
// time fudge (defaults to 300 seconds) and the current time
// The TSIG MAC is saved in that Tsig RR.
// When TsigGenerate is called for the first time requestMAC is set to the empty string and
// timersOnly is false.
// If something goes wrong an error is returned, otherwise it is nil.
func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, string, error) {
if m.IsTsig() == nil {
panic("dns: TSIG not last RR in additional")
}
// If we barf here, the caller is to blame
rawsecret, err := packBase64([]byte(secret))
if err != nil {
return nil, "", err
}
rr := m.Extra[len(m.Extra)-1].(*TSIG)
m.Extra = m.Extra[0 : len(m.Extra)-1] // kill the TSIG from the msg
mbuf, err := m.Pack()
if err != nil {
return nil, "", err
}
buf := tsigBuffer(mbuf, rr, requestMAC, timersOnly)
t := new(TSIG)
var h hash.Hash
switch rr.Algorithm {
case HmacMD5:
h = hmac.New(md5.New, []byte(rawsecret))
case HmacSHA1:
h = hmac.New(sha1.New, []byte(rawsecret))
case HmacSHA256:
h = hmac.New(sha256.New, []byte(rawsecret))
default:
return nil, "", ErrKeyAlg
}
io.WriteString(h, string(buf))
t.MAC = hex.EncodeToString(h.Sum(nil))
t.MACSize = uint16(len(t.MAC) / 2) // Size is half!
t.Hdr = RR_Header{Name: rr.Hdr.Name, Rrtype: TypeTSIG, Class: ClassANY, Ttl: 0}
t.Fudge = rr.Fudge
t.TimeSigned = rr.TimeSigned
t.Algorithm = rr.Algorithm
t.OrigId = m.Id
tbuf := make([]byte, t.len())
if off, err := PackRR(t, tbuf, 0, nil, false); err == nil {
tbuf = tbuf[:off] // reset to actual size used
} else {
return nil, "", err
}
mbuf = append(mbuf, tbuf...)
rawSetExtraLen(mbuf, uint16(len(m.Extra)+1))
return mbuf, t.MAC, nil
}
// TsigVerify verifies the TSIG on a message.
// If the signature does not validate err contains the
// error, otherwise it is nil.
func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error {
rawsecret, err := packBase64([]byte(secret))
if err != nil {
return err
}
// Strip the TSIG from the incoming msg
stripped, tsig, err := stripTsig(msg)
if err != nil {
return err
}
msgMAC, err := hex.DecodeString(tsig.MAC)
if err != nil {
return err
}
buf := tsigBuffer(stripped, tsig, requestMAC, timersOnly)
ti := uint64(time.Now().Unix()) - tsig.TimeSigned
if uint64(tsig.Fudge) < ti {
return ErrTime
}
var h hash.Hash
switch tsig.Algorithm {
case HmacMD5:
h = hmac.New(md5.New, rawsecret)
case HmacSHA1:
h = hmac.New(sha1.New, rawsecret)
case HmacSHA256:
h = hmac.New(sha256.New, rawsecret)
default:
return ErrKeyAlg
}
h.Write(buf)
if !hmac.Equal(h.Sum(nil), msgMAC) {
return ErrSig
}
return nil
}
// Create a wiredata buffer for the MAC calculation.
func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []byte {
var buf []byte
if rr.TimeSigned == 0 {
rr.TimeSigned = uint64(time.Now().Unix())
}
if rr.Fudge == 0 {
rr.Fudge = 300 // Standard (RFC) default.
}
if requestMAC != "" {
m := new(macWireFmt)
m.MACSize = uint16(len(requestMAC) / 2)
m.MAC = requestMAC
buf = make([]byte, len(requestMAC)) // long enough
n, _ := PackStruct(m, buf, 0)
buf = buf[:n]
}
tsigvar := make([]byte, DefaultMsgSize)
if timersOnly {
tsig := new(timerWireFmt)
tsig.TimeSigned = rr.TimeSigned
tsig.Fudge = rr.Fudge
n, _ := PackStruct(tsig, tsigvar, 0)
tsigvar = tsigvar[:n]
} else {
tsig := new(tsigWireFmt)
tsig.Name = strings.ToLower(rr.Hdr.Name)
tsig.Class = ClassANY
tsig.Ttl = rr.Hdr.Ttl
tsig.Algorithm = strings.ToLower(rr.Algorithm)
tsig.TimeSigned = rr.TimeSigned
tsig.Fudge = rr.Fudge
tsig.Error = rr.Error
tsig.OtherLen = rr.OtherLen
tsig.OtherData = rr.OtherData
n, _ := PackStruct(tsig, tsigvar, 0)
tsigvar = tsigvar[:n]
}
if requestMAC != "" {
x := append(buf, msgbuf...)
buf = append(x, tsigvar...)
} else {
buf = append(msgbuf, tsigvar...)
}
return buf
}
// Strip the TSIG from the raw message.
func stripTsig(msg []byte) ([]byte, *TSIG, error) {
// Copied from msg.go's Unpack()
// Header.
var dh Header
var err error
dns := new(Msg)
rr := new(TSIG)
off := 0
tsigoff := 0
if off, err = UnpackStruct(&dh, msg, off); err != nil {
return nil, nil, err
}
if dh.Arcount == 0 {
return nil, nil, ErrNoSig
}
// Rcode, see msg.go Unpack()
if int(dh.Bits&0xF) == RcodeNotAuth {
return nil, nil, ErrAuth
}
// Arrays.
dns.Question = make([]Question, dh.Qdcount)
dns.Answer = make([]RR, dh.Ancount)
dns.Ns = make([]RR, dh.Nscount)
dns.Extra = make([]RR, dh.Arcount)
for i := 0; i < len(dns.Question); i++ {
off, err = UnpackStruct(&dns.Question[i], msg, off)
if err != nil {
return nil, nil, err
}
}
for i := 0; i < len(dns.Answer); i++ {
dns.Answer[i], off, err = UnpackRR(msg, off)
if err != nil {
return nil, nil, err
}
}
for i := 0; i < len(dns.Ns); i++ {
dns.Ns[i], off, err = UnpackRR(msg, off)
if err != nil {
return nil, nil, err
}
}
for i := 0; i < len(dns.Extra); i++ {
tsigoff = off
dns.Extra[i], off, err = UnpackRR(msg, off)
if err != nil {
return nil, nil, err
}
if dns.Extra[i].Header().Rrtype == TypeTSIG {
rr = dns.Extra[i].(*TSIG)
// Adjust Arcount.
arcount, _ := unpackUint16(msg, 10)
msg[10], msg[11] = packUint16(arcount - 1)
break
}
}
if rr == nil {
return nil, nil, ErrNoSig
}
return msg[:tsigoff], rr, nil
}
// Translate the TSIG time signed into a date. There is no
// need for RFC1982 calculations as this date is 48 bits.
func tsigTimeToString(t uint64) string {
ti := time.Unix(int64(t), 0).UTC()
return ti.Format("20060102150405")
}

1685
Godeps/_workspace/src/github.com/miekg/dns/types.go generated vendored Normal file

File diff suppressed because it is too large Load diff

55
Godeps/_workspace/src/github.com/miekg/dns/udp.go generated vendored Normal file
View file

@ -0,0 +1,55 @@
// +build !windows
package dns
import (
"net"
"syscall"
)
type sessionUDP struct {
raddr *net.UDPAddr
context []byte
}
func (s *sessionUDP) RemoteAddr() net.Addr { return s.raddr }
// setUDPSocketOptions sets the UDP socket options.
// This function is implemented on a per platform basis. See udp_*.go for more details
func setUDPSocketOptions(conn *net.UDPConn) error {
sa, err := getUDPSocketName(conn)
if err != nil {
return err
}
switch sa.(type) {
case *syscall.SockaddrInet6:
v6only, err := getUDPSocketOptions6Only(conn)
if err != nil {
return err
}
setUDPSocketOptions6(conn)
if !v6only {
setUDPSocketOptions4(conn)
}
case *syscall.SockaddrInet4:
setUDPSocketOptions4(conn)
}
return nil
}
// readFromSessionUDP acts just like net.UDPConn.ReadFrom(), but returns a session object instead of a
// net.UDPAddr.
func readFromSessionUDP(conn *net.UDPConn, b []byte) (int, *sessionUDP, error) {
oob := make([]byte, 40)
n, oobn, _, raddr, err := conn.ReadMsgUDP(b, oob)
if err != nil {
return n, nil, err
}
return n, &sessionUDP{raddr, oob[:oobn]}, err
}
// writeToSessionUDP acts just like net.UDPConn.WritetTo(), but uses a *sessionUDP instead of a net.Addr.
func writeToSessionUDP(conn *net.UDPConn, b []byte, session *sessionUDP) (int, error) {
n, _, err := conn.WriteMsgUDP(b, session.context, session.raddr)
return n, err
}

View file

@ -0,0 +1,63 @@
// +build linux
package dns
// See:
// * http://stackoverflow.com/questions/3062205/setting-the-source-ip-for-a-udp-socket and
// * http://blog.powerdns.com/2012/10/08/on-binding-datagram-udp-sockets-to-the-any-addresses/
//
// Why do we need this: When listening on 0.0.0.0 with UDP so kernel decides what is the outgoing
// interface, this might not always be the correct one. This code will make sure the egress
// packet's interface matched the ingress' one.
import (
"net"
"syscall"
)
// setUDPSocketOptions4 prepares the v4 socket for sessions.
func setUDPSocketOptions4(conn *net.UDPConn) error {
file, err := conn.File()
if err != nil {
return err
}
if err := syscall.SetsockoptInt(int(file.Fd()), syscall.IPPROTO_IP, syscall.IP_PKTINFO, 1); err != nil {
return err
}
return nil
}
// setUDPSocketOptions6 prepares the v6 socket for sessions.
func setUDPSocketOptions6(conn *net.UDPConn) error {
file, err := conn.File()
if err != nil {
return err
}
if err := syscall.SetsockoptInt(int(file.Fd()), syscall.IPPROTO_IPV6, syscall.IPV6_RECVPKTINFO, 1); err != nil {
return err
}
return nil
}
// getUDPSocketOption6Only return true if the socket is v6 only and false when it is v4/v6 combined
// (dualstack).
func getUDPSocketOptions6Only(conn *net.UDPConn) (bool, error) {
file, err := conn.File()
if err != nil {
return false, err
}
// dual stack. See http://stackoverflow.com/questions/1618240/how-to-support-both-ipv4-and-ipv6-connections
v6only, err := syscall.GetsockoptInt(int(file.Fd()), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY)
if err != nil {
return false, err
}
return v6only == 1, nil
}
func getUDPSocketName(conn *net.UDPConn) (syscall.Sockaddr, error) {
file, err := conn.File()
if err != nil {
return nil, err
}
return syscall.Getsockname(int(file.Fd()))
}

View file

@ -0,0 +1,17 @@
// +build !linux
package dns
import (
"net"
"syscall"
)
// These do nothing. See udp_linux.go for an example of how to implement this.
// We tried to adhire to some kind of naming scheme.
func setUDPSocketOptions4(conn *net.UDPConn) error { return nil }
func setUDPSocketOptions6(conn *net.UDPConn) error { return nil }
func getUDPSocketOptions6Only(conn *net.UDPConn) (bool, error) { return false, nil }
func getUDPSocketName(conn *net.UDPConn) (syscall.Sockaddr, error) { return nil, nil }

View file

@ -0,0 +1,34 @@
// +build windows
package dns
import "net"
type sessionUDP struct {
raddr *net.UDPAddr
}
// readFromSessionUDP acts just like net.UDPConn.ReadFrom(), but returns a session object instead of a
// net.UDPAddr.
func readFromSessionUDP(conn *net.UDPConn, b []byte) (int, *sessionUDP, error) {
n, raddr, err := conn.ReadFrom(b)
if err != nil {
return n, nil, err
}
session := &sessionUDP{raddr.(*net.UDPAddr)}
return n, session, err
}
// writeToSessionUDP acts just like net.UDPConn.WritetTo(), but uses a *sessionUDP instead of a net.Addr.
func writeToSessionUDP(conn *net.UDPConn, b []byte, session *sessionUDP) (int, error) {
n, err := conn.WriteTo(b, session.raddr)
return n, err
}
func (s *sessionUDP) RemoteAddr() net.Addr { return s.raddr }
// setUDPSocketOptions sets the UDP socket options.
// This function is implemented on a per platform basis. See udp_*.go for more details
func setUDPSocketOptions(conn *net.UDPConn) error {
return nil
}

132
Godeps/_workspace/src/github.com/miekg/dns/update.go generated vendored Normal file
View file

@ -0,0 +1,132 @@
// DYNAMIC UPDATES
//
// Dynamic updates reuses the DNS message format, but renames three of
// the sections. Question is Zone, Answer is Prerequisite, Authority is
// Update, only the Additional is not renamed. See RFC 2136 for the gory details.
//
// You can set a rather complex set of rules for the existence of absence of
// certain resource records or names in a zone to specify if resource records
// should be added or removed. The table from RFC 2136 supplemented with the Go
// DNS function shows which functions exist to specify the prerequisites.
//
// 3.2.4 - Table Of Metavalues Used In Prerequisite Section
//
// CLASS TYPE RDATA Meaning Function
// --------------------------------------------------------------
// ANY ANY empty Name is in use dns.NameUsed
// ANY rrset empty RRset exists (value indep) dns.RRsetUsed
// NONE ANY empty Name is not in use dns.NameNotUsed
// NONE rrset empty RRset does not exist dns.RRsetNotUsed
// zone rrset rr RRset exists (value dep) dns.Used
//
// The prerequisite section can also be left empty.
// If you have decided on the prerequisites you can tell what RRs should
// be added or deleted. The next table shows the options you have and
// what functions to call.
//
// 3.4.2.6 - Table Of Metavalues Used In Update Section
//
// CLASS TYPE RDATA Meaning Function
// ---------------------------------------------------------------
// ANY ANY empty Delete all RRsets from name dns.RemoveName
// ANY rrset empty Delete an RRset dns.RemoveRRset
// NONE rrset rr Delete an RR from RRset dns.Remove
// zone rrset rr Add to an RRset dns.Insert
//
package dns
// NameUsed sets the RRs in the prereq section to
// "Name is in use" RRs. RFC 2136 section 2.4.4.
func (u *Msg) NameUsed(rr []RR) {
u.Answer = make([]RR, len(rr))
for i, r := range rr {
u.Answer[i] = &ANY{Hdr: RR_Header{Name: r.Header().Name, Ttl: 0, Rrtype: TypeANY, Class: ClassANY}}
}
}
// NameNotUsed sets the RRs in the prereq section to
// "Name is in not use" RRs. RFC 2136 section 2.4.5.
func (u *Msg) NameNotUsed(rr []RR) {
u.Answer = make([]RR, len(rr))
for i, r := range rr {
u.Answer[i] = &ANY{Hdr: RR_Header{Name: r.Header().Name, Ttl: 0, Rrtype: TypeANY, Class: ClassNONE}}
}
}
// Used sets the RRs in the prereq section to
// "RRset exists (value dependent -- with rdata)" RRs. RFC 2136 section 2.4.2.
func (u *Msg) Used(rr []RR) {
if len(u.Question) == 0 {
panic("dns: empty question section")
}
u.Answer = make([]RR, len(rr))
for i, r := range rr {
u.Answer[i] = r
u.Answer[i].Header().Class = u.Question[0].Qclass
}
}
// RRsetUsed sets the RRs in the prereq section to
// "RRset exists (value independent -- no rdata)" RRs. RFC 2136 section 2.4.1.
func (u *Msg) RRsetUsed(rr []RR) {
u.Answer = make([]RR, len(rr))
for i, r := range rr {
u.Answer[i] = r
u.Answer[i].Header().Class = ClassANY
u.Answer[i].Header().Ttl = 0
u.Answer[i].Header().Rdlength = 0
}
}
// RRsetNotUsed sets the RRs in the prereq section to
// "RRset does not exist" RRs. RFC 2136 section 2.4.3.
func (u *Msg) RRsetNotUsed(rr []RR) {
u.Answer = make([]RR, len(rr))
for i, r := range rr {
u.Answer[i] = r
u.Answer[i].Header().Class = ClassNONE
u.Answer[i].Header().Rdlength = 0
u.Answer[i].Header().Ttl = 0
}
}
// Insert creates a dynamic update packet that adds an complete RRset, see RFC 2136 section 2.5.1.
func (u *Msg) Insert(rr []RR) {
if len(u.Question) == 0 {
panic("dns: empty question section")
}
u.Ns = make([]RR, len(rr))
for i, r := range rr {
u.Ns[i] = r
u.Ns[i].Header().Class = u.Question[0].Qclass
}
}
// RemoveRRset creates a dynamic update packet that deletes an RRset, see RFC 2136 section 2.5.2.
func (u *Msg) RemoveRRset(rr []RR) {
u.Ns = make([]RR, len(rr))
for i, r := range rr {
u.Ns[i] = r
u.Ns[i].Header().Class = ClassANY
u.Ns[i].Header().Rdlength = 0
u.Ns[i].Header().Ttl = 0
}
}
// RemoveName creates a dynamic update packet that deletes all RRsets of a name, see RFC 2136 section 2.5.3
func (u *Msg) RemoveName(rr []RR) {
u.Ns = make([]RR, len(rr))
for i, r := range rr {
u.Ns[i] = &ANY{Hdr: RR_Header{Name: r.Header().Name, Ttl: 0, Rrtype: TypeANY, Class: ClassANY}}
}
}
// Remove creates a dynamic update packet deletes RR from the RRSset, see RFC 2136 section 2.5.4
func (u *Msg) Remove(rr []RR) {
u.Ns = make([]RR, len(rr))
for i, r := range rr {
u.Ns[i] = r
u.Ns[i].Header().Class = ClassNONE
u.Ns[i].Header().Ttl = 0
}
}

236
Godeps/_workspace/src/github.com/miekg/dns/xfr.go generated vendored Normal file
View file

@ -0,0 +1,236 @@
package dns
import (
"time"
)
// Envelope is used when doing a zone transfer with a remote server.
type Envelope struct {
RR []RR // The set of RRs in the answer section of the xfr reply message.
Error error // If something went wrong, this contains the error.
}
// A Transfer defines parameters that are used during a zone transfer.
type Transfer struct {
*Conn
DialTimeout time.Duration // net.DialTimeout (ns), defaults to 2 * 1e9
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections (ns), defaults to 2 * 1e9
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections (ns), defaults to 2 * 1e9
TsigSecret map[string]string // Secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
tsigTimersOnly bool
}
// Think we need to away to stop the transfer
// In performs an incoming transfer with the server in a.
func (t *Transfer) In(q *Msg, a string) (env chan *Envelope, err error) {
timeout := dnsTimeout
if t.DialTimeout != 0 {
timeout = t.DialTimeout
}
t.Conn, err = DialTimeout("tcp", a, timeout)
if err != nil {
return nil, err
}
if err := t.WriteMsg(q); err != nil {
return nil, err
}
env = make(chan *Envelope)
go func() {
if q.Question[0].Qtype == TypeAXFR {
go t.inAxfr(q.Id, env)
return
}
if q.Question[0].Qtype == TypeIXFR {
go t.inIxfr(q.Id, env)
return
}
}()
return env, nil
}
func (t *Transfer) inAxfr(id uint16, c chan *Envelope) {
first := true
defer t.Close()
defer close(c)
timeout := dnsTimeout
if t.ReadTimeout != 0 {
timeout = t.ReadTimeout
}
for {
t.Conn.SetReadDeadline(time.Now().Add(timeout))
in, err := t.ReadMsg()
if err != nil {
c <- &Envelope{nil, err}
return
}
if id != in.Id {
c <- &Envelope{in.Answer, ErrId}
return
}
if first {
if !isSOAFirst(in) {
c <- &Envelope{in.Answer, ErrSoa}
return
}
first = !first
// only one answer that is SOA, receive more
if len(in.Answer) == 1 {
t.tsigTimersOnly = true
c <- &Envelope{in.Answer, nil}
continue
}
}
if !first {
t.tsigTimersOnly = true // Subsequent envelopes use this.
if isSOALast(in) {
c <- &Envelope{in.Answer, nil}
return
}
c <- &Envelope{in.Answer, nil}
}
}
panic("dns: not reached")
}
func (t *Transfer) inIxfr(id uint16, c chan *Envelope) {
serial := uint32(0) // The first serial seen is the current server serial
first := true
defer t.Close()
defer close(c)
timeout := dnsTimeout
if t.ReadTimeout != 0 {
timeout = t.ReadTimeout
}
for {
t.SetReadDeadline(time.Now().Add(timeout))
in, err := t.ReadMsg()
if err != nil {
c <- &Envelope{in.Answer, err}
return
}
if id != in.Id {
c <- &Envelope{in.Answer, ErrId}
return
}
if first {
// A single SOA RR signals "no changes"
if len(in.Answer) == 1 && isSOAFirst(in) {
c <- &Envelope{in.Answer, nil}
return
}
// Check if the returned answer is ok
if !isSOAFirst(in) {
c <- &Envelope{in.Answer, ErrSoa}
return
}
// This serial is important
serial = in.Answer[0].(*SOA).Serial
first = !first
}
// Now we need to check each message for SOA records, to see what we need to do
if !first {
t.tsigTimersOnly = true
// If the last record in the IXFR contains the servers' SOA, we should quit
if v, ok := in.Answer[len(in.Answer)-1].(*SOA); ok {
if v.Serial == serial {
c <- &Envelope{in.Answer, nil}
return
}
}
c <- &Envelope{in.Answer, nil}
}
}
}
// Out performs an outgoing transfer with the client connecting in w.
// Basic use pattern:
//
// ch := make(chan *dns.Envelope)
// tr := new(dns.Transfer)
// tr.Out(w, r, ch)
// c <- &dns.Envelope{RR: []dns.RR{soa, rr1, rr2, rr3, soa}}
// close(ch)
// w.Hijack()
// // w.Close() // Client closes connection
//
// The server is responsible for sending the correct sequence of RRs through the
// channel ch.
func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error {
r := new(Msg)
// Compress?
r.SetReply(q)
r.Authoritative = true
go func() {
for x := range ch {
// assume it fits TODO(miek): fix
r.Answer = append(r.Answer, x.RR...)
if err := w.WriteMsg(r); err != nil {
return
}
}
w.TsigTimersOnly(true)
r.Answer = nil
}()
return nil
}
// ReadMsg reads a message from the transfer connection t.
func (t *Transfer) ReadMsg() (*Msg, error) {
m := new(Msg)
p := make([]byte, MaxMsgSize)
n, err := t.Read(p)
if err != nil && n == 0 {
return nil, err
}
p = p[:n]
if err := m.Unpack(p); err != nil {
return nil, err
}
if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil {
if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok {
return m, ErrSecret
}
// Need to work on the original message p, as that was used to calculate the tsig.
err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly)
}
return m, err
}
// WriteMsg writes a message through the transfer connection t.
func (t *Transfer) WriteMsg(m *Msg) (err error) {
var out []byte
if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil {
if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok {
return ErrSecret
}
out, t.tsigRequestMAC, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly)
} else {
out, err = m.Pack()
}
if err != nil {
return err
}
if _, err = t.Write(out); err != nil {
return err
}
return nil
}
func isSOAFirst(in *Msg) bool {
if len(in.Answer) > 0 {
return in.Answer[0].Header().Rrtype == TypeSOA
}
return false
}
func isSOALast(in *Msg) bool {
if len(in.Answer) > 0 {
return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA
}
return false
}

157
Godeps/_workspace/src/github.com/miekg/dns/zgenerate.go generated vendored Normal file
View file

@ -0,0 +1,157 @@
package dns
import (
"fmt"
"strconv"
"strings"
)
// Parse the $GENERATE statement as used in BIND9 zones.
// See http://www.zytrax.com/books/dns/ch8/generate.html for instance.
// We are called after '$GENERATE '. After which we expect:
// * the range (12-24/2)
// * lhs (ownername)
// * [[ttl][class]]
// * type
// * rhs (rdata)
// But we are lazy here, only the range is parsed *all* occurences
// of $ after that are interpreted.
// Any error are returned as a string value, the empty string signals
// "no error".
func generate(l lex, c chan lex, t chan *Token, o string) string {
step := 1
if i := strings.IndexAny(l.token, "/"); i != -1 {
if i+1 == len(l.token) {
return "bad step in $GENERATE range"
}
if s, e := strconv.Atoi(l.token[i+1:]); e != nil {
return "bad step in $GENERATE range"
} else {
if s < 0 {
return "bad step in $GENERATE range"
}
step = s
}
l.token = l.token[:i]
}
sx := strings.SplitN(l.token, "-", 2)
if len(sx) != 2 {
return "bad start-stop in $GENERATE range"
}
start, err := strconv.Atoi(sx[0])
if err != nil {
return "bad start in $GENERATE range"
}
end, err := strconv.Atoi(sx[1])
if err != nil {
return "bad stop in $GENERATE range"
}
if end < 0 || start < 0 || end <= start {
return "bad range in $GENERATE range"
}
<-c // _BLANK
// Create a complete new string, which we then parse again.
s := ""
BuildRR:
l = <-c
if l.value != _NEWLINE && l.value != _EOF {
s += l.token
goto BuildRR
}
for i := start; i <= end; i += step {
var (
escape bool
dom string
mod string
err string
offset int
)
for j := 0; j < len(s); j++ { // No 'range' because we need to jump around
switch s[j] {
case '\\':
if escape {
dom += "\\"
escape = false
continue
}
escape = true
case '$':
mod = "%d"
offset = 0
if escape {
dom += "$"
escape = false
continue
}
escape = false
if j+1 >= len(s) { // End of the string
dom += fmt.Sprintf(mod, i+offset)
continue
} else {
if s[j+1] == '$' {
dom += "$"
j++
continue
}
}
// Search for { and }
if s[j+1] == '{' { // Modifier block
sep := strings.Index(s[j+2:], "}")
if sep == -1 {
return "bad modifier in $GENERATE"
}
mod, offset, err = modToPrintf(s[j+2 : j+2+sep])
if err != "" {
return err
}
j += 2 + sep // Jump to it
}
dom += fmt.Sprintf(mod, i+offset)
default:
if escape { // Pretty useless here
escape = false
continue
}
dom += string(s[j])
}
}
// Re-parse the RR and send it on the current channel t
rx, e := NewRR("$ORIGIN " + o + "\n" + dom)
if e != nil {
return e.(*ParseError).err
}
t <- &Token{RR: rx}
// Its more efficient to first built the rrlist and then parse it in
// one go! But is this a problem?
}
return ""
}
// Convert a $GENERATE modifier 0,0,d to something Printf can deal with.
func modToPrintf(s string) (string, int, string) {
xs := strings.SplitN(s, ",", 3)
if len(xs) != 3 {
return "", 0, "bad modifier in $GENERATE"
}
// xs[0] is offset, xs[1] is width, xs[2] is base
if xs[2] != "o" && xs[2] != "d" && xs[2] != "x" && xs[2] != "X" {
return "", 0, "bad base in $GENERATE"
}
offset, err := strconv.Atoi(xs[0])
if err != nil {
return "", 0, "bad offset in $GENERATE"
}
width, err := strconv.Atoi(xs[1])
if err != nil {
return "", offset, "bad width in $GENERATE"
}
switch {
case width < 0:
return "", offset, "bad width in $GENERATE"
case width == 0:
return "%" + xs[1] + xs[2], offset, ""
}
return "%0" + xs[1] + xs[2], offset, ""
}

956
Godeps/_workspace/src/github.com/miekg/dns/zscan.go generated vendored Normal file
View file

@ -0,0 +1,956 @@
package dns
import (
"io"
"log"
"os"
"strconv"
"strings"
)
type debugging bool
const debug debugging = false
func (d debugging) Printf(format string, args ...interface{}) {
if d {
log.Printf(format, args...)
}
}
const maxTok = 2048 // Largest token we can return.
const maxUint16 = 1<<16 - 1
// Tokinize a RFC 1035 zone file. The tokenizer will normalize it:
// * Add ownernames if they are left blank;
// * Suppress sequences of spaces;
// * Make each RR fit on one line (_NEWLINE is send as last)
// * Handle comments: ;
// * Handle braces - anywhere.
const (
// Zonefile
_EOF = iota
_STRING
_BLANK
_QUOTE
_NEWLINE
_RRTYPE
_OWNER
_CLASS
_DIRORIGIN // $ORIGIN
_DIRTTL // $TTL
_DIRINCLUDE // $INCLUDE
_DIRGENERATE // $GENERATE
// Privatekey file
_VALUE
_KEY
_EXPECT_OWNER_DIR // Ownername
_EXPECT_OWNER_BL // Whitespace after the ownername
_EXPECT_ANY // Expect rrtype, ttl or class
_EXPECT_ANY_NOCLASS // Expect rrtype or ttl
_EXPECT_ANY_NOCLASS_BL // The whitespace after _EXPECT_ANY_NOCLASS
_EXPECT_ANY_NOTTL // Expect rrtype or class
_EXPECT_ANY_NOTTL_BL // Whitespace after _EXPECT_ANY_NOTTL
_EXPECT_RRTYPE // Expect rrtype
_EXPECT_RRTYPE_BL // Whitespace BEFORE rrtype
_EXPECT_RDATA // The first element of the rdata
_EXPECT_DIRTTL_BL // Space after directive $TTL
_EXPECT_DIRTTL // Directive $TTL
_EXPECT_DIRORIGIN_BL // Space after directive $ORIGIN
_EXPECT_DIRORIGIN // Directive $ORIGIN
_EXPECT_DIRINCLUDE_BL // Space after directive $INCLUDE
_EXPECT_DIRINCLUDE // Directive $INCLUDE
_EXPECT_DIRGENERATE // Directive $GENERATE
_EXPECT_DIRGENERATE_BL // Space after directive $GENERATE
)
// ParseError is a parsing error. It contains the parse error and the location in the io.Reader
// where the error occured.
type ParseError struct {
file string
err string
lex lex
}
func (e *ParseError) Error() (s string) {
if e.file != "" {
s = e.file + ": "
}
s += "dns: " + e.err + ": " + strconv.QuoteToASCII(e.lex.token) + " at line: " +
strconv.Itoa(e.lex.line) + ":" + strconv.Itoa(e.lex.column)
return
}
type lex struct {
token string // text of the token
tokenUpper string // uppercase text of the token
length int // lenght of the token
err bool // when true, token text has lexer error
value uint8 // value: _STRING, _BLANK, etc.
line int // line in the file
column int // column in the file
torc uint16 // type or class as parsed in the lexer, we only need to look this up in the grammar
comment string // any comment text seen
}
// *Tokens are returned when a zone file is parsed.
type Token struct {
RR // the scanned resource record when error is not nil
Error *ParseError // when an error occured, this has the error specifics
Comment string // a potential comment positioned after the RR and on the same line
}
// NewRR reads the RR contained in the string s. Only the first RR is returned.
// The class defaults to IN and TTL defaults to 3600. The full zone file
// syntax like $TTL, $ORIGIN, etc. is supported.
// All fields of the returned RR are set, except RR.Header().Rdlength which is set to 0.
func NewRR(s string) (RR, error) {
if s[len(s)-1] != '\n' { // We need a closing newline
return ReadRR(strings.NewReader(s+"\n"), "")
}
return ReadRR(strings.NewReader(s), "")
}
// ReadRR reads the RR contained in q.
// See NewRR for more documentation.
func ReadRR(q io.Reader, filename string) (RR, error) {
r := <-parseZoneHelper(q, ".", filename, 1)
if r.Error != nil {
return nil, r.Error
}
return r.RR, nil
}
// ParseZone reads a RFC 1035 style one from r. It returns *Tokens on the
// returned channel, which consist out the parsed RR, a potential comment or an error.
// If there is an error the RR is nil. The string file is only used
// in error reporting. The string origin is used as the initial origin, as
// if the file would start with: $ORIGIN origin .
// The directives $INCLUDE, $ORIGIN, $TTL and $GENERATE are supported.
// The channel t is closed by ParseZone when the end of r is reached.
//
// Basic usage pattern when reading from a string (z) containing the
// zone data:
//
// for x := range dns.ParseZone(strings.NewReader(z), "", "") {
// if x.Error != nil {
// // Do something with x.RR
// }
// }
//
// Comments specified after an RR (and on the same line!) are returned too:
//
// foo. IN A 10.0.0.1 ; this is a comment
//
// The text "; this is comment" is returned in Token.Comment . Comments inside the
// RR are discarded. Comments on a line by themselves are discarded too.
func ParseZone(r io.Reader, origin, file string) chan *Token {
return parseZoneHelper(r, origin, file, 10000)
}
func parseZoneHelper(r io.Reader, origin, file string, chansize int) chan *Token {
t := make(chan *Token, chansize)
go parseZone(r, origin, file, t, 0)
return t
}
func parseZone(r io.Reader, origin, f string, t chan *Token, include int) {
defer func() {
if include == 0 {
close(t)
}
}()
s := scanInit(r)
c := make(chan lex, 1000)
// Start the lexer
go zlexer(s, c)
// 6 possible beginnings of a line, _ is a space
// 0. _RRTYPE -> all omitted until the rrtype
// 1. _OWNER _ _RRTYPE -> class/ttl omitted
// 2. _OWNER _ _STRING _ _RRTYPE -> class omitted
// 3. _OWNER _ _STRING _ _CLASS _ _RRTYPE -> ttl/class
// 4. _OWNER _ _CLASS _ _RRTYPE -> ttl omitted
// 5. _OWNER _ _CLASS _ _STRING _ _RRTYPE -> class/ttl (reversed)
// After detecting these, we know the _RRTYPE so we can jump to functions
// handling the rdata for each of these types.
if origin == "" {
origin = "."
}
origin = Fqdn(origin)
if _, ok := IsDomainName(origin); !ok {
t <- &Token{Error: &ParseError{f, "bad initial origin name", lex{}}}
return
}
st := _EXPECT_OWNER_DIR // initial state
var h RR_Header
var defttl uint32 = defaultTtl
var prevName string
for l := range c {
// Lexer spotted an error already
if l.err == true {
t <- &Token{Error: &ParseError{f, l.token, l}}
return
}
switch st {
case _EXPECT_OWNER_DIR:
// We can also expect a directive, like $TTL or $ORIGIN
h.Ttl = defttl
h.Class = ClassINET
switch l.value {
case _NEWLINE: // Empty line
st = _EXPECT_OWNER_DIR
case _OWNER:
h.Name = l.token
if l.token[0] == '@' {
h.Name = origin
prevName = h.Name
st = _EXPECT_OWNER_BL
break
}
if h.Name[l.length-1] != '.' {
h.Name = appendOrigin(h.Name, origin)
}
_, ok := IsDomainName(l.token)
if !ok {
t <- &Token{Error: &ParseError{f, "bad owner name", l}}
return
}
prevName = h.Name
st = _EXPECT_OWNER_BL
case _DIRTTL:
st = _EXPECT_DIRTTL_BL
case _DIRORIGIN:
st = _EXPECT_DIRORIGIN_BL
case _DIRINCLUDE:
st = _EXPECT_DIRINCLUDE_BL
case _DIRGENERATE:
st = _EXPECT_DIRGENERATE_BL
case _RRTYPE: // Everthing has been omitted, this is the first thing on the line
h.Name = prevName
h.Rrtype = l.torc
st = _EXPECT_RDATA
case _CLASS: // First thing on the line is the class
h.Name = prevName
h.Class = l.torc
st = _EXPECT_ANY_NOCLASS_BL
case _BLANK:
// Discard, can happen when there is nothing on the
// line except the RR type
case _STRING: // First thing on the is the ttl
if ttl, ok := stringToTtl(l.token); !ok {
t <- &Token{Error: &ParseError{f, "not a TTL", l}}
return
} else {
h.Ttl = ttl
// Don't about the defttl, we should take the $TTL value
// defttl = ttl
}
st = _EXPECT_ANY_NOTTL_BL
default:
t <- &Token{Error: &ParseError{f, "syntax error at beginning", l}}
return
}
case _EXPECT_DIRINCLUDE_BL:
if l.value != _BLANK {
t <- &Token{Error: &ParseError{f, "no blank after $INCLUDE-directive", l}}
return
}
st = _EXPECT_DIRINCLUDE
case _EXPECT_DIRINCLUDE:
if l.value != _STRING {
t <- &Token{Error: &ParseError{f, "expecting $INCLUDE value, not this...", l}}
return
}
neworigin := origin // There may be optionally a new origin set after the filename, if not use current one
l := <-c
switch l.value {
case _BLANK:
l := <-c
if l.value == _STRING {
if _, ok := IsDomainName(l.token); !ok {
t <- &Token{Error: &ParseError{f, "bad origin name", l}}
return
}
// a new origin is specified.
if l.token[l.length-1] != '.' {
if origin != "." { // Prevent .. endings
neworigin = l.token + "." + origin
} else {
neworigin = l.token + origin
}
} else {
neworigin = l.token
}
}
case _NEWLINE, _EOF:
// Ok
default:
t <- &Token{Error: &ParseError{f, "garbage after $INCLUDE", l}}
return
}
// Start with the new file
r1, e1 := os.Open(l.token)
if e1 != nil {
t <- &Token{Error: &ParseError{f, "failed to open `" + l.token + "'", l}}
return
}
if include+1 > 7 {
t <- &Token{Error: &ParseError{f, "too deeply nested $INCLUDE", l}}
return
}
parseZone(r1, l.token, neworigin, t, include+1)
st = _EXPECT_OWNER_DIR
case _EXPECT_DIRTTL_BL:
if l.value != _BLANK {
t <- &Token{Error: &ParseError{f, "no blank after $TTL-directive", l}}
return
}
st = _EXPECT_DIRTTL
case _EXPECT_DIRTTL:
if l.value != _STRING {
t <- &Token{Error: &ParseError{f, "expecting $TTL value, not this...", l}}
return
}
if e, _ := slurpRemainder(c, f); e != nil {
t <- &Token{Error: e}
return
}
if ttl, ok := stringToTtl(l.token); !ok {
t <- &Token{Error: &ParseError{f, "expecting $TTL value, not this...", l}}
return
} else {
defttl = ttl
}
st = _EXPECT_OWNER_DIR
case _EXPECT_DIRORIGIN_BL:
if l.value != _BLANK {
t <- &Token{Error: &ParseError{f, "no blank after $ORIGIN-directive", l}}
return
}
st = _EXPECT_DIRORIGIN
case _EXPECT_DIRORIGIN:
if l.value != _STRING {
t <- &Token{Error: &ParseError{f, "expecting $ORIGIN value, not this...", l}}
return
}
if e, _ := slurpRemainder(c, f); e != nil {
t <- &Token{Error: e}
}
if _, ok := IsDomainName(l.token); !ok {
t <- &Token{Error: &ParseError{f, "bad origin name", l}}
return
}
if l.token[l.length-1] != '.' {
if origin != "." { // Prevent .. endings
origin = l.token + "." + origin
} else {
origin = l.token + origin
}
} else {
origin = l.token
}
st = _EXPECT_OWNER_DIR
case _EXPECT_DIRGENERATE_BL:
if l.value != _BLANK {
t <- &Token{Error: &ParseError{f, "no blank after $GENERATE-directive", l}}
return
}
st = _EXPECT_DIRGENERATE
case _EXPECT_DIRGENERATE:
if l.value != _STRING {
t <- &Token{Error: &ParseError{f, "expecting $GENERATE value, not this...", l}}
return
}
if e := generate(l, c, t, origin); e != "" {
t <- &Token{Error: &ParseError{f, e, l}}
return
}
st = _EXPECT_OWNER_DIR
case _EXPECT_OWNER_BL:
if l.value != _BLANK {
t <- &Token{Error: &ParseError{f, "no blank after owner", l}}
return
}
st = _EXPECT_ANY
case _EXPECT_ANY:
switch l.value {
case _RRTYPE:
h.Rrtype = l.torc
st = _EXPECT_RDATA
case _CLASS:
h.Class = l.torc
st = _EXPECT_ANY_NOCLASS_BL
case _STRING: // TTL is this case
if ttl, ok := stringToTtl(l.token); !ok {
t <- &Token{Error: &ParseError{f, "not a TTL", l}}
return
} else {
h.Ttl = ttl
// defttl = ttl // don't set the defttl here
}
st = _EXPECT_ANY_NOTTL_BL
default:
t <- &Token{Error: &ParseError{f, "expecting RR type, TTL or class, not this...", l}}
return
}
case _EXPECT_ANY_NOCLASS_BL:
if l.value != _BLANK {
t <- &Token{Error: &ParseError{f, "no blank before class", l}}
return
}
st = _EXPECT_ANY_NOCLASS
case _EXPECT_ANY_NOTTL_BL:
if l.value != _BLANK {
t <- &Token{Error: &ParseError{f, "no blank before TTL", l}}
return
}
st = _EXPECT_ANY_NOTTL
case _EXPECT_ANY_NOTTL:
switch l.value {
case _CLASS:
h.Class = l.torc
st = _EXPECT_RRTYPE_BL
case _RRTYPE:
h.Rrtype = l.torc
st = _EXPECT_RDATA
default:
t <- &Token{Error: &ParseError{f, "expecting RR type or class, not this...", l}}
return
}
case _EXPECT_ANY_NOCLASS:
switch l.value {
case _STRING: // TTL
if ttl, ok := stringToTtl(l.token); !ok {
t <- &Token{Error: &ParseError{f, "not a TTL", l}}
return
} else {
h.Ttl = ttl
// defttl = ttl // don't set the def ttl anymore
}
st = _EXPECT_RRTYPE_BL
case _RRTYPE:
h.Rrtype = l.torc
st = _EXPECT_RDATA
default:
t <- &Token{Error: &ParseError{f, "expecting RR type or TTL, not this...", l}}
return
}
case _EXPECT_RRTYPE_BL:
if l.value != _BLANK {
t <- &Token{Error: &ParseError{f, "no blank before RR type", l}}
return
}
st = _EXPECT_RRTYPE
case _EXPECT_RRTYPE:
if l.value != _RRTYPE {
t <- &Token{Error: &ParseError{f, "unknown RR type", l}}
return
}
h.Rrtype = l.torc
st = _EXPECT_RDATA
case _EXPECT_RDATA:
r, e, c1 := setRR(h, c, origin, f)
if e != nil {
// If e.lex is nil than we have encounter a unknown RR type
// in that case we substitute our current lex token
if e.lex.token == "" && e.lex.value == 0 {
e.lex = l // Uh, dirty
}
t <- &Token{Error: e}
return
}
t <- &Token{RR: r, Comment: c1}
st = _EXPECT_OWNER_DIR
}
}
// If we get here, we and the h.Rrtype is still zero, we haven't parsed anything, this
// is not an error, because an empty zone file is still a zone file.
}
// zlexer scans the sourcefile and returns tokens on the channel c.
func zlexer(s *scan, c chan lex) {
var l lex
str := make([]byte, maxTok) // Should be enough for any token
stri := 0 // Offset in str (0 means empty)
com := make([]byte, maxTok) // Hold comment text
comi := 0
quote := false
escape := false
space := false
commt := false
rrtype := false
owner := true
brace := 0
x, err := s.tokenText()
defer close(c)
for err == nil {
l.column = s.position.Column
l.line = s.position.Line
if stri > maxTok {
l.token = "token length insufficient for parsing"
l.err = true
debug.Printf("[%+v]", l.token)
c <- l
return
}
if comi > maxTok {
l.token = "comment length insufficient for parsing"
l.err = true
debug.Printf("[%+v]", l.token)
c <- l
return
}
switch x {
case ' ', '\t':
if escape {
escape = false
str[stri] = x
stri++
break
}
if quote {
// Inside quotes this is legal
str[stri] = x
stri++
break
}
if commt {
com[comi] = x
comi++
break
}
if stri == 0 {
// Space directly in the beginning, handled in the grammar
} else if owner {
// If we have a string and its the first, make it an owner
l.value = _OWNER
l.token = string(str[:stri])
l.tokenUpper = strings.ToUpper(l.token)
l.length = stri
// escape $... start with a \ not a $, so this will work
switch l.tokenUpper {
case "$TTL":
l.value = _DIRTTL
case "$ORIGIN":
l.value = _DIRORIGIN
case "$INCLUDE":
l.value = _DIRINCLUDE
case "$GENERATE":
l.value = _DIRGENERATE
}
debug.Printf("[7 %+v]", l.token)
c <- l
} else {
l.value = _STRING
l.token = string(str[:stri])
l.tokenUpper = strings.ToUpper(l.token)
l.length = stri
if !rrtype {
if t, ok := StringToType[l.tokenUpper]; ok {
l.value = _RRTYPE
l.torc = t
rrtype = true
} else {
if strings.HasPrefix(l.tokenUpper, "TYPE") {
if t, ok := typeToInt(l.token); !ok {
l.token = "unknown RR type"
l.err = true
c <- l
return
} else {
l.value = _RRTYPE
l.torc = t
}
}
}
if t, ok := StringToClass[l.tokenUpper]; ok {
l.value = _CLASS
l.torc = t
} else {
if strings.HasPrefix(l.tokenUpper, "CLASS") {
if t, ok := classToInt(l.token); !ok {
l.token = "unknown class"
l.err = true
c <- l
return
} else {
l.value = _CLASS
l.torc = t
}
}
}
}
debug.Printf("[6 %+v]", l.token)
c <- l
}
stri = 0
// I reverse space stuff here
if !space && !commt {
l.value = _BLANK
l.token = " "
l.length = 1
debug.Printf("[5 %+v]", l.token)
c <- l
}
owner = false
space = true
case ';':
if escape {
escape = false
str[stri] = x
stri++
break
}
if quote {
// Inside quotes this is legal
str[stri] = x
stri++
break
}
if stri > 0 {
l.value = _STRING
l.token = string(str[:stri])
l.length = stri
debug.Printf("[4 %+v]", l.token)
c <- l
stri = 0
}
commt = true
com[comi] = ';'
comi++
case '\r':
escape = false
if quote {
str[stri] = x
stri++
break
}
// discard if outside of quotes
case '\n':
escape = false
// Escaped newline
if quote {
str[stri] = x
stri++
break
}
// inside quotes this is legal
if commt {
// Reset a comment
commt = false
rrtype = false
stri = 0
// If not in a brace this ends the comment AND the RR
if brace == 0 {
owner = true
owner = true
l.value = _NEWLINE
l.token = "\n"
l.length = 1
l.comment = string(com[:comi])
debug.Printf("[3 %+v %+v]", l.token, l.comment)
c <- l
l.comment = ""
comi = 0
break
}
com[comi] = ' ' // convert newline to space
comi++
break
}
if brace == 0 {
// If there is previous text, we should output it here
if stri != 0 {
l.value = _STRING
l.token = string(str[:stri])
l.tokenUpper = strings.ToUpper(l.token)
l.length = stri
if !rrtype {
if t, ok := StringToType[l.tokenUpper]; ok {
l.value = _RRTYPE
l.torc = t
rrtype = true
}
}
debug.Printf("[2 %+v]", l.token)
c <- l
}
l.value = _NEWLINE
l.token = "\n"
l.length = 1
debug.Printf("[1 %+v]", l.token)
c <- l
stri = 0
commt = false
rrtype = false
owner = true
comi = 0
}
case '\\':
// comments do not get escaped chars, everything is copied
if commt {
com[comi] = x
comi++
break
}
// something already escaped must be in string
if escape {
str[stri] = x
stri++
escape = false
break
}
// something escaped outside of string gets added to string
str[stri] = x
stri++
escape = true
case '"':
if commt {
com[comi] = x
comi++
break
}
if escape {
str[stri] = x
stri++
escape = false
break
}
space = false
// send previous gathered text and the quote
if stri != 0 {
l.value = _STRING
l.token = string(str[:stri])
l.length = stri
debug.Printf("[%+v]", l.token)
c <- l
stri = 0
}
// send quote itself as separate token
l.value = _QUOTE
l.token = "\""
l.length = 1
c <- l
quote = !quote
case '(', ')':
if commt {
com[comi] = x
comi++
break
}
if escape {
str[stri] = x
stri++
escape = false
break
}
if quote {
str[stri] = x
stri++
break
}
switch x {
case ')':
brace--
if brace < 0 {
l.token = "extra closing brace"
l.err = true
debug.Printf("[%+v]", l.token)
c <- l
return
}
case '(':
brace++
}
default:
escape = false
if commt {
com[comi] = x
comi++
break
}
str[stri] = x
stri++
space = false
}
x, err = s.tokenText()
}
if stri > 0 {
// Send remainder
l.token = string(str[:stri])
l.length = stri
l.value = _STRING
debug.Printf("[%+v]", l.token)
c <- l
}
}
// Extract the class number from CLASSxx
func classToInt(token string) (uint16, bool) {
class, ok := strconv.Atoi(token[5:])
if ok != nil || class > maxUint16 {
return 0, false
}
return uint16(class), true
}
// Extract the rr number from TYPExxx
func typeToInt(token string) (uint16, bool) {
typ, ok := strconv.Atoi(token[4:])
if ok != nil || typ > maxUint16 {
return 0, false
}
return uint16(typ), true
}
// Parse things like 2w, 2m, etc, Return the time in seconds.
func stringToTtl(token string) (uint32, bool) {
s := uint32(0)
i := uint32(0)
for _, c := range token {
switch c {
case 's', 'S':
s += i
i = 0
case 'm', 'M':
s += i * 60
i = 0
case 'h', 'H':
s += i * 60 * 60
i = 0
case 'd', 'D':
s += i * 60 * 60 * 24
i = 0
case 'w', 'W':
s += i * 60 * 60 * 24 * 7
i = 0
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
i *= 10
i += uint32(c) - '0'
default:
return 0, false
}
}
return s + i, true
}
// Parse LOC records' <digits>[.<digits>][mM] into a
// mantissa exponent format. Token should contain the entire
// string (i.e. no spaces allowed)
func stringToCm(token string) (e, m uint8, ok bool) {
if token[len(token)-1] == 'M' || token[len(token)-1] == 'm' {
token = token[0 : len(token)-1]
}
s := strings.SplitN(token, ".", 2)
var meters, cmeters, val int
var err error
switch len(s) {
case 2:
if cmeters, err = strconv.Atoi(s[1]); err != nil {
return
}
fallthrough
case 1:
if meters, err = strconv.Atoi(s[0]); err != nil {
return
}
case 0:
// huh?
return 0, 0, false
}
ok = true
if meters > 0 {
e = 2
val = meters
} else {
e = 0
val = cmeters
}
for val > 10 {
e++
val /= 10
}
if e > 9 {
ok = false
}
m = uint8(val)
return
}
func appendOrigin(name, origin string) string {
if origin == "." {
return name + origin
}
return name + "." + origin
}
// LOC record helper function
func locCheckNorth(token string, latitude uint32) (uint32, bool) {
switch token {
case "n", "N":
return _LOC_EQUATOR + latitude, true
case "s", "S":
return _LOC_EQUATOR - latitude, true
}
return latitude, false
}
// LOC record helper function
func locCheckEast(token string, longitude uint32) (uint32, bool) {
switch token {
case "e", "E":
return _LOC_EQUATOR + longitude, true
case "w", "W":
return _LOC_EQUATOR - longitude, true
}
return longitude, false
}
// "Eat" the rest of the "line". Return potential comments
func slurpRemainder(c chan lex, f string) (*ParseError, string) {
l := <-c
com := ""
switch l.value {
case _BLANK:
l = <-c
com = l.comment
if l.value != _NEWLINE && l.value != _EOF {
return &ParseError{f, "garbage after rdata", l}, ""
}
case _NEWLINE:
com = l.comment
case _EOF:
default:
return &ParseError{f, "garbage after rdata", l}, ""
}
return nil, com
}
// Parse a 64 bit-like ipv6 address: "0014:4fff:ff20:ee64"
// Used for NID and L64 record.
func stringToNodeID(l lex) (uint64, *ParseError) {
if len(l.token) < 19 {
return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l}
}
// There must be three colons at fixes postitions, if not its a parse error
if l.token[4] != ':' && l.token[9] != ':' && l.token[14] != ':' {
return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l}
}
s := l.token[0:4] + l.token[5:9] + l.token[10:14] + l.token[15:19]
u, e := strconv.ParseUint(s, 16, 64)
if e != nil {
return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l}
}
return u, nil
}

2178
Godeps/_workspace/src/github.com/miekg/dns/zscan_rr.go generated vendored Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,252 @@
// Copyright (c) 2012, Suryandaru Triandana <syndtr@gmail.com>
// All rights reserved.
//
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package leveldb
import (
"encoding/binary"
"fmt"
"github.com/syndtr/goleveldb/leveldb/errors"
"github.com/syndtr/goleveldb/leveldb/memdb"
)
type ErrBatchCorrupted struct {
Reason string
}
func (e *ErrBatchCorrupted) Error() string {
return fmt.Sprintf("leveldb: batch corrupted: %s", e.Reason)
}
func newErrBatchCorrupted(reason string) error {
return errors.NewErrCorrupted(nil, &ErrBatchCorrupted{reason})
}
const (
batchHdrLen = 8 + 4
batchGrowRec = 3000
)
type BatchReplay interface {
Put(key, value []byte)
Delete(key []byte)
}
// Batch is a write batch.
type Batch struct {
data []byte
rLen, bLen int
seq uint64
sync bool
}
func (b *Batch) grow(n int) {
off := len(b.data)
if off == 0 {
off = batchHdrLen
if b.data != nil {
b.data = b.data[:off]
}
}
if cap(b.data)-off < n {
if b.data == nil {
b.data = make([]byte, off, off+n)
} else {
odata := b.data
div := 1
if b.rLen > batchGrowRec {
div = b.rLen / batchGrowRec
}
b.data = make([]byte, off, off+n+(off-batchHdrLen)/div)
copy(b.data, odata)
}
}
}
func (b *Batch) appendRec(kt kType, key, value []byte) {
n := 1 + binary.MaxVarintLen32 + len(key)
if kt == ktVal {
n += binary.MaxVarintLen32 + len(value)
}
b.grow(n)
off := len(b.data)
data := b.data[:off+n]
data[off] = byte(kt)
off += 1
off += binary.PutUvarint(data[off:], uint64(len(key)))
copy(data[off:], key)
off += len(key)
if kt == ktVal {
off += binary.PutUvarint(data[off:], uint64(len(value)))
copy(data[off:], value)
off += len(value)
}
b.data = data[:off]
b.rLen++
// Include 8-byte ikey header
b.bLen += len(key) + len(value) + 8
}
// Put appends 'put operation' of the given key/value pair to the batch.
// It is safe to modify the contents of the argument after Put returns.
func (b *Batch) Put(key, value []byte) {
b.appendRec(ktVal, key, value)
}
// Delete appends 'delete operation' of the given key to the batch.
// It is safe to modify the contents of the argument after Delete returns.
func (b *Batch) Delete(key []byte) {
b.appendRec(ktDel, key, nil)
}
// Dump dumps batch contents. The returned slice can be loaded into the
// batch using Load method.
// The returned slice is not its own copy, so the contents should not be
// modified.
func (b *Batch) Dump() []byte {
return b.encode()
}
// Load loads given slice into the batch. Previous contents of the batch
// will be discarded.
// The given slice will not be copied and will be used as batch buffer, so
// it is not safe to modify the contents of the slice.
func (b *Batch) Load(data []byte) error {
return b.decode(0, data)
}
// Replay replays batch contents.
func (b *Batch) Replay(r BatchReplay) error {
return b.decodeRec(func(i int, kt kType, key, value []byte) {
switch kt {
case ktVal:
r.Put(key, value)
case ktDel:
r.Delete(key)
}
})
}
// Len returns number of records in the batch.
func (b *Batch) Len() int {
return b.rLen
}
// Reset resets the batch.
func (b *Batch) Reset() {
b.data = b.data[:0]
b.seq = 0
b.rLen = 0
b.bLen = 0
b.sync = false
}
func (b *Batch) init(sync bool) {
b.sync = sync
}
func (b *Batch) append(p *Batch) {
if p.rLen > 0 {
b.grow(len(p.data) - batchHdrLen)
b.data = append(b.data, p.data[batchHdrLen:]...)
b.rLen += p.rLen
}
if p.sync {
b.sync = true
}
}
// size returns sums of key/value pair length plus 8-bytes ikey.
func (b *Batch) size() int {
return b.bLen
}
func (b *Batch) encode() []byte {
b.grow(0)
binary.LittleEndian.PutUint64(b.data, b.seq)
binary.LittleEndian.PutUint32(b.data[8:], uint32(b.rLen))
return b.data
}
func (b *Batch) decode(prevSeq uint64, data []byte) error {
if len(data) < batchHdrLen {
return newErrBatchCorrupted("too short")
}
b.seq = binary.LittleEndian.Uint64(data)
if b.seq < prevSeq {
return newErrBatchCorrupted("invalid sequence number")
}
b.rLen = int(binary.LittleEndian.Uint32(data[8:]))
if b.rLen < 0 {
return newErrBatchCorrupted("invalid records length")
}
// No need to be precise at this point, it won't be used anyway
b.bLen = len(data) - batchHdrLen
b.data = data
return nil
}
func (b *Batch) decodeRec(f func(i int, kt kType, key, value []byte)) (err error) {
off := batchHdrLen
for i := 0; i < b.rLen; i++ {
if off >= len(b.data) {
return newErrBatchCorrupted("invalid records length")
}
kt := kType(b.data[off])
if kt > ktVal {
return newErrBatchCorrupted("bad record: invalid type")
}
off += 1
x, n := binary.Uvarint(b.data[off:])
off += n
if n <= 0 || off+int(x) > len(b.data) {
return newErrBatchCorrupted("bad record: invalid key length")
}
key := b.data[off : off+int(x)]
off += int(x)
var value []byte
if kt == ktVal {
x, n := binary.Uvarint(b.data[off:])
off += n
if n <= 0 || off+int(x) > len(b.data) {
return newErrBatchCorrupted("bad record: invalid value length")
}
value = b.data[off : off+int(x)]
off += int(x)
}
f(i, kt, key, value)
}
return nil
}
func (b *Batch) memReplay(to *memdb.DB) error {
return b.decodeRec(func(i int, kt kType, key, value []byte) {
ikey := newIkey(key, b.seq+uint64(i), kt)
to.Put(ikey, value)
})
}
func (b *Batch) memDecodeAndReplay(prevSeq uint64, data []byte, to *memdb.DB) error {
if err := b.decode(prevSeq, data); err != nil {
return err
}
return b.memReplay(to)
}
func (b *Batch) revertMemReplay(to *memdb.DB) error {
return b.decodeRec(func(i int, kt kType, key, value []byte) {
ikey := newIkey(key, b.seq+uint64(i), kt)
to.Delete(ikey)
})
}

View file

@ -0,0 +1,120 @@
// Copyright (c) 2012, Suryandaru Triandana <syndtr@gmail.com>
// All rights reserved.
//
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package leveldb
import (
"bytes"
"testing"
"github.com/syndtr/goleveldb/leveldb/comparer"
"github.com/syndtr/goleveldb/leveldb/memdb"
)
type tbRec struct {
kt kType
key, value []byte
}
type testBatch struct {
rec []*tbRec
}
func (p *testBatch) Put(key, value []byte) {
p.rec = append(p.rec, &tbRec{ktVal, key, value})
}
func (p *testBatch) Delete(key []byte) {
p.rec = append(p.rec, &tbRec{ktDel, key, nil})
}
func compareBatch(t *testing.T, b1, b2 *Batch) {
if b1.seq != b2.seq {
t.Errorf("invalid seq number want %d, got %d", b1.seq, b2.seq)
}
if b1.Len() != b2.Len() {
t.Fatalf("invalid record length want %d, got %d", b1.Len(), b2.Len())
}
p1, p2 := new(testBatch), new(testBatch)
err := b1.Replay(p1)
if err != nil {
t.Fatal("error when replaying batch 1: ", err)
}
err = b2.Replay(p2)
if err != nil {
t.Fatal("error when replaying batch 2: ", err)
}
for i := range p1.rec {
r1, r2 := p1.rec[i], p2.rec[i]
if r1.kt != r2.kt {
t.Errorf("invalid type on record '%d' want %d, got %d", i, r1.kt, r2.kt)
}
if !bytes.Equal(r1.key, r2.key) {
t.Errorf("invalid key on record '%d' want %s, got %s", i, string(r1.key), string(r2.key))
}
if r1.kt == ktVal {
if !bytes.Equal(r1.value, r2.value) {
t.Errorf("invalid value on record '%d' want %s, got %s", i, string(r1.value), string(r2.value))
}
}
}
}
func TestBatch_EncodeDecode(t *testing.T) {
b1 := new(Batch)
b1.seq = 10009
b1.Put([]byte("key1"), []byte("value1"))
b1.Put([]byte("key2"), []byte("value2"))
b1.Delete([]byte("key1"))
b1.Put([]byte("k"), []byte(""))
b1.Put([]byte("zzzzzzzzzzz"), []byte("zzzzzzzzzzzzzzzzzzzzzzzz"))
b1.Delete([]byte("key10000"))
b1.Delete([]byte("k"))
buf := b1.encode()
b2 := new(Batch)
err := b2.decode(0, buf)
if err != nil {
t.Error("error when decoding batch: ", err)
}
compareBatch(t, b1, b2)
}
func TestBatch_Append(t *testing.T) {
b1 := new(Batch)
b1.seq = 10009
b1.Put([]byte("key1"), []byte("value1"))
b1.Put([]byte("key2"), []byte("value2"))
b1.Delete([]byte("key1"))
b1.Put([]byte("foo"), []byte("foovalue"))
b1.Put([]byte("bar"), []byte("barvalue"))
b2a := new(Batch)
b2a.seq = 10009
b2a.Put([]byte("key1"), []byte("value1"))
b2a.Put([]byte("key2"), []byte("value2"))
b2a.Delete([]byte("key1"))
b2b := new(Batch)
b2b.Put([]byte("foo"), []byte("foovalue"))
b2b.Put([]byte("bar"), []byte("barvalue"))
b2a.append(b2b)
compareBatch(t, b1, b2a)
}
func TestBatch_Size(t *testing.T) {
b := new(Batch)
for i := 0; i < 2; i++ {
b.Put([]byte("key1"), []byte("value1"))
b.Put([]byte("key2"), []byte("value2"))
b.Delete([]byte("key1"))
b.Put([]byte("foo"), []byte("foovalue"))
b.Put([]byte("bar"), []byte("barvalue"))
mem := memdb.New(&iComparer{comparer.DefaultComparer}, 0)
b.memReplay(mem)
if b.size() != mem.Size() {
t.Errorf("invalid batch size calculation, want=%d got=%d", mem.Size(), b.size())
}
b.Reset()
}
}

View file

@ -0,0 +1,464 @@
// Copyright (c) 2012, Suryandaru Triandana <syndtr@gmail.com>
// All rights reserved.
//
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package leveldb
import (
"bytes"
"fmt"
"math/rand"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/syndtr/goleveldb/leveldb/iterator"
"github.com/syndtr/goleveldb/leveldb/opt"
"github.com/syndtr/goleveldb/leveldb/storage"
)
func randomString(r *rand.Rand, n int) []byte {
b := new(bytes.Buffer)
for i := 0; i < n; i++ {
b.WriteByte(' ' + byte(r.Intn(95)))
}
return b.Bytes()
}
func compressibleStr(r *rand.Rand, frac float32, n int) []byte {
nn := int(float32(n) * frac)
rb := randomString(r, nn)
b := make([]byte, 0, n+nn)
for len(b) < n {
b = append(b, rb...)
}
return b[:n]
}
type valueGen struct {
src []byte
pos int
}
func newValueGen(frac float32) *valueGen {
v := new(valueGen)
r := rand.New(rand.NewSource(301))
v.src = make([]byte, 0, 1048576+100)
for len(v.src) < 1048576 {
v.src = append(v.src, compressibleStr(r, frac, 100)...)
}
return v
}
func (v *valueGen) get(n int) []byte {
if v.pos+n > len(v.src) {
v.pos = 0
}
v.pos += n
return v.src[v.pos-n : v.pos]
}
var benchDB = filepath.Join(os.TempDir(), fmt.Sprintf("goleveldbbench-%d", os.Getuid()))
type dbBench struct {
b *testing.B
stor storage.Storage
db *DB
o *opt.Options
ro *opt.ReadOptions
wo *opt.WriteOptions
keys, values [][]byte
}
func openDBBench(b *testing.B, noCompress bool) *dbBench {
_, err := os.Stat(benchDB)
if err == nil {
err = os.RemoveAll(benchDB)
if err != nil {
b.Fatal("cannot remove old db: ", err)
}
}
p := &dbBench{
b: b,
o: &opt.Options{},
ro: &opt.ReadOptions{},
wo: &opt.WriteOptions{},
}
p.stor, err = storage.OpenFile(benchDB)
if err != nil {
b.Fatal("cannot open stor: ", err)
}
if noCompress {
p.o.Compression = opt.NoCompression
}
p.db, err = Open(p.stor, p.o)
if err != nil {
b.Fatal("cannot open db: ", err)
}
runtime.GOMAXPROCS(runtime.NumCPU())
return p
}
func (p *dbBench) reopen() {
p.db.Close()
var err error
p.db, err = Open(p.stor, p.o)
if err != nil {
p.b.Fatal("Reopen: got error: ", err)
}
}
func (p *dbBench) populate(n int) {
p.keys, p.values = make([][]byte, n), make([][]byte, n)
v := newValueGen(0.5)
for i := range p.keys {
p.keys[i], p.values[i] = []byte(fmt.Sprintf("%016d", i)), v.get(100)
}
}
func (p *dbBench) randomize() {
m := len(p.keys)
times := m * 2
r1, r2 := rand.New(rand.NewSource(0xdeadbeef)), rand.New(rand.NewSource(0xbeefface))
for n := 0; n < times; n++ {
i, j := r1.Int()%m, r2.Int()%m
if i == j {
continue
}
p.keys[i], p.keys[j] = p.keys[j], p.keys[i]
p.values[i], p.values[j] = p.values[j], p.values[i]
}
}
func (p *dbBench) writes(perBatch int) {
b := p.b
db := p.db
n := len(p.keys)
m := n / perBatch
if n%perBatch > 0 {
m++
}
batches := make([]Batch, m)
j := 0
for i := range batches {
first := true
for ; j < n && ((j+1)%perBatch != 0 || first); j++ {
first = false
batches[i].Put(p.keys[j], p.values[j])
}
}
runtime.GC()
b.ResetTimer()
b.StartTimer()
for i := range batches {
err := db.Write(&(batches[i]), p.wo)
if err != nil {
b.Fatal("write failed: ", err)
}
}
b.StopTimer()
b.SetBytes(116)
}
func (p *dbBench) gc() {
p.keys, p.values = nil, nil
runtime.GC()
}
func (p *dbBench) puts() {
b := p.b
db := p.db
b.ResetTimer()
b.StartTimer()
for i := range p.keys {
err := db.Put(p.keys[i], p.values[i], p.wo)
if err != nil {
b.Fatal("put failed: ", err)
}
}
b.StopTimer()
b.SetBytes(116)
}
func (p *dbBench) fill() {
b := p.b
db := p.db
perBatch := 10000
batch := new(Batch)
for i, n := 0, len(p.keys); i < n; {
first := true
for ; i < n && ((i+1)%perBatch != 0 || first); i++ {
first = false
batch.Put(p.keys[i], p.values[i])
}
err := db.Write(batch, p.wo)
if err != nil {
b.Fatal("write failed: ", err)
}
batch.Reset()
}
}
func (p *dbBench) gets() {
b := p.b
db := p.db
b.ResetTimer()
for i := range p.keys {
_, err := db.Get(p.keys[i], p.ro)
if err != nil {
b.Error("got error: ", err)
}
}
b.StopTimer()
}
func (p *dbBench) seeks() {
b := p.b
iter := p.newIter()
defer iter.Release()
b.ResetTimer()
for i := range p.keys {
if !iter.Seek(p.keys[i]) {
b.Error("value not found for: ", string(p.keys[i]))
}
}
b.StopTimer()
}
func (p *dbBench) newIter() iterator.Iterator {
iter := p.db.NewIterator(nil, p.ro)
err := iter.Error()
if err != nil {
p.b.Fatal("cannot create iterator: ", err)
}
return iter
}
func (p *dbBench) close() {
if bp, err := p.db.GetProperty("leveldb.blockpool"); err == nil {
p.b.Log("Block pool stats: ", bp)
}
p.db.Close()
p.stor.Close()
os.RemoveAll(benchDB)
p.db = nil
p.keys = nil
p.values = nil
runtime.GC()
runtime.GOMAXPROCS(1)
}
func BenchmarkDBWrite(b *testing.B) {
p := openDBBench(b, false)
p.populate(b.N)
p.writes(1)
p.close()
}
func BenchmarkDBWriteBatch(b *testing.B) {
p := openDBBench(b, false)
p.populate(b.N)
p.writes(1000)
p.close()
}
func BenchmarkDBWriteUncompressed(b *testing.B) {
p := openDBBench(b, true)
p.populate(b.N)
p.writes(1)
p.close()
}
func BenchmarkDBWriteBatchUncompressed(b *testing.B) {
p := openDBBench(b, true)
p.populate(b.N)
p.writes(1000)
p.close()
}
func BenchmarkDBWriteRandom(b *testing.B) {
p := openDBBench(b, false)
p.populate(b.N)
p.randomize()
p.writes(1)
p.close()
}
func BenchmarkDBWriteRandomSync(b *testing.B) {
p := openDBBench(b, false)
p.wo.Sync = true
p.populate(b.N)
p.writes(1)
p.close()
}
func BenchmarkDBOverwrite(b *testing.B) {
p := openDBBench(b, false)
p.populate(b.N)
p.writes(1)
p.writes(1)
p.close()
}
func BenchmarkDBOverwriteRandom(b *testing.B) {
p := openDBBench(b, false)
p.populate(b.N)
p.writes(1)
p.randomize()
p.writes(1)
p.close()
}
func BenchmarkDBPut(b *testing.B) {
p := openDBBench(b, false)
p.populate(b.N)
p.puts()
p.close()
}
func BenchmarkDBRead(b *testing.B) {
p := openDBBench(b, false)
p.populate(b.N)
p.fill()
p.gc()
iter := p.newIter()
b.ResetTimer()
for iter.Next() {
}
iter.Release()
b.StopTimer()
b.SetBytes(116)
p.close()
}
func BenchmarkDBReadGC(b *testing.B) {
p := openDBBench(b, false)
p.populate(b.N)
p.fill()
iter := p.newIter()
b.ResetTimer()
for iter.Next() {
}
iter.Release()
b.StopTimer()
b.SetBytes(116)
p.close()
}
func BenchmarkDBReadUncompressed(b *testing.B) {
p := openDBBench(b, true)
p.populate(b.N)
p.fill()
p.gc()
iter := p.newIter()
b.ResetTimer()
for iter.Next() {
}
iter.Release()
b.StopTimer()
b.SetBytes(116)
p.close()
}
func BenchmarkDBReadTable(b *testing.B) {
p := openDBBench(b, false)
p.populate(b.N)
p.fill()
p.reopen()
p.gc()
iter := p.newIter()
b.ResetTimer()
for iter.Next() {
}
iter.Release()
b.StopTimer()
b.SetBytes(116)
p.close()
}
func BenchmarkDBReadReverse(b *testing.B) {
p := openDBBench(b, false)
p.populate(b.N)
p.fill()
p.gc()
iter := p.newIter()
b.ResetTimer()
iter.Last()
for iter.Prev() {
}
iter.Release()
b.StopTimer()
b.SetBytes(116)
p.close()
}
func BenchmarkDBReadReverseTable(b *testing.B) {
p := openDBBench(b, false)
p.populate(b.N)
p.fill()
p.reopen()
p.gc()
iter := p.newIter()
b.ResetTimer()
iter.Last()
for iter.Prev() {
}
iter.Release()
b.StopTimer()
b.SetBytes(116)
p.close()
}
func BenchmarkDBSeek(b *testing.B) {
p := openDBBench(b, false)
p.populate(b.N)
p.fill()
p.seeks()
p.close()
}
func BenchmarkDBSeekRandom(b *testing.B) {
p := openDBBench(b, false)
p.populate(b.N)
p.fill()
p.randomize()
p.seeks()
p.close()
}
func BenchmarkDBGet(b *testing.B) {
p := openDBBench(b, false)
p.populate(b.N)
p.fill()
p.gets()
p.close()
}
func BenchmarkDBGetRandom(b *testing.B) {
p := openDBBench(b, false)
p.populate(b.N)
p.fill()
p.randomize()
p.gets()
p.close()
}

View file

@ -0,0 +1,676 @@
// Copyright (c) 2012, Suryandaru Triandana <syndtr@gmail.com>
// All rights reserved.
//
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// Package cache provides interface and implementation of a cache algorithms.
package cache
import (
"sync"
"sync/atomic"
"unsafe"
"github.com/syndtr/goleveldb/leveldb/util"
)
// Cacher provides interface to implements a caching functionality.
// An implementation must be goroutine-safe.
type Cacher interface {
// Capacity returns cache capacity.
Capacity() int
// SetCapacity sets cache capacity.
SetCapacity(capacity int)
// Promote promotes the 'cache node'.
Promote(n *Node)
// Ban evicts the 'cache node' and prevent subsequent 'promote'.
Ban(n *Node)
// Evict evicts the 'cache node'.
Evict(n *Node)
// EvictNS evicts 'cache node' with the given namespace.
EvictNS(ns uint64)
// EvictAll evicts all 'cache node'.
EvictAll()
// Close closes the 'cache tree'
Close() error
}
// Value is a 'cacheable object'. It may implements util.Releaser, if
// so the the Release method will be called once object is released.
type Value interface{}
type CacheGetter struct {
Cache *Cache
NS uint64
}
func (g *CacheGetter) Get(key uint64, setFunc func() (size int, value Value)) *Handle {
return g.Cache.Get(g.NS, key, setFunc)
}
// The hash tables implementation is based on:
// "Dynamic-Sized Nonblocking Hash Tables", by Yujie Liu, Kunlong Zhang, and Michael Spear. ACM Symposium on Principles of Distributed Computing, Jul 2014.
const (
mInitialSize = 1 << 4
mOverflowThreshold = 1 << 5
mOverflowGrowThreshold = 1 << 7
)
type mBucket struct {
mu sync.Mutex
node []*Node
frozen bool
}
func (b *mBucket) freeze() []*Node {
b.mu.Lock()
defer b.mu.Unlock()
if !b.frozen {
b.frozen = true
}
return b.node
}
func (b *mBucket) get(r *Cache, h *mNode, hash uint32, ns, key uint64, noset bool) (done, added bool, n *Node) {
b.mu.Lock()
if b.frozen {
b.mu.Unlock()
return
}
// Scan the node.
for _, n := range b.node {
if n.hash == hash && n.ns == ns && n.key == key {
atomic.AddInt32(&n.ref, 1)
b.mu.Unlock()
return true, false, n
}
}
// Get only.
if noset {
b.mu.Unlock()
return true, false, nil
}
// Create node.
n = &Node{
r: r,
hash: hash,
ns: ns,
key: key,
ref: 1,
}
// Add node to bucket.
b.node = append(b.node, n)
bLen := len(b.node)
b.mu.Unlock()
// Update counter.
grow := atomic.AddInt32(&r.nodes, 1) >= h.growThreshold
if bLen > mOverflowThreshold {
grow = grow || atomic.AddInt32(&h.overflow, 1) >= mOverflowGrowThreshold
}
// Grow.
if grow && atomic.CompareAndSwapInt32(&h.resizeInProgess, 0, 1) {
nhLen := len(h.buckets) << 1
nh := &mNode{
buckets: make([]unsafe.Pointer, nhLen),
mask: uint32(nhLen) - 1,
pred: unsafe.Pointer(h),
growThreshold: int32(nhLen * mOverflowThreshold),
shrinkThreshold: int32(nhLen >> 1),
}
ok := atomic.CompareAndSwapPointer(&r.mHead, unsafe.Pointer(h), unsafe.Pointer(nh))
if !ok {
panic("BUG: failed swapping head")
}
go nh.initBuckets()
}
return true, true, n
}
func (b *mBucket) delete(r *Cache, h *mNode, hash uint32, ns, key uint64) (done, deleted bool) {
b.mu.Lock()
if b.frozen {
b.mu.Unlock()
return
}
// Scan the node.
var (
n *Node
bLen int
)
for i := range b.node {
n = b.node[i]
if n.ns == ns && n.key == key {
if atomic.LoadInt32(&n.ref) == 0 {
deleted = true
// Call releaser.
if n.value != nil {
if r, ok := n.value.(util.Releaser); ok {
r.Release()
}
n.value = nil
}
// Remove node from bucket.
b.node = append(b.node[:i], b.node[i+1:]...)
bLen = len(b.node)
}
break
}
}
b.mu.Unlock()
if deleted {
// Call OnDel.
for _, f := range n.onDel {
f()
}
// Update counter.
atomic.AddInt32(&r.size, int32(n.size)*-1)
shrink := atomic.AddInt32(&r.nodes, -1) < h.shrinkThreshold
if bLen >= mOverflowThreshold {
atomic.AddInt32(&h.overflow, -1)
}
// Shrink.
if shrink && len(h.buckets) > mInitialSize && atomic.CompareAndSwapInt32(&h.resizeInProgess, 0, 1) {
nhLen := len(h.buckets) >> 1
nh := &mNode{
buckets: make([]unsafe.Pointer, nhLen),
mask: uint32(nhLen) - 1,
pred: unsafe.Pointer(h),
growThreshold: int32(nhLen * mOverflowThreshold),
shrinkThreshold: int32(nhLen >> 1),
}
ok := atomic.CompareAndSwapPointer(&r.mHead, unsafe.Pointer(h), unsafe.Pointer(nh))
if !ok {
panic("BUG: failed swapping head")
}
go nh.initBuckets()
}
}
return true, deleted
}
type mNode struct {
buckets []unsafe.Pointer // []*mBucket
mask uint32
pred unsafe.Pointer // *mNode
resizeInProgess int32
overflow int32
growThreshold int32
shrinkThreshold int32
}
func (n *mNode) initBucket(i uint32) *mBucket {
if b := (*mBucket)(atomic.LoadPointer(&n.buckets[i])); b != nil {
return b
}
p := (*mNode)(atomic.LoadPointer(&n.pred))
if p != nil {
var node []*Node
if n.mask > p.mask {
// Grow.
pb := (*mBucket)(atomic.LoadPointer(&p.buckets[i&p.mask]))
if pb == nil {
pb = p.initBucket(i & p.mask)
}
m := pb.freeze()
// Split nodes.
for _, x := range m {
if x.hash&n.mask == i {
node = append(node, x)
}
}
} else {
// Shrink.
pb0 := (*mBucket)(atomic.LoadPointer(&p.buckets[i]))
if pb0 == nil {
pb0 = p.initBucket(i)
}
pb1 := (*mBucket)(atomic.LoadPointer(&p.buckets[i+uint32(len(n.buckets))]))
if pb1 == nil {
pb1 = p.initBucket(i + uint32(len(n.buckets)))
}
m0 := pb0.freeze()
m1 := pb1.freeze()
// Merge nodes.
node = make([]*Node, 0, len(m0)+len(m1))
node = append(node, m0...)
node = append(node, m1...)
}
b := &mBucket{node: node}
if atomic.CompareAndSwapPointer(&n.buckets[i], nil, unsafe.Pointer(b)) {
if len(node) > mOverflowThreshold {
atomic.AddInt32(&n.overflow, int32(len(node)-mOverflowThreshold))
}
return b
}
}
return (*mBucket)(atomic.LoadPointer(&n.buckets[i]))
}
func (n *mNode) initBuckets() {
for i := range n.buckets {
n.initBucket(uint32(i))
}
atomic.StorePointer(&n.pred, nil)
}
// Cache is a 'cache map'.
type Cache struct {
mu sync.RWMutex
mHead unsafe.Pointer // *mNode
nodes int32
size int32
cacher Cacher
closed bool
}
// NewCache creates a new 'cache map'. The cacher is optional and
// may be nil.
func NewCache(cacher Cacher) *Cache {
h := &mNode{
buckets: make([]unsafe.Pointer, mInitialSize),
mask: mInitialSize - 1,
growThreshold: int32(mInitialSize * mOverflowThreshold),
shrinkThreshold: 0,
}
for i := range h.buckets {
h.buckets[i] = unsafe.Pointer(&mBucket{})
}
r := &Cache{
mHead: unsafe.Pointer(h),
cacher: cacher,
}
return r
}
func (r *Cache) getBucket(hash uint32) (*mNode, *mBucket) {
h := (*mNode)(atomic.LoadPointer(&r.mHead))
i := hash & h.mask
b := (*mBucket)(atomic.LoadPointer(&h.buckets[i]))
if b == nil {
b = h.initBucket(i)
}
return h, b
}
func (r *Cache) delete(n *Node) bool {
for {
h, b := r.getBucket(n.hash)
done, deleted := b.delete(r, h, n.hash, n.ns, n.key)
if done {
return deleted
}
}
return false
}
// Nodes returns number of 'cache node' in the map.
func (r *Cache) Nodes() int {
return int(atomic.LoadInt32(&r.nodes))
}
// Size returns sums of 'cache node' size in the map.
func (r *Cache) Size() int {
return int(atomic.LoadInt32(&r.size))
}
// Capacity returns cache capacity.
func (r *Cache) Capacity() int {
if r.cacher == nil {
return 0
}
return r.cacher.Capacity()
}
// SetCapacity sets cache capacity.
func (r *Cache) SetCapacity(capacity int) {
if r.cacher != nil {
r.cacher.SetCapacity(capacity)
}
}
// Get gets 'cache node' with the given namespace and key.
// If cache node is not found and setFunc is not nil, Get will atomically creates
// the 'cache node' by calling setFunc. Otherwise Get will returns nil.
//
// The returned 'cache handle' should be released after use by calling Release
// method.
func (r *Cache) Get(ns, key uint64, setFunc func() (size int, value Value)) *Handle {
r.mu.RLock()
defer r.mu.RUnlock()
if r.closed {
return nil
}
hash := murmur32(ns, key, 0xf00)
for {
h, b := r.getBucket(hash)
done, _, n := b.get(r, h, hash, ns, key, setFunc == nil)
if done {
if n != nil {
n.mu.Lock()
if n.value == nil {
if setFunc == nil {
n.mu.Unlock()
n.unref()
return nil
}
n.size, n.value = setFunc()
if n.value == nil {
n.size = 0
n.mu.Unlock()
n.unref()
return nil
}
atomic.AddInt32(&r.size, int32(n.size))
}
n.mu.Unlock()
if r.cacher != nil {
r.cacher.Promote(n)
}
return &Handle{unsafe.Pointer(n)}
}
break
}
}
return nil
}
// Delete removes and ban 'cache node' with the given namespace and key.
// A banned 'cache node' will never inserted into the 'cache tree'. Ban
// only attributed to the particular 'cache node', so when a 'cache node'
// is recreated it will not be banned.
//
// If onDel is not nil, then it will be executed if such 'cache node'
// doesn't exist or once the 'cache node' is released.
//
// Delete return true is such 'cache node' exist.
func (r *Cache) Delete(ns, key uint64, onDel func()) bool {
r.mu.RLock()
defer r.mu.RUnlock()
if r.closed {
return false
}
hash := murmur32(ns, key, 0xf00)
for {
h, b := r.getBucket(hash)
done, _, n := b.get(r, h, hash, ns, key, true)
if done {
if n != nil {
if onDel != nil {
n.mu.Lock()
n.onDel = append(n.onDel, onDel)
n.mu.Unlock()
}
if r.cacher != nil {
r.cacher.Ban(n)
}
n.unref()
return true
}
break
}
}
if onDel != nil {
onDel()
}
return false
}
// Evict evicts 'cache node' with the given namespace and key. This will
// simply call Cacher.Evict.
//
// Evict return true is such 'cache node' exist.
func (r *Cache) Evict(ns, key uint64) bool {
r.mu.RLock()
defer r.mu.RUnlock()
if r.closed {
return false
}
hash := murmur32(ns, key, 0xf00)
for {
h, b := r.getBucket(hash)
done, _, n := b.get(r, h, hash, ns, key, true)
if done {
if n != nil {
if r.cacher != nil {
r.cacher.Evict(n)
}
n.unref()
return true
}
break
}
}
return false
}
// EvictNS evicts 'cache node' with the given namespace. This will
// simply call Cacher.EvictNS.
func (r *Cache) EvictNS(ns uint64) {
r.mu.RLock()
defer r.mu.RUnlock()
if r.closed {
return
}
if r.cacher != nil {
r.cacher.EvictNS(ns)
}
}
// EvictAll evicts all 'cache node'. This will simply call Cacher.EvictAll.
func (r *Cache) EvictAll() {
r.mu.RLock()
defer r.mu.RUnlock()
if r.closed {
return
}
if r.cacher != nil {
r.cacher.EvictAll()
}
}
// Close closes the 'cache map' and releases all 'cache node'.
func (r *Cache) Close() error {
r.mu.Lock()
if !r.closed {
r.closed = true
if r.cacher != nil {
if err := r.cacher.Close(); err != nil {
return err
}
}
h := (*mNode)(r.mHead)
h.initBuckets()
for i := range h.buckets {
b := (*mBucket)(h.buckets[i])
for _, n := range b.node {
// Call releaser.
if n.value != nil {
if r, ok := n.value.(util.Releaser); ok {
r.Release()
}
n.value = nil
}
// Call OnDel.
for _, f := range n.onDel {
f()
}
}
}
}
r.mu.Unlock()
return nil
}
// Node is a 'cache node'.
type Node struct {
r *Cache
hash uint32
ns, key uint64
mu sync.Mutex
size int
value Value
ref int32
onDel []func()
CacheData unsafe.Pointer
}
// NS returns this 'cache node' namespace.
func (n *Node) NS() uint64 {
return n.ns
}
// Key returns this 'cache node' key.
func (n *Node) Key() uint64 {
return n.key
}
// Size returns this 'cache node' size.
func (n *Node) Size() int {
return n.size
}
// Value returns this 'cache node' value.
func (n *Node) Value() Value {
return n.value
}
// Ref returns this 'cache node' ref counter.
func (n *Node) Ref() int32 {
return atomic.LoadInt32(&n.ref)
}
// GetHandle returns an handle for this 'cache node'.
func (n *Node) GetHandle() *Handle {
if atomic.AddInt32(&n.ref, 1) <= 1 {
panic("BUG: Node.GetHandle on zero ref")
}
return &Handle{unsafe.Pointer(n)}
}
func (n *Node) unref() {
if atomic.AddInt32(&n.ref, -1) == 0 {
n.r.delete(n)
}
}
func (n *Node) unrefLocked() {
if atomic.AddInt32(&n.ref, -1) == 0 {
n.r.mu.RLock()
if !n.r.closed {
n.r.delete(n)
}
n.r.mu.RUnlock()
}
}
type Handle struct {
n unsafe.Pointer // *Node
}
func (h *Handle) Value() Value {
n := (*Node)(atomic.LoadPointer(&h.n))
if n != nil {
return n.value
}
return nil
}
func (h *Handle) Release() {
nPtr := atomic.LoadPointer(&h.n)
if nPtr != nil && atomic.CompareAndSwapPointer(&h.n, nPtr, nil) {
n := (*Node)(nPtr)
n.unrefLocked()
}
}
func murmur32(ns, key uint64, seed uint32) uint32 {
const (
m = uint32(0x5bd1e995)
r = 24
)
k1 := uint32(ns >> 32)
k2 := uint32(ns)
k3 := uint32(key >> 32)
k4 := uint32(key)
k1 *= m
k1 ^= k1 >> r
k1 *= m
k2 *= m
k2 ^= k2 >> r
k2 *= m
k3 *= m
k3 ^= k3 >> r
k3 *= m
k4 *= m
k4 ^= k4 >> r
k4 *= m
h := seed
h *= m
h ^= k1
h *= m
h ^= k2
h *= m
h ^= k3
h *= m
h ^= k4
h ^= h >> 13
h *= m
h ^= h >> 15
return h
}

View file

@ -0,0 +1,570 @@
// Copyright (c) 2012, Suryandaru Triandana <syndtr@gmail.com>
// All rights reserved.
//
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package cache
import (
"math/rand"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
"unsafe"
)
type int32o int32
func (o *int32o) acquire() {
if atomic.AddInt32((*int32)(o), 1) != 1 {
panic("BUG: invalid ref")
}
}
func (o *int32o) Release() {
if atomic.AddInt32((*int32)(o), -1) != 0 {
panic("BUG: invalid ref")
}
}
type releaserFunc struct {
fn func()
value Value
}
func (r releaserFunc) Release() {
if r.fn != nil {
r.fn()
}
}
func set(c *Cache, ns, key uint64, value Value, charge int, relf func()) *Handle {
return c.Get(ns, key, func() (int, Value) {
if relf != nil {
return charge, releaserFunc{relf, value}
} else {
return charge, value
}
})
}
func TestCacheMap(t *testing.T) {
runtime.GOMAXPROCS(runtime.NumCPU())
nsx := []struct {
nobjects, nhandles, concurrent, repeat int
}{
{10000, 400, 50, 3},
{100000, 1000, 100, 10},
}
var (
objects [][]int32o
handles [][]unsafe.Pointer
)
for _, x := range nsx {
objects = append(objects, make([]int32o, x.nobjects))
handles = append(handles, make([]unsafe.Pointer, x.nhandles))
}
c := NewCache(nil)
wg := new(sync.WaitGroup)
var done int32
for ns, x := range nsx {
for i := 0; i < x.concurrent; i++ {
wg.Add(1)
go func(ns, i, repeat int, objects []int32o, handles []unsafe.Pointer) {
defer wg.Done()
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for j := len(objects) * repeat; j >= 0; j-- {
key := uint64(r.Intn(len(objects)))
h := c.Get(uint64(ns), key, func() (int, Value) {
o := &objects[key]
o.acquire()
return 1, o
})
if v := h.Value().(*int32o); v != &objects[key] {
t.Fatalf("#%d invalid value: want=%p got=%p", ns, &objects[key], v)
}
if objects[key] != 1 {
t.Fatalf("#%d invalid object %d: %d", ns, key, objects[key])
}
if !atomic.CompareAndSwapPointer(&handles[r.Intn(len(handles))], nil, unsafe.Pointer(h)) {
h.Release()
}
}
}(ns, i, x.repeat, objects[ns], handles[ns])
}
go func(handles []unsafe.Pointer) {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for atomic.LoadInt32(&done) == 0 {
i := r.Intn(len(handles))
h := (*Handle)(atomic.LoadPointer(&handles[i]))
if h != nil && atomic.CompareAndSwapPointer(&handles[i], unsafe.Pointer(h), nil) {
h.Release()
}
time.Sleep(time.Millisecond)
}
}(handles[ns])
}
go func() {
handles := make([]*Handle, 100000)
for atomic.LoadInt32(&done) == 0 {
for i := range handles {
handles[i] = c.Get(999999999, uint64(i), func() (int, Value) {
return 1, 1
})
}
for _, h := range handles {
h.Release()
}
}
}()
wg.Wait()
atomic.StoreInt32(&done, 1)
for _, handles0 := range handles {
for i := range handles0 {
h := (*Handle)(atomic.LoadPointer(&handles0[i]))
if h != nil && atomic.CompareAndSwapPointer(&handles0[i], unsafe.Pointer(h), nil) {
h.Release()
}
}
}
for ns, objects0 := range objects {
for i, o := range objects0 {
if o != 0 {
t.Fatalf("invalid object #%d.%d: ref=%d", ns, i, o)
}
}
}
}
func TestCacheMap_NodesAndSize(t *testing.T) {
c := NewCache(nil)
if c.Nodes() != 0 {
t.Errorf("invalid nodes counter: want=%d got=%d", 0, c.Nodes())
}
if c.Size() != 0 {
t.Errorf("invalid size counter: want=%d got=%d", 0, c.Size())
}
set(c, 0, 1, 1, 1, nil)
set(c, 0, 2, 2, 2, nil)
set(c, 1, 1, 3, 3, nil)
set(c, 2, 1, 4, 1, nil)
if c.Nodes() != 4 {
t.Errorf("invalid nodes counter: want=%d got=%d", 4, c.Nodes())
}
if c.Size() != 7 {
t.Errorf("invalid size counter: want=%d got=%d", 4, c.Size())
}
}
func TestLRUCache_Capacity(t *testing.T) {
c := NewCache(NewLRU(10))
if c.Capacity() != 10 {
t.Errorf("invalid capacity: want=%d got=%d", 10, c.Capacity())
}
set(c, 0, 1, 1, 1, nil).Release()
set(c, 0, 2, 2, 2, nil).Release()
set(c, 1, 1, 3, 3, nil).Release()
set(c, 2, 1, 4, 1, nil).Release()
set(c, 2, 2, 5, 1, nil).Release()
set(c, 2, 3, 6, 1, nil).Release()
set(c, 2, 4, 7, 1, nil).Release()
set(c, 2, 5, 8, 1, nil).Release()
if c.Nodes() != 7 {
t.Errorf("invalid nodes counter: want=%d got=%d", 7, c.Nodes())
}
if c.Size() != 10 {
t.Errorf("invalid size counter: want=%d got=%d", 10, c.Size())
}
c.SetCapacity(9)
if c.Capacity() != 9 {
t.Errorf("invalid capacity: want=%d got=%d", 9, c.Capacity())
}
if c.Nodes() != 6 {
t.Errorf("invalid nodes counter: want=%d got=%d", 6, c.Nodes())
}
if c.Size() != 8 {
t.Errorf("invalid size counter: want=%d got=%d", 8, c.Size())
}
}
func TestCacheMap_NilValue(t *testing.T) {
c := NewCache(NewLRU(10))
h := c.Get(0, 0, func() (size int, value Value) {
return 1, nil
})
if h != nil {
t.Error("cache handle is non-nil")
}
if c.Nodes() != 0 {
t.Errorf("invalid nodes counter: want=%d got=%d", 0, c.Nodes())
}
if c.Size() != 0 {
t.Errorf("invalid size counter: want=%d got=%d", 0, c.Size())
}
}
func TestLRUCache_GetLatency(t *testing.T) {
runtime.GOMAXPROCS(runtime.NumCPU())
const (
concurrentSet = 30
concurrentGet = 3
duration = 3 * time.Second
delay = 3 * time.Millisecond
maxkey = 100000
)
var (
set, getHit, getAll int32
getMaxLatency, getDuration int64
)
c := NewCache(NewLRU(5000))
wg := &sync.WaitGroup{}
until := time.Now().Add(duration)
for i := 0; i < concurrentSet; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for time.Now().Before(until) {
c.Get(0, uint64(r.Intn(maxkey)), func() (int, Value) {
time.Sleep(delay)
atomic.AddInt32(&set, 1)
return 1, 1
}).Release()
}
}(i)
}
for i := 0; i < concurrentGet; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for {
mark := time.Now()
if mark.Before(until) {
h := c.Get(0, uint64(r.Intn(maxkey)), nil)
latency := int64(time.Now().Sub(mark))
m := atomic.LoadInt64(&getMaxLatency)
if latency > m {
atomic.CompareAndSwapInt64(&getMaxLatency, m, latency)
}
atomic.AddInt64(&getDuration, latency)
if h != nil {
atomic.AddInt32(&getHit, 1)
h.Release()
}
atomic.AddInt32(&getAll, 1)
} else {
break
}
}
}(i)
}
wg.Wait()
getAvglatency := time.Duration(getDuration) / time.Duration(getAll)
t.Logf("set=%d getHit=%d getAll=%d getMaxLatency=%v getAvgLatency=%v",
set, getHit, getAll, time.Duration(getMaxLatency), getAvglatency)
if getAvglatency > delay/3 {
t.Errorf("get avg latency > %v: got=%v", delay/3, getAvglatency)
}
}
func TestLRUCache_HitMiss(t *testing.T) {
cases := []struct {
key uint64
value string
}{
{1, "vvvvvvvvv"},
{100, "v1"},
{0, "v2"},
{12346, "v3"},
{777, "v4"},
{999, "v5"},
{7654, "v6"},
{2, "v7"},
{3, "v8"},
{9, "v9"},
}
setfin := 0
c := NewCache(NewLRU(1000))
for i, x := range cases {
set(c, 0, x.key, x.value, len(x.value), func() {
setfin++
}).Release()
for j, y := range cases {
h := c.Get(0, y.key, nil)
if j <= i {
// should hit
if h == nil {
t.Errorf("case '%d' iteration '%d' is miss", i, j)
} else {
if x := h.Value().(releaserFunc).value.(string); x != y.value {
t.Errorf("case '%d' iteration '%d' has invalid value got '%s', want '%s'", i, j, x, y.value)
}
}
} else {
// should miss
if h != nil {
t.Errorf("case '%d' iteration '%d' is hit , value '%s'", i, j, h.Value().(releaserFunc).value.(string))
}
}
if h != nil {
h.Release()
}
}
}
for i, x := range cases {
finalizerOk := false
c.Delete(0, x.key, func() {
finalizerOk = true
})
if !finalizerOk {
t.Errorf("case %d delete finalizer not executed", i)
}
for j, y := range cases {
h := c.Get(0, y.key, nil)
if j > i {
// should hit
if h == nil {
t.Errorf("case '%d' iteration '%d' is miss", i, j)
} else {
if x := h.Value().(releaserFunc).value.(string); x != y.value {
t.Errorf("case '%d' iteration '%d' has invalid value got '%s', want '%s'", i, j, x, y.value)
}
}
} else {
// should miss
if h != nil {
t.Errorf("case '%d' iteration '%d' is hit, value '%s'", i, j, h.Value().(releaserFunc).value.(string))
}
}
if h != nil {
h.Release()
}
}
}
if setfin != len(cases) {
t.Errorf("some set finalizer may not be executed, want=%d got=%d", len(cases), setfin)
}
}
func TestLRUCache_Eviction(t *testing.T) {
c := NewCache(NewLRU(12))
o1 := set(c, 0, 1, 1, 1, nil)
set(c, 0, 2, 2, 1, nil).Release()
set(c, 0, 3, 3, 1, nil).Release()
set(c, 0, 4, 4, 1, nil).Release()
set(c, 0, 5, 5, 1, nil).Release()
if h := c.Get(0, 2, nil); h != nil { // 1,3,4,5,2
h.Release()
}
set(c, 0, 9, 9, 10, nil).Release() // 5,2,9
for _, key := range []uint64{9, 2, 5, 1} {
h := c.Get(0, key, nil)
if h == nil {
t.Errorf("miss for key '%d'", key)
} else {
if x := h.Value().(int); x != int(key) {
t.Errorf("invalid value for key '%d' want '%d', got '%d'", key, key, x)
}
h.Release()
}
}
o1.Release()
for _, key := range []uint64{1, 2, 5} {
h := c.Get(0, key, nil)
if h == nil {
t.Errorf("miss for key '%d'", key)
} else {
if x := h.Value().(int); x != int(key) {
t.Errorf("invalid value for key '%d' want '%d', got '%d'", key, key, x)
}
h.Release()
}
}
for _, key := range []uint64{3, 4, 9} {
h := c.Get(0, key, nil)
if h != nil {
t.Errorf("hit for key '%d'", key)
if x := h.Value().(int); x != int(key) {
t.Errorf("invalid value for key '%d' want '%d', got '%d'", key, key, x)
}
h.Release()
}
}
}
func TestLRUCache_Evict(t *testing.T) {
c := NewCache(NewLRU(6))
set(c, 0, 1, 1, 1, nil).Release()
set(c, 0, 2, 2, 1, nil).Release()
set(c, 1, 1, 4, 1, nil).Release()
set(c, 1, 2, 5, 1, nil).Release()
set(c, 2, 1, 6, 1, nil).Release()
set(c, 2, 2, 7, 1, nil).Release()
for ns := 0; ns < 3; ns++ {
for key := 1; key < 3; key++ {
if h := c.Get(uint64(ns), uint64(key), nil); h != nil {
h.Release()
} else {
t.Errorf("Cache.Get on #%d.%d return nil", ns, key)
}
}
}
if ok := c.Evict(0, 1); !ok {
t.Error("first Cache.Evict on #0.1 return false")
}
if ok := c.Evict(0, 1); ok {
t.Error("second Cache.Evict on #0.1 return true")
}
if h := c.Get(0, 1, nil); h != nil {
t.Errorf("Cache.Get on #0.1 return non-nil: %v", h.Value())
}
c.EvictNS(1)
if h := c.Get(1, 1, nil); h != nil {
t.Errorf("Cache.Get on #1.1 return non-nil: %v", h.Value())
}
if h := c.Get(1, 2, nil); h != nil {
t.Errorf("Cache.Get on #1.2 return non-nil: %v", h.Value())
}
c.EvictAll()
for ns := 0; ns < 3; ns++ {
for key := 1; key < 3; key++ {
if h := c.Get(uint64(ns), uint64(key), nil); h != nil {
t.Errorf("Cache.Get on #%d.%d return non-nil: %v", ns, key, h.Value())
}
}
}
}
func TestLRUCache_Delete(t *testing.T) {
delFuncCalled := 0
delFunc := func() {
delFuncCalled++
}
c := NewCache(NewLRU(2))
set(c, 0, 1, 1, 1, nil).Release()
set(c, 0, 2, 2, 1, nil).Release()
if ok := c.Delete(0, 1, delFunc); !ok {
t.Error("Cache.Delete on #1 return false")
}
if h := c.Get(0, 1, nil); h != nil {
t.Errorf("Cache.Get on #1 return non-nil: %v", h.Value())
}
if ok := c.Delete(0, 1, delFunc); ok {
t.Error("Cache.Delete on #1 return true")
}
h2 := c.Get(0, 2, nil)
if h2 == nil {
t.Error("Cache.Get on #2 return nil")
}
if ok := c.Delete(0, 2, delFunc); !ok {
t.Error("(1) Cache.Delete on #2 return false")
}
if ok := c.Delete(0, 2, delFunc); !ok {
t.Error("(2) Cache.Delete on #2 return false")
}
set(c, 0, 3, 3, 1, nil).Release()
set(c, 0, 4, 4, 1, nil).Release()
c.Get(0, 2, nil).Release()
for key := 2; key <= 4; key++ {
if h := c.Get(0, uint64(key), nil); h != nil {
h.Release()
} else {
t.Errorf("Cache.Get on #%d return nil", key)
}
}
h2.Release()
if h := c.Get(0, 2, nil); h != nil {
t.Errorf("Cache.Get on #2 return non-nil: %v", h.Value())
}
if delFuncCalled != 4 {
t.Errorf("delFunc isn't called 4 times: got=%d", delFuncCalled)
}
}
func TestLRUCache_Close(t *testing.T) {
relFuncCalled := 0
relFunc := func() {
relFuncCalled++
}
delFuncCalled := 0
delFunc := func() {
delFuncCalled++
}
c := NewCache(NewLRU(2))
set(c, 0, 1, 1, 1, relFunc).Release()
set(c, 0, 2, 2, 1, relFunc).Release()
h3 := set(c, 0, 3, 3, 1, relFunc)
if h3 == nil {
t.Error("Cache.Get on #3 return nil")
}
if ok := c.Delete(0, 3, delFunc); !ok {
t.Error("Cache.Delete on #3 return false")
}
c.Close()
if relFuncCalled != 3 {
t.Errorf("relFunc isn't called 3 times: got=%d", relFuncCalled)
}
if delFuncCalled != 1 {
t.Errorf("delFunc isn't called 1 times: got=%d", delFuncCalled)
}
}
func BenchmarkLRUCache(b *testing.B) {
c := NewCache(NewLRU(10000))
b.SetParallelism(10)
b.RunParallel(func(pb *testing.PB) {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for pb.Next() {
key := uint64(r.Intn(1000000))
c.Get(0, key, func() (int, Value) {
return 1, key
}).Release()
}
})
}

View file

@ -0,0 +1,195 @@
// Copyright (c) 2012, Suryandaru Triandana <syndtr@gmail.com>
// All rights reserved.
//
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package cache
import (
"sync"
"unsafe"
)
type lruNode struct {
n *Node
h *Handle
ban bool
next, prev *lruNode
}
func (n *lruNode) insert(at *lruNode) {
x := at.next
at.next = n
n.prev = at
n.next = x
x.prev = n
}
func (n *lruNode) remove() {
if n.prev != nil {
n.prev.next = n.next
n.next.prev = n.prev
n.prev = nil
n.next = nil
} else {
panic("BUG: removing removed node")
}
}
type lru struct {
mu sync.Mutex
capacity int
used int
recent lruNode
}
func (r *lru) reset() {
r.recent.next = &r.recent
r.recent.prev = &r.recent
r.used = 0
}
func (r *lru) Capacity() int {
r.mu.Lock()
defer r.mu.Unlock()
return r.capacity
}
func (r *lru) SetCapacity(capacity int) {
var evicted []*lruNode
r.mu.Lock()
r.capacity = capacity
for r.used > r.capacity {
rn := r.recent.prev
if rn == nil {
panic("BUG: invalid LRU used or capacity counter")
}
rn.remove()
rn.n.CacheData = nil
r.used -= rn.n.Size()
evicted = append(evicted, rn)
}
r.mu.Unlock()
for _, rn := range evicted {
rn.h.Release()
}
}
func (r *lru) Promote(n *Node) {
var evicted []*lruNode
r.mu.Lock()
if n.CacheData == nil {
if n.Size() <= r.capacity {
rn := &lruNode{n: n, h: n.GetHandle()}
rn.insert(&r.recent)
n.CacheData = unsafe.Pointer(rn)
r.used += n.Size()
for r.used > r.capacity {
rn := r.recent.prev
if rn == nil {
panic("BUG: invalid LRU used or capacity counter")
}
rn.remove()
rn.n.CacheData = nil
r.used -= rn.n.Size()
evicted = append(evicted, rn)
}
}
} else {
rn := (*lruNode)(n.CacheData)
if !rn.ban {
rn.remove()
rn.insert(&r.recent)
}
}
r.mu.Unlock()
for _, rn := range evicted {
rn.h.Release()
}
}
func (r *lru) Ban(n *Node) {
r.mu.Lock()
if n.CacheData == nil {
n.CacheData = unsafe.Pointer(&lruNode{n: n, ban: true})
} else {
rn := (*lruNode)(n.CacheData)
if !rn.ban {
rn.remove()
rn.ban = true
r.used -= rn.n.Size()
r.mu.Unlock()
rn.h.Release()
rn.h = nil
return
}
}
r.mu.Unlock()
}
func (r *lru) Evict(n *Node) {
r.mu.Lock()
rn := (*lruNode)(n.CacheData)
if rn == nil || rn.ban {
r.mu.Unlock()
return
}
n.CacheData = nil
r.mu.Unlock()
rn.h.Release()
}
func (r *lru) EvictNS(ns uint64) {
var evicted []*lruNode
r.mu.Lock()
for e := r.recent.prev; e != &r.recent; {
rn := e
e = e.prev
if rn.n.NS() == ns {
rn.remove()
rn.n.CacheData = nil
r.used -= rn.n.Size()
evicted = append(evicted, rn)
}
}
r.mu.Unlock()
for _, rn := range evicted {
rn.h.Release()
}
}
func (r *lru) EvictAll() {
r.mu.Lock()
back := r.recent.prev
for rn := back; rn != &r.recent; rn = rn.prev {
rn.n.CacheData = nil
}
r.reset()
r.mu.Unlock()
for rn := back; rn != &r.recent; rn = rn.prev {
rn.h.Release()
}
}
func (r *lru) Close() error {
return nil
}
// NewLRU create a new LRU-cache.
func NewLRU(capacity int) Cacher {
r := &lru{capacity: capacity}
r.reset()
return r
}

View file

@ -0,0 +1,75 @@
// Copyright (c) 2012, Suryandaru Triandana <syndtr@gmail.com>
// All rights reserved.
//
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package leveldb
import "github.com/syndtr/goleveldb/leveldb/comparer"
type iComparer struct {
ucmp comparer.Comparer
}
func (icmp *iComparer) uName() string {
return icmp.ucmp.Name()
}
func (icmp *iComparer) uCompare(a, b []byte) int {
return icmp.ucmp.Compare(a, b)
}
func (icmp *iComparer) uSeparator(dst, a, b []byte) []byte {
return icmp.ucmp.Separator(dst, a, b)
}
func (icmp *iComparer) uSuccessor(dst, b []byte) []byte {
return icmp.ucmp.Successor(dst, b)
}
func (icmp *iComparer) Name() string {
return icmp.uName()
}
func (icmp *iComparer) Compare(a, b []byte) int {
x := icmp.ucmp.Compare(iKey(a).ukey(), iKey(b).ukey())
if x == 0 {
if m, n := iKey(a).num(), iKey(b).num(); m > n {
x = -1
} else if m < n {
x = 1
}
}
return x
}
func (icmp *iComparer) Separator(dst, a, b []byte) []byte {
ua, ub := iKey(a).ukey(), iKey(b).ukey()
dst = icmp.ucmp.Separator(dst, ua, ub)
if dst == nil {
return nil
}
if len(dst) < len(ua) && icmp.uCompare(ua, dst) < 0 {
dst = append(dst, kMaxNumBytes...)
} else {
// Did not close possibilities that n maybe longer than len(ub).
dst = append(dst, a[len(a)-8:]...)
}
return dst
}
func (icmp *iComparer) Successor(dst, b []byte) []byte {
ub := iKey(b).ukey()
dst = icmp.ucmp.Successor(dst, ub)
if dst == nil {
return nil
}
if len(dst) < len(ub) && icmp.uCompare(ub, dst) < 0 {
dst = append(dst, kMaxNumBytes...)
} else {
// Did not close possibilities that n maybe longer than len(ub).
dst = append(dst, b[len(b)-8:]...)
}
return dst
}

View file

@ -0,0 +1,51 @@
// Copyright (c) 2012, Suryandaru Triandana <syndtr@gmail.com>
// All rights reserved.
//
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package comparer
import "bytes"
type bytesComparer struct{}
func (bytesComparer) Compare(a, b []byte) int {
return bytes.Compare(a, b)
}
func (bytesComparer) Name() string {
return "leveldb.BytewiseComparator"
}
func (bytesComparer) Separator(dst, a, b []byte) []byte {
i, n := 0, len(a)
if n > len(b) {
n = len(b)
}
for ; i < n && a[i] == b[i]; i++ {
}
if i >= n {
// Do not shorten if one string is a prefix of the other
} else if c := a[i]; c < 0xff && c+1 < b[i] {
dst = append(dst, a[:i+1]...)
dst[i]++
return dst
}
return nil
}
func (bytesComparer) Successor(dst, b []byte) []byte {
for i, c := range b {
if c != 0xff {
dst = append(dst, b[:i+1]...)
dst[i]++
return dst
}
}
return nil
}
// DefaultComparer are default implementation of the Comparer interface.
// It uses the natural ordering, consistent with bytes.Compare.
var DefaultComparer = bytesComparer{}

View file

@ -0,0 +1,57 @@
// Copyright (c) 2012, Suryandaru Triandana <syndtr@gmail.com>
// All rights reserved.
//
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// Package comparer provides interface and implementation for ordering
// sets of data.
package comparer
// BasicComparer is the interface that wraps the basic Compare method.
type BasicComparer interface {
// Compare returns -1, 0, or +1 depending on whether a is 'less than',
// 'equal to' or 'greater than' b. The two arguments can only be 'equal'
// if their contents are exactly equal. Furthermore, the empty slice
// must be 'less than' any non-empty slice.
Compare(a, b []byte) int
}
// Comparer defines a total ordering over the space of []byte keys: a 'less
// than' relationship.
type Comparer interface {
BasicComparer
// Name returns name of the comparer.
//
// The Level-DB on-disk format stores the comparer name, and opening a
// database with a different comparer from the one it was created with
// will result in an error.
//
// An implementation to a new name whenever the comparer implementation
// changes in a way that will cause the relative ordering of any two keys
// to change.
//
// Names starting with "leveldb." are reserved and should not be used
// by any users of this package.
Name() string
// Bellow are advanced functions used used to reduce the space requirements
// for internal data structures such as index blocks.
// Separator appends a sequence of bytes x to dst such that a <= x && x < b,
// where 'less than' is consistent with Compare. An implementation should
// return nil if x equal to a.
//
// Either contents of a or b should not by any means modified. Doing so
// may cause corruption on the internal state.
Separator(dst, a, b []byte) []byte
// Successor appends a sequence of bytes x to dst such that x >= b, where
// 'less than' is consistent with Compare. An implementation should return
// nil if x equal to b.
//
// Contents of b should not by any means modified. Doing so may cause
// corruption on the internal state.
Successor(dst, b []byte) []byte
}

View file

@ -0,0 +1,500 @@
// Copyright (c) 2013, Suryandaru Triandana <syndtr@gmail.com>
// All rights reserved.
//
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package leveldb
import (
"bytes"
"fmt"
"github.com/syndtr/goleveldb/leveldb/filter"
"github.com/syndtr/goleveldb/leveldb/opt"
"github.com/syndtr/goleveldb/leveldb/storage"
"io"
"math/rand"
"testing"
)
const ctValSize = 1000
type dbCorruptHarness struct {
dbHarness
}
func newDbCorruptHarnessWopt(t *testing.T, o *opt.Options) *dbCorruptHarness {
h := new(dbCorruptHarness)
h.init(t, o)
return h
}
func newDbCorruptHarness(t *testing.T) *dbCorruptHarness {
return newDbCorruptHarnessWopt(t, &opt.Options{
BlockCacheCapacity: 100,
Strict: opt.StrictJournalChecksum,
})
}
func (h *dbCorruptHarness) recover() {
p := &h.dbHarness
t := p.t
var err error
p.db, err = Recover(h.stor, h.o)
if err != nil {
t.Fatal("Repair: got error: ", err)
}
}
func (h *dbCorruptHarness) build(n int) {
p := &h.dbHarness
t := p.t
db := p.db
batch := new(Batch)
for i := 0; i < n; i++ {
batch.Reset()
batch.Put(tkey(i), tval(i, ctValSize))
err := db.Write(batch, p.wo)
if err != nil {
t.Fatal("write error: ", err)
}
}
}
func (h *dbCorruptHarness) buildShuffled(n int, rnd *rand.Rand) {
p := &h.dbHarness
t := p.t
db := p.db
batch := new(Batch)
for i := range rnd.Perm(n) {
batch.Reset()
batch.Put(tkey(i), tval(i, ctValSize))
err := db.Write(batch, p.wo)
if err != nil {
t.Fatal("write error: ", err)
}
}
}
func (h *dbCorruptHarness) deleteRand(n, max int, rnd *rand.Rand) {
p := &h.dbHarness
t := p.t
db := p.db
batch := new(Batch)
for i := 0; i < n; i++ {
batch.Reset()
batch.Delete(tkey(rnd.Intn(max)))
err := db.Write(batch, p.wo)
if err != nil {
t.Fatal("write error: ", err)
}
}
}
func (h *dbCorruptHarness) corrupt(ft storage.FileType, fi, offset, n int) {
p := &h.dbHarness
t := p.t
ff, _ := p.stor.GetFiles(ft)
sff := files(ff)
sff.sort()
if fi < 0 {
fi = len(sff) - 1
}
if fi >= len(sff) {
t.Fatalf("no such file with type %q with index %d", ft, fi)
}
file := sff[fi]
r, err := file.Open()
if err != nil {
t.Fatal("cannot open file: ", err)
}
x, err := r.Seek(0, 2)
if err != nil {
t.Fatal("cannot query file size: ", err)
}
m := int(x)
if _, err := r.Seek(0, 0); err != nil {
t.Fatal(err)
}
if offset < 0 {
if -offset > m {
offset = 0
} else {
offset = m + offset
}
}
if offset > m {
offset = m
}
if offset+n > m {
n = m - offset
}
buf := make([]byte, m)
_, err = io.ReadFull(r, buf)
if err != nil {
t.Fatal("cannot read file: ", err)
}
r.Close()
for i := 0; i < n; i++ {
buf[offset+i] ^= 0x80
}
err = file.Remove()
if err != nil {
t.Fatal("cannot remove old file: ", err)
}
w, err := file.Create()
if err != nil {
t.Fatal("cannot create new file: ", err)
}
_, err = w.Write(buf)
if err != nil {
t.Fatal("cannot write new file: ", err)
}
w.Close()
}
func (h *dbCorruptHarness) removeAll(ft storage.FileType) {
ff, err := h.stor.GetFiles(ft)
if err != nil {
h.t.Fatal("get files: ", err)
}
for _, f := range ff {
if err := f.Remove(); err != nil {
h.t.Error("remove file: ", err)
}
}
}
func (h *dbCorruptHarness) removeOne(ft storage.FileType) {
ff, err := h.stor.GetFiles(ft)
if err != nil {
h.t.Fatal("get files: ", err)
}
f := ff[rand.Intn(len(ff))]
h.t.Logf("removing file @%d", f.Num())
if err := f.Remove(); err != nil {
h.t.Error("remove file: ", err)
}
}
func (h *dbCorruptHarness) check(min, max int) {
p := &h.dbHarness
t := p.t
db := p.db
var n, badk, badv, missed, good int
iter := db.NewIterator(nil, p.ro)
for iter.Next() {
k := 0
fmt.Sscanf(string(iter.Key()), "%d", &k)
if k < n {
badk++
continue
}
missed += k - n
n = k + 1
if !bytes.Equal(iter.Value(), tval(k, ctValSize)) {
badv++
} else {
good++
}
}
err := iter.Error()
iter.Release()
t.Logf("want=%d..%d got=%d badkeys=%d badvalues=%d missed=%d, err=%v",
min, max, good, badk, badv, missed, err)
if good < min || good > max {
t.Errorf("good entries number not in range")
}
}
func TestCorruptDB_Journal(t *testing.T) {
h := newDbCorruptHarness(t)
h.build(100)
h.check(100, 100)
h.closeDB()
h.corrupt(storage.TypeJournal, -1, 19, 1)
h.corrupt(storage.TypeJournal, -1, 32*1024+1000, 1)
h.openDB()
h.check(36, 36)
h.close()
}
func TestCorruptDB_Table(t *testing.T) {
h := newDbCorruptHarness(t)
h.build(100)
h.compactMem()
h.compactRangeAt(0, "", "")
h.compactRangeAt(1, "", "")
h.closeDB()
h.corrupt(storage.TypeTable, -1, 100, 1)
h.openDB()
h.check(99, 99)
h.close()
}
func TestCorruptDB_TableIndex(t *testing.T) {
h := newDbCorruptHarness(t)
h.build(10000)
h.compactMem()
h.closeDB()
h.corrupt(storage.TypeTable, -1, -2000, 500)
h.openDB()
h.check(5000, 9999)
h.close()
}
func TestCorruptDB_MissingManifest(t *testing.T) {
rnd := rand.New(rand.NewSource(0x0badda7a))
h := newDbCorruptHarnessWopt(t, &opt.Options{
BlockCacheCapacity: 100,
Strict: opt.StrictJournalChecksum,
WriteBuffer: 1000 * 60,
})
h.build(1000)
h.compactMem()
h.buildShuffled(1000, rnd)
h.compactMem()
h.deleteRand(500, 1000, rnd)
h.compactMem()
h.buildShuffled(1000, rnd)
h.compactMem()
h.deleteRand(500, 1000, rnd)
h.compactMem()
h.buildShuffled(1000, rnd)
h.compactMem()
h.closeDB()
h.stor.SetIgnoreOpenErr(storage.TypeManifest)
h.removeAll(storage.TypeManifest)
h.openAssert(false)
h.stor.SetIgnoreOpenErr(0)
h.recover()
h.check(1000, 1000)
h.build(1000)
h.compactMem()
h.compactRange("", "")
h.closeDB()
h.recover()
h.check(1000, 1000)
h.close()
}
func TestCorruptDB_SequenceNumberRecovery(t *testing.T) {
h := newDbCorruptHarness(t)
h.put("foo", "v1")
h.put("foo", "v2")
h.put("foo", "v3")
h.put("foo", "v4")
h.put("foo", "v5")
h.closeDB()
h.recover()
h.getVal("foo", "v5")
h.put("foo", "v6")
h.getVal("foo", "v6")
h.reopenDB()
h.getVal("foo", "v6")
h.close()
}
func TestCorruptDB_SequenceNumberRecoveryTable(t *testing.T) {
h := newDbCorruptHarness(t)
h.put("foo", "v1")
h.put("foo", "v2")
h.put("foo", "v3")
h.compactMem()
h.put("foo", "v4")
h.put("foo", "v5")
h.compactMem()
h.closeDB()
h.recover()
h.getVal("foo", "v5")
h.put("foo", "v6")
h.getVal("foo", "v6")
h.reopenDB()
h.getVal("foo", "v6")
h.close()
}
func TestCorruptDB_CorruptedManifest(t *testing.T) {
h := newDbCorruptHarness(t)
h.put("foo", "hello")
h.compactMem()
h.compactRange("", "")
h.closeDB()
h.corrupt(storage.TypeManifest, -1, 0, 1000)
h.openAssert(false)
h.recover()
h.getVal("foo", "hello")
h.close()
}
func TestCorruptDB_CompactionInputError(t *testing.T) {
h := newDbCorruptHarness(t)
h.build(10)
h.compactMem()
h.closeDB()
h.corrupt(storage.TypeTable, -1, 100, 1)
h.openDB()
h.check(9, 9)
h.build(10000)
h.check(10000, 10000)
h.close()
}
func TestCorruptDB_UnrelatedKeys(t *testing.T) {
h := newDbCorruptHarness(t)
h.build(10)
h.compactMem()
h.closeDB()
h.corrupt(storage.TypeTable, -1, 100, 1)
h.openDB()
h.put(string(tkey(1000)), string(tval(1000, ctValSize)))
h.getVal(string(tkey(1000)), string(tval(1000, ctValSize)))
h.compactMem()
h.getVal(string(tkey(1000)), string(tval(1000, ctValSize)))
h.close()
}
func TestCorruptDB_Level0NewerFileHasOlderSeqnum(t *testing.T) {
h := newDbCorruptHarness(t)
h.put("a", "v1")
h.put("b", "v1")
h.compactMem()
h.put("a", "v2")
h.put("b", "v2")
h.compactMem()
h.put("a", "v3")
h.put("b", "v3")
h.compactMem()
h.put("c", "v0")
h.put("d", "v0")
h.compactMem()
h.compactRangeAt(1, "", "")
h.closeDB()
h.recover()
h.getVal("a", "v3")
h.getVal("b", "v3")
h.getVal("c", "v0")
h.getVal("d", "v0")
h.close()
}
func TestCorruptDB_RecoverInvalidSeq_Issue53(t *testing.T) {
h := newDbCorruptHarness(t)
h.put("a", "v1")
h.put("b", "v1")
h.compactMem()
h.put("a", "v2")
h.put("b", "v2")
h.compactMem()
h.put("a", "v3")
h.put("b", "v3")
h.compactMem()
h.put("c", "v0")
h.put("d", "v0")
h.compactMem()
h.compactRangeAt(0, "", "")
h.closeDB()
h.recover()
h.getVal("a", "v3")
h.getVal("b", "v3")
h.getVal("c", "v0")
h.getVal("d", "v0")
h.close()
}
func TestCorruptDB_MissingTableFiles(t *testing.T) {
h := newDbCorruptHarness(t)
h.put("a", "v1")
h.put("b", "v1")
h.compactMem()
h.put("c", "v2")
h.put("d", "v2")
h.compactMem()
h.put("e", "v3")
h.put("f", "v3")
h.closeDB()
h.removeOne(storage.TypeTable)
h.openAssert(false)
h.close()
}
func TestCorruptDB_RecoverTable(t *testing.T) {
h := newDbCorruptHarnessWopt(t, &opt.Options{
WriteBuffer: 112 * opt.KiB,
CompactionTableSize: 90 * opt.KiB,
Filter: filter.NewBloomFilter(10),
})
h.build(1000)
h.compactMem()
h.compactRangeAt(0, "", "")
h.compactRangeAt(1, "", "")
seq := h.db.seq
h.closeDB()
h.corrupt(storage.TypeTable, 0, 1000, 1)
h.corrupt(storage.TypeTable, 3, 10000, 1)
// Corrupted filter shouldn't affect recovery.
h.corrupt(storage.TypeTable, 3, 113888, 10)
h.corrupt(storage.TypeTable, -1, 20000, 1)
h.recover()
if h.db.seq != seq {
t.Errorf("invalid seq, want=%d got=%d", seq, h.db.seq)
}
h.check(985, 985)
h.close()
}

View file

@ -0,0 +1,943 @@
// Copyright (c) 2012, Suryandaru Triandana <syndtr@gmail.com>
// All rights reserved.
//
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package leveldb
import (
"container/list"
"fmt"
"io"
"os"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/syndtr/goleveldb/leveldb/errors"
"github.com/syndtr/goleveldb/leveldb/iterator"
"github.com/syndtr/goleveldb/leveldb/journal"
"github.com/syndtr/goleveldb/leveldb/memdb"
"github.com/syndtr/goleveldb/leveldb/opt"
"github.com/syndtr/goleveldb/leveldb/storage"
"github.com/syndtr/goleveldb/leveldb/table"
"github.com/syndtr/goleveldb/leveldb/util"
)
// DB is a LevelDB database.
type DB struct {
// Need 64-bit alignment.
seq uint64
// Session.
s *session
// MemDB.
memMu sync.RWMutex
memPool chan *memdb.DB
mem, frozenMem *memDB
journal *journal.Writer
journalWriter storage.Writer
journalFile storage.File
frozenJournalFile storage.File
frozenSeq uint64
// Snapshot.
snapsMu sync.Mutex
snapsList *list.List
// Stats.
aliveSnaps, aliveIters int32
// Write.
writeC chan *Batch
writeMergedC chan bool
writeLockC chan struct{}
writeAckC chan error
writeDelay time.Duration
writeDelayN int
journalC chan *Batch
journalAckC chan error
// Compaction.
tcompCmdC chan cCmd
tcompPauseC chan chan<- struct{}
mcompCmdC chan cCmd
compErrC chan error
compPerErrC chan error
compErrSetC chan error
compStats []cStats
// Close.
closeW sync.WaitGroup
closeC chan struct{}
closed uint32
closer io.Closer
}
func openDB(s *session) (*DB, error) {
s.log("db@open opening")
start := time.Now()
db := &DB{
s: s,
// Initial sequence
seq: s.stSeqNum,
// MemDB
memPool: make(chan *memdb.DB, 1),
// Snapshot
snapsList: list.New(),
// Write
writeC: make(chan *Batch),
writeMergedC: make(chan bool),
writeLockC: make(chan struct{}, 1),
writeAckC: make(chan error),
journalC: make(chan *Batch),
journalAckC: make(chan error),
// Compaction
tcompCmdC: make(chan cCmd),
tcompPauseC: make(chan chan<- struct{}),
mcompCmdC: make(chan cCmd),
compErrC: make(chan error),
compPerErrC: make(chan error),
compErrSetC: make(chan error),
compStats: make([]cStats, s.o.GetNumLevel()),
// Close
closeC: make(chan struct{}),
}
if err := db.recoverJournal(); err != nil {
return nil, err
}
// Remove any obsolete files.
if err := db.checkAndCleanFiles(); err != nil {
// Close journal.
if db.journal != nil {
db.journal.Close()
db.journalWriter.Close()
}
return nil, err
}
// Doesn't need to be included in the wait group.
go db.compactionError()
go db.mpoolDrain()
db.closeW.Add(3)
go db.tCompaction()
go db.mCompaction()
go db.jWriter()
s.logf("db@open done T·%v", time.Since(start))
runtime.SetFinalizer(db, (*DB).Close)
return db, nil
}
// Open opens or creates a DB for the given storage.
// The DB will be created if not exist, unless ErrorIfMissing is true.
// Also, if ErrorIfExist is true and the DB exist Open will returns
// os.ErrExist error.
//
// Open will return an error with type of ErrCorrupted if corruption
// detected in the DB. Corrupted DB can be recovered with Recover
// function.
//
// The returned DB instance is goroutine-safe.
// The DB must be closed after use, by calling Close method.
func Open(stor storage.Storage, o *opt.Options) (db *DB, err error) {
s, err := newSession(stor, o)
if err != nil {
return
}
defer func() {
if err != nil {
s.close()
s.release()
}
}()
err = s.recover()
if err != nil {
if !os.IsNotExist(err) || s.o.GetErrorIfMissing() {
return
}
err = s.create()
if err != nil {
return
}
} else if s.o.GetErrorIfExist() {
err = os.ErrExist
return
}
return openDB(s)
}
// OpenFile opens or creates a DB for the given path.
// The DB will be created if not exist, unless ErrorIfMissing is true.
// Also, if ErrorIfExist is true and the DB exist OpenFile will returns
// os.ErrExist error.
//
// OpenFile uses standard file-system backed storage implementation as
// desribed in the leveldb/storage package.
//
// OpenFile will return an error with type of ErrCorrupted if corruption
// detected in the DB. Corrupted DB can be recovered with Recover
// function.
//
// The returned DB instance is goroutine-safe.
// The DB must be closed after use, by calling Close method.
func OpenFile(path string, o *opt.Options) (db *DB, err error) {
stor, err := storage.OpenFile(path)
if err != nil {
return
}
db, err = Open(stor, o)
if err != nil {
stor.Close()
} else {
db.closer = stor
}
return
}
// Recover recovers and opens a DB with missing or corrupted manifest files
// for the given storage. It will ignore any manifest files, valid or not.
// The DB must already exist or it will returns an error.
// Also, Recover will ignore ErrorIfMissing and ErrorIfExist options.
//
// The returned DB instance is goroutine-safe.
// The DB must be closed after use, by calling Close method.
func Recover(stor storage.Storage, o *opt.Options) (db *DB, err error) {
s, err := newSession(stor, o)
if err != nil {
return
}
defer func() {
if err != nil {
s.close()
s.release()
}
}()
err = recoverTable(s, o)
if err != nil {
return
}
return openDB(s)
}
// RecoverFile recovers and opens a DB with missing or corrupted manifest files
// for the given path. It will ignore any manifest files, valid or not.
// The DB must already exist or it will returns an error.
// Also, Recover will ignore ErrorIfMissing and ErrorIfExist options.
//
// RecoverFile uses standard file-system backed storage implementation as desribed
// in the leveldb/storage package.
//
// The returned DB instance is goroutine-safe.
// The DB must be closed after use, by calling Close method.
func RecoverFile(path string, o *opt.Options) (db *DB, err error) {
stor, err := storage.OpenFile(path)
if err != nil {
return
}
db, err = Recover(stor, o)
if err != nil {
stor.Close()
} else {
db.closer = stor
}
return
}
func recoverTable(s *session, o *opt.Options) error {
o = dupOptions(o)
// Mask StrictReader, lets StrictRecovery doing its job.
o.Strict &= ^opt.StrictReader
// Get all tables and sort it by file number.
tableFiles_, err := s.getFiles(storage.TypeTable)
if err != nil {
return err
}
tableFiles := files(tableFiles_)
tableFiles.sort()
var (
maxSeq uint64
recoveredKey, goodKey, corruptedKey, corruptedBlock, droppedTable int
// We will drop corrupted table.
strict = o.GetStrict(opt.StrictRecovery)
rec = &sessionRecord{numLevel: o.GetNumLevel()}
bpool = util.NewBufferPool(o.GetBlockSize() + 5)
)
buildTable := func(iter iterator.Iterator) (tmp storage.File, size int64, err error) {
tmp = s.newTemp()
writer, err := tmp.Create()
if err != nil {
return
}
defer func() {
writer.Close()
if err != nil {
tmp.Remove()
tmp = nil
}
}()
// Copy entries.
tw := table.NewWriter(writer, o)
for iter.Next() {
key := iter.Key()
if validIkey(key) {
err = tw.Append(key, iter.Value())
if err != nil {
return
}
}
}
err = iter.Error()
if err != nil {
return
}
err = tw.Close()
if err != nil {
return
}
err = writer.Sync()
if err != nil {
return
}
size = int64(tw.BytesLen())
return
}
recoverTable := func(file storage.File) error {
s.logf("table@recovery recovering @%d", file.Num())
reader, err := file.Open()
if err != nil {
return err
}
var closed bool
defer func() {
if !closed {
reader.Close()
}
}()
// Get file size.
size, err := reader.Seek(0, 2)
if err != nil {
return err
}
var (
tSeq uint64
tgoodKey, tcorruptedKey, tcorruptedBlock int
imin, imax []byte
)
tr, err := table.NewReader(reader, size, storage.NewFileInfo(file), nil, bpool, o)
if err != nil {
return err
}
iter := tr.NewIterator(nil, nil)
iter.(iterator.ErrorCallbackSetter).SetErrorCallback(func(err error) {
if errors.IsCorrupted(err) {
s.logf("table@recovery block corruption @%d %q", file.Num(), err)
tcorruptedBlock++
}
})
// Scan the table.
for iter.Next() {
key := iter.Key()
_, seq, _, kerr := parseIkey(key)
if kerr != nil {
tcorruptedKey++
continue
}
tgoodKey++
if seq > tSeq {
tSeq = seq
}
if imin == nil {
imin = append([]byte{}, key...)
}
imax = append(imax[:0], key...)
}
if err := iter.Error(); err != nil {
iter.Release()
return err
}
iter.Release()
goodKey += tgoodKey
corruptedKey += tcorruptedKey
corruptedBlock += tcorruptedBlock
if strict && (tcorruptedKey > 0 || tcorruptedBlock > 0) {
droppedTable++
s.logf("table@recovery dropped @%d Gk·%d Ck·%d Cb·%d S·%d Q·%d", file.Num(), tgoodKey, tcorruptedKey, tcorruptedBlock, size, tSeq)
return nil
}
if tgoodKey > 0 {
if tcorruptedKey > 0 || tcorruptedBlock > 0 {
// Rebuild the table.
s.logf("table@recovery rebuilding @%d", file.Num())
iter := tr.NewIterator(nil, nil)
tmp, newSize, err := buildTable(iter)
iter.Release()
if err != nil {
return err
}
closed = true
reader.Close()
if err := file.Replace(tmp); err != nil {
return err
}
size = newSize
}
if tSeq > maxSeq {
maxSeq = tSeq
}
recoveredKey += tgoodKey
// Add table to level 0.
rec.addTable(0, file.Num(), uint64(size), imin, imax)
s.logf("table@recovery recovered @%d Gk·%d Ck·%d Cb·%d S·%d Q·%d", file.Num(), tgoodKey, tcorruptedKey, tcorruptedBlock, size, tSeq)
} else {
droppedTable++
s.logf("table@recovery unrecoverable @%d Ck·%d Cb·%d S·%d", file.Num(), tcorruptedKey, tcorruptedBlock, size)
}
return nil
}
// Recover all tables.
if len(tableFiles) > 0 {
s.logf("table@recovery F·%d", len(tableFiles))
// Mark file number as used.
s.markFileNum(tableFiles[len(tableFiles)-1].Num())
for _, file := range tableFiles {
if err := recoverTable(file); err != nil {
return err
}
}
s.logf("table@recovery recovered F·%d N·%d Gk·%d Ck·%d Q·%d", len(tableFiles), recoveredKey, goodKey, corruptedKey, maxSeq)
}
// Set sequence number.
rec.setSeqNum(maxSeq)
// Create new manifest.
if err := s.create(); err != nil {
return err
}
// Commit.
return s.commit(rec)
}
func (db *DB) recoverJournal() error {
// Get all tables and sort it by file number.
journalFiles_, err := db.s.getFiles(storage.TypeJournal)
if err != nil {
return err
}
journalFiles := files(journalFiles_)
journalFiles.sort()
// Discard older journal.
prev := -1
for i, file := range journalFiles {
if file.Num() >= db.s.stJournalNum {
if prev >= 0 {
i--
journalFiles[i] = journalFiles[prev]
}
journalFiles = journalFiles[i:]
break
} else if file.Num() == db.s.stPrevJournalNum {
prev = i
}
}
var jr *journal.Reader
var of storage.File
var mem *memdb.DB
batch := new(Batch)
cm := newCMem(db.s)
buf := new(util.Buffer)
// Options.
strict := db.s.o.GetStrict(opt.StrictJournal)
checksum := db.s.o.GetStrict(opt.StrictJournalChecksum)
writeBuffer := db.s.o.GetWriteBuffer()
recoverJournal := func(file storage.File) error {
db.logf("journal@recovery recovering @%d", file.Num())
reader, err := file.Open()
if err != nil {
return err
}
defer reader.Close()
// Create/reset journal reader instance.
if jr == nil {
jr = journal.NewReader(reader, dropper{db.s, file}, strict, checksum)
} else {
jr.Reset(reader, dropper{db.s, file}, strict, checksum)
}
// Flush memdb and remove obsolete journal file.
if of != nil {
if mem.Len() > 0 {
if err := cm.flush(mem, 0); err != nil {
return err
}
}
if err := cm.commit(file.Num(), db.seq); err != nil {
return err
}
cm.reset()
of.Remove()
of = nil
}
// Replay journal to memdb.
mem.Reset()
for {
r, err := jr.Next()
if err != nil {
if err == io.EOF {
break
}
return errors.SetFile(err, file)
}
buf.Reset()
if _, err := buf.ReadFrom(r); err != nil {
if err == io.ErrUnexpectedEOF {
// This is error returned due to corruption, with strict == false.
continue
} else {
return errors.SetFile(err, file)
}
}
if err := batch.memDecodeAndReplay(db.seq, buf.Bytes(), mem); err != nil {
if strict || !errors.IsCorrupted(err) {
return errors.SetFile(err, file)
} else {
db.s.logf("journal error: %v (skipped)", err)
// We won't apply sequence number as it might be corrupted.
continue
}
}
// Save sequence number.
db.seq = batch.seq + uint64(batch.Len())
// Flush it if large enough.
if mem.Size() >= writeBuffer {
if err := cm.flush(mem, 0); err != nil {
return err
}
mem.Reset()
}
}
of = file
return nil
}
// Recover all journals.
if len(journalFiles) > 0 {
db.logf("journal@recovery F·%d", len(journalFiles))
// Mark file number as used.
db.s.markFileNum(journalFiles[len(journalFiles)-1].Num())
mem = memdb.New(db.s.icmp, writeBuffer)
for _, file := range journalFiles {
if err := recoverJournal(file); err != nil {
return err
}
}
// Flush the last journal.
if mem.Len() > 0 {
if err := cm.flush(mem, 0); err != nil {
return err
}
}
}
// Create a new journal.
if _, err := db.newMem(0); err != nil {
return err
}
// Commit.
if err := cm.commit(db.journalFile.Num(), db.seq); err != nil {
// Close journal.
if db.journal != nil {
db.journal.Close()
db.journalWriter.Close()
}
return err
}
// Remove the last obsolete journal file.
if of != nil {
of.Remove()
}
return nil
}
func (db *DB) get(key []byte, seq uint64, ro *opt.ReadOptions) (value []byte, err error) {
ikey := newIkey(key, seq, ktSeek)
em, fm := db.getMems()
for _, m := range [...]*memDB{em, fm} {
if m == nil {
continue
}
defer m.decref()
mk, mv, me := m.mdb.Find(ikey)
if me == nil {
ukey, _, kt, kerr := parseIkey(mk)
if kerr != nil {
// Shouldn't have had happen.
panic(kerr)
}
if db.s.icmp.uCompare(ukey, key) == 0 {
if kt == ktDel {
return nil, ErrNotFound
}
return append([]byte{}, mv...), nil
}
} else if me != ErrNotFound {
return nil, me
}
}
v := db.s.version()
value, cSched, err := v.get(ikey, ro, false)
v.release()
if cSched {
// Trigger table compaction.
db.compSendTrigger(db.tcompCmdC)
}
return
}
func (db *DB) has(key []byte, seq uint64, ro *opt.ReadOptions) (ret bool, err error) {
ikey := newIkey(key, seq, ktSeek)
em, fm := db.getMems()
for _, m := range [...]*memDB{em, fm} {
if m == nil {
continue
}
defer m.decref()
mk, _, me := m.mdb.Find(ikey)
if me == nil {
ukey, _, kt, kerr := parseIkey(mk)
if kerr != nil {
// Shouldn't have had happen.
panic(kerr)
}
if db.s.icmp.uCompare(ukey, key) == 0 {
if kt == ktDel {
return false, nil
}
return true, nil
}
} else if me != ErrNotFound {
return false, me
}
}
v := db.s.version()
_, cSched, err := v.get(ikey, ro, true)
v.release()
if cSched {
// Trigger table compaction.
db.compSendTrigger(db.tcompCmdC)
}
if err == nil {
ret = true
} else if err == ErrNotFound {
err = nil
}
return
}
// Get gets the value for the given key. It returns ErrNotFound if the
// DB does not contains the key.
//
// The returned slice is its own copy, it is safe to modify the contents
// of the returned slice.
// It is safe to modify the contents of the argument after Get returns.
func (db *DB) Get(key []byte, ro *opt.ReadOptions) (value []byte, err error) {
err = db.ok()
if err != nil {
return
}
se := db.acquireSnapshot()
defer db.releaseSnapshot(se)
return db.get(key, se.seq, ro)
}
// Has returns true if the DB does contains the given key.
//
// It is safe to modify the contents of the argument after Get returns.
func (db *DB) Has(key []byte, ro *opt.ReadOptions) (ret bool, err error) {
err = db.ok()
if err != nil {
return
}
se := db.acquireSnapshot()
defer db.releaseSnapshot(se)
return db.has(key, se.seq, ro)
}
// NewIterator returns an iterator for the latest snapshot of the
// uderlying DB.
// The returned iterator is not goroutine-safe, but it is safe to use
// multiple iterators concurrently, with each in a dedicated goroutine.
// It is also safe to use an iterator concurrently with modifying its
// underlying DB. The resultant key/value pairs are guaranteed to be
// consistent.
//
// Slice allows slicing the iterator to only contains keys in the given
// range. A nil Range.Start is treated as a key before all keys in the
// DB. And a nil Range.Limit is treated as a key after all keys in
// the DB.
//
// The iterator must be released after use, by calling Release method.
//
// Also read Iterator documentation of the leveldb/iterator package.
func (db *DB) NewIterator(slice *util.Range, ro *opt.ReadOptions) iterator.Iterator {
if err := db.ok(); err != nil {
return iterator.NewEmptyIterator(err)
}
se := db.acquireSnapshot()
defer db.releaseSnapshot(se)
// Iterator holds 'version' lock, 'version' is immutable so snapshot
// can be released after iterator created.
return db.newIterator(se.seq, slice, ro)
}
// GetSnapshot returns a latest snapshot of the underlying DB. A snapshot
// is a frozen snapshot of a DB state at a particular point in time. The
// content of snapshot are guaranteed to be consistent.
//
// The snapshot must be released after use, by calling Release method.
func (db *DB) GetSnapshot() (*Snapshot, error) {
if err := db.ok(); err != nil {
return nil, err
}
return db.newSnapshot(), nil
}
// GetProperty returns value of the given property name.
//
// Property names:
// leveldb.num-files-at-level{n}
// Returns the number of files at level 'n'.
// leveldb.stats
// Returns statistics of the underlying DB.
// leveldb.sstables
// Returns sstables list for each level.
// leveldb.blockpool
// Returns block pool stats.
// leveldb.cachedblock
// Returns size of cached block.
// leveldb.openedtables
// Returns number of opened tables.
// leveldb.alivesnaps
// Returns number of alive snapshots.
// leveldb.aliveiters
// Returns number of alive iterators.
func (db *DB) GetProperty(name string) (value string, err error) {
err = db.ok()
if err != nil {
return
}
const prefix = "leveldb."
if !strings.HasPrefix(name, prefix) {
return "", errors.New("leveldb: GetProperty: unknown property: " + name)
}
p := name[len(prefix):]
v := db.s.version()
defer v.release()
numFilesPrefix := "num-files-at-level"
switch {
case strings.HasPrefix(p, numFilesPrefix):
var level uint
var rest string
n, _ := fmt.Sscanf(p[len(numFilesPrefix):], "%d%s", &level, &rest)
if n != 1 || int(level) >= db.s.o.GetNumLevel() {
err = errors.New("leveldb: GetProperty: invalid property: " + name)
} else {
value = fmt.Sprint(v.tLen(int(level)))
}
case p == "stats":
value = "Compactions\n" +
" Level | Tables | Size(MB) | Time(sec) | Read(MB) | Write(MB)\n" +
"-------+------------+---------------+---------------+---------------+---------------\n"
for level, tables := range v.tables {
duration, read, write := db.compStats[level].get()
if len(tables) == 0 && duration == 0 {
continue
}
value += fmt.Sprintf(" %3d | %10d | %13.5f | %13.5f | %13.5f | %13.5f\n",
level, len(tables), float64(tables.size())/1048576.0, duration.Seconds(),
float64(read)/1048576.0, float64(write)/1048576.0)
}
case p == "sstables":
for level, tables := range v.tables {
value += fmt.Sprintf("--- level %d ---\n", level)
for _, t := range tables {
value += fmt.Sprintf("%d:%d[%q .. %q]\n", t.file.Num(), t.size, t.imin, t.imax)
}
}
case p == "blockpool":
value = fmt.Sprintf("%v", db.s.tops.bpool)
case p == "cachedblock":
if db.s.tops.bcache != nil {
value = fmt.Sprintf("%d", db.s.tops.bcache.Size())
} else {
value = "<nil>"
}
case p == "openedtables":
value = fmt.Sprintf("%d", db.s.tops.cache.Size())
case p == "alivesnaps":
value = fmt.Sprintf("%d", atomic.LoadInt32(&db.aliveSnaps))
case p == "aliveiters":
value = fmt.Sprintf("%d", atomic.LoadInt32(&db.aliveIters))
default:
err = errors.New("leveldb: GetProperty: unknown property: " + name)
}
return
}
// SizeOf calculates approximate sizes of the given key ranges.
// The length of the returned sizes are equal with the length of the given
// ranges. The returned sizes measure storage space usage, so if the user
// data compresses by a factor of ten, the returned sizes will be one-tenth
// the size of the corresponding user data size.
// The results may not include the sizes of recently written data.
func (db *DB) SizeOf(ranges []util.Range) (Sizes, error) {
if err := db.ok(); err != nil {
return nil, err
}
v := db.s.version()
defer v.release()
sizes := make(Sizes, 0, len(ranges))
for _, r := range ranges {
imin := newIkey(r.Start, kMaxSeq, ktSeek)
imax := newIkey(r.Limit, kMaxSeq, ktSeek)
start, err := v.offsetOf(imin)
if err != nil {
return nil, err
}
limit, err := v.offsetOf(imax)
if err != nil {
return nil, err
}
var size uint64
if limit >= start {
size = limit - start
}
sizes = append(sizes, size)
}
return sizes, nil
}
// Close closes the DB. This will also releases any outstanding snapshot and
// abort any in-flight compaction.
//
// It is not safe to close a DB until all outstanding iterators are released.
// It is valid to call Close multiple times. Other methods should not be
// called after the DB has been closed.
func (db *DB) Close() error {
if !db.setClosed() {
return ErrClosed
}
start := time.Now()
db.log("db@close closing")
// Clear the finalizer.
runtime.SetFinalizer(db, nil)
// Get compaction error.
var err error
select {
case err = <-db.compErrC:
default:
}
// Signal all goroutines.
close(db.closeC)
// Wait for all gorotines to exit.
db.closeW.Wait()
// Lock writer and closes journal.
db.writeLockC <- struct{}{}
if db.journal != nil {
db.journal.Close()
db.journalWriter.Close()
}
if db.writeDelayN > 0 {
db.logf("db@write was delayed N·%d T·%v", db.writeDelayN, db.writeDelay)
}
// Close session.
db.s.close()
db.logf("db@close done T·%v", time.Since(start))
db.s.release()
if db.closer != nil {
if err1 := db.closer.Close(); err == nil {
err = err1
}
}
// NIL'ing pointers.
db.s = nil
db.mem = nil
db.frozenMem = nil
db.journal = nil
db.journalWriter = nil
db.journalFile = nil
db.frozenJournalFile = nil
db.closer = nil
return err
}

View file

@ -0,0 +1,835 @@
// Copyright (c) 2012, Suryandaru Triandana <syndtr@gmail.com>
// All rights reserved.
//
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package leveldb
import (
"sync"
"time"
"github.com/syndtr/goleveldb/leveldb/errors"
"github.com/syndtr/goleveldb/leveldb/memdb"
"github.com/syndtr/goleveldb/leveldb/opt"
)
var (
errCompactionTransactExiting = errors.New("leveldb: compaction transact exiting")
)
type cStats struct {
sync.Mutex
duration time.Duration
read uint64
write uint64
}
func (p *cStats) add(n *cStatsStaging) {
p.Lock()
p.duration += n.duration
p.read += n.read
p.write += n.write
p.Unlock()
}
func (p *cStats) get() (duration time.Duration, read, write uint64) {
p.Lock()
defer p.Unlock()
return p.duration, p.read, p.write
}
type cStatsStaging struct {
start time.Time
duration time.Duration
on bool
read uint64
write uint64
}
func (p *cStatsStaging) startTimer() {
if !p.on {
p.start = time.Now()
p.on = true
}
}
func (p *cStatsStaging) stopTimer() {
if p.on {
p.duration += time.Since(p.start)
p.on = false
}
}
type cMem struct {
s *session
level int
rec *sessionRecord
}
func newCMem(s *session) *cMem {
return &cMem{s: s, rec: &sessionRecord{numLevel: s.o.GetNumLevel()}}
}
func (c *cMem) flush(mem *memdb.DB, level int) error {
s := c.s
// Write memdb to table.
iter := mem.NewIterator(nil)
defer iter.Release()
t, n, err := s.tops.createFrom(iter)
if err != nil {
return err
}
// Pick level.
if level < 0 {
v := s.version()
level = v.pickLevel(t.imin.ukey(), t.imax.ukey())
v.release()
}
c.rec.addTableFile(level, t)
s.logf("mem@flush created L%d@%d N·%d S·%s %q:%q", level, t.file.Num(), n, shortenb(int(t.size)), t.imin, t.imax)
c.level = level
return nil
}
func (c *cMem) reset() {
c.rec = &sessionRecord{numLevel: c.s.o.GetNumLevel()}
}
func (c *cMem) commit(journal, seq uint64) error {
c.rec.setJournalNum(journal)
c.rec.setSeqNum(seq)
// Commit changes.
return c.s.commit(c.rec)
}
func (db *DB) compactionError() {
var (
err error
wlocked bool
)
noerr:
// No error.
for {
select {
case err = <-db.compErrSetC:
switch {
case err == nil:
case errors.IsCorrupted(err):
goto hasperr
default:
goto haserr
}
case _, _ = <-db.closeC:
return
}
}
haserr:
// Transient error.
for {
select {
case db.compErrC <- err:
case err = <-db.compErrSetC:
switch {
case err == nil:
goto noerr
case errors.IsCorrupted(err):
goto hasperr
default:
}
case _, _ = <-db.closeC:
return
}
}
hasperr:
// Persistent error.
for {
select {
case db.compErrC <- err:
case db.compPerErrC <- err:
case db.writeLockC <- struct{}{}:
// Hold write lock, so that write won't pass-through.
wlocked = true
case _, _ = <-db.closeC:
if wlocked {
// We should release the lock or Close will hang.
<-db.writeLockC
}
return
}
}
}
type compactionTransactCounter int
func (cnt *compactionTransactCounter) incr() {
*cnt++
}
type compactionTransactInterface interface {
run(cnt *compactionTransactCounter) error
revert() error
}
func (db *DB) compactionTransact(name string, t compactionTransactInterface) {
defer func() {
if x := recover(); x != nil {
if x == errCompactionTransactExiting {
if err := t.revert(); err != nil {
db.logf("%s revert error %q", name, err)
}
}
panic(x)
}
}()
const (
backoffMin = 1 * time.Second
backoffMax = 8 * time.Second
backoffMul = 2 * time.Second
)
var (
backoff = backoffMin
backoffT = time.NewTimer(backoff)
lastCnt = compactionTransactCounter(0)
disableBackoff = db.s.o.GetDisableCompactionBackoff()
)
for n := 0; ; n++ {
// Check wether the DB is closed.
if db.isClosed() {
db.logf("%s exiting", name)
db.compactionExitTransact()
} else if n > 0 {
db.logf("%s retrying N·%d", name, n)
}
// Execute.
cnt := compactionTransactCounter(0)
err := t.run(&cnt)
if err != nil {
db.logf("%s error I·%d %q", name, cnt, err)
}
// Set compaction error status.
select {
case db.compErrSetC <- err:
case perr := <-db.compPerErrC:
if err != nil {
db.logf("%s exiting (persistent error %q)", name, perr)
db.compactionExitTransact()
}
case _, _ = <-db.closeC:
db.logf("%s exiting", name)
db.compactionExitTransact()
}
if err == nil {
return
}
if errors.IsCorrupted(err) {
db.logf("%s exiting (corruption detected)", name)
db.compactionExitTransact()
}
if !disableBackoff {
// Reset backoff duration if counter is advancing.
if cnt > lastCnt {
backoff = backoffMin
lastCnt = cnt
}
// Backoff.
backoffT.Reset(backoff)
if backoff < backoffMax {
backoff *= backoffMul
if backoff > backoffMax {
backoff = backoffMax
}
}
select {
case <-backoffT.C:
case _, _ = <-db.closeC:
db.logf("%s exiting", name)
db.compactionExitTransact()
}
}
}
}
type compactionTransactFunc struct {
runFunc func(cnt *compactionTransactCounter) error
revertFunc func() error
}
func (t *compactionTransactFunc) run(cnt *compactionTransactCounter) error {
return t.runFunc(cnt)
}
func (t *compactionTransactFunc) revert() error {
if t.revertFunc != nil {
return t.revertFunc()
}
return nil
}
func (db *DB) compactionTransactFunc(name string, run func(cnt *compactionTransactCounter) error, revert func() error) {
db.compactionTransact(name, &compactionTransactFunc{run, revert})
}
func (db *DB) compactionExitTransact() {
panic(errCompactionTransactExiting)
}
func (db *DB) memCompaction() {
mem := db.getFrozenMem()
if mem == nil {
return
}
defer mem.decref()
c := newCMem(db.s)
stats := new(cStatsStaging)
db.logf("mem@flush N·%d S·%s", mem.mdb.Len(), shortenb(mem.mdb.Size()))
// Don't compact empty memdb.
if mem.mdb.Len() == 0 {
db.logf("mem@flush skipping")
// drop frozen mem
db.dropFrozenMem()
return
}
// Pause table compaction.
resumeC := make(chan struct{})
select {
case db.tcompPauseC <- (chan<- struct{})(resumeC):
case <-db.compPerErrC:
close(resumeC)
resumeC = nil
case _, _ = <-db.closeC:
return
}
db.compactionTransactFunc("mem@flush", func(cnt *compactionTransactCounter) (err error) {
stats.startTimer()
defer stats.stopTimer()
return c.flush(mem.mdb, -1)
}, func() error {
for _, r := range c.rec.addedTables {
db.logf("mem@flush revert @%d", r.num)
f := db.s.getTableFile(r.num)
if err := f.Remove(); err != nil {
return err
}
}
return nil
})
db.compactionTransactFunc("mem@commit", func(cnt *compactionTransactCounter) (err error) {
stats.startTimer()
defer stats.stopTimer()
return c.commit(db.journalFile.Num(), db.frozenSeq)
}, nil)
db.logf("mem@flush committed F·%d T·%v", len(c.rec.addedTables), stats.duration)
for _, r := range c.rec.addedTables {
stats.write += r.size
}
db.compStats[c.level].add(stats)
// Drop frozen mem.
db.dropFrozenMem()
// Resume table compaction.
if resumeC != nil {
select {
case <-resumeC:
close(resumeC)
case _, _ = <-db.closeC:
return
}
}
// Trigger table compaction.
db.compSendTrigger(db.tcompCmdC)
}
type tableCompactionBuilder struct {
db *DB
s *session
c *compaction
rec *sessionRecord
stat0, stat1 *cStatsStaging
snapHasLastUkey bool
snapLastUkey []byte
snapLastSeq uint64
snapIter int
snapKerrCnt int
snapDropCnt int
kerrCnt int
dropCnt int
minSeq uint64
strict bool
tableSize int
tw *tWriter
}
func (b *tableCompactionBuilder) appendKV(key, value []byte) error {
// Create new table if not already.
if b.tw == nil {
// Check for pause event.
if b.db != nil {
select {
case ch := <-b.db.tcompPauseC:
b.db.pauseCompaction(ch)
case _, _ = <-b.db.closeC:
b.db.compactionExitTransact()
default:
}
}
// Create new table.
var err error
b.tw, err = b.s.tops.create()
if err != nil {
return err
}
}
// Write key/value into table.
return b.tw.append(key, value)
}
func (b *tableCompactionBuilder) needFlush() bool {
return b.tw.tw.BytesLen() >= b.tableSize
}
func (b *tableCompactionBuilder) flush() error {
t, err := b.tw.finish()
if err != nil {
return err
}
b.rec.addTableFile(b.c.level+1, t)
b.stat1.write += t.size
b.s.logf("table@build created L%d@%d N·%d S·%s %q:%q", b.c.level+1, t.file.Num(), b.tw.tw.EntriesLen(), shortenb(int(t.size)), t.imin, t.imax)
b.tw = nil
return nil
}
func (b *tableCompactionBuilder) cleanup() {
if b.tw != nil {
b.tw.drop()
b.tw = nil
}
}
func (b *tableCompactionBuilder) run(cnt *compactionTransactCounter) error {
snapResumed := b.snapIter > 0
hasLastUkey := b.snapHasLastUkey // The key might has zero length, so this is necessary.
lastUkey := append([]byte{}, b.snapLastUkey...)
lastSeq := b.snapLastSeq
b.kerrCnt = b.snapKerrCnt
b.dropCnt = b.snapDropCnt
// Restore compaction state.
b.c.restore()
defer b.cleanup()
b.stat1.startTimer()
defer b.stat1.stopTimer()
iter := b.c.newIterator()
defer iter.Release()
for i := 0; iter.Next(); i++ {
// Incr transact counter.
cnt.incr()
// Skip until last state.
if i < b.snapIter {
continue
}
resumed := false
if snapResumed {
resumed = true
snapResumed = false
}
ikey := iter.Key()
ukey, seq, kt, kerr := parseIkey(ikey)
if kerr == nil {
shouldStop := !resumed && b.c.shouldStopBefore(ikey)
if !hasLastUkey || b.s.icmp.uCompare(lastUkey, ukey) != 0 {
// First occurrence of this user key.
// Only rotate tables if ukey doesn't hop across.
if b.tw != nil && (shouldStop || b.needFlush()) {
if err := b.flush(); err != nil {
return err
}
// Creates snapshot of the state.
b.c.save()
b.snapHasLastUkey = hasLastUkey
b.snapLastUkey = append(b.snapLastUkey[:0], lastUkey...)
b.snapLastSeq = lastSeq
b.snapIter = i
b.snapKerrCnt = b.kerrCnt
b.snapDropCnt = b.dropCnt
}
hasLastUkey = true
lastUkey = append(lastUkey[:0], ukey...)
lastSeq = kMaxSeq
}
switch {
case lastSeq <= b.minSeq:
// Dropped because newer entry for same user key exist
fallthrough // (A)
case kt == ktDel && seq <= b.minSeq && b.c.baseLevelForKey(lastUkey):
// For this user key:
// (1) there is no data in higher levels
// (2) data in lower levels will have larger seq numbers
// (3) data in layers that are being compacted here and have
// smaller seq numbers will be dropped in the next
// few iterations of this loop (by rule (A) above).
// Therefore this deletion marker is obsolete and can be dropped.
lastSeq = seq
b.dropCnt++
continue
default:
lastSeq = seq
}
} else {
if b.strict {
return kerr
}
// Don't drop corrupted keys.
hasLastUkey = false
lastUkey = lastUkey[:0]
lastSeq = kMaxSeq
b.kerrCnt++
}
if err := b.appendKV(ikey, iter.Value()); err != nil {
return err
}
}
if err := iter.Error(); err != nil {
return err
}
// Finish last table.
if b.tw != nil && !b.tw.empty() {
return b.flush()
}
return nil
}
func (b *tableCompactionBuilder) revert() error {
for _, at := range b.rec.addedTables {
b.s.logf("table@build revert @%d", at.num)
f := b.s.getTableFile(at.num)
if err := f.Remove(); err != nil {
return err
}
}
return nil
}
func (db *DB) tableCompaction(c *compaction, noTrivial bool) {
defer c.release()
rec := &sessionRecord{numLevel: db.s.o.GetNumLevel()}
rec.addCompPtr(c.level, c.imax)
if !noTrivial && c.trivial() {
t := c.tables[0][0]
db.logf("table@move L%d@%d -> L%d", c.level, t.file.Num(), c.level+1)
rec.delTable(c.level, t.file.Num())
rec.addTableFile(c.level+1, t)
db.compactionTransactFunc("table@move", func(cnt *compactionTransactCounter) (err error) {
return db.s.commit(rec)
}, nil)
return
}
var stats [2]cStatsStaging
for i, tables := range c.tables {
for _, t := range tables {
stats[i].read += t.size
// Insert deleted tables into record
rec.delTable(c.level+i, t.file.Num())
}
}
sourceSize := int(stats[0].read + stats[1].read)
minSeq := db.minSeq()
db.logf("table@compaction L%d·%d -> L%d·%d S·%s Q·%d", c.level, len(c.tables[0]), c.level+1, len(c.tables[1]), shortenb(sourceSize), minSeq)
b := &tableCompactionBuilder{
db: db,
s: db.s,
c: c,
rec: rec,
stat1: &stats[1],
minSeq: minSeq,
strict: db.s.o.GetStrict(opt.StrictCompaction),
tableSize: db.s.o.GetCompactionTableSize(c.level + 1),
}
db.compactionTransact("table@build", b)
// Commit changes
db.compactionTransactFunc("table@commit", func(cnt *compactionTransactCounter) (err error) {
stats[1].startTimer()
defer stats[1].stopTimer()
return db.s.commit(rec)
}, nil)
resultSize := int(stats[1].write)
db.logf("table@compaction committed F%s S%s Ke·%d D·%d T·%v", sint(len(rec.addedTables)-len(rec.deletedTables)), sshortenb(resultSize-sourceSize), b.kerrCnt, b.dropCnt, stats[1].duration)
// Save compaction stats
for i := range stats {
db.compStats[c.level+1].add(&stats[i])
}
}
func (db *DB) tableRangeCompaction(level int, umin, umax []byte) {
db.logf("table@compaction range L%d %q:%q", level, umin, umax)
if level >= 0 {
if c := db.s.getCompactionRange(level, umin, umax); c != nil {
db.tableCompaction(c, true)
}
} else {
v := db.s.version()
m := 1
for i, t := range v.tables[1:] {
if t.overlaps(db.s.icmp, umin, umax, false) {
m = i + 1
}
}
v.release()
for level := 0; level < m; level++ {
if c := db.s.getCompactionRange(level, umin, umax); c != nil {
db.tableCompaction(c, true)
}
}
}
}
func (db *DB) tableAutoCompaction() {
if c := db.s.pickCompaction(); c != nil {
db.tableCompaction(c, false)
}
}
func (db *DB) tableNeedCompaction() bool {
v := db.s.version()
defer v.release()
return v.needCompaction()
}
func (db *DB) pauseCompaction(ch chan<- struct{}) {
select {
case ch <- struct{}{}:
case _, _ = <-db.closeC:
db.compactionExitTransact()
}
}
type cCmd interface {
ack(err error)
}
type cIdle struct {
ackC chan<- error
}
func (r cIdle) ack(err error) {
if r.ackC != nil {
defer func() {
recover()
}()
r.ackC <- err
}
}
type cRange struct {
level int
min, max []byte
ackC chan<- error
}
func (r cRange) ack(err error) {
if r.ackC != nil {
defer func() {
recover()
}()
r.ackC <- err
}
}
// This will trigger auto compation and/or wait for all compaction to be done.
func (db *DB) compSendIdle(compC chan<- cCmd) (err error) {
ch := make(chan error)
defer close(ch)
// Send cmd.
select {
case compC <- cIdle{ch}:
case err = <-db.compErrC:
return
case _, _ = <-db.closeC:
return ErrClosed
}
// Wait cmd.
select {
case err = <-ch:
case err = <-db.compErrC:
case _, _ = <-db.closeC:
return ErrClosed
}
return err
}
// This will trigger auto compaction but will not wait for it.
func (db *DB) compSendTrigger(compC chan<- cCmd) {
select {
case compC <- cIdle{}:
default:
}
}
// Send range compaction request.
func (db *DB) compSendRange(compC chan<- cCmd, level int, min, max []byte) (err error) {
ch := make(chan error)
defer close(ch)
// Send cmd.
select {
case compC <- cRange{level, min, max, ch}:
case err := <-db.compErrC:
return err
case _, _ = <-db.closeC:
return ErrClosed
}
// Wait cmd.
select {
case err = <-ch:
case err = <-db.compErrC:
case _, _ = <-db.closeC:
return ErrClosed
}
return err
}
func (db *DB) mCompaction() {
var x cCmd
defer func() {
if x := recover(); x != nil {
if x != errCompactionTransactExiting {
panic(x)
}
}
if x != nil {
x.ack(ErrClosed)
}
db.closeW.Done()
}()
for {
select {
case x = <-db.mcompCmdC:
switch x.(type) {
case cIdle:
db.memCompaction()
x.ack(nil)
x = nil
default:
panic("leveldb: unknown command")
}
case _, _ = <-db.closeC:
return
}
}
}
func (db *DB) tCompaction() {
var x cCmd
var ackQ []cCmd
defer func() {
if x := recover(); x != nil {
if x != errCompactionTransactExiting {
panic(x)
}
}
for i := range ackQ {
ackQ[i].ack(ErrClosed)
ackQ[i] = nil
}
if x != nil {
x.ack(ErrClosed)
}
db.closeW.Done()
}()
for {
if db.tableNeedCompaction() {
select {
case x = <-db.tcompCmdC:
case ch := <-db.tcompPauseC:
db.pauseCompaction(ch)
continue
case _, _ = <-db.closeC:
return
default:
}
} else {
for i := range ackQ {
ackQ[i].ack(nil)
ackQ[i] = nil
}
ackQ = ackQ[:0]
select {
case x = <-db.tcompCmdC:
case ch := <-db.tcompPauseC:
db.pauseCompaction(ch)
continue
case _, _ = <-db.closeC:
return
}
}
if x != nil {
switch cmd := x.(type) {
case cIdle:
ackQ = append(ackQ, x)
case cRange:
db.tableRangeCompaction(cmd.level, cmd.min, cmd.max)
x.ack(nil)
default:
panic("leveldb: unknown command")
}
x = nil
}
db.tableAutoCompaction()
}
}

Some files were not shown because too many files have changed in this diff Show more